import dataiku
from dataiku.llm.python import BaseLLM
from dataikuapi.dss.llm import DSSLLMStreamedCompletionChunk, DSSLLMStreamedCompletionFooter
from typing import TypedDict, Annotated
from langchain_core.tools import StructuredTool
from langgraph.graph.message import add_messages
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage
from langgraph.prebuilt import ToolNode
from langgraph.graph import START, StateGraph
from langgraph.prebuilt import tools_condition
from dataiku.langchain.dku_llm import DKULLM, DKUChatLLM
from langchain.tools import Tool

from langgraph.checkpoint.memory import InMemorySaver
from langgraph.prebuilt import create_react_agent
from langgraph_swarm import create_handoff_tool, create_swarm


def add(a: int, b: int) -> int:
    """Add two numbers"""
    return a + b


class MyLLM(BaseLLM):
    def __init__(self):
        # Instantiate the LLM
        LLM_ID = dataiku.get_custom_variables()["LLM_ID"]
        
        client = dataiku.api_client()
        project = client.get_default_project()
        llm = project.get_llm(LLM_ID).as_langchain_chat_model()
        #llm_with_tools = llm.bind_tools(tools, tool_choice=tool_choice)
        
        alice = create_react_agent(
            llm,
            [add, create_handoff_tool(agent_name="Shipping", description="Transfer to the shipping specialist for order tracking and delivery issues.")],
            prompt="You are Alice, an addition expert.",
            name="Alice",
        )

        bob = create_react_agent(
            llm,
            [create_handoff_tool(agent_name="Alice", description="Transfer to Alice, she can help with math")],
            prompt="You are the Shipping Agent. Use tools to check order status.",
            name="Shipping",
        )
        
        checkpointer = InMemorySaver()
        workflow = create_swarm(
            [alice, bob],
            default_active_agent="Alice"
        )
        self.app = workflow.compile(checkpointer=checkpointer)

   
    def process(self, query, settings, trace):
        prompt = query["messages"][-1]["content"]
        print("PROMPT", prompt)
        print("QUERY", query)
        conv_id = 1#query["context"]["dku_conversation_id"]
                
        config = {"configurable": {"thread_id": conv_id}}

        response = self.app.invoke(
            {"messages": [{"role": "user", "content": prompt}]},
            config,
        )
        return {"text": str(response["messages"][-1].content)}
    