# === Third-party Libraries ===
import dataiku

# === Dataiku LLM Integration ===
from dataiku.llm.python import BaseLLM

# === Autogen AgentChat Framework ===
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import HandoffTermination, TextMentionTermination
from autogen_agentchat.messages import HandoffMessage
from autogen_agentchat.teams import Swarm
from autogen_agentchat.ui import Console

# === Autogen Model Clients ===
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_core.models import ModelFamily

# === Custom Project Tools ===
from tools import check_order_status, upgrade_subscription

# === Agent Memory Integration ===
from dku_agent_memory import DatasetMemory


class MyLLM(BaseLLM):
    def __init__(self):
        # Retrieve LLM configuration from Dataiku project variables
        LLM_ID = dataiku.get_custom_variables()["LLM_ID"]
        
        # Initialize BASE_URL and API_KEY variables
        BASE_URL = "https://design.ds-platform.ondku.net/public/api/projects/MULTIAGENTSINDSS/llms/openai/v1/"

        auth_info = dataiku.api_client().get_auth_info(with_secrets=True)
        for secret in auth_info["secrets"]:
            if secret["key"] == "api_key":
                API_KEY = secret["value"]
                break
        
        # Instantiate the LLM
        self.model_client = OpenAIChatCompletionClient(
            base_url=BASE_URL,
            model=LLM_ID,
            api_key=API_KEY,
            model_info={
                "family": ModelFamily.GPT_4,
                "function_calling": True,
                "json_output": True,
                "multiple_system_messages": True,
                "structured_output": True,
                "vision": True
            },
            parallel_tool_calls=False # Allow to execute tools in parallel
        )
    
        # Define the agents
        shipping_agent = AssistantAgent(
            "shipping_agent",
            model_client=self.model_client,
            handoffs=["billing_agent", "user"],
            tools=[check_order_status],
            system_message="""You are the Shipping Agent. Use tools to check order status.
            Once you have answered the shipping question, REVIEW the user's original request.
            If the user also asked about subscriptions/billing, transfer to billing_agent.
            Answer general message otherwise.
            IMPORTANT: You must always reply to the user with text (e.g., "How can I help you?", "Let me check that") AND THEN you use the transfer_to_user tool.
            """,
        )

        billing_agent = AssistantAgent(
            "billing_agent",
            model_client=self.model_client,
            handoffs=["shipping_agent", "user"],
            tools=[upgrade_subscription],
            system_message="""You are the Billing Agent in charge if subscription and billing requests. 
            Help the user upgrade their plan.
            Our Gold Tier is currently 20% off. Propose it.
            If the user also asked about shipping related questions, transfer to shipping_agent.
            Answer general message otherwise. 
            IMPORTANT: You must always reply to the user with text AND THEN you use the transfer_to_user tool.
            """,
        )
        # Set a termination condition
        termination = HandoffTermination(target="user") | TextMentionTermination("TERMINATE")
  
        # Create the swarm
        self.team = Swarm([shipping_agent, billing_agent], termination_condition=termination)

    async def aprocess(self, query, settings, trace):
        conv_id = "CONVERSATION_ID"
        messages = query["messages"]
        prompt = messages[-1]["content"]
        
        # Define a memory dataset to store a state
        dataset_name = "memory_dataset"
        dataset_memory = DatasetMemory(dataset_name, conv_id)
        
        # Invoke the swarm and load/save the state
        if len(messages) == 1:
            task_result = await Console(self.team.run_stream(task=prompt))
            last_message = [i for i in task_result.messages if i.type == "TextMessage"][-1]
            dataset_memory.create_or_update_state({"last_agent": str(last_message.source)})
            print("ICI", str(task_result))
        
        else:
            state = dataset_memory.get_state()
            task_result = await Console(
                self.team.run_stream(task=HandoffMessage(source="user", target=state["last_agent"], content=prompt))
            )
            last_message = [i for i in task_result.messages if i.type == "TextMessage"][-1]
            dataset_memory.create_or_update_state({"last_agent": str(last_message.source)})    
        
        return {"text": last_message.content}
    