
import json

from langgraph.graph import StateGraph, END, MessagesState
from langgraph.prebuilt import ToolNode
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage, SystemMessage

import dataiku
from dataiku.llm.python import BaseLLM
from dataiku.langchain import LangchainToDKUTracer

OPENAI_CONNECTION_NAME = "REPLACE_BY_YOUR_OPENAI_CONNECTION_NAME"
ROUTER_LLM_ID = f"openai:{OPENAI_CONNECTION_NAME}:gpt-5-chat-latest"
TOOLS_LLM_ID = f"openai:{OPENAI_CONNECTION_NAME}:gpt-5-mini"
EXPERT_LLM_ID = f"openai:{OPENAI_CONNECTION_NAME}:gpt-5"


@tool
def search_car_database(query: str):
    """
    Searches a database for information about car sales, prices, or models.
    Use this tool for any specific questions about car data.
    """
    print(f"--- Tool Used: search_car_database(query='{query}') ---")
    query = query.lower()
    if "dealer" in query and "chicago" in query:
        return json.dumps(
            {"dealer": "Windy City Auto", "total_sales": 150, "top_model": "Sedan"}
        )
    else:
        return json.dumps(
            {"error": "No specific information found for that query."}
        )


tools = [search_car_database]
tool_node = ToolNode(tools)

router_llm = dataiku.api_client().get_default_project().get_llm(ROUTER_LLM_ID).as_langchain_chat_model()
expert_llm = dataiku.api_client().get_default_project().get_llm(EXPERT_LLM_ID).as_langchain_chat_model()
tools_llm = dataiku.api_client().get_default_project().get_llm(TOOLS_LLM_ID).as_langchain_chat_model()
llm_with_tools = tools_llm.bind_tools(tools)


def call_router_llm(state):
    """
    First node to classify the user's intent.
    Does the query require a tool, or is it a complex question for an expert?
    Returns the name of the next node to call.
    """
    print("--- Classifying query ---")
    messages = state["messages"]
    last_human_message = ""
    for msg in reversed(messages):
        if type(msg) == HumanMessage:
            last_human_message = msg.content or ""
            break

    # Create a router prompt
    router_prompt = [
        SystemMessage(
            content="You are an expert router. Your job is to classify the user's query. "
            f"The user's query is: '{last_human_message}'\n\n"
            "Based on this query, decide if it can be answered with a simple tool "
            f"that can search a car database (tool schema: {search_car_database.args_schema.schema()}), "
            "or if it requires a complex, expert-level response (e.g., 'what is the future of electric vehicles?').\n\n"
            "Respond with only the single word 'TOOL' or 'EXPERT'."
        ),
    ]
    
    response = router_llm.invoke(router_prompt)
    return {"messages": [response]}


def route_call(state):
    """
    Router function. Checks whether to use an agent with tools or an expert LLM
    """
    print("--- Checking where to route the query ---")
    decision = state["messages"][-1].content.strip().upper()
    print(f"--- Router decision: {decision} ---")
    if "TOOL" in decision:
        return "call_llm_with_tools"
    else:
        return "call_expert_llm"
    
    
def call_expert_llm(state):
    """
    Calls the LLM with an "expert" persona for complex questions.
    """
    print("--- Calling Expert LLM ---")
    messages = state["messages"]
    expert_messages = [
        {"role": "system", "content": "You are a world-class automotive industry expert. "
         "Provide a detailed, insightful, and comprehensive answer to the user's query."}
    ] + messages[:-1]
    
    response = expert_llm.invoke(expert_messages)
    return {"messages": [response]}


def call_llm_with_tools(state):
    """
    Calls the LLM with the current message history.
    The LLM's response will either be a final answer or a tool call.
    """
    print("--- Calling LLM ---")
    messages = state["messages"]
    response = llm_with_tools.invoke(messages[:-1])
    return {"messages": [response]}


def should_call_tool(state):
    """
    Router function. Checks the last message to see if it contains a tool call.
    """
    print("--- Checking for tool call ---")
    last_message = state["messages"][-1]
    if last_message.tool_calls:
        print("--- Decision: Call tool ---")
        return "call_tool"
    else:
        print("--- Decision: End session ---")
        return END


class ToolAgentLLM(BaseLLM):
    def __init__(self):
        graph = StateGraph(MessagesState)
        graph.add_node("router_llm", call_router_llm)
        graph.add_conditional_edges("router_llm", route_call)
        graph.set_entry_point("router_llm")
        graph.add_node("call_llm_with_tools", call_llm_with_tools)
        graph.add_conditional_edges("call_llm_with_tools", should_call_tool)
        graph.add_node("call_tool", tool_node)
        graph.add_node("call_expert_llm", call_expert_llm)
        graph.add_edge("call_expert_llm", END)
        self.graph = graph.compile()

    def process(self, query, settings, trace):
        tracer = LangchainToDKUTracer(dku_trace=trace)
        initial_state = {"messages": query["messages"]}
        result = self.graph.invoke(
            initial_state, 
            config={"callbacks": [tracer]}
        )
        final_message = result["messages"][-1]
        return {"text": final_message.content}
    