# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import dataiku
import json
import io
import colorsys
from PIL import Image
from skimage.transform import resize
import numpy as np
import torch
from transformers import CLIPSegForImageSegmentation, CLIPSegProcessor
from project_utils import load_image, save_array

folder = dataiku.Folder("PIBaZG9p")
output_folder = dataiku.Folder("3gKc7IXi")

model_name = dataiku.get_custom_variables()["clipseg_model_name"]
model = CLIPSegForImageSegmentation.from_pretrained(model_name).eval()
processor = CLIPSegProcessor.from_pretrained(model_name)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
def get_image_embedding(patch):
    with torch.no_grad():
        return model.clip.get_image_features(processor(
            images=patch,
            padding=True,
            return_tensors="pt"
        )["pixel_values"])

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
labels = set()
for path in folder.list_paths_in_partition():
    labels.add("_".join(path[1:].split("_")[:-1]))
labels = sorted(list(labels))

class_embedding = {k:None for k in labels}

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# For each example image...
for path in folder.list_paths_in_partition():
    image = load_image(folder, path)
    label = "_".join(path[1:].split("_")[:-1])

    # ... rotate the image with a 0, 90, 180 and 270° angle
    for angle in [0, 90, 180, 270]:
        # ... compute the corresponding embedding
        emb = get_image_embedding(image.rotate(angle))
        # ... group it according to the label
        if class_embedding[label] is None:
            class_embedding[label] = emb
        else:
            class_embedding[label] = torch.cat((class_embedding[label], emb), dim=0)

# Average the embeddings for each label
for label in labels:
    class_embedding[label] = torch.mean(class_embedding[label], dim=0, keepdim=True)

class_embeddings = torch.cat([class_embedding[k] for k in labels], dim=0)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
save_array(
    output_folder,
    "class_embeddings.npz",
    class_embeddings=class_embeddings,
    labels=np.array(labels)
)