# -*- coding: utf-8 -*-

# Import necessary libraries and modules
import dataiku
import pandas as pd
from PIL import Image
import torch
from transformers import IdeficsForVisionText2Text, AutoProcessor, BitsAndBytesConfig
import os
import re

# Define the checkpoint for the pre-trained model
checkpoint = "HuggingFaceM4/idefics-9b-instruct"

# Read the input folder
images = dataiku.Folder("rwNI6KPi")

# List of the images
list_images = images.list_paths_in_partition()

# Initialize the AutoProcessor for the model
processor = AutoProcessor.from_pretrained(checkpoint)

# Configure quantization to reduce memory footprint
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
)

# Determine the device for inference (cuda if available, otherwise cpu)
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the pre-trained model with specified configurations
model = IdeficsForVisionText2Text.from_pretrained(
    checkpoint,
    quantization_config=quantization_config,
    device_map="auto"
)

# Define a regex pattern for extracting assistant's responses from generated text
pattern = re.compile(r'\nAssistant:(.*?)(?: \nUser|$)', flags = re.DOTALL)

# Initialize an empty list to store generated captions
captions = list()

# Define the prompt to caption the images
prompt = "You are an expert in art. Please provide a descriptive caption for this image."

# Iterate through the images in the folder
for image_path in list_images:

    # Open the image using the specified folder path
    with images.get_download_stream(path=image_path) as stream:
        raw_image = Image.open(stream)

    # Create prompts for model input
    prompts = [["User:" + prompt, raw_image , "\nAssistant:"]]

    # Process inputs using the AutoProcessor
    inputs = processor(prompts, return_tensors="pt").to(device)

    # Configure generation arguments
    bad_words_ids = processor.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
    generated_ids = model.generate(**inputs, bad_words_ids=bad_words_ids, max_new_tokens=200)

    # Decode the generated text and extract the first answer using the regex pattern
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
    caption = pattern.findall(generated_text[0])[0]
	
    # Append the answer to the list of answers
    captions.append(caption.lstrip())

# Create a dataframe with the captions
captions_df = pd.DataFrame({'images': list_images, 'captions': captions})

# Write recipe outputs
image_captions_IDEFICS = dataiku.Dataset("image_captions_idefics")
image_captions_IDEFICS.write_with_schema(captions_df)