import dataiku
import uuid
import json
from dataiku.llm.python import BaseLLM
from typing import Annotated, List, Literal, TypedDict, Union

# LangGraph Imports
from langchain_core.tools import tool
from langchain_core.messages import ToolMessage, BaseMessage, AIMessage, HumanMessage
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from langgraph.checkpoint.memory import InMemorySaver

# --- 1. Define Real Tools ---
@tool
def check_order_status(order_id: str):
    """Checks the status of a specific order ID."""
    return f"Order {order_id} is currently Out for Delivery. Expected arrival: Today by 5 PM."

@tool
def upgrade_subscription(tier: str):
    """Upgrades the user's plan to the specified tier."""
    return f"Success! The plan has been upgraded to the {tier} Tier."

# --- 2. Define Handoff Logic (Using Special Markers) ---
# Instead of a complex Command, we use simple "Transfer Tools" 
# that act as signals for our Router.

@tool
def transfer_to_shipping():
    """Transfer to the Shipping Agent."""
    return "TRANSFER_TRIGGERED"

@tool
def transfer_to_billing():
    """Transfer to the Billing Agent."""
    return "TRANSFER_TRIGGERED"

@tool
def transfer_to_closer():
    """Transfer to the Closer Agent."""
    return "TRANSFER_TRIGGERED"

# --- 3. The Main Class ---

class MyLLM(BaseLLM):
    def __init__(self):
        LLM_ID = dataiku.get_custom_variables().get("LLM_ID", "your_llm_id")
        client = dataiku.api_client()
        project = client.get_default_project()
        self.llm = project.get_llm(LLM_ID).as_langchain_chat_model()
        
        # Define Tool Sets per Agent
        self.triage_tools = [transfer_to_shipping, transfer_to_billing]
        self.shipping_tools = [check_order_status, transfer_to_billing, transfer_to_closer]
        self.billing_tools = [upgrade_subscription, transfer_to_closer]
        
        # Build the Graph
        self.app = self.build_graph()

    def build_graph(self):
        # A. Define the Agent Nodes (Manually)
        # This gives us full control over the flow
        
        def triage_node(state):
            # Bind tools specifically for Triage
            llm_with_tools = self.llm.bind_tools(self.triage_tools)
            prompt = (
                "You are the Triage Router. "
                "Route the user to Shipping (orders) or Billing (plans). "
                "If unsure or both, prioritize Shipping."
            )
            # Invoke
            messages = [HumanMessage(content=prompt)] + state["messages"]
            response = llm_with_tools.invoke(messages)
            return {"messages": [response]}

        def shipping_node(state):
            llm_with_tools = self.llm.bind_tools(self.shipping_tools)
            prompt = (
                "You are the Shipping Agent. Check order status if needed. "
                "If the user also wants to upgrade, call transfer_to_billing. "
                "Otherwise, call transfer_to_closer."
            )
            # Filter history to keep context but remind agent of persona
            # (In a real app, you might summarize history here)
            response = llm_with_tools.invoke(state["messages"])
            return {"messages": [response]}

        def billing_node(state):
            llm_with_tools = self.llm.bind_tools(self.billing_tools)
            prompt = "You are the Billing Agent. Upgrade plans. Then transfer to Closer."
            response = llm_with_tools.invoke(state["messages"])
            return {"messages": [response]}

        def closer_node(state):
            prompt = "You are the Closer. Summarize and say goodbye."
            messages = state["messages"] + [HumanMessage(content=prompt)]
            response = self.llm.invoke(messages)
            return {"messages": [response]}

        # B. Define the Router (The Brain)
        # This function looks at the last message and decides where to go
        def router(state) -> Literal["tools", "Shipping_Agent", "Billing_Agent", "Closer_Agent", "__end__"]:
            messages = state["messages"]
            last_message = messages[-1]
            
            # If no tools called, stop (or return to user)
            if not last_message.tool_calls:
                return END
            
            # Check which tool was called
            tool_name = last_message.tool_calls[0]["name"]
            
            if tool_name == "transfer_to_shipping":
                return "Shipping_Agent"
            elif tool_name == "transfer_to_billing":
                return "Billing_Agent"
            elif tool_name == "transfer_to_closer":
                return "Closer_Agent"
            
            # If it's a REAL tool (check_order, etc.), go to the ToolNode
            return "tools"

        # C. Build the Structure
        builder = StateGraph(state_schema=TypedDict("State", {"messages": Annotated[List[BaseMessage], add_messages]}))
        
        # Add Agent Nodes
        builder.add_node("Triage_Agent", triage_node)
        builder.add_node("Shipping_Agent", shipping_node)
        builder.add_node("Billing_Agent", billing_node)
        builder.add_node("Closer_Agent", closer_node)
        
        # Add Tool Executor Node (Handles the "Real" tools)
        # We combine all real tools here. Handoff tools don't need to run, 
        # but if they do, they just return "TRIGGERED" which is harmless.
        all_tools = self.triage_tools + self.shipping_tools + self.billing_tools
        builder.add_node("tools", ToolNode(all_tools))

        # D. Add Edges
        builder.add_edge(START, "Triage_Agent")
        
        # Connect Agents to the Router
        builder.add_conditional_edges("Triage_Agent", router)
        builder.add_conditional_edges("Shipping_Agent", router)
        builder.add_conditional_edges("Billing_Agent", router)
        
        # Connect ToolNode back to the Agent that called it?
        # This is the tricky part of a flattened graph: "Who called the tool?"
        # For simplicity, we can route back based on the *previous* sender, 
        # OR we can let the Router handle the "Post-Tool" logic.
        # Simple fix: We allow the ToolNode to execute, and then we rely on the 
        # LLM's next turn to pick up the conversation. 
        # However, we need to know WHICH agent to go back to.
        
        # To fix the "Return to Sender" issue, we use a simpler pattern:
        # The 'tools' node usually returns to the node that called it.
        # Since we have multiple agents, we can add a conditional edge OUT of 'tools'.
        
        def tool_output_router(state):
            # Look at the tool call ID to find who sent it? 
            # Or simpler: Look at the last AIMessage sender.
            # LangGraph messages often have a 'name' field, or we check history.
            # For this specific case, we can infer context or direct mapping.
            
            # Hardcoded Return Logic (Simplest for this demo):
            # If the last tool result was for 'check_order', go back to Shipping.
            # If 'upgrade_plan', go back to Billing.
            last_msg = state["messages"][-1] # ToolMessage
            
            # This is a bit hacky but works for specific domain tools
            if "Order" in str(last_msg.content): 
                return "Shipping_Agent"
            if "upgraded" in str(last_msg.content):
                return "Billing_Agent"
            
            # Default fallback
            return "Triage_Agent"

        builder.add_conditional_edges("tools", tool_output_router)

        checkpointer = InMemorySaver()
        return builder.compile(checkpointer=checkpointer)

    def process(self, query, settings, trace):
        prompt = query["messages"][-1]["content"]
        conv_id = str(uuid.uuid4())
        config = {"configurable": {"thread_id": conv_id}}

        response = self.app.invoke(
            {"messages": [{"role": "user", "content": prompt}]},
            config
        )
        return {"text": str(response["messages"][-1].content)}