import dataiku
import time

import dash
from dash import dcc
from dash import html
from dash.dependencies import Input, Output, State
import dash_bootstrap_components as dbc

NUM_SEARCH_RESULTS = 10
KB_ID = "iXhIIXV5"

# Test whether the embedding model connection exists
CONNECTION_AVAILABLE = False
project = dataiku.api_client().get_default_project()
for kb in project.list_knowledge_banks():
    if kb["id"] == KB_ID:
        for llm in project.list_llms(purpose='TEXT_EMBEDDING_EXTRACTION'):
            if llm["id"] == kb["embeddingLLMId"]:
                CONNECTION_AVAILABLE = True
                break
        break

ERROR_MESSAGE_MISSING_CONNECTION = dcc.Markdown("""
LLM connection missing. Cf. this project's [wiki](https://gallery.dataiku.com/projects/EX_LLM_STARTER_KIT/wiki/1/Project%20description).

**Please note that this web app is not live on Dataiku’s public project gallery but you can test it by downloading the project and providing an LLM Connection.**
""")
        
        
kb = project.get_knowledge_bank(KB_ID).as_core_knowledge_bank()
df = dataiku.Dataset("documents_extracted_prepared").get_dataframe()
sources = list(set(df["filename"]))

# Layout

send_icon = html.Span(html.I(className="bi bi-send"))
question_bar = dbc.InputGroup(
    [
        dbc.Input(id='query', value='', type='text', minLength=0),
        dbc.Button(send_icon, id='send-btn', title='Get search results')
    ],
    style = {"margin-top": "20px", "margin-bottom": "10px"}
)

sources_filter = html.Div(dcc.Dropdown(
    id='sources_selected',
    options=[{'label': x, 'value': x} for x in sources],
    value=[],
    placeholder="Select one or several sources",
    multi=True
))

answer = html.Div(
    [
        dbc.Spinner(
            html.Div(id='spinner', style={"height": "40px"}),
            color="primary"
        ),
        html.Div(id='answer')
    ]
)

app.title = "Semantic search"
app.config.external_stylesheets = [
    dbc.themes.ZEPHYR,
    dbc.icons.BOOTSTRAP
]
app.layout = html.Div(
    [
        html.H4(
            "Semantic search",
            style={"margin-top": "40px", "text-align": "center"}
        ),
        question_bar,
        sources_filter,
        answer,
    ],
    style={
        "margin": "auto",
        "text-align": "left",
        "max-width": "800px"
    }
)

# Callback


def retrieve(search_query, search_filter={}, k=NUM_SEARCH_RESULTS):
    """
    Retrieve chunks most semantically similar to the query.
    """
    fetch_k = k if len(search_filter) == 0 else 10*k
    return kb.as_langchain_retriever(
        search_kwargs={"k": k, "fetch_k": fetch_k, "filter": search_filter}
    ).get_relevant_documents(search_query)

def escape_markdown(text):
    """
    Escape markdown syntax.
    """
    return text.replace('\\*', '*').replace('\\`', '`').replace('\\_', '_')\
        .replace('\\~', '~').replace('\\>', '>').replace('\\[', '[')\
        .replace('\\]', ']').replace('\\(', '(').replace('\\)', ')')\
        .replace('*', '\\*').replace('`', '\\`').replace('_', '\\_')\
        .replace('~', '\\~').replace('>', '\\>').replace('[', '\\[')\
        .replace(']', '\\]').replace('(', '\\(').replace(')', '\\)')

def format_docs(docs):
    """
    Format the search results snippets.
    """
    result = []
    for i in range(len(docs)):
        content = "**" + escape_markdown(f"{i+1}. {docs[i].metadata['unique_id']}") + "**\n\n"
        content += escape_markdown(docs[i].page_content)
        result.append(dcc.Markdown(content))
    return result

@app.callback(
    Output('answer', 'children'),
    Output('spinner', 'children'),
    Input('send-btn', 'n_clicks'),
    Input('query', 'n_submit'),
    Input('sources_selected', 'value'),
    State('query', 'value'),
)
def similarity_search(n_clicks, n_submit, sources_selected, query):
    """
    Display the semantic search results.
    """
    if len(query) == 0:
        return "", ""
    if not CONNECTION_AVAILABLE:
        return ERROR_MESSAGE_MISSING_CONNECTION, ""
    start = time.time()
    if len(sources_selected) in [0, len(sources)]:
        answer = format_docs(retrieve(query))
    else:
        answer = format_docs(retrieve(query, search_filter={"filename": sources_selected}))
    answer.append(html.P(f"Query delay: {(time.time()-start):.1f} seconds"))
    return answer, ""