from datetime import datetime
import json
import dataiku

from langchain.prompts import (
    ChatPromptTemplate,
    MessagesPlaceholder,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
from langchain.chains import LLMChain
from langchain.memory import ConversationBufferMemory

import dash
from dash import dcc, html, ctx
from dash.dependencies import Input, Output, State
import dash_bootstrap_components as dbc

WEBAPP_NAME = "chatbot_webapp" # Name of the app (logged when a conversation is flagged)
VERSION = "1.0" # Version of the app (logged when a conversation is flagged)

LLM_ID = dataiku.get_custom_variables()["LLM_id"]
llm_provided = False
if len(LLM_ID) > 0:
    from dataiku.langchain.dku_llm import DKUChatLLM
    llm = DKUChatLLM(llm_id=LLM_ID, temperature=0)
    llm_provided = True

# Folder to log conversations flagged by users
folder = dataiku.Folder("zhYE6r2h")

ERROR_MESSAGE_MISSING_KEY = """
LLM connection missing. You need to add it as a project variable. Cf. this project's wiki.

Please note that this web app is not live on Dataiku’s public project gallery but you can test it by downloading the project and providing an LLM connection.
"""

def answer_question(history_questions, history_answers):
    """
    Answer the last question given all previous answers and questions
    """
    prompt = ChatPromptTemplate(
        messages=[
            SystemMessagePromptTemplate.from_template("You are a helpful assistant."),
            MessagesPlaceholder(variable_name="chat_history"),
            HumanMessagePromptTemplate.from_template("{question}")
        ]
    )
    memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)

    for i in range(len(history_questions)-1):
        memory.chat_memory.add_user_message(history_questions[i])
        try:
            memory.chat_memory.add_ai_message(history_answers[i])
        except IndexError:
            pass

    conversation_chain = LLMChain(
        llm=llm,
        prompt=prompt,
        verbose=True,
        memory=memory
    )
    
    result = conversation_chain({"question": history_questions[-1]})
    return result["text"]

# Layout

STYLE_CONVERSATION = {
    "overflow-y": "auto",
    "display": "flex",
    "height": "calc(90vh - 50px)",
    "flex-direction": "column-reverse",
    "width": "100%"
}

STYLE_MESSAGE = {
    "max-width": "80%",
    "width": "max-content",
    "padding": "5px 10px",
    "border-radius": 10,
    "margin-bottom": 10,
}

STYLE_BUTTON_BAR = {
    "display": "flex",
    "justify-content": "center",
    "gap": "10px",
    "margin-top": "10px"
}

reset_icon = html.Span(html.I(className="bi bi-trash3"), style=dict(paddingRight="5px"))
flag_icon = html.Span(html.I(className="bi bi-flag"), style=dict(paddingRight="5px"))

button_bar = html.Div(
    [
        dbc.Button(
            html.Span([reset_icon, 'Reset conversation']),
            id='reset-btn',
            title='Delete all previous messages'
        ),
        dbc.Button(
            html.Span([flag_icon, 'Flag conversation']),
            id='flag-btn',
            title='Flag the conversation, e.g. in case of inappropriate or erroneous replies'
        ),
    ],
    style=STYLE_BUTTON_BAR
)

send_icon = html.Span(html.I(className="bi bi-send"))

question_bar = html.Div(
    [
        dbc.InputGroup(
            [
                dbc.Input(id='query', value='', type='text', minLength=0),
                dbc.Button(send_icon, id='send-btn', title='Get an answer')
            ],
        ),
        button_bar
    ]
)

conversation = html.Div(
    html.Div(id="conversation"),
    style=STYLE_CONVERSATION
)

spinning_wheel = dbc.Spinner(
    dcc.Markdown(
        id='spinning_wheel',
        style={
            "height": "20px",
            "margin-bottom": "20px"
        }
    ),
    color="primary"
)

app.title = "Chatbot"
app.config.external_stylesheets = [
    dbc.themes.ZEPHYR,
    dbc.icons.BOOTSTRAP
]

app.layout = html.Div(
    [  
        conversation,
        spinning_wheel,
        question_bar,
        dcc.Store(id="logged"),
        dcc.Store(id='history_questions', storage_type='memory', data=[]),
        dcc.Store(id='history_answers', storage_type='memory', data=[]),
    ],
    style={
        "margin": "auto",
        "text-align": "left",
        "max-width": "800px"
    }
)

# Callbacks

def textbox(text, box="AI"):
    """
    Create the text box corresponding to a message
    """
    style = dict(STYLE_MESSAGE)
    if box == "user":
        style["margin-left"] = "auto"
        style["margin-right"] = 0
        color, inverse = "primary", True
    else:
        style["margin-left"] = 0
        style["margin-right"] = "auto"
        color, inverse = "light", False
    return dbc.Card(text, style=style, body=True, color=color, inverse=inverse)

@app.callback(
    Output("conversation", "children"),
    Input("history_questions", "data"),
    Input("history_answers", "data"),
    prevent_initial_call=True
)
def update_display(history_questions, history_answers):
    """
    Display the messages of the conversation
    """
    result = []
    for i in range(len(history_questions)):
        result.append(textbox(history_questions[i], box="user"))
        try:
            result.append(textbox(history_answers[i], box="AI"))
        except IndexError:
            pass
    return result

@app.callback(
    Output('history_questions', 'data'),
    Output('query', 'value'),
    Input('reset-btn', 'n_clicks'),
    Input('send-btn', 'n_clicks'),
    Input('query', 'n_submit'),
    State('query', 'value'),
    State('history_questions', 'data'),
    State('history_answers', 'data'),
    prevent_initial_call=True
)
def receive_query(reset, n_clicks, n_submit, query, history_questions, history_answers):
    """
    Receive the new query from the user
    """
    if not llm_provided:
        return [ERROR_MESSAGE_MISSING_KEY], ""
    if ctx.triggered_id == "reset-btn":
        return [], ""
    if len(query) > 0 and len(history_questions) == len(history_answers):
        history_questions.append(query)
        return history_questions, ""
    else:
        return history_questions, query

@app.callback(
    Output('history_answers', 'data'),
    Output('spinning_wheel', 'children'),
    Input('history_questions', 'data'),
    State('history_answers', 'data'),
    prevent_initial_call=True
)
def get_answer(history_questions, history_answers):
    """
    Receive the answer from the LLM
    """
    if not llm_provided:
        return [], ""
    if len(history_questions) == 0:
        return [], ''
    if len(history_questions) > len(history_answers):
        answer = answer_question(history_questions, history_answers)
        history_answers.append(answer)
    return list(history_answers), ''

@app.callback(
    Output('logged', 'data'),
    Input('flag-btn', 'n_clicks'),
    Input('history_questions', 'data'),
    Input('history_answers', 'data'),
)
def log_conversation(n_clicks, history_questions, history_answers):
    """
    Log the current conversation
    """
    if len(history_answers) > 0 and ctx.triggered_id == "flag-btn":
        path = f"/{hash(tuple(history_questions + history_answers))}.json"
        with folder.get_writer(path) as w:
            w.write(bytes(json.dumps({
                "history_questions": history_questions,
                "history_answers": history_answers,
                "timestamp": str(datetime.now()),
                "webapp": WEBAPP_NAME,
                "version": VERSION
            }), "utf-8"))
        return path
    else:
        return ""

@app.callback(
    Output('reset-btn', 'disabled'),
    Output('flag-btn', 'disabled'),
    Output('send-btn', 'disabled'),
    Input('history_questions', 'data'),
    Input('history_answers', 'data'),
    Input('logged', 'data'),
)
def disable_buttons(history_questions, history_answers, flagged):
    """
    Disable buttons when appropriate
    """
    if len(history_answers) == 0:
        disable_reset = True
        disable_flag = True
    else:       
        disable_reset = False
        disable_flag = True if flagged != "" else False
    disable_send = True if len(history_questions) > len(history_answers) else False
    return disable_reset, disable_flag, disable_send