import asyncio
import json
import threading
import uuid
from copy import deepcopy
from typing import Any, Dict, List, Optional

import dataiku
from dataikuapi.dss.llm import DSSLLMStreamedCompletionFooter

from backend.models.events import EventKind
from backend.utils.conv_utils import get_selected_agents_as_objs, normalise_stream_event
from backend.utils.llm_utils import add_history_to_completion
from backend.utils.logging_utils import extract_error_message, get_logger

logger = get_logger(__name__)


class AgentResult:
    def __init__(self, agent_id: str, agent_query: Dict[str, Any], agent_name: str):
        self.agent_id = agent_id
        self.agent_query = agent_query
        self.agent_name = agent_name
        self.answer_text = ""
        self.sources: List[Dict[str, Any]] = []
        self.success = True
        self.error_message: Optional[str] = None
        self.artifacts = []

    def to_generated_answer(self) -> Dict[str, Any]:
        return {
            "agent_id": self.agent_id,
            "agent_name": self.agent_name,
            "agent_query": self.agent_query,
            "agent_answer": self.answer_text
            if self.success
            else (self.error_message or "Error while processing your request"),
        }


class OrchestratorService:
    # ---------- Utility helpers ----------
    @staticmethod
    def create_agent_completion(agent_id, query, messages):
        client = dataiku.api_client()
        project_key = agent_id.split(":")[0]
        project = client.get_project(project_key)
        llm = project.get_llm("agent:" + agent_id.split(":")[2])
        completion = llm.new_completion()
        completion = add_history_to_completion(completion, messages)
        return completion

    # ---------- Single agent streaming ----------
    @staticmethod
    async def _stream_single_agent(
        *,
        query: Dict[str, Any],
        agent_id: str,
        agent_name: str,
        context: Dict[str, Any],
        messages: List[Dict[str, Any]],
        pcb,  # event callback: pcb(payload: Dict[str, Any]) -> None
        artifacts_meta: Dict[str, Any] = {},
        cancel_event: threading.Event,
        stream_timeout_s: Optional[float] = None,
    ) -> AgentResult:
        """
        Streams a single agent to completion, emitting structured events for the frontend.
        Handles both async and sync streaming implementations.
        """
        result = AgentResult(agent_id=agent_id, agent_query=query, agent_name=agent_name)

        # Tell frontend this agent started (so it can create a dedicated panel/box)
        pcb(
            {
                "eventKind": EventKind.AGENT_STARTED,
                "eventData": {"agentId": agent_id, "agentName": agent_name, "query": query},
            }
        )
        # Rebuild actual agent message history based on the query sent to it
        msgs = deepcopy(messages)
        msgs[-1]["content"] = query

        def handle_data_event(data: Dict[str, Any]):
            """Process a single chunk-shaped dict from the agent stream."""
            normalise_stream_event(
                ev={"chunk": data},
                tcb=lambda x: None,
                pcb=pcb,
                msgs=msgs,
                aid=agent_id,
                aname=agent_name,
                query=query,
                artifacts_meta=artifacts_meta,
            )
            # Text pieces
            text_piece = None
            if "text" in data and isinstance(data["text"], str):
                text_piece = data["text"]
            if text_piece:
                result.answer_text += text_piece
                # pcb(
                #     {
                #         "eventKind": EventKind.AGENT_CHUNK,
                #         "eventData": {"agentId": agent_id, "text": text_piece, "agentName": agent_name},
                #     }
                # )
            if "artifacts" in data:
                result.artifacts = data["artifacts"]

        def _stream_sync_iter(sync_iter):
            """
            Runs in a worker thread via asyncio.to_thread to avoid blocking the event loop.
            It's safe to call pcb here if pcb is thread-safe; otherwise, marshal back to loop.
            """
            for chunk in sync_iter:
                if cancel_event.is_set():
                    break
                data = getattr(chunk, "data", None) or chunk
                if isinstance(chunk, DSSLLMStreamedCompletionFooter):
                    normalise_stream_event(
                        ev={"footer": {"additionalInformation": chunk.data.get("additionalInformation", {})}},
                        tcb=lambda x: None,
                        pcb=pcb,
                        msgs=msgs,
                        aid=agent_id,
                        aname=agent_name,
                        query=query,
                    )
                    maybe_sources = data.get("additionalInformation", {}).get("sources", [])
                    if maybe_sources:
                        result.sources = maybe_sources
                else:
                    handle_data_event(data)

        try:
            completion = OrchestratorService.create_agent_completion(
                agent_id=agent_id,
                query=query,
                messages=messages,
            )
            completion.with_context(context)

            logger.info(
                "Agent Hub, direct agent call, id=[%s], name=[%s], cq=%s, settings=%s",
                agent_id,
                agent_name,
                completion.cq if hasattr(completion, "cq") else "unknown",
                completion.settings if hasattr(completion, "settings") else "unknown",
            )

            async def _do_stream():
                stream_obj = completion.execute_streamed()
                # Iterate in a worker thread so the event loop stays free
                await asyncio.to_thread(_stream_sync_iter, stream_obj)

            if stream_timeout_s:
                await asyncio.wait_for(_do_stream(), timeout=stream_timeout_s)
            else:
                await _do_stream()

            if cancel_event.is_set():
                pcb(
                    {
                        "eventKind": EventKind.AGENT_FINISHED,
                        "eventData": {
                            "agentId": agent_id,
                            "agentName": agent_name,
                            "answer": result.answer_text,
                            "sources": result.sources,
                            "artifacts": result.artifacts,
                            "status": "cancelled",
                        },
                    }
                )
            else:
                pcb(
                    {
                        "eventKind": EventKind.AGENT_FINISHED,
                        "eventData": {
                            "agentId": agent_id,
                            "agentName": agent_name,
                            "status": "ok",
                            "answer": result.answer_text,
                            "sources": result.sources,
                            "artifacts": result.artifacts,
                        },
                    }
                )

        except asyncio.TimeoutError:
            result.success = False
            result.error_message = "Timeout while streaming the agent response"
            pcb(
                {
                    "eventKind": EventKind.AGENT_FINISHED,
                    "eventData": {"agentId": agent_id, "agentName": agent_name, "status": "timeout"},
                }
            )

        except Exception as e:
            result.success = False
            result.error_message = "Error while processing your request"
            pcb(
                {
                    "eventKind": EventKind.AGENT_ERROR,
                    "eventData": {"agentId": agent_id, "agentName": agent_name, "message": extract_error_message(str(e))},
                }
            )
            pcb(
                {
                    "eventKind": EventKind.AGENT_FINISHED,
                    "eventData": {"agentId": agent_id, "agentName": agent_name, "status": "Failed to get an answer"},
                }
            )

        return result

    # ---------- Multi-agent orchestration ----------

    @staticmethod
    async def stream_multiple_agents_async(
        llm_id: str,
        agents_queries: Dict[str, Any],
        agents: list[dict],
        context: Dict[str, Any],
        messages: List[Dict[str, Any]],
        tcb,  # text callback for FINAL synthesis only (unchanged)
        pcb,  # event callback for UI
        artifacts_meta,  # type: Dict[str, Any]
        cancel_event: threading.Event,
        *,
        stream_timeout_s: Optional[float] = None,
    ) -> str:
        """
        Run multiple agents concurrently via asyncio, stream their outputs as isolated event streams,
        then synthesize a final answer (also streamed).
        """
        # pcb(
        #     {
        #         "eventKind": EventKind.CALLING_AGENT,
        #         "eventData": {"agents": [{"agentId": a.get("id"), "agentName": a.get("name")} for a in agents]},
        #     }
        # )
        tasks = [
            OrchestratorService._stream_single_agent(
                query=agents_queries.get(agent.get("id")) or agents_queries.get(agent.get("uaid")),
                agent_id=agent.get("id"),
                agent_name=agent.get("name"),
                messages=messages,
                context=context,
                pcb=pcb,
                artifacts_meta=artifacts_meta,
                cancel_event=cancel_event,
                stream_timeout_s=stream_timeout_s,
            )
            for agent in agents
            if agent.get("id") in agents_queries or agent.get("uaid", "None") in agents_queries
        ]

        results: List[AgentResult] = []
        for coro in asyncio.as_completed(tasks):
            if cancel_event.is_set():
                break
            res = await coro
            results.append(res)

        generated_answers = [r.to_generated_answer() for r in results]

        if cancel_event.is_set():
            return ""

        # --- Synthesis phase ---
        pcb({"eventKind": EventKind.SYNTHESIZING_STARTED, "eventData": {}})

        final = ""
        stream_id = str(uuid.uuid4())
        try:
            answer_stream = OrchestratorService.generate_final_answer(
                llm_id=llm_id,
                agents=agents,
                agents_queries_answers=generated_answers,
                messages=messages,
            )
            for resp in answer_stream:
                if cancel_event.is_set():
                    break
                normalise_stream_event(ev=resp, tcb=tcb, pcb=pcb, msgs=messages, aname="Agent Hub", stream_id=stream_id)
                if (
                    isinstance(resp, dict)
                    and "chunk" in resp
                    and isinstance(resp["chunk"], dict)
                    and "text" in resp["chunk"]
                ):
                    text_piece = resp["chunk"]["text"]
                    final += text_piece
                    # pcb({"eventKind": EventKind.SYNTHESIZING_CHUNK, "eventData": {"text": text_piece}})
        finally:
            # pcb(
            #     {
            #         "eventKind": EventKind.SYNTHESIZING_FINISHED,
            #         "eventData": {"status": "cancelled" if cancel_event.is_set() else "ok"},
            #     }
            # )

            return final

    # ---------- Sync wrapper (safe in any context) ----------

    @staticmethod
    def stream_multiple_agents(
        llm_id,
        sel_agents,
        messages,
        context,
        tcb,
        pcb,
        artifacts_meta,
        cancel_event: threading.Event,
        store=None,
        *,
        stream_timeout_s: Optional[float] = None,
    ) -> str:
        """
        Synchronous entrypoint that runs the asyncio pipeline.

        If already inside an event loop (e.g., FastAPI request handler), we spin up
        a fresh loop in a background thread to avoid 'event loop is running' errors.
        """
        sel_ids = [a["agentId"] for a in sel_agents if "agentId" in a]
        agents_obj = get_selected_agents_as_objs(store, sel_ids)
        agents_queries = {a["agentId"]: a.get("query", {}) for a in sel_agents if "agentId" in a}

        async def _runner():
            return await OrchestratorService.stream_multiple_agents_async(
                llm_id=llm_id,
                agents_queries=agents_queries,
                agents=agents_obj,
                messages=messages,
                context=context,
                tcb=tcb,
                pcb=pcb,
                artifacts_meta=artifacts_meta,
                cancel_event=cancel_event,
                stream_timeout_s=stream_timeout_s,
            )

        # If no loop is running in this thread, use asyncio.run
        try:
            asyncio.get_running_loop()
            loop_running_here = True
        except RuntimeError:
            loop_running_here = False

        if not loop_running_here:
            return asyncio.run(_runner())

        # A loop is already running in this thread -> start a new loop in a worker thread
        result_box = {"value": ""}

        def _thread_target():
            new_loop = asyncio.new_event_loop()
            try:
                asyncio.set_event_loop(new_loop)
                result_box["value"] = new_loop.run_until_complete(_runner())
            finally:
                new_loop.close()

        th = threading.Thread(target=_thread_target, daemon=True)
        th.start()
        th.join()
        return result_box["value"]

    # ---------- Final synthesis (unchanged, streams sync) ----------

    @staticmethod
    def generate_final_answer(llm_id, agents, agents_queries_answers, messages):
        client = dataiku.api_client()
        project = client.get_default_project()
        comp = project.get_llm(llm_id).new_completion()

        system_prompt = f"""
        # Role and Guidelines
        You are an assistant that synthesizes information from different agents to provide a final answer to the user query.
        Your role is to read the initial user query and any answers generated by different agents to answer parts of the query or all and provide a final answer.

        Your responsibilities:
        - Read the user query and read answers from agents (based on sub-queries).
        - Understand which agent provided which answer (via `agent_id`).
        - Each agent has its own scope and expertise. Use the answers provided by them as they are and rely on their answers to provide user with full answer. Do not make up your own.
        - Synthesize the information from the answers provided by the different agents to provide a final answer to the user query.
        - Do NOT alter the answers provided by external agents.
        - If you need to provide a different answer, make an additional one and clearly mention it's your own.

        Use the agent metadata to help you understand each agents's scope and its answers.
        Given the initial user query, any possible generated answers, and the context of the conversation, provide a final answer to the initial user query.
        Don't change the answer generated by the agents. If you need to provide a different answer, provide it in a new answer and mention that you are providing a different answer.
        Here is additional metadata that might help you about the available external agents:
        {
            json.dumps(
                list(
                    map(
                        lambda a: {
                            "id": a.get("id"),
                            "description": a.get("tool_agent_description"),
                        },
                        agents,
                    )
                ),
                indent=2,
            )
        }
        """

        gen_queries_answers_json = (
            [
                json.dumps(
                    {
                        "agent_query": qa.get("agent_query", "no query provided"),
                        "agent_answer": qa.get("agent_answer", "no answer provided"),
                        "agent_id": qa.get("agent_id", ""),
                    }
                )
                .replace("{", "{{")
                .replace("}", "}}")
                for qa in agents_queries_answers
            ]
            if agents_queries_answers
            else []
        )

        final_prompt = r"""
         {system_prompt}
         Generated sub queries and answers by agents based on user query:
         {agents_queries_answers}
         """.format(
            system_prompt=system_prompt,
            agents_queries_answers=gen_queries_answers_json,
        )

        comp.with_message(final_prompt, role="system")
        comp = add_history_to_completion(completion=comp, messages=messages)
        user_prompt = (
            "Based on the system instructions and user query, provide the final answer to the initial user query."
        )
        comp.with_message(user_prompt, role="user")
        logger.info(f"Final prompt for synthesizing agents answers: {final_prompt}")

        # Dataiku's execute_streamed() is a synchronous iterator
        for chunk in comp.execute_streamed():
            if isinstance(chunk, DSSLLMStreamedCompletionFooter):
                yield {"footer": {"additionalInformation": chunk.data.get("additionalInformation", {})}}
                trace_data = getattr(chunk, "trace", None)
                yield {"trace_ready": {"trace": trace_data if trace_data else {}}}
            else:
                yield {"chunk": chunk.data}
