import json
from collections import Counter

import dataiku

import numpy as np
import pandas as pd
from transformers import AutoTokenizer, AutoModel

import elasticsearch

import dash
from dash import dcc
from dash import html
from dash.dependencies import Input, Output, State, ALL
import dash_bootstrap_components as dbc
import plotly.graph_objs as go

from project_utils import normalize, compute_embeddings

N_RESULTS = 20
MAX_BAR_CHART_ROWS = 5
EXACT_SEARCH = True

text_label = dataiku.get_custom_variables()["text_label"]

# Load model and tokenizer
model_name = dataiku.get_custom_variables()["model_name"]
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

# Create Elasticsearch client
elastic_user, elastic_password, cloud_id = None, None, None

client = dataiku.api_client()
auth_info = client.get_auth_info(with_secrets=True)
for secret in auth_info["secrets"]:
    if secret["key"] == "elastic_user":
        elastic_user = secret["value"]
    elif secret["key"] == "elastic_password":
        elastic_password = secret["value"]
    elif secret["key"] == "cloud_id":
        cloud_id = secret["value"]

assert (
    elastic_user is not None
    and elastic_password is not None
    and cloud_id is not None
)

es = elasticsearch.Elasticsearch(
    cloud_id=cloud_id,
    basic_auth=(elastic_user, elastic_password),
    request_timeout = 60,
    max_retries=5,
    retry_on_timeout=True,
)
index = 'semantic-search'

# CSS styles
PAGE_STYLE = {
    'max-width': '1200px',
    'margin': 'auto',
    'text-align': 'center'
}

SIDEBAR_STYLE = {
    "position": "fixed",
    "top": 0,
    "left": 0,
    "bottom": 0,
    "width": "24rem",
    "padding": "2rem 1rem",
    "background-color": "#f8f9fa",
}

FILTER_COMPONENT_STYLE = {
    'margin-bottom': '10px',
}

CONTENT_STYLE = {
    "margin-left": "26rem",
    "padding": "2rem 1rem"
}

SEARCH_BAR_STYLE = {
    'margin': 'auto',
    'text-align': 'center'
}

RESULTS_STYLE = {
    'text-align': 'justify',
    'margin-top': '10px'
}

#### Edit below to adjust filters
filter_components = []

aggs = {
    "min_date": {"min": {"field": "date", "format": "yyyy-MM-dd"}},
    "max_date": {"max": {"field": "date", "format": "yyyy-MM-dd"}}
}
result = es.search(index=index, aggs=aggs, size=0)
min_year = int(result["aggregations"]["min_date"]["value_as_string"][:4])
max_year = int(result["aggregations"]["max_date"]["value_as_string"][:4])

filter_components.append(
    html.Div(dcc.RangeSlider(
        id='year_range',
        min=min_year,
        max=max_year,
        step=1,
        marks={i:str(i) for i in range(min_year, max_year + 1)},
        value=[min_year, max_year],
    ), style=FILTER_COMPONENT_STYLE)
)

aggs = {
    "list" : {
        "terms" : {
            "field" : "category", 
            "size": 100000
        }
    }
}

result = es.search(index=index, aggs=aggs, size=0)
categories = sorted([x["key"] for x in result["aggregations"]["list"]["buckets"]])
categories_df = dataiku.Dataset("categories").get_dataframe()
labels = {
    categories_df['code'].iloc[i]: categories_df['category'].iloc[i]
    for i in range(len(categories_df))
}

def shorten(string):
    """
    Create a shortened version of a label to be included in a dropdown box
    """
    s = string.strip()
    return s if len(s) <= 36 else s[:33] + '...'

filter_components.append(
    html.Div(dcc.Dropdown(
        id='categories_selected',
        options=[{'label': shorten(labels[x]), 'value': x} for x in categories],
        value=[],
        placeholder="Select one or several categories",
        multi=True,
    ), style=FILTER_COMPONENT_STYLE)
)

aggs = {
    "list" : {
        "terms" : {
            "field" : "organization", 
            "size": 100000
        }
    }
}

result = es.search(index=index, aggs=aggs, size=0)
organizations = sorted([x["key"] for x in result["aggregations"]["list"]["buckets"]])

filter_components.append(
    html.Div(dcc.Dropdown(
        id='orgs_selected',
        options=[{'label': shorten(x), 'value': x} for x in organizations],
        value=[],
        placeholder="Select one or several organizations",
        multi=True,
    ), style=FILTER_COMPONENT_STYLE)
)

filter_inputs = [Input(x.children.id,'value') for x in filter_components]

def build_filters(*args):
    """
    Create the filter argument of the Elasticsearch query
    """
    year_range, categories_selected, orgs_selected = args
    filters = []
    if len(orgs_selected) > 0:
        filters.append({"terms": {"organization": orgs_selected}})
    if len(categories_selected) > 0:
        filters.append({"terms": {"category": categories_selected}})
    if (year_range[0] > min_year) or (year_range[1] < max_year):
        date_min, date_max = str(year_range[0]) + "-01-01", str(year_range[1]) + "-12-31"
        filters.append({"range": {"date": {"gte": date_min, "lte": date_max}}})
    if len(filters) == 0:
        filters = [{"match_all": {}}]
    return filters

def no_results_displayed(*args):
    """
    In case of an empty query, determine when no results are displayed
    """
    year_range, categories_selected, orgs_selected = args
    return len(categories_selected) == 0 and len(orgs_selected) == 0

#### Edit below to adjust the display of charts
def create_fig(l):
    """
    Create a bar chart displaying the number of occurrences of the Top-X elements of the input list
    """
    x, y = [], []
    count = Counter(l)
    for key, value in count.most_common()[:MAX_BAR_CHART_ROWS][::-1]:
        y.append(shorten(str(key)) + ' ')
        x.append(value)
    fig = go.Figure([go.Bar(y=y, x=x, text=x, orientation='h', marker_color='#3459e6')])
    fig.update_layout(
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)',
        margin=dict(t=60, b=0, r=10, l=10),
        height=30*len(x) + 60,
        xaxis = dict(
            visible = False,
            tickmode = 'linear',
            tick0 = 0,
            dtick = 1
        )
    )
    return fig

num_figs = 2 # Should be the same as the length of the output of create_figs below

def create_figs(results):
    """
    Derive the figures from the results dataframe
    """
    categories = []
    for i in range(len(results)):
        for category in results.iloc[i]["category"].split(", "):
            categories.append(shorten(labels[category.strip()]))
    years = []
    for i in range(len(results)):
        years.append(results.iloc[i]["date"][:4])
    return [create_fig(categories), create_fig(years)]

fig_components = [
    dcc.Graph(
        id=f"bar-{i}",
        figure={},
        config={'displayModeBar': False})
    for i in range(num_figs)
]

fig_outputs = [Output(f"bar-{i}", "figure") for i in range(num_figs)]

#### Edit below to adjust the display of results
def format_results(results):
    """
    Display the content of each individual document matching the query
    """
    return [html.P([
        html.Span(html.B(f"{str(i + 1)}. {results.iloc[i]['title']} ")),
        dbc.Badge(results.iloc[i]['organization'], color="primary", className="ml-1", style={'font-size': '16px'}) if results.iloc[i]['organization'] != 'nan' else '',
        html.Span(f" ({results.iloc[i]['date'][:4]}): {results.iloc[i][text_label]}"),
        html.Br(),
        html.Span(html.I(' / '.join([shorten(labels[x.strip()]) for x in results.iloc[i]['category'].split(',')]) + ' ')),
        html.A('Similar results', id={'type': 'similar', 'index': str(results.iloc[i].name)}, href='#')
    ]) for i in range(len(results))]

#### No need to edit the code below

# Layout

download_icon = html.Span(html.I(className="bi bi-cloud-arrow-down-fill"))
download_button = [html.Span([
    dbc.Button(download_icon, id="download-button"),
    dcc.Download(id="download-csv")
])]

sidebar = html.Div(
    filter_components + [
        html.Div(download_button + fig_components, id='graph-container', style={"display": "none"})
    ],
    style=SIDEBAR_STYLE,
)

search_icon = html.Span(html.I(className="bi bi-search"))
input_group = dbc.InputGroup(
    [
        dbc.Input(id='my-input', value='', type='text'),
        dbc.Button(search_icon, id='submit-val')
    ]
)

search_bar = html.Div(
    [
        html.H4('Semantic search'),
        input_group,
        html.Br()
    ],
    style=SEARCH_BAR_STYLE
)

results_section = html.Div(id='my-output', style=RESULTS_STYLE)
memory = dcc.Store(id='memory', storage_type='memory')

main_section = html.Div(
    [
        search_bar,
        results_section,
        memory
    ],
    style=CONTENT_STYLE
)

app.config.external_stylesheets =[dbc.themes.ZEPHYR, dbc.icons.BOOTSTRAP]
app.title = "Semantic search"
app.layout = html.Div(
    [
        sidebar,
        main_section
    ],
    style=PAGE_STYLE
)       

# Search functions

def convert_dict_to_row(result):
    output = []
    for k in ["title", text_label, "date"]:
        output.append(result["_source"][k])
    for k in ["organization", "category"]: 
        output.append(', '.join(result["_source"][k]))
    return output

def query_search_engine(text, query_embedding, *args):
    """
    Query the search engine (keyword search and vector similarity search)
    """
    query = {
        "bool": {
            "filter": build_filters(*args), 
            "should": {
                "term": {text_label: text}
            }
        }
    }
    knn_query = {
        "script_score": {
            "query" : query if EXACT_SEARCH else "",
            "script": {
                "source": "_score/(1.0 + _score) + (1.0 + cosineSimilarity(params.vector, 'embedding'))/2", 
                "params": {
                    "vector": list(query_embedding[0, :])
                }
            }
        }
    }
    results = es.search(
        index=index,
        query=knn_query,
        source_excludes=["embedding"],
        size=N_RESULTS,
        from_=0
    )
    results = results["hits"]["hits"]
    return pd.DataFrame.from_dict(
        {result["_id"]: convert_dict_to_row(result) for result in results},
        orient="index",
        columns=["title", text_label, "date", "organization", "category"]
    )

def get_vector(idx):
    """
    Get the vector representation of the document with the idx index
    """
    query = {
        "term": {
            "_id": idx
        }
    }
    resp = es.search(index=index, query=query, size=1, source_includes=["embedding"])
    return np.array([resp["hits"]["hits"][0]["_source"]["embedding"]])

# Search callback
@app.callback(
    Output('my-output', 'children'),
    *fig_outputs,
    Output("graph-container", "style"),
    Output('my-input', 'value'),
    Output('memory', 'data'),
    Input('submit-val', 'n_clicks'),
    Input('my-input', 'n_submit'),
    *filter_inputs,
    Input({'type': 'similar', 'index': ALL}, 'n_clicks_timestamp'),
    State('my-input', 'value'),
)
def search(n_clicks, n_submit, *args):
    """
    Compute the search results
    """
    values = args[-2]
    new_query, query_text = args[-1], args[-1]
    query_id = None

    for i in range(len(values)):
        # Case of a click on a "similar results" link
        if values[i] is not None:
            query_id = int(dash.callback_context.inputs_list[-1][i]['id']['index'])
            query_embeddings = get_vector(query_id)
            new_query = f'similar:{str(query_id)}'
            query_text = ""
            break
    else:
        # Case of an empty query, without filters applied
        if len(query_text) == 0 and no_results_displayed(*args[:-2]):
            return [''] + [{}]*num_figs + [{'display':'none'}, '', "{}"]
        # Case of the search of results similar to a given line
        if 'similar:' in query_text:
            try:
                query_id = query_text.split('similar:')[1]
                query_embeddings = get_vector(query_id)
                query_text = ""
            # Case of an unknown id
            except ValueError:
                return [[html.P('No results')]] + [{}]*num_figs + [{'display':'none'}, new_query, "{}"]
        # General case of a non-empty query
        else:
            query_embeddings = compute_embeddings(model, tokenizer, [str(query_text)])
            query_embeddings = normalize(query_embeddings)
    
    # Elasticsearch query
    results = query_search_engine(query_text, query_embeddings, *args[:-2])
    
    # Display of results
    if len(results) > 0:
        output = format_results(results)
        visibility = {'display':'block'}
    else:
        output = [html.P('No results')]
        visibility = {'display': 'none'}
    return [output] + create_figs(results) + [visibility, new_query, results.to_json(orient='columns')]

@app.callback(
    Output("download-csv", "data"),
    Input("download-button", "n_clicks"),
    State("memory", "data"),
    prevent_initial_call=True,
)
def download_results(n_clicks, memory):
    """
    Generate a CSV file to be downloaded by the user
    """
    output_df = pd.DataFrame.from_dict(json.loads(memory))
    return dcc.send_data_frame(output_df.to_csv, "data.csv")