# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
import dataiku
from dataiku.langchain.dku_llm import DKUChatLLM
from typing import Callable
from typing import Literal

from langchain.tools import StructuredTool
from langgraph.graph import MessagesState, START
from langchain_core.messages import AIMessage, ToolMessage, HumanMessage
from langgraph.prebuilt import ToolNode, create_react_agent
from langgraph.graph import StateGraph
from langgraph.checkpoint.memory import MemorySaver

from fictitious_tools import (
    get_customer_id,
    get_details,
    reset_password,
    cancel_appointment,
    schedule_local_intervention,
    schedule_distant_intervention,
    run_diagnostics,
    sign_up_to_option,
    cancel_option
)

df = dataiku.Dataset("requests").get_dataframe()

LLM_ID = dataiku.get_custom_variables()["LLM_id"]
llm = DKUChatLLM(
    llm_id=LLM_ID,
    temperature=0,
)

USE_DEFAULT_REACT_AGENT = True

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
functions = [
    reset_password,
    cancel_appointment,
    schedule_local_intervention,
    schedule_distant_intervention,
    run_diagnostics,
    sign_up_to_option,
    cancel_option,
    get_customer_id
]

def get_partial_functions(customer_id):
    """
    Provide partial functions corresponding to a specific customer.
    These partial functions should be used when the customer is identified.
    """
    return [
        lambda: reset_password(customer_id),
        lambda: cancel_appointment(customer_id),
        lambda: schedule_local_intervention(customer_id),
        lambda: schedule_distant_intervention(customer_id),
        lambda: run_diagnostics(customer_id),
        lambda option: sign_up_to_option(customer_id, option),
        lambda option: cancel_option(customer_id, option),
    ]

def get_tools(customer_id=None):
    """
    Provide either the general tools or the customer-specific tools.
    """
    tools = []
    if customer_id is None:
        for f in functions:
            tools.append(StructuredTool.from_function(f))
    else:
        partial_functions = get_partial_functions(customer_id)
        for i in range(len(partial_functions)):
            tools.append(StructuredTool.from_function(
                partial_functions[i],
                name=functions[i].__name__,
                description=functions[i].__doc__
            ))
    return tools

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
if USE_DEFAULT_REACT_AGENT:

    def build_graph(llm, tools):
        """
        Build the LangGraph graph (case of the default React agent).
        """
        return create_react_agent(llm, tools=tools)

else:

    def process_assistant_reply(state) -> Literal["answer_human", "action"]:
        """
        Go to the tool node if a tool has been called or to the answer node if not.
        """
        last_message = state["messages"][-1]
        if not last_message.tool_calls:
            return "answer_human"
        return "action"

    def call_model(llm):
        """
        Return a function that calls an LLM to generate an answer.
        """
        return lambda state: {"messages": [llm.invoke(state["messages"])]}

    def neutral_action(state):
        """
        Do nothing (no automated action is taken when an answer is sent to the user).
        """
        return {"messages": []}

    def build_graph(llm, tools):
        """
        Build the LangGraph graph (case of a general graph).
        """
        workflow = StateGraph(MessagesState)
        workflow.add_node("agent", call_model(llm))
        workflow.add_node("action", ToolNode(tools))
        workflow.add_node("answer_human", neutral_action)
        stop_nodes = ["answer_human"]

        workflow.add_edge(START, "agent")
        workflow.add_conditional_edges("agent", process_assistant_reply)
        workflow.add_edge("action", "agent")

        memory = MemorySaver()
        graph = workflow.compile(checkpointer=memory, interrupt_before=stop_nodes)
        return graph

def process_request(request, customer_id=None):
    """
    Handle the request of the customer, taking into account his/her customer id if provided.
    """
    tools = get_tools(customer_id=customer_id)
    llm = llm.bind_tools(tools, parallel_tool_calls=False)
    graph = build_graph(llm, tools)
    identity = get_details(customer_id).split(",")[0]
    request = f"{request} (message received from {identity})"
    result = graph.invoke({"messages": [request]})

    tool_calls = {}
    for m in result["messages"][1:]:
        if type(m) == AIMessage:
            if "tool_calls" in m.additional_kwargs:
                for tool_call in m.additional_kwargs["tool_calls"]:
                    call_id = tool_call["id"]
                    tool_calls[call_id] = {
                        "name": tool_call["function"]["name"],
                        "arguments": tool_call["function"]["arguments"],
                    }
            else:
                reply = m.content
        if type(m) == ToolMessage:
            tool_calls[m.tool_call_id]["result"] = m.content

    intermediate_steps, actions, i = "", "", 0
    for k in tool_calls:
        tool_call = tool_calls[k]
        i += 1
        action = f"{i}. {tool_call['name']}({tool_call['arguments']})"
        actions += action + "\n"
        intermediate_steps += f"{action} --> {tool_call['result']}" + "\n"
    if len(actions) > 0:
        actions = actions[:-1]
        intermediate_steps = intermediate_steps[:-1]

    return reply, actions, intermediate_steps

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
for i in range(len(df)):
    reply, actions, intermediate_steps = process_request(
        df.iloc[i].request,
        customer_id=int(df.iloc[i].customer_id)
    )
    df.at[i, "draft_reply"] = reply
    df.at[i, "actions"] = actions
    df.at[i, "intermediate_steps"] = intermediate_steps

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
dataiku.Dataset("requests_processed_langgraph").write_with_schema(df)