import dataiku
from dataiku.langchain.dku_llm import DKUChatLLM
from dataiku.llm.python import BaseLLM
import json
from typing import Literal

from langchain.tools import StructuredTool, tool
from langchain.tools import WikipediaQueryRun
from langchain_community.utilities import WikipediaAPIWrapper
from langchain_core.messages import AIMessage, ToolMessage, HumanMessage
from langgraph.graph import MessagesState, START
from langgraph.prebuilt import ToolNode, ToolInvocation
from langgraph.graph import StateGraph
from langgraph.checkpoint.memory import MemorySaver
from langgraph.checkpoint.serde.jsonplus import JsonPlusSerializer

jsonplus_serde = JsonPlusSerializer()

LLM_ID = dataiku.get_custom_variables()["LLM_id"]
llms = dataiku.api_client().get_default_project().list_llms()

@tool
def add_integers(x: int, y: int):
    """Add two integer numbers."""
    return x + y

tools = [add_integers, WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())]
sensitive_tools = ["add_integers"]

llm = DKUChatLLM(
    llm_id=LLM_ID,
    temperature=0
)
llm = llm.bind_tools(tools, parallel_tool_calls=False)

# Define nodes and conditional edges

def take_action(state) -> Literal["answer_human", "check_sensitivity"]:
    """
    Analyze the LLM response and provide an answer unless a tool has been called.
    """
    last_message = state["messages"][-1]
    if not last_message.tool_calls:
        return "answer_human"
    return "check_sensitivity"

def check_tool(state) -> Literal["accept_sensitive_action", "action"]:
    """
    Execute the action if it is not sensitive or sollicit the user.
    """
    if state["messages"][-1].tool_calls[0]["name"] in sensitive_tools:
        return "accept_sensitive_action"
    else:
        return "action"

def accept_sensitive_action(state) -> Literal["action", "agent"]:
    """
    Execute the sensitive action if approved. Otherwise, inform the agent of the refusal.
    """
    if type(state["messages"][-1]) == ToolMessage:
        return "agent"
    else:
        return "action"
    
def call_model(state):
    """
    Query the LLM.
    """
    response = llm.invoke(state["messages"])
    return {"messages": [response]}

def neutral_action(state):
    """
    Do nothing (placeholder function for nodes that do not automatically change the state).
    """
    return {"messages": []}

# Build the graph

def build_graph(state, next_node):
    """
    Build the graph defining the behavior of the agent.
    """
    workflow = StateGraph(MessagesState)

    workflow.add_node("agent", call_model)
    workflow.add_node("check_sensitivity", neutral_action)
    workflow.add_node("action", ToolNode(tools))
    workflow.add_node("accept_sensitive_action", neutral_action)
    workflow.add_node("answer_human", neutral_action)
    
    stop_nodes = ["answer_human", "accept_sensitive_action"]
    workflow.add_conditional_edges("agent", take_action)
    workflow.add_conditional_edges("check_sensitivity", check_tool)
    workflow.add_conditional_edges("accept_sensitive_action", accept_sensitive_action) 
    workflow.add_edge("action", "agent")
    workflow.add_edge("answer_human", "agent")
    if next_node == "accepted":
        workflow.add_node("accepted", neutral_action)
        workflow.add_edge("accepted", "action")
        workflow.add_edge(START, "accepted")
    else:
        workflow.add_edge(START, "agent")
    
    memory = MemorySaver()
    graph = workflow.compile(checkpointer=memory, interrupt_before=stop_nodes)
    config = {"configurable": {"thread_id": "1"}}
    if state is not None:
        graph.update_state(config, state, next_node)
    return graph

def update_discussion(state, next_node, new_human_message):
    """
    Update the state of the agent based on a new user message.
    """
    if next_node == "accept_sensitive_action" and new_human_message != "No":
        next_node = "accepted"
    graph = build_graph(state, next_node)
    config = {"configurable": {"thread_id": "1"}}
    state = graph.get_state(config)
    if len(state.next) > 0 and next_node is not None:
        messages = state.values["messages"]
        if next_node == "accept_sensitive_action":
            tool_call_id = state.values["messages"][-1].tool_calls[0]["id"]
            tool_message = ToolMessage(
                content="Tool execution refused by the user",
                tool_call_id=tool_call_id
            )
            messages.append(tool_message)
        elif next_node != "accepted":
            messages.append(new_human_message)
    else:
        messages = new_human_message
    result = graph.invoke({"messages": messages}, config)
    return (
        jsonplus_serde.dumps(graph.get_state(config)[:2]),
        [(m.construct().__repr__()[:-2], m.dict()) for m in result["messages"]]
    )

class MyLLM(BaseLLM):
    def __init__(self):
        pass

    def process(self, query, settings, trace):
        inputs = json.loads(query["messages"][0]["content"])
        query = inputs["query"]
        if "state" in inputs:
            state, next_nodes = jsonplus_serde.loads(inputs["state"])
            next_node = next_nodes[0]
        else:
            state, next_node = None, None
        state, messages = update_discussion(state, next_node, query)
        return {
            "text": json.dumps(
                {
                    "state": state.decode('utf-8'),
                    "messages": messages
                }
            )
        }

        resp_text = "%s\n%s" % (resp_text, llm_resp.text)

        return {"text": resp_text}
