import dataiku
import json
import io
from functools import lru_cache
import numpy as np
from PIL import Image, ImageDraw
from skimage.transform import resize
import torch
from transformers import CLIPSegForImageSegmentation, CLIPSegProcessor

import plotly.express as px
import plotly.graph_objs as go
from dash import dcc, html, Input, Output
import dash_bootstrap_components as dbc

from project_utils import (
    get_color_map,
    compute_activations,
    load_array,
    convert_box
)

df = dataiku.Dataset("test").get_dataframe()
embeddings_folder = dataiku.Folder("3gKc7IXi")
folder = dataiku.Folder("aKlMxTsk")

with load_array(embeddings_folder, "class_embeddings.npz") as data:
    class_embeddings = data["class_embeddings"]
    all_labels = [x.replace("_", " ") for x in data["labels"]]
label2color = get_color_map(all_labels)
    
model_name = dataiku.get_custom_variables()["clipseg_model_name"]
model = CLIPSegForImageSegmentation.from_pretrained(model_name).eval()
processor = CLIPSegProcessor.from_pretrained(model_name)

@lru_cache()
def get_image(image_idx):
    """
    Load and resize an image with a given index from the input folder.
    """
    boxes = json.loads(df.loc[image_idx, "label"])
    path = df.loc[image_idx, "record_id"]
    with folder.get_download_stream(path) as f:
        buf = io.BytesIO(f.read())
        image = Image.open(buf)

    dctx = ImageDraw.Draw(image)

    for j in range(len(boxes)):
        color = label2color[boxes[j]["category"].replace("_", " ")]
        dctx.rectangle(convert_box(boxes[j]["bbox"]), outline=color, width=3)
    
    w, h = image.size
    return image.resize((800, h*800//w))

def get_figure(image):
    """
    Create and configure a plotly figure from an image.
    """
    fig = px.imshow(image)
    fig.update_layout(dragmode=False, margin=dict(l=0, r=0, t=0, b=0))
    fig.update_traces(hoverinfo='none', hovertemplate=None)
    fig.update_xaxes(visible=False)
    fig.update_yaxes(visible=False)
    return fig

get_activations = lru_cache(
    lambda x: compute_activations(model, processor, get_image(x))
)
# Cache the activations for each test image to accelerate the UI (at the cost of a longer startup)
for i in range(len(df)):
    _ = get_activations(i)


def compute_logits(image_idx, value):
    """
    Compute the logits for a given set of prompts and an image.
    """
    with torch.no_grad():
        return model.decoder(
            get_activations(image_idx),
            torch.Tensor(class_embeddings[value, :])
        )[0]

def compute_mask(logits, labels, label2color, threshold, size=None):
    """
    Compute the segmentation mask given CLIPSeg's predictions, labels and threshold.
    """
    if size is None:
        preds = 1/(1 + np.exp(-logits))
    else:
        preds = 1/(1 + np.exp(-resize(logits, (logits.shape[0], *size))))
    values, indices = np.max(preds, axis=0), np.argmax(preds, axis=0)
    segmentation_map = np.where(values < threshold, -1, indices)
    
    rgba_array = np.zeros((*(segmentation_map.shape), 4), dtype=np.uint8)
    for i in range(len(labels)):
        label = labels[i]
        color = label2color[label]
        indices = segmentation_map == i
        rgba_array[indices] = color
    return Image.fromarray(rgba_array)

def get_legend(colors):
    """
    Create a list of styled HTML spans representing the legend for a set of colors.
    """
    result = []
    for label in colors:
        color = f"rgb{str(tuple(colors[label][:-1]))}"
        result.append(
            html.Span([
                html.I(className="bi bi-square-fill", style={"color": color}),
                html.Span(label, style={"margin": "0px 10px 0px 5px"})
            ])
        )
    return result

## CSS styles

STYLE_COMPONENT = {
    "margin": "10px auto 0px auto",
    "max-width": "650px",
    "text-align": "center"
}

STYLE_GROUP = {
    "margin-top": "20px"
}

STYLE_PAGINATION = {
    "display": "flex",
    "justify-content": "space-evenly",
    "margin": "10px auto 0px auto",
}

STYLE_PAGE = {
    "margin": "auto",
    "max-width": "800px",
    "text-align": "center"
}

## Input bar

threshold_slider = html.Div(
    [
        html.Label("Threshold"),
        dcc.Slider(
            0, 1.0, 0.01,
            value=0.5,
            marks={
              i/100: {"label": str(i/100)} for i in [0] + list(range(10, 110, 10))
            },
            id='threshold'
        )       
    ],
    style=STYLE_COMPONENT
)

class_select = dcc.Dropdown(
    id="select",
    options=[{"label": all_labels[i], "value": i} for i in range(len(all_labels))],
    multi=True,
    style=STYLE_COMPONENT,
    value=list(range(len(all_labels))),
)

inputs = html.Div(
    [
        html.Div(
            dbc.Pagination(
                id="image-idx",
                active_page=1,
                max_value=len(df),
                first_last=True,
                previous_next=True,
                fully_expanded=False,
            ),
            style=STYLE_PAGINATION
        ),
        html.Div(
            [
                class_select,
                threshold_slider
            ],
            style=STYLE_GROUP
        )
    ]
)

## Outputs

outputs = html.Div(
    [
        dcc.Store(id='segmented', storage_type='memory'),
        dcc.Graph(
            id="graph-picture",
            figure=get_figure(get_image(0)),
            config={
                "displayModeBar": False,
                "displaylogo": False
            },
            style=STYLE_GROUP
        ),
        html.Div(get_legend(label2color), style=STYLE_GROUP)
    ]
)

## Overall layout

app.title = "Few-shot object detection"
app.config.external_stylesheets = [
    dbc.themes.ZEPHYR,
    dbc.icons.BOOTSTRAP
]
app.layout = html.Div(
    [
        inputs,
        outputs
    ],
    style=STYLE_PAGE
)

## Image segmentation

@app.callback(
    Output('graph-picture', 'figure'),
    Input('image-idx', 'active_page'),
    Input('select', 'value'),
    Input('threshold', 'value')
)
def segment_image(page, value, threshold):
    """
    Segment the image
    """
    image_idx = page - 1
    # Case of an empty query
    if len(value) == 0:
        return get_figure(get_image(image_idx))
    
    # Image segmentation
    image = get_image(image_idx)
    logits = compute_logits(image_idx, value).numpy()
    if len(value) == 1:
        logits = np.expand_dims(logits, axis=0)    
    
    labels = [all_labels[i] for i in value]
    mask = compute_mask(
        np.array(logits.tolist()),
        labels,
        {k: label2color[k] for k in labels},
        threshold,
        size=image.size[::-1]
    )
    image2 = image.copy()
    image2.paste(mask, (0, 0), mask)  
    return get_figure(image2)