import uuid

import dataiku
from backend.config import get_charts_generation_llm_id, get_config, get_enterprise_agents
from backend.constants import PUBLISHED_ZONE
from backend.models.artifacts import ArtifactsMetadata
from backend.models.events import EventKind
from backend.utils.events_utils import extract_artifacts_preview, get_artifacts_size_mb, has_records
from backend.utils.logging_utils import get_logger
from backend.utils.project_utils import get_visual_agent_in_zone
from backend.utils.sources_utils import format_sources_for_ui, get_records_from_artifacts

logger = get_logger(__name__)

# Track text generation state per streaming session
_stream_text_generating = set()

# Track reasoning chunks per streaming session
_stream_reasoning_chunks = {}


def _extract_reasoning_text(artifacts):
    """Extract text from REASONING type artifacts."""
    if not isinstance(artifacts, list):
        return ""
    reasoning_text = []
    for artifact in artifacts:
        if artifact.get("type") == "REASONING":
            for part in artifact.get("parts", []):
                if part.get("type") == "TEXT" and "text" in part:
                    reasoning_text.append(part["text"])
    return "".join(reasoning_text)


def normalise_stream_event(
    ev, tcb, pcb, msgs, aid=None, trace=None, aname=None, query=None, artifacts_meta={}, stream_id=None
):
    if "chunk" in ev:
        chunk = ev["chunk"]
        if "text" in chunk:
            # emit answer_stream_start when text generation starts (first chunk or after interruption)
            if stream_id and stream_id not in _stream_text_generating:
                if aname:
                    pcb({"eventKind": EventKind.ANSWER_STREAM_START, "eventData": {"agentName": aname}})
                else:
                    pcb({"eventKind": EventKind.ANSWER_STREAM_START, "eventData": {"agentName": "Agent Hub"}})
                _stream_text_generating.add(stream_id)
            tcb(chunk["text"])
        elif chunk.get("type") == "event":
            # Reset text generation state when we encounter non-text events (like tool calls)
            if stream_id and stream_id in _stream_text_generating:
                _stream_text_generating.remove(stream_id)
                # Log final reasoning summary and cleanup
                if stream_id in _stream_reasoning_chunks:
                    accumulated_reasoning = "".join(_stream_reasoning_chunks[stream_id])
                    if accumulated_reasoning:
                        logger.debug(
                            f"Stream ended (stream_id: {stream_id[:8]}...). "
                            f"Total reasoning text accumulated: {len(accumulated_reasoning)} characters. "
                            f"Full accumulated reasoning: {accumulated_reasoning}"
                        )
                    del _stream_reasoning_chunks[stream_id]
            data = {}
            if chunk.get("eventData"):
                data = chunk["eventData"]
            data = chunk.get("eventData") or {}
            if aname and "agentName" not in data:
                data["agentName"] = aname
            pcb({"eventKind": chunk["eventKind"], "eventData": data})
        elif "artifacts" in chunk:
            artifacts = chunk["artifacts"]

            # Detect reasoning artifacts
            has_reasoning = (
                isinstance(artifacts, list)
                and any(artifact.get("type") == "REASONING" for artifact in artifacts)
            )

            if has_reasoning:
                reasoning_text = _extract_reasoning_text(artifacts)
                if reasoning_text:
                    # Initialize accumulator for this stream if it doesn't exist
                    if stream_id:
                        if stream_id not in _stream_reasoning_chunks:
                            _stream_reasoning_chunks[stream_id] = []
                        _stream_reasoning_chunks[stream_id].append(reasoning_text)
                        accumulated_reasoning = "".join(_stream_reasoning_chunks[stream_id])
                        logger.debug(
                            f"Reasoning chunk appended (stream_id: {stream_id[:8]}...). "
                            f"Current chunk length: {len(reasoning_text)}, "
                            f"Accumulated length: {len(accumulated_reasoning)} characters. "
                            f"Accumulated reasoning: {accumulated_reasoning}"
                        )
                    else:
                        # Fallback if stream_id is not provided
                        logger.debug(
                            f"Reasoning chunk detected but no stream_id provided. "
                            f"Chunk length: {len(reasoning_text)} characters. "
                            f"Reasoning text: {reasoning_text}"
                        )

            # Persist artifacts
            artifacts_id = str(uuid.uuid4())
            size = get_artifacts_size_mb(artifacts)
            max_size = get_config().get("max_artifacts_size_mb", 2)
            with_records = has_records(artifacts)
            # max_size = 0.0001  # For testing purpose only, set to 0.0001MB = 0.1KB
            meta = ArtifactsMetadata(
                size_mb=size,
                artifacts_id=artifacts_id,
                agentName=aname or "Used tool",
                agentId=aid or "",
                query=query or "",
                has_records=with_records,
                artifacts=artifacts,
                preview=True if size > max_size else False,
            )
            artifacts_meta[artifacts_id] = meta
            # Don't send the Downloads event for REASONING artifacts
            if not has_reasoning:
                pcb(
                    {
                        "eventKind": "artifacts",
                        "eventData": {
                            **meta,
                            "artifacts": extract_artifacts_preview(artifacts) if size > max_size else artifacts,
                        },
                    },
                    store_event=False,
                )
            if size > max_size:
                logger.info(
                    f"Artifacts size {size} MB exceeds the maximum allowed size, skipping visualization generation."
                )
            elif with_records:
                logger.info(f"Artifacts size: {size} MB")
                handle_artifacts(
                    artifacts=artifacts,
                    artifacts_id=artifacts_id,
                    msgs=msgs,
                    pcb=pcb,
                    trace=trace,
                    aname=aname,
                    query=query,
                )
    elif "footer" in ev:
        sources = ev["footer"].get("additionalInformation", {}).get("sources", [])
        if sources:
            sources = format_sources_for_ui(sources, aid)
            pcb({"eventKind": "references", "eventData": {"sources": sources}})
        # Clean up reasoning chunks when stream ends (footer indicates end of stream)
        if stream_id and stream_id in _stream_reasoning_chunks:
            accumulated_reasoning = "".join(_stream_reasoning_chunks[stream_id])
            if accumulated_reasoning:
                logger.debug(
                    f"Stream completed (stream_id: {stream_id[:8]}...). "
                    f"Total reasoning text accumulated: {len(accumulated_reasoning)} characters. "
                    f"Full accumulated reasoning: {accumulated_reasoning}"
                )
            del _stream_reasoning_chunks[stream_id]
    # elif "trace_ready" in ev:
    #     trace_data = ev["trace_ready"].get("trace")
    #     if trace_data:
    #         pcb({"eventKind": "TRACE", "eventData": {"trace": trace_data}})


def handle_artifacts(artifacts, artifacts_id, msgs, pcb, trace=None, aname=None, query=None):
    from backend.services.visualization_service import VisualizationService

    vis_mode = get_config().get("visualization_generation_mode")
    artifacts_records = get_records_from_artifacts(artifacts)
    graph_llm_id = get_charts_generation_llm_id()
    MAX_ROWS = 50  # TODO make it configurable
    logger.info(f"Visualization mode: {vis_mode}, graph_llm_id: {graph_llm_id}")
    # We generate a chart per records for the chart generation
    if vis_mode == "AUTO" and artifacts_records and graph_llm_id:
        for art_index in range(len(artifacts)):
            for idx, r in enumerate(artifacts_records[art_index] or []):
                pcb({"eventKind": EventKind.GENERATING_CHART, "eventData": {}})
                plan = None
                excep = None
                try:
                    plan = VisualizationService.prepare_chart_plan(
                        llm_id=graph_llm_id, messages=msgs, data=r["data"][:MAX_ROWS], columns=r["columns"], trace=trace
                    )
                except Exception as e:
                    logger.exception(f"Failed to generate graph {e}")
                    excep = e
                if plan:
                    pcb(
                        {
                            "eventKind": EventKind.CHART_PLAN,
                            "eventData": {
                                "chart_plan": plan,
                                "source_name": aname or "Used tool",
                                "artifacts_id": artifacts_id,
                                "artifact_index": art_index,
                                "records_index": idx,
                                "query": query,
                            },
                        }
                    )
                else:
                    pcb(
                        {
                            "eventKind": EventKind.CHART_PLAN_ERROR,
                            "eventData": f"Failed to get chart {excep if excep else ''}",
                        }
                    )


def get_selected_agents_as_objs(store, sel_ids) -> list[dict]:
    from flask import g

    agents_obj = []
    current_user = getattr(g, "authIdentifier", None)
    enterprise = get_enterprise_agents(current_user)
    client = dataiku.api_client()
    user_agents = store.get_all_agents()
    for ea in enterprise:
        if ea.get("id") in sel_ids:
            agents_obj.append(ea)
    for ua in user_agents:
        if ua["id"] in sel_ids:
            aid = ua["id"]
            is_owner = ua.get("owner") == current_user
            # MUST have published version
            if not ua.get("published_version"):
                logger.warning(f"Agent {aid} has no published version, skipping")
                continue  # Skip this agent
            logger.debug(f"Agent {aid}: is_owner={is_owner}, use_published={True}")
            project_key = ua.get("id")
            project = client.get_project(project_key)
            vis_agent = get_visual_agent_in_zone(project, PUBLISHED_ZONE)
            if not vis_agent:
                logger.warning(f"Agent {aid} has no visual agent in published zone, skipping")
                continue
            agents_obj.append(
                {
                    **ua,
                    "id": f"{project_key}:agent:{vis_agent.id}",
                    "uaid": aid,
                    "name": ua["published_version"].get("name", "Unnamed Agent"),
                }
            )  # include uaid to keep original id needed later in orchestrator service
    return agents_obj
