##############################################################################
# UserAgent – local implementation of a black-box agent callable like a DSS
#             agent but executed locally.
##############################################################################
from __future__ import annotations

import datetime
import json
from typing import Optional

from backend.config import get_default_embedding_llm, get_default_llm_id
from backend.constants import DRAFT_ZONE, PUBLISHED_ZONE
from backend.models.events import EventKind
from backend.utils.logging_utils import extract_error_message, get_logger
from backend.utils.project_utils import get_augmented_llm_in_zone, get_ua_project, get_visual_agent_in_zone
from dataiku.llm.python import BaseLLM
from dataikuapi.dss.llm import (
    DSSLLMStreamedCompletionFooter,
)

logger = get_logger(__name__)


class UserAgent(BaseLLM):
    """
    Wraps a *user* agent (created in the UI) and offers the same process /
    aprocess_stream interface as a DSS LLM agent.

    • If the agent has tools or doc -> use visual agent with tools (inc kb tool)

    • If the agent has *no* documents and *no* tools
                                           → behaves as a plain base LLM.

    Zone-aware: Uses draft zone for owner, published zone for shared users.
    """

    # ------------------------------------------------------------------ #
    # Construction
    # ------------------------------------------------------------------ #
    def __init__(self, ua_record: dict, use_published: bool = False):
        """
        Initialize a UserAgent.

        Args:
            ua_record: Agent dictionary from database
            use_published: If True, use published zone (for shared users)
        """
        self.ua = dict(ua_record)
        self.use_published = use_published
        self.base_llm_id: str = self.ua.get("llmid") or get_default_llm_id()
        self.embedding_llm_id: str = get_default_embedding_llm()
        self._augmented_llm_id: Optional[str] = None
        self._vis_agent = None
        self.name = self.ua["name"] if not use_published else self.ua["published_version"].get("name")
        logger.info(
            "Initialized UserAgent '%s' (base LLM: %s, embedding LLM: %s, zone: %s)",
            self.name,
            self.base_llm_id,
            self.embedding_llm_id,
            PUBLISHED_ZONE if use_published else DRAFT_ZONE,
        )

    # ------------------------------------------------------------------ #
    # Helper – get the appropriate augmented LLM id
    # ------------------------------------------------------------------ #
    def _get_augmented_llm_id(self) -> str:
        """Get the augmented LLM ID for the appropriate zone."""
        if self._augmented_llm_id is None:
            project = get_ua_project(self.ua)
            zone = PUBLISHED_ZONE if self.use_published else DRAFT_ZONE

            llm_id = get_augmented_llm_in_zone(project, zone)
            if not llm_id:
                raise Exception(f"Could not find augmented LLM in zone {zone}")

            self._augmented_llm_id = llm_id
            logger.debug(f"_get_augmented_llm_id: {self._augmented_llm_id}")

        return self._augmented_llm_id

    # ------------------------------------------------------------------ #
    # Helper – build the inner tool graph (lazy)
    # ------------------------------------------------------------------ #
    def _ensure_vis_agent(self):
        """Return the visual agent representation of the user agent"""
        if self._vis_agent is not None:
            return

        # Use published version if applicable
        zone = DRAFT_ZONE
        if self.use_published and self.ua.get("published_version"):
            logger.debug("Using published version of user agent")
            zone = PUBLISHED_ZONE
        else:
            logger.debug("Using draft version of user agent")

        project = get_ua_project(self.ua)
        self._vis_agent = get_visual_agent_in_zone(project=project, zone_name=zone)

    def prepare_completion(self, query):
        self._ensure_vis_agent()
        if not self._vis_agent:
            raise Exception("No visual agent found for the quick agent. Cannot process query")
        
        user_messages = query.get("messages", [])
        today = datetime.date.today().strftime("%Y-%m-%d")
        messages = [{"role": "system", "content": f"Today is {today}"}] + user_messages

        # --- Visual agent path ------------------------------
        comp = self._vis_agent.as_llm().new_completion()
        for m in messages:
            comp = comp.with_message(m["content"], role=m["role"])
        comp.with_context(query.get("context", {}))
        return comp

    # ------------------------------------------------------------------ #
    #  synchronous helper (used by StructuredTool wrapper & Conversation)
    # ------------------------------------------------------------------ #
    def process(self, query, settings=None, trace=None):
        """
        Synchronous helper that mirrors what AgentConnect.process() returns:
            { "text": "...", "additionalInformation": {"sources":[…]} }
        """
        generic_error = "Error during agent execution"
        try:
            comp = self.prepare_completion(query)
            result = comp.execute()
            sources = {}
            if "additionalInformation" in result._raw:
                sources = result._raw.get("additionalInformation").get("sources", {})
            else:
                sources = result._raw.get("sources", {})
            return (
                result.text,
                {
                    "output": result.text,
                    "sources": sources,
                },
            )
        except Exception as e:
            # Other execution errors - return error message as text
            error_message = f"{generic_error}: {extract_error_message(str(e))}"
            logger.exception(f"Exception in process() for user agent {self.ua.get('id', 'unknown')}: {error_message}")
            return (
                generic_error,
                {
                    "output": error_message,
                    "sources": {},
                },
            )

    # ------------------------------------------------------------------ #
    #  async streaming interface
    # ------------------------------------------------------------------ #
    async def aprocess_stream(self, query, settings, trace):
        """
        Streaming variant emitting the same event schema AgentConnect expects.
        """
        comp = self.prepare_completion(query)
        # Activate the output trajectory for tracing
        comp.settings["outputTrajectory"] = True

        logger.info(
            "Calling user agent\nproject_key=[%s]\nvisual agent id=[%s]\nname=[%s]\ncompletion_query=%s\nsettings=%s",
            self.ua.get("id", "unknown"),
            self._vis_agent.id,
            self.name,
            json.dumps(comp.cq, indent=2, sort_keys=True),
            json.dumps(comp.settings, indent=2, sort_keys=True),
        )

        # ------------------ Visual agent streaming --------------------
        generic_error = "Error during agent execution"
        try:
            for chunk in comp.execute_streamed():
                if isinstance(chunk, DSSLLMStreamedCompletionFooter):
                    # Extract the trace data from the footer
                    yield {"footer": {"additionalInformation": chunk.data.get("additionalInformation", {})}}
                    # Todo : Add a function call to persist the output trajectory if needed
                    output_trajectory = chunk.data.get("additionalInformation", {}).get('trajectory', None)
                    if output_trajectory:
                        logger.debug(
                            f"Agent {self._vis_agent.id} execution trajectory: {output_trajectory}")
                    trace_data = chunk.trace
                    if trace_data:
                        trace.append_trace(trace_data)
                else:
                    yield {"chunk": chunk.data}
        except Exception as e:
            # Catch any exceptions during streaming and emit error message as text and error event
            error_message = f"{generic_error}: {extract_error_message(str(e))}"
            logger.exception(f"Exception in streaming for user agent {self.ua.get('id', 'unknown')}: {error_message}")
            # Yield error message as text so it appears in the assistant response
            yield {"chunk": {"text": generic_error}}
            # Yield error event for the event log
            yield {
                "chunk": {
                    "type": "event",
                    "eventKind": EventKind.AGENT_ERROR,
                    "eventData": {
                        "message": error_message,
                        "agentId": self.ua.get("id", "unknown"),
                        "agentName": self.name
                    }
                }
            }
            yield {"footer": {"additionalInformation": {}}}
