# -*- coding: utf-8 -*-
import dataiku
import transformers
import io
import pickle
from utils import load_image

# Define the Dataiku folder to store the processed data
output_folder = dataiku.Folder("BM1RzKhd")

# Define the pre-trained UDOP model name
model_name = "microsoft/udop-large-512-300k"

# Load the processor for the UDOP model (with OCR enabled)
processor = transformers.AutoProcessor.from_pretrained(
    model_name, local_files_only=True
)

# Read the input Dataiku folder containing images and their ground truth labels
folder_classification = dataiku.Folder("cKQJveH0")
files = (
    folder_classification.list_paths_in_partition()
)  # Get list of all files in the folder

# Remove the "label.json" file from the list of images
files.remove("/label.json")

# Read the ground truth labels from the "label.json" file
ground_truth = folder_classification.read_json("label.json")

# Load images from the folder and corresponding labels from the ground truth
images = [load_image(folder_classification, file).resize((800,800)) for file in files]  # Load images
labels = [ground_truth[file[1:]] for file in files]  # Match each image with its label

# Define a classification prompt (same prompt for all images)
prompt = ["document classification." for _ in labels]

# Encode the text prompts and images using the processor for the UDOP model
data_encoding = processor(
    text=prompt,
    images=images,
    return_tensors="pt",
    padding="max_length",
    truncation=True,
    max_length=768,
)

# Encode the labels (ground truth) using the processor, and store them in the 'labels' field
data_encoding["labels"] = processor(
    text_target=labels,
    return_tensors="pt",
    padding="max_length",
    truncation=True,
    max_length=768,
).input_ids

# Modify the 'labels' tensor by replacing padding tokens (0) with -100
# This tells the model to ignore these tokens during training.
input_ids = data_encoding["labels"]
input_ids[input_ids == 0] = -100  # Replace padding tokens (0) with -100

# Assign the modified tensor back to the 'data_encoding' dictionary
data_encoding["labels"] = input_ids

# Save the processed data in a .pt file format (PyTorch format)
with io.BytesIO() as buf:
    pickle.dump(
        data_encoding, buf
    )  # Serialize the data_encoding dictionary using pickle
    buf.seek(0)  # Move to the beginning of the buffer for reading
    binary_data = buf.read()  # Read the binary data from the buffer

# Upload the binary data (serialized .pt file) to the Dataiku folder
output_folder.upload_data("data_OCR_1024.pt", binary_data)
