import dataiku
import colorsys

from dash import dcc
from dash import html
from dash.dependencies import Input, Output, State
import dash_bootstrap_components as dbc

models = [
    "tf-idf",
    "sentence-embeddings",
    "finetuning",
    "zero-shot_nli",
    "zero-shot_chatgpt",
    "few-shot_setfit",
    "few-shot_chatgpt"
]
dfs = {}

def get_color_map(labels, saturation=0.8, lightness=0.4):
    """
    Get a category-to-color mapping for several categories.
    """
    color_map = {}
    n = len(labels)
    for k in range(n):
        hls = (k/n, lightness, saturation)
        color_map[labels[k]] = f"rgb{tuple([int(255*x) for x in (list(colorsys.hls_to_rgb(*hls)))])}"
    return color_map

for model in models:
    df = dataiku.Dataset(f"test_scored_{model}").get_dataframe()
    df = df[df["prediction"] != df["label_text"]]
    dfs[model] = df
labels = sorted(list(set(df["label_text"])))
label2color = get_color_map(labels)

# Layout

STYLE_CONTAINER = {
    "margin-top": "20px"
}

app.title = "Error analysis"
app.config.external_stylesheets = [
    dbc.themes.ZEPHYR,
    dbc.icons.BOOTSTRAP
]

options_models = [{"label": model, "value": model} for model in models]
options_labels = [{"label": label, "value": label} for label in labels]

app.layout = html.Div(
    [
        html.H4(
            "Error analysis",
            style={"margin-top": "20px", "text-align": "center"}
        ),
        dbc.Form(
            dbc.Row(
                [
                    dbc.Label("Model", width="auto"),
                    dbc.Col(dbc.Select(
                        id='model_selected',
                        options=options_models,
                        value=models[0],
                    ), className="me-3"),
                    dbc.Label("Ground truth label", width="auto"),
                    dbc.Col(dbc.Select(
                        id='label_selected',
                        options=options_labels,
                        value=labels[0],
                    ), className="me-3"),
                ],
                className="g-2"
            ),
            style=STYLE_CONTAINER
        ),
        html.Div(
            id='examples',
            style=STYLE_CONTAINER
        )
    ],
    style={
        "margin": "auto",
        "text-align": "left",
        "max-width": "800px"
    }
)

# Callback

@app.callback(
    Output("examples", "children"),
    Input("model_selected", "value"),
    Input("label_selected", "value"),
)
def display_errors(model, label):
    """
    Display the errors for a given model and a given ground truth label
    """
    df = dfs[model]
    df = df[df["label_text"] == label]
    output = [html.H6(f"{len(df)} misclassified test examples (with the incorrect predictions)")]
    output += [
        html.P(
            [
                html.Span(f"{str(i + 1)}. "),
                dbc.Badge(
                    df.iloc[i]['prediction'],
                    color=label2color[df.iloc[i]['prediction']],
                    className="ml-1",
                    style={'font-size': '16px'}
                ),
                html.Span(f" {df.iloc[i]['text']}")                    
            ]
        )
        for i in range(len(df))
    ]
    return output