import os
from datetime import datetime
import json
import dataiku
import base64
from PIL import Image
from io import BytesIO
from dash.exceptions import PreventUpdate
from lens import Lens, LensProcessor
import torch

from openai import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.prompts import (
    ChatPromptTemplate,
    MessagesPlaceholder,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
from langchain.chains import LLMChain
from langchain.memory import ConversationBufferMemory

import dash
from dash import dcc, html, ctx
from dash.dependencies import Input, Output, State
import dash_bootstrap_components as dbc

from chatbot.app_style import *
from chatbot.app_functions import *

# Name and version of the app (logged when a conversation is flagged)
WEBAPP_NAME = "chatbot_webapp" 
VERSION = dataiku.get_custom_variables()[WEBAPP_NAME] # Version of the app (logged when a conversation is flagged)

# LLM ID
LLM_ID = dataiku.get_custom_variables()["LLM_id"]

# Check if LLM is provided
llm_provided = False
if len(LLM_ID) > 0:
    from dataiku.langchain.dku_llm import DKUChatLLM
    llm = DKUChatLLM(llm_id=LLM_ID, temperature=0)
    llm_provided = True

# OpenAI credentials (to be added as a user secret)
auth_info = dataiku.api_client().get_auth_info(with_secrets=True)
for secret in auth_info["secrets"]:
    if secret["key"] == "openai_key":
        llm_provided = True
        openai_api_key = secret["value"]
        client = OpenAI(api_key=openai_api_key)

# Load LENS model and processor
if llm_provided:
    lens = Lens()
    processor = LensProcessor()
        
# Folder to log conversations flagged by users
folder = dataiku.Folder("FyDceqZA")

# Error message for missing OpenAI API key
ERROR_MESSAGE_MISSING_KEY = """
OpenAI API key or LLM connection missing. Cf. this project's wiki.

Please note that this web app is not live on Dataiku’s public project gallery but you can test it by downloading the project and providing an OpenAI API key and/or an LLM connection.
"""

# Layout

# Icons
reset_icon = html.Span(html.I(className="bi bi-trash3"), style=dict(paddingRight="5px"))
flag_icon = html.Span(html.I(className="bi bi-flag"), style=dict(paddingRight="5px"))
image_icon = html.Span(html.I(className="bi bi-card-image"))
send_icon = html.Span(html.I(className="bi bi-send"))

# Button bar
button_bar = html.Div(
    [
        dbc.Button(
            html.Span([reset_icon, 'Reset conversation']),
            id='reset-btn',
            title='Delete all previous messages'
        ),
        dbc.Button(
            html.Span([flag_icon, 'Flag conversation']),
            id='flag-btn',
            title='Flag the conversation, e.g. in case of inappropriate or erroneous replies'
        ),
    ],
    style=STYLE_BUTTON_BAR
)

# Question bar
question_bar = html.Div(
    [
        dbc.InputGroup(
            [
                dcc.Upload(
                   children=[dbc.Button(image_icon, id='upload-btn', title='Upload File', style={"border-radius": "5px 0px 0px 5px"})],
                   id='upload-img'
                   
                ),
                dbc.Input(id='query', value='', type='text', minLength=0),
                dbc.Button(send_icon, id='send-btn', title='Get an answer')
            ],
        ),
        dcc.Store(id='store-img', storage_type='memory', data={}),
        dcc.Store(id='b64-images', storage_type='memory', data=[]),
        button_bar
    ],
    style={"margin-top": "20px"}
)

# Conversation display
conversation = html.Div(
    html.Div(id="conversation"),
    style=STYLE_CONVERSATION
)

# Spinning wheel for loading
spinning_wheel = dbc.Spinner(
    dcc.Markdown(
        id='spinning_wheel'
    ),
    color="primary"
)

# Spinning wheel for loading image
spinning_wheel_loading = dbc.Spinner(
    dcc.Markdown(
        id='spinning_wheel_load'
    ),
    color="primary"
)

# Model selection dropdown
select = html.Div([
        html.P("Choose your model:", style={"margin-right": "10px", "margin-top": "8px"}),
        dbc.Select(
            id="select",
            value="lens",
            options=[
                {"label": "LENS", "value": "lens"},
                {"label": "GPT-4V", "value": "gpt4v"}
            ],
            style={"width": "120px"}
    )], style={"margin-bottom": "10px", "display": "flex"})

# App setup
app.title = "Chatbot"
app.config.external_stylesheets = [
    "https://fonts.googleapis.com/css2?family=Outfit:wght@100;200;300;400;500;600;700;800;900&display=swap",
    dbc.themes.ZEPHYR,
    dbc.icons.BOOTSTRAP
]
font_family = "Outfit"

initial_questions = [] if llm_provided else [ERROR_MESSAGE_MISSING_KEY]

# App layout
app.layout = html.Div(
    [
        html.H2('Multimodal Chatbot', style={"font-size": 25}),
        html.H4("Have a conversation with a multimodal AI system.", style={"color": "#ABABAB", "font-weight": "300", "font-size": 16}),
        select,
        conversation,
        spinning_wheel,
        spinning_wheel_loading,
        question_bar,
        dcc.Store(id="logged"),
        dcc.Store(id='history_questions', storage_type='memory', data=initial_questions),
        dcc.Store(id='history_answers', storage_type='memory', data=[]),
    ],
    style={
        "margin": "auto",
        "text-align": "left",
        "max-width": "700px",
        "font-family": font_family,
        "justify-content": "center",  # Center horizontally
        "align-items": "center",      # Center vertically
        "height": "100vh",
        "padding": "20px"
    }
)

# Callbacks

# Callback to update conversation display
def textbox(text, box="AI"):
    """
    Create the text box corresponding to a message
    """
    style = dict(STYLE_MESSAGE)
    if box == "user":
        style["margin-left"] = "auto"
        style["margin-right"] = 10
        color, inverse = "primary", True
        #if text.get('type') == 'Img':
        if type(text) != str:
            color="secondary"
            style["border"] = "none"
            style["box-shadow"] = "none"
    else:
        style["margin-left"] = 10
        style["margin-right"] = "auto"
        color, inverse = "light", False
    return dbc.Card(text, style=style, body=False, color=color, inverse=inverse)

# Callback to update conversation display
@app.callback(
    Output("conversation", "children"),
    Input("history_questions", "data"),
    Input("history_answers", "data"),
#    prevent_initial_call=True
)
def update_display(history_questions, history_answers):
    """
    Display the messages of the conversation
    """
    result = []
    for i in range(len(history_questions)):
        result.append(textbox(history_questions[i], box="user"))
        try:
            result.append(textbox(history_answers[i], box="AI"))
        except IndexError:
            pass
    return result

# Callback to process the user query
@app.callback(
    Output('history_questions', 'data'),
    Output('query', 'value'),
    Output('store-img', 'data'),
    Output('spinning_wheel_load', 'children'),
    Output('b64-images', 'data'),
    Output('upload-img', 'contents'),
    Input('reset-btn', 'n_clicks'),
    Input('send-btn', 'n_clicks'),
    Input('query', 'n_submit'),
    Input('upload-img', 'contents'),
    Input('select', 'value'),
    State('query', 'value'),
    State('history_questions', 'data'),
    State('history_answers', 'data'),
    State('store-img', 'data'),
    State('b64-images', 'data'),
    prevent_initial_call=True
)
def receive_query(reset, n_clicks, n_submit, content, select, query, history_questions, history_answers, store_img, b64_images):
    """
    Receive the new query from the user
    """
    if not llm_provided:
        return history_questions, "", {}, "", [], None
    if ctx.triggered_id == "reset-btn" or ctx.triggered_id == "select":
        return [], "", {}, "", [], None
    if ctx.triggered_id == "upload-img":
        captions = get_captions(lens, processor, content, select)
        store_img[content], b64_image = captions
        b64_images.append(b64_image)
        img = html.Img(src=content, style=STYLE_IMAGE)
        history_questions.append(img)
        return history_questions, query, store_img, "", b64_images, content
    if len(query) > 0 and len(history_questions) == len(history_answers):
        history_questions.append(query)
        return history_questions, "", store_img, "", b64_images, content
    else:
        return history_questions, query, store_img, "", b64_images, content

# Callback to get AI answer
@app.callback(
    Output('history_answers', 'data'),
    Output('spinning_wheel', 'children'),
    Input('history_questions', 'data'),
    State('history_answers', 'data'),
    State('store-img', 'data'),
    State('b64-images', 'data'),
    State('select', 'value'),
    prevent_initial_call=True
)
def get_answer(history_questions, history_answers, store_img, b64_images, model):
    """
    Receive the answer from the LLM
    """
    if not llm_provided or len(history_questions) == 0:
        return [], ""
    if len(history_questions) > len(history_answers):
        if type(history_questions[-1]) == str: 
            answer = answer_question(
                history_questions, 
                history_answers, 
                store_img, 
                b64_images, 
                model,
                client,
                openai_api_key,
                llm
            )
        else:
            answer = "Thank you for this image"
        history_answers.append(answer)
    return list(history_answers), ''

# Callback to log conversation
@app.callback(
    Output('logged', 'data'),
    Input('flag-btn', 'n_clicks'),
    Input('history_questions', 'data'),
    Input('history_answers', 'data'),
)
def log_conversation(n_clicks, history_questions, history_answers):
    """
    Log the current conversation
    """
    if len(history_answers) > 0 and ctx.triggered_id == "flag-btn":
        path = f"/{hash(tuple(history_questions + history_answers))}.json"
        with folder.get_writer(path) as w:
            w.write(bytes(json.dumps({
                "history_questions": history_questions,
                "history_answers": history_answers,
                "timestamp": str(datetime.now()),
                "webapp": WEBAPP_NAME,
                "version": VERSION
            }), "utf-8"))
        return path
    else:
        return ""

# Callback to disable buttons
@app.callback(
    Output('reset-btn', 'disabled'),
    Output('flag-btn', 'disabled'),
    Output('send-btn', 'disabled'),
    Input('history_questions', 'data'),
    Input('history_answers', 'data'),
    Input('logged', 'data'),
)
def disable_buttons(history_questions, history_answers, flagged):
    """
    Disable buttons when appropriate
    """
    if len(history_answers) == 0:
        disable_reset = True
        disable_flag = True
    else:       
        disable_reset = False
        disable_flag = True if flagged != "" else False
    disable_send = True if len(history_questions) > len(history_answers) else False
    return disable_reset, disable_flag, disable_send