import dataiku
from dataiku.llm.python import BaseLLM
from dataikuapi.dss.llm import DSSLLMStreamedCompletionChunk, DSSLLMStreamedCompletionFooter
import os
from typing import TypedDict, Literal
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, END


def router_node(state: AgentState):
    """
    Classifies the user query to determine which sub-agent to use.
    """
    print("--- ROUTER: Classifying Query ---")
    query = state["query"]
    
    prompt = f"""
    Analyze the following user query and categorize it as either 'technical' or 'billing'.
    Return ONLY the word 'technical' or 'billing'.
    
    Query: {query}
    """
    response = llm.invoke([HumanMessage(content=prompt)])
    category = response.content.strip().lower()
    
    # Fallback logic if LLM hallucinates
    if "billing" in category:
        return {"category": "billing"}
    else:
        return {"category": "technical"}

def tech_agent_node(state: AgentState):
    """
    Sub-agent specialized in Technical Support.
    """
    print("--- SUB-AGENT: Tech Support ---")
    query = state["query"]
    
    prompt = f"""
    You are a technical support expert. Provide a troubleshooting solution for:
    {query}
    """
    response = llm.invoke([HumanMessage(content=prompt)])
    return {"draft_answer": response.content}

def billing_agent_node(state: AgentState):
    """
    Sub-agent specialized in Billing and Refunds.
    """
    print("--- SUB-AGENT: Billing Support ---")
    query = state["query"]
    
    prompt = f"""
    You are a billing specialist. Address this payment/invoice inquiry:
    {query}
    """
    response = llm.invoke([HumanMessage(content=prompt)])
    return {"draft_answer": response.content}

def final_qa_agent_node(state: AgentState):
    """
    Common Final Agent: Polishes the answer from either sub-agent.
    """
    print("--- FINAL AGENT: QA & Formatting ---")
    draft = state["draft_answer"]
    
    prompt = f"""
    You are a Quality Assurance manager. 
    Review the following draft answer. Rewrite it to be extremely polite, 
    professional, and ensure it ends with "Thank you for choosing OmniCorp."
    
    Draft Answer: {draft}
    """
    response = llm.invoke([HumanMessage(content=prompt)])
    return {"final_response": response.content}

# --- 4. Define the Routing Logic ---

def decide_next_node(state: AgentState) -> Literal["tech_agent", "billing_agent"]:
    """
    Conditional logic that reads the state category and points to the next node name.
    """
    category = state["category"]
    if category == "billing":
        return "billing_agent"
    else:
        return "tech_agent"

class MyLLM(BaseLLM):
    def __init__(self):
                LLM_ID = dataiku.get_custom_variables().get("LLM_ID", "your_llm_id")
        client = dataiku.api_client()
        project = client.get_default_project()
        llm = project.get_llm(LLM_ID).as_langchain_chat_model()
        
        workflow = StateGraph(AgentState)

        # Add Nodes
        workflow.add_node("router", router_node)
        workflow.add_node("tech_agent", tech_agent_node)
        workflow.add_node("billing_agent", billing_agent_node)
        workflow.add_node("final_agent", final_qa_agent_node)

        # Set Entry Point
        workflow.set_entry_point("router")

        # Add Conditional Edges (The Split)
        workflow.add_conditional_edges(
            "router",             # From this node
            decide_next_node,     # Run this logic
            {                     # Map logic output to actual nodes
                "tech_agent": "tech_agent",
                "billing_agent": "billing_agent"
            }
        )

        # Add Normal Edges (The Merge)
        # Both sub-agents feed into the same final agent
        workflow.add_edge("tech_agent", "final_agent")
        workflow.add_edge("billing_agent", "final_agent")

        # End the graph
        workflow.add_edge("final_agent", END)

        # Compile
        self.app = workflow.compile()

    def process(self, query, settings, trace):
        prompt = query["messages"][-1]["content"]
                
        response = self.app.invoke(
            {"messages": [{"role": "user", "content": prompt}]}
        )
        return {"text": str(response["messages"][-1].content)}
    