import io
import re
import colorsys
from collections import Counter

import dataiku

import numpy as np
from sentence_transformers import SentenceTransformer

import dash
from dash import dcc
from dash import html
from dash.dependencies import Input, Output, State
import dash_bootstrap_components as dbc
import plotly.graph_objs as go

MAX_BAR_CHART_ROWS = 5
WEIGHT_EXACT_SCORE = 10

df = dataiku.Dataset("train").get_dataframe()
folder = dataiku.Folder("GuB6FLVF")
model = SentenceTransformer('sentence-transformers/paraphrase-mpnet-base-v2')

with folder.get_download_stream("embeddings.npy") as stream:
    corpus_embeddings = np.load(io.BytesIO(stream.read()))

index_words = {}
for i in range(len(df)):
    words = re.findall(r"\b\w+\b", df.iloc[i].text)
    for word in set([word.lower() for word in words]):
        if word in index_words:
            index_words[word].append(i)
        else:
            index_words[word] = [i]
    
def get_color_map(labels, saturation=0.8, lightness=0.4):
    """
    Get a category-to-color mapping for several categories.
    """
    color_map = {}
    n = len(labels)
    for k in range(n):
        hls = (k/n, lightness, saturation)
        color_map[labels[k]] = f"rgb{tuple([int(255*x) for x in (list(colorsys.hls_to_rgb(*hls)))])}"
    return color_map

# CSS styles
PAGE_STYLE = {
    'max-width': '1200px',
    'margin': 'auto',
    'text-align': 'center'
}

SIDEBAR_STYLE = {
    "position": "fixed",
    "top": 0,
    "left": 0,
    "bottom": 0,
    "width": "24rem",
    "padding": "2rem 1rem",
    "background-color": "#f8f9fa",
}

FILTER_COMPONENT_STYLE = {
    'margin-bottom': '10px',
}

CONTENT_STYLE = {
    "margin-left": "26rem",
    "padding": "2rem 1rem"
}

SEARCH_BAR_STYLE = {
    'margin': 'auto',
    'margin-top': '20px',
    'text-align': 'center'
}

RESULTS_STYLE = {
    'text-align': 'justify',
    'margin-top': '20px'
}

def shorten(string):
    """
    Create a shortened version of a label to be included in a dropdown box
    """
    s = string.strip()
    return s if len(s) <= 40 else s[:37] + '...'

df["label_text"] = df["label_text"].apply(shorten)

labels = set()
for row in df["label_text"]:
    labels.add(row.strip())
labels = sorted(list(labels))
label2color = get_color_map(labels)
labels = {c:c for c in labels}

def create_fig(results):
    """
    Derive the figures from the results dataframe
    """
    labels = []
    for i in range(len(results)):
        labels.append(results.iloc[i]["label_text"])
    x, y, colors = [], [], []
    count = Counter(labels)
    for key, value in count.most_common()[:MAX_BAR_CHART_ROWS][::-1]:
        colors.append(label2color[key])
        y.append(str(key) + ' ')
        x.append(value)
    fig = go.Figure([
        go.Bar(
            y=y,
            x=x,
            text=x,
            orientation='h',
            marker=dict(color=colors),
            textfont=dict(color="white")
        )
    ])
    fig.update_layout(
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)',
        margin=dict(t=60, b=0, r=10, l=10),
        height=30*len(x) + 60,
        xaxis = dict(
            visible = False,
            tickmode = 'linear',
            tick0 = 0,
            dtick = 1
        )
    )
    return fig

def format_results(results):
    """
    Display the content of each individual document matching the query
    """
    return [html.P(
        [
            html.Span(f"{str(i + 1)}. "),
            dbc.Badge(
                results.iloc[i]['label_text'],
                color=label2color[results.iloc[i]['label_text']],
                className="ml-1",
                style={'font-size': '16px'}
            ),
            html.Span(f" {results.iloc[i]['text']}")
        ]) for i in range(len(results))]

# Layout

sidebar = html.Div(
    [
        html.Label("Maximum number of results"),
        html.Div(
            dcc.Slider(
                5, 50, 5,
                value=20,
                id='num-results',
            ),
            style=FILTER_COMPONENT_STYLE
        ),
        dcc.Dropdown(
            id='search-type',
            options=[
                {'label': "Exact and semantic search", 'value': 1},
                {'label': "Exact search only", 'value': 0}
            ],
            value=1,
            clearable=False,
            style=FILTER_COMPONENT_STYLE
        ),
        dcc.Dropdown(
            id='labels-selected',
            options=[{'label': shorten(labels[x]), 'value': x} for x in labels],
            value=[],
            placeholder="Select one or several labels",
            multi=True,
            style=FILTER_COMPONENT_STYLE
        ),
        html.Div(
            dcc.Graph(
                id="bar-chart",
                figure={},
                config={'displayModeBar': False}
            ),
            id='graph-container',
            style={"display": "none"}
        )
    ],
    style=SIDEBAR_STYLE,
)

search_icon = html.Span(html.I(className="bi bi-search"))
search_bar = dbc.InputGroup(
    [
        dbc.Input(id='my-input', value='', type='text', placeholder="Enter search terms"),
        dbc.Button(search_icon, id='submit-val')
    ],
    style=SEARCH_BAR_STYLE
)

main_section = html.Div(
    [
        html.H4(
            "Interactive search",
            style={"margin-top": "20px", "text-align": "center"}
        ),
        search_bar,
        html.Div(id='my-output', style=RESULTS_STYLE),
    ],
    style=CONTENT_STYLE
)

app.config.external_stylesheets =[dbc.themes.ZEPHYR, dbc.icons.BOOTSTRAP]
app.title = "Semantic search"
app.layout = html.Div(
    [
        sidebar,
        main_section
    ],
    style=PAGE_STYLE
)       

# Search functions

def compute_exact_search_score(query):
    """
    Compute the exact search score
    """
    result = np.zeros(len(df))
    words = re.findall(r"\b\w+\b", query)
    if len(words) == 0:
        return result
    for w in set([word.lower() for word in words]):
        if w in index_words:
            for i in index_words[w]:
                result[i] = 1
    return result/len(words)

def compute_similarity_search_score(query_embeddings):
    """
    Compute the semantic search score
    """
    return (query_embeddings@np.transpose(corpus_embeddings))[0, :]

def compute_rank(query_embeddings, query, search_type):
    """
    Order the documents by similarity with the query vector
    """
    scores = WEIGHT_EXACT_SCORE*compute_exact_search_score(query)
    if search_type == 1:
        scores += compute_similarity_search_score(query_embeddings)
        return np.argsort(scores)[::-1]
    else:
        indices = np.argsort(scores)
        return indices[scores[indices] > 0][::-1]

# Search callback
@app.callback(
    Output('my-output', 'children'),
    Output('bar-chart', 'figure'),
    Output('graph-container', 'style'),
    Input('submit-val', 'n_clicks'),
    Input('my-input', 'n_submit'),
    Input('labels-selected','value'),
    Input('num-results','value'),
    Input('search-type','value'),
    State('my-input', 'value'),
)
def search(n_clicks, n_submit, labels_selected, num_results, search_type, query_text):
    """
    Compute the search results
    """
    if len(query_text) == 0 and len(labels_selected) == 0:
        return '', {}, {'display':'none'}

    query_embeddings = model.encode([str(query_text)])
    
    # Computation of the relevance score and reordering of the dataset
    ranks = compute_rank(query_embeddings, query_text, search_type)
    results = df.reindex(ranks)
    if len(labels_selected) > 0:
        results = results[results["label_text"].isin(labels_selected)]
    results = results.iloc[:num_results]
    
    # Display of results
    if len(results) > 0:
        output = format_results(results)
        visibility = {'display':'block'}
    else:
        output = [html.P('No results')]
        visibility = {'display': 'none'}
    return output, create_fig(results), visibility