# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import dataiku
import json
import os
from datetime import datetime
import mlflow

from project_utils import YouSearch, YouSearchNews, BraveSearch, clean_html

from langchain.tools import WikipediaQueryRun
from langchain.utilities import WikipediaAPIWrapper
from langchain.agents import initialize_agent, AgentType
from langchain.callbacks import MlflowCallbackHandler
from langchain.chat_models import ChatOpenAI
from langchain.schema import SystemMessage

df = dataiku.Dataset("questions").get_dataframe()
project = dataiku.api_client().get_default_project()

project = dataiku.api_client().get_default_project()
mlflow_handle = project.setup_mlflow(project.get_managed_folder("8SUKHUwt"))
EXPERIMENT_NAME = "agent"
mlflow.set_experiment(EXPERIMENT_NAME)

auth_info = dataiku.api_client().get_auth_info(with_secrets=True)
for secret in auth_info["secrets"]:
    if secret["key"] == "BRAVE_API_KEY":
        os.environ["BRAVE_API_KEY"] = secret["value"]
    elif secret["key"] == "YDC_API_KEY":
        os.environ["YDC_API_KEY"] = secret["value"]
    elif secret["key"] == "openai_key":
        openai_api_key = secret["value"]

SEARCH_ENGINE = "You" # "You" or "Brave"
LLM = "gpt-3.5-turbo-16k"
NUM_SEARCH_RESULTS = 10

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
def get_answer_agent(question, tools, callbacks=[]):
    """
    Answer a question with an agent equipped with retrieval tools.
    """
    agent = initialize_agent(
        tools,
        ChatOpenAI(
            temperature=0,
            model=LLM,
            openai_api_key=openai_api_key,
            callbacks=[mlflow_callback]
        ),
        agent=AgentType.OPENAI_FUNCTIONS,
        return_intermediate_steps=True,
        callbacks=[mlflow_callback],
    )
    agent.agent.prompt.messages[0] = SystemMessage(
        content=f"You are a helpful AI assistant. The current date and time are: {str(datetime.now())}"
    )
    result = agent(question)
    sources = ""
    if len(result["intermediate_steps"]) > 0:
        sources = "Sources:\n\n"
        last_action = result["intermediate_steps"][-1]
        tool = last_action[0].tool
        if tool == "Wikipedia":
            pages = last_action[1].split("Page: ")
            for page in pages:
                try:
                    title, content, *_ = page.split("\nSummary: ")
                    sources += f"Wikipedia: {title}" + "\n" + clean_html(content).strip()
                except ValueError:
                    continue
        elif tool in ["brave_search", "you_search", "you_search_news"]:
            for search_result in json.loads(last_action[1]):
                title = search_result["title"]
                link = search_result["link"]
                snippet = clean_html(search_result["snippet"])
                sources += f"[{title}]({link}): {snippet}" + "\n\n"
        else:
            sources = ""

    return (result["output"] + "\n\n" + sources).strip()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
tools = [WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())]
if SEARCH_ENGINE == "Brave":
    tools.append(
        BraveSearch.create(NUM_SEARCH_RESULTS)
    )
elif SEARCH_ENGINE == "You":
    tools += [
        YouSearch.create(NUM_SEARCH_RESULTS),
        YouSearchNews.create(NUM_SEARCH_RESULTS),
    ]

for i in df.index:
    mlflow_callback = MlflowCallbackHandler(
        experiment=EXPERIMENT_NAME,
        tags={"model": LLM}
    )
    df.at[i, "answer"] = get_answer_agent(
        df.loc[i, "question"],
        tools,
        callbacks=[mlflow_callback]
    )

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
dataiku.Dataset("answers_agent").write_with_schema(df)