import dataiku
import json
import io
from functools import lru_cache
import numpy as np
from PIL import Image
from skimage.transform import resize
import torch
from transformers import CLIPSegForImageSegmentation, CLIPSegProcessor

SAVE_DISABLED = True # Change this to False if you want users to be able to save predictions

import plotly.express as px
import plotly.graph_objs as go
from dash import dcc, html, Input, Output, State
import dash_bootstrap_components as dbc

from project_utils import (
    get_color_map,
    compute_activations,
    save_image
)

folder = dataiku.Folder("Tmf77vDr")
output_folder = dataiku.Folder("1HRv3kle")
paths = sorted(folder.list_paths_in_partition())

model_name = dataiku.get_custom_variables()["clipseg_model_name"]
model = CLIPSegForImageSegmentation.from_pretrained(model_name).eval()
processor = CLIPSegProcessor.from_pretrained(model_name)

@lru_cache()
def get_image(image_idx):
    """
    Load and resize an image with a given index from the input folder.
    """
    path = paths[image_idx]
    with folder.get_download_stream(path) as f:
        buf = io.BytesIO(f.read())
        image = Image.open(buf)
    w, h = image.size
    return image.resize((800, h*800//w))

def get_figure(image):
    """
    Create and configure a plotly figure from an image.
    """
    fig = px.imshow(image)
    fig.update_layout(dragmode=False, margin=dict(l=0, r=0, t=0, b=0))
    fig.update_traces(hoverinfo='none', hovertemplate=None)
    fig.update_xaxes(visible=False)
    fig.update_yaxes(visible=False)
    return fig

get_activations = lru_cache(
    lambda x: compute_activations(model, processor, get_image(x))
)

for i in range(len(paths)):
    _ = get_activations(i)

def compute_logits(image_idx, prompts):
    """
    Compute the logits for a given set of prompts and an image.
    """
    with torch.no_grad():
        conditional_embeddings = model.get_conditional_embeddings(
            batch_size=len(prompts),
            **processor(
                text=prompts,
                padding=True,
                return_tensors="pt"
            ),
        )
        return model.decoder(
            get_activations(image_idx),
            conditional_embeddings
        )[0]

def compute_mask(logits, labels, label2color, threshold, size=None):
    """
    Compute the segmentation mask given CLIPSeg's predictions, labels and threshold.
    """
    if size is None:
        preds = 1/(1 + np.exp(-logits))
    else:
        preds = 1/(1 + np.exp(-resize(logits, (logits.shape[0], *size))))
    values, indices = np.max(preds, axis=0), np.argmax(preds, axis=0)
    segmentation_map = np.where(values < threshold, -1, indices)
    
    rgba_array = np.zeros((*(segmentation_map.shape), 4), dtype=np.uint8)
    for i in range(len(labels)):
        label = labels[i]
        color = label2color[label]
        indices = segmentation_map == i
        rgba_array[indices] = color
    return Image.fromarray(rgba_array)

def get_legend(colors):
    """
    Create a list of styled HTML spans representing the legend for a set of colors.
    """
    result = []
    for label in colors:
        color = f"rgb{str(tuple(colors[label][:-1]))}"
        result.append(
            html.Span([
                html.I(className="bi bi-square-fill", style={"color": color}),
                html.Span(label, style={"margin": "0px 10px 0px 5px"})
            ])
        )
    return result

## CSS styles

STYLE_COMPONENT = {
    "margin": "10px auto 0px auto",
    "max-width": "650px"
}

STYLE_GROUP = {
    "margin-top": "20px"
}

STYLE_DOWNLOAD_BUTTON = STYLE_COMPONENT.copy()
STYLE_DOWNLOAD_BUTTON["display"] = "none"

STYLE_PAGINATION = {
    "display": "flex",
    "justify-content": "space-evenly",
    "margin": "10px auto 0px auto",
}

STYLE_NOTIFICATION = {
    "position": "absolute",
    "top": 10,
    "left": "50%",
    "transform": "translateX(-50%)",
    "width": 250,
    "z-index": "10"
}

STYLE_PAGE = {
    "margin": "auto",
    "max-width": "800px",
    "text-align": "center",
}

## Input bar

search_icon = html.Span(html.I(className="bi bi-search"))
search_bar = dbc.InputGroup(
    [
        dbc.Input(id='query', value='', type='text', minLength=0),
        dbc.Button(search_icon, id='search-btn', title='Detect object')
    ],
    style=STYLE_COMPONENT
)

threshold_slider = html.Div(
    [
        html.Label("Threshold"),
        dcc.Slider(
            0, 1.0, 0.01,
            value=0.5,
            marks={
              i/100: {"label": str(i/100)} for i in [0] + list(range(10, 110, 10))
            },
            id='threshold',
        )       
    ],
    style=STYLE_COMPONENT
)

inputs = html.Div(
    [
        html.Div(
            dbc.Pagination(
                id="image-idx",
                active_page=1,
                max_value=len(paths),
                first_last=True,
                previous_next=True,
                fully_expanded=False,
            ),
            style=STYLE_PAGINATION
        ),
        html.Div(
            [
                search_bar,
                threshold_slider
            ],
            style=STYLE_GROUP
        )
   ]
)

## Outputs

download_icon = html.Span(html.I(className="bi bi-cloud-arrow-down-fill"))

outputs = html.Div(
    [
        dcc.Store(id='segmented', storage_type='memory'),
        dcc.Graph(
            id="graph-picture",
            figure=get_figure(get_image(0)),
            config={
                "displayModeBar": False,
                "displaylogo": False
            },
            style=STYLE_GROUP
        ),
        html.Div(
            [
                html.Div(id="legend", style=STYLE_COMPONENT),        
                dbc.Button(
                    download_icon,
                    id="download-btn",
                    style=STYLE_DOWNLOAD_BUTTON,
                ),
            ],
            style=STYLE_GROUP
        ),
        dbc.Toast(
            id="notification",
            header_style={"display": "none"},
            duration=1000,
            style=STYLE_NOTIFICATION,
            is_open=False,
        ),
    ],
)

## Overall layout

app.title = "Few-shot object detection"
app.config.external_stylesheets = [
    dbc.themes.ZEPHYR,
    dbc.icons.BOOTSTRAP
]
app.layout = html.Div(
    [
        inputs,
        outputs
    ],
    style=STYLE_PAGE
)

## Image segmentation


@app.callback(
    Output('segmented', 'data'),
    Input('search-btn', 'n_clicks'),
    Input('query', 'n_submit'),
    Input('image-idx', 'active_page'),
    State('query', 'value'),
    prevent_initial_call=True
)
def segment_image(n_clicks, n_submit, page, query):
    """
    Segment the image
    """
    image_idx = page - 1
    output = {
        'predictions': {},
        'color': {},
        'labels': [],
        'idx': image_idx
    }
    
    # Case of an empty query
    if len(query) == 0:
        return json.dumps(output)
    
    # Parse the query (classes are separated by "," and the labels can be added after ":")
    # Example: laptop:computer,desktop:computer,mouse
    queries = [x.strip() for x in query.split(",")]
    classes, prompts = [], []
    for i in range(len(queries)):
        if ":" in queries[i]:
            splitted = queries[i].split(":")
            label = ":".join(splitted[1:])
            prompt = splitted[0]
        else:
            prompt = queries[i]
            label = queries[i]
        prompts.append(prompt)
        classes.append(label)
    
    output["labels"] = classes
    output["color"] = get_color_map(classes)
    
    # Image segmentation
    logits = compute_logits(image_idx, prompts).numpy()
    if len(prompts) == 1:
        logits = np.expand_dims(logits, axis=0)    
    output["predictions"] = logits.tolist()
    
    return json.dumps(output)

@app.callback(
    Output('graph-picture', 'figure'),
    Output('legend', 'children'),
    Output('download-btn', 'style'),
    Input('segmented', 'data'),
    Input('threshold', 'value'),
    prevent_initial_call=True
)
def update_image(segmented, threshold):
    """
    Update the image with the segmentation mask.
    """
    segmented = json.loads(segmented)
    image = get_image(segmented["idx"])
    style = STYLE_DOWNLOAD_BUTTON.copy()

    if len(segmented["predictions"]) == 0:
        return get_figure(image), "", style
    
    style["display"] = "block"
    
    mask = compute_mask(
        np.array(segmented["predictions"]),
        segmented["labels"],
        segmented["color"],
        threshold,
        size=image.size[::-1]
    )
    image2 = image.copy()
    image2.paste(mask, (0, 0), mask)
    
    return get_figure(image2), get_legend(segmented["color"]), style

@app.callback(
    Output('notification', 'is_open'),
    Output('notification', 'children'),
    Input('download-btn', 'n_clicks'),
    State('segmented', 'data'),
    State('threshold', 'value'),
    prevent_initial_call=True
)
def save_annotation(n_clicks, segmented, threshold):
    """
    Download the annotation of the current image.
    """
    
    if SAVE_DISABLED:
        return True, "Save button disabled. Change SAVE_DISABLED to enable it"
    
    segmented = json.loads(segmented)
    if len(segmented["predictions"]) == 0:
        return False, ""
    path = paths[segmented["idx"]]
    
    # Save segmentation mask as a .png image
    mask = compute_mask(
        np.array(segmented["predictions"]),
        segmented["labels"],
        segmented["color"],
        threshold
    )
    save_image(output_folder, path +  ".png", mask, image_format="PNG")
    
    # Save color map as a .json file
    with output_folder.get_writer(path + ".json") as w:
        w.write(bytes(json.dumps(segmented["color"]), 'utf-8'))

    return True, "Segmentation mask saved"