import asyncio
import base64
import datetime
import json
import re
import threading
import time
import uuid
from contextlib import closing
from typing import Any, Dict, List

import dataiku
from flask import g, request

from backend.agents.user_agent import UserAgent
from backend.config import (
    agents_as_tools,
    get_conversation_vision_llm,
    get_enable_upload_documents,
    get_enterprise_agents,
    get_extraction_mode,
    get_global_system_prompt,
    get_quota_images_per_conversation,
    get_uploads_managedfolder_id,
)
from backend.constants import ExtractionMode
from backend.models.events import EventKind
from backend.services.derived_documents_service import (
    count_conversation_images,
    get_structured_documents,
    predict_new_image_count,
    process_derived_documents,
)
from backend.services.orchestrator_service import OrchestratorService
from backend.utils.conv_utils import normalise_stream_event
from backend.utils.events_utils import (
    get_chart_plans,
    get_references,
    get_selected_agents,
    get_used_agent_ids,
    get_used_tables,
)
from backend.utils.logging_utils import extract_error_message, get_logger
from backend.utils.user_utils import get_agent_context
from backend.utils.utils import (
    build_agent_connect,
    call_dss_agent_full_conversation,
    get_user_and_ent_agents,
    get_user_base_llm,
    select_agents,
)

logger = get_logger(__name__)


class ConversationError(Exception):
    """Raised when payload validation fails or other business rules break."""


def build_actions(events: list[dict]) -> dict:
    from flask import g

    actions = {}
    if events:
        chart_plans = get_chart_plans(events)
        if chart_plans:
            actions["chart_plans"] = chart_plans
        references = get_references(events)
        sel_agents = get_selected_agents(events)
        logger.info(f"Building message actions, references: {len(references)}")
        logger.info(f"Building message actions, selected agents: {len(sel_agents)}")
        if len(references) != 1 or len(sel_agents) != 1:
            # Only enable stories for one agent and one tool with references
            actions["stories"] = {"enable_stories": False}
        else:
            tables = get_used_tables(references[0])
            logger.info(f"Building message actions, tables: {tables}")
            eas = get_enterprise_agents(g.get("authIdentifier", None))
            agent = next((ea for ea in eas if ea.get("id") == sel_agents[0]["agentId"]), None)
            logger.info(f"Building message actions, agent: {agent}")
            if agent and agent.get("stories_workspace"):
                actions["stories"] = {
                    "enable_stories": True,
                    "tables": tables,
                    "agent_id": sel_agents[0]["agentId"],
                    "stories_workspace": agent.get("stories_workspace"),
                    "agent_short_example_queries": agent.get("agent_short_example_queries", []),
                    "agent_example_queries": agent.get("agent_example_queries", []),
                }
    return actions


class ConversationService:
    def __init__(self, store, draft_mode: bool = False):
        self.store = store
        self.draft_mode = draft_mode  # True only in test contexts
        self.project = dataiku.api_client().get_default_project()
        self._user_agents_list = self.store.get_all_agents()
        self._user_agents_dict = {a["id"]: a for a in self._user_agents_list}

    def _missing_agents(self, agent_ids: List[str], user: str) -> List[str]:
        existing = {e["id"] for e in get_enterprise_agents(user)}
        existing.update(self._user_agents_dict.keys())
        return [aid for aid in agent_ids if aid not in existing]

    @staticmethod
    def _blank_conversation(conv_id: str, agent_ids: List[str]) -> dict:
        return {"id": conv_id, "title": None, "messages": [], "agentIds": agent_ids}

    def _generate_title(self, messages: list[dict], trace=None) -> str:
        # Create a subspan if trace is provided
        if trace:
            title_span = trace.subspan("generate_title")
            title_span.begin(int(time.time() * 1000))
        else:
            title_span = None

        today = datetime.date.today().strftime("%Y-%m-%d")
        llm_id = get_user_base_llm(self.store)

        # Add inputs to trace
        if title_span:
            title_span.inputs["llm_id"] = llm_id
            title_span.inputs["message_count"] = len(messages)
            title_span.inputs["date"] = today

        system_msg = (
            "You are a titling assistant. "
            f"Today is {today}. "
            "Return a single short title (≤7 words) that summarizes the conversation. "
            "Plain text ONLY: no prefixes (e.g., 'Title:'), no quotes, no emojis, "
            "no code fences, no extra lines, and no trailing punctuation."
        )
        user_msg = "Generate a short title for this conversation. Output only the title text."

        comp = self.project.get_llm(llm_id).new_completion()
        comp.with_message(system_msg, role="system")
        comp.with_message(user_msg, role="user")

        for m in messages:
            comp = comp.with_message(m["content"], role=m["role"])

        logger.info(
            "Generating conversation title:\nLLM id=[%s]\nCompletion Query=%s\nCompletion Settings=%s\n",
            llm_id,
            json.dumps(comp.cq, indent=2, sort_keys=True),
            json.dumps(comp.settings, indent=2, sort_keys=True),
        )

        resp = comp.execute()
        raw_title = resp.text or ""
        logger.info(
            "Conversation title LLM response:\nLLM id=[%s]\nCompletion Response=%s",
            llm_id,
            raw_title,
        )

        if title_span and resp.trace:
            title_span.append_trace(resp.trace)

        cleaned_title = self._clean_title(raw_title)

        # Add outputs to trace
        if title_span:
            title_span.outputs["raw_title"] = raw_title
            title_span.outputs["cleaned_title"] = cleaned_title
            title_span.end(int(time.time() * 1000))
        return cleaned_title

    def _build_history(
        self,
        agent_ids: list[str],
        conv_messages: list[dict],
        *,
        documents_map: dict[str, dict] | None = None,
    ) -> list[dict]:
        history: list[dict] = []
        docs_map = documents_map or {}
        for msg in conv_messages:
            entry = {
                "role": msg.get("role"),
                "content": msg.get("content", ""),
            }
            attachments = msg.get("attachments") or []
            doc_entries = []
            for att in attachments:
                path = att.get("document_path")
                if not path:
                    continue
                doc_info = docs_map.get(path)
                if not doc_info:
                    continue
                doc_entries.append(
                    {
                        "name": doc_info.get("name") or att.get("name") or path.rsplit("/", 1)[-1],
                        "snapshots": doc_info.get("snapshots") or [],
                        "text": doc_info.get("text") or "",
                    }
                )
            if doc_entries:
                entry["__documents__"] = doc_entries
            history.append(entry)
        if len(agent_ids) == 1:
            agent_id = agent_ids[0]
            # Try to get enterprise agent first
            enterprise_agents = get_enterprise_agents(g.get("authIdentifier", None))
            enterprise_agent = next((ea for ea in enterprise_agents if ea.get("id") == agent_id), None)
            if enterprise_agent:
                # Enterprise agent: use agent_system_instructions
                agent_system_instructions = enterprise_agent.get("agent_system_instructions", "")
                if agent_system_instructions:
                    today = datetime.date.today().strftime("%Y-%m-%d")
                    sys_msg = {"role": "system", "content": f"Today is {today}. {agent_system_instructions}"}
                    return [sys_msg] + list(history)
                return history  # no system instructions
        today = datetime.date.today().strftime("%Y-%m-%d")
        global_prompt = get_global_system_prompt()
        sys_msg = {"role": "system", "content": f"Today is {today}.{global_prompt}"}
        return [sys_msg] + list(history)

    # ------------------------------------------------------------------ #
    #  Streaming send  (WebSocket)
    # ------------------------------------------------------------------ #
    def stream_message(self, payload: dict, emit, cancel_event: threading.Event):
        logger.info("stream_message called from %s with Payload: %r", request.remote_addr, payload)
        from dataiku.llm.tracing import new_trace

        trace = new_trace("DKU_AGENT_HUB_QUERY")
        trace.begin(int(time.time() * 1000))

        assistant_msg_dict = None
        conv = {}
        events: list[dict] = []
        agent_ids = None
        draft_mode = False
        try:
            draft_mode = bool(payload.get("draftMode"))
            conv_id = payload.get("conversationId")
            user_msg = (payload.get("userMessage") or "").strip()
            agents_enabled = payload.get("modeAgents", True)
            selected_llm = payload.get("selectedLLM", "")
            payload_attachments = payload.get("attachments") or []
            agent_ids = self._filter_active_agents(payload.get("agentIds") or [], allow_draft=draft_mode)
            logger.info(f"_filter_active_agents called, {agent_ids}")
            # -- validation -------------------------------------------------
            # Allow empty user message if attachments are provided
            if not conv_id or (not user_msg and not payload_attachments):
                emit("chat_error", {"error": "Missing conversationId or userMessage"})
                return
            miss = self._missing_agents(agent_ids, g.get("authIdentifier"))
            if miss:
                emit("chat_error", {"error": f"Agent(s) not found: {', '.join(miss)}"})
                return

            # -- draft-mode path (legacy full-rewrite) ----------------------
            if draft_mode:
                self._stream_draft(conv_id, agent_ids, user_msg, emit, cancel_event, trace)
                return

            # -------- non-draft incremental path ---------------------------
            conv = self.store.get_conversation(conv_id) or {}
            if not conv:
                self.store.ensure_conversation_exists(
                    conv_id, agent_ids, agents_enabled, selected_llm
                )
                conv = self._blank_conversation(conv_id, agent_ids)
            old_agent_ids = list(conv.get("agentIds", []))

            # user msg
            user_msg_dict = {
                "id": str(uuid.uuid4()),
                "role": "user",
                "content": user_msg,
                "eventLog": [],
                "modeAgents": agents_enabled,
                "selectedLLM": selected_llm,
            }
            # attachments provided by the client (fresh upload metadata),
            if payload_attachments:
                user_msg_dict["attachments"] = payload_attachments
                # Store attachments in message_attachments table (extraction_mode will be set after processing)
                self.store.insert_or_update_message_attachments(user_msg_dict["id"], json.dumps(payload_attachments))
                # Mark these attachments as 'attached' in derived_documents (if relevant)
                docs = self.store.get_derived_documents(conv_id)
                payload_doc_names = set([a["document_name"] for a in payload_attachments if "document_name" in a])
                for doc in docs:
                    if (
                        doc.get("document_name") in payload_doc_names
                        and doc.get("metadata", {}).get("status") == "uploaded"
                    ):
                        doc["metadata"]["status"] = "processed"
                        self.store.upsert_derived_document(
                            conv_id, doc["document_name"], doc["document_path"], doc["metadata"]
                        )

            conv["messages"].append(user_msg_dict)

            # -------- Document Extraction and Quota Tracking --------
            # Shared events buffer (persisted in assistant message eventLog)
            # Get configuration parameters
            enable_document_upload = get_enable_upload_documents()
            image_quota = get_quota_images_per_conversation()
            extraction_mode = get_extraction_mode()

            # Only process documents if document upload is enabled
            if payload_attachments and enable_document_upload:
                # Count existing images in the conversation
                current_image_count = count_conversation_images(self.store, conv_id)

                # Initialize quota_exceeded flag (will be set to True if quota is exceeded)
                quota_exceeded = False

                # Predict how many new images would be generated from attachments (only if screenshots mode)
                predicted_new_images = 0
                if extraction_mode == ExtractionMode.PAGES_SCREENSHOTS.value:
                    folder_id = get_uploads_managedfolder_id()
                    predicted_new_images = predict_new_image_count(payload_attachments, folder_id)
                    logger.info(
                        f"Predicted {predicted_new_images} new images from {len(payload_attachments)} attachments"
                    )

                    # Calculate total predicted image count
                    predicted_total = current_image_count + predicted_new_images

                    # Check if quota would be exceeded with new documents
                    if predicted_total > image_quota:
                        extraction_mode = ExtractionMode.PAGES_TEXT.value
                        quota_exceeded = True
                        logger.warning(
                            f"Image quota WOULD BE EXCEEDED for conversation {conv_id}. "
                            f"Predicted total: {predicted_total} > Quota: {image_quota}. "
                            f"Switching to TEXT-ONLY mode for all documents."
                        )
                        # Emit event to notify frontend that extraction mode changed to text-only
                        evt_mode_change = {
                            "eventKind": EventKind.EXTRACTION_MODE_CHANGED,
                            "eventData": {
                                "extractionMode": ExtractionMode.PAGES_TEXT.value,
                                "reason": "quota_exceeded",
                            },
                        }
                        emit("chat_event", {**evt_mode_change, "conversationId": conv_id})
                        events.append(evt_mode_change)
                    else:
                        logger.info(
                            f"Document extraction mode: {extraction_mode} for conversation {conv_id}. "
                            f"Current images: {current_image_count}/{image_quota}, "
                            f"Predicted new: {predicted_new_images}, "
                            f"Predicted total: {predicted_total}/{image_quota}"
                        )

                # Generate derived context (screenshots or text only) based on extraction mode
                # Emit document analysis event (live) and persist it in events buffer
                evt_da = {
                    "eventKind": EventKind.DOCUMENT_ANALYSIS,
                    "eventData": {"count": len(payload_attachments)},
                }
                emit("chat_event", {**evt_da, "conversationId": conv_id})
                events.append(evt_da)
                try:
                    process_derived_documents(
                        self.store,
                        conv_id,
                        g.get("authIdentifier") or "",
                        payload_attachments,
                        extraction_mode=extraction_mode,
                    )
                    # Store extraction_mode as JSON with quota_exceeded flag
                    self.store.insert_or_update_message_attachments(
                        user_msg_dict["id"],
                        json.dumps(payload_attachments),
                        extraction_mode=extraction_mode,
                        quota_exceeded=quota_exceeded
                        if extraction_mode == ExtractionMode.PAGES_TEXT.value
                        else False,
                    )
                    # Mark all attachments as ready after successful processing
                    for att in payload_attachments:
                        att["uploadStatus"] = "ready"

                    # Log the actual new image count
                    new_image_count = count_conversation_images(self.store, conv_id)
                    images_added = new_image_count - current_image_count

                    if extraction_mode == ExtractionMode.PAGES_SCREENSHOTS.value:
                        logger.info(
                            f"Added {images_added} images to conversation {conv_id}. "
                            f"Total: {new_image_count}/{image_quota} "
                            f"(predicted: {predicted_new_images}, actual: {images_added})"
                        )
                    else:
                        logger.info(
                            f"Processed {len(payload_attachments)} documents in TEXT-ONLY mode. "
                            f"Total images unchanged: {new_image_count}/{image_quota}"
                        )

                    evt_da = {
                        "eventKind": EventKind.DOCUMENT_ANALYSIS_COMPLETED,
                        "eventData": {"conv_id": conv_id, "documents": payload_attachments},
                    }
                    emit("chat_event", {**evt_da, "conversationId": conv_id})
                    events.append(evt_da)
                except Exception:
                    logger.exception("Failed to process derived documents for conversation %s", conv_id)

            # Select based on settings
            sel_agents, justification = select_agents(self.store, agent_ids, conv["messages"], trace=trace)
            sel_ids = [a["agentId"] for a in sel_agents]
            if justification and not agents_as_tools():
                evt = {
                    "eventKind": EventKind.AGENT_SELECTION,
                    "eventData": {"justification": justification, "selection": sel_agents},
                }
                emit("chat_event", evt)
                events.append(evt)
            elif len(sel_agents):
                evt = {
                    "eventKind": EventKind.AGENT_SELECTION,
                    "eventData": {"justification": "User selected", "selection": sel_agents},
                }
                # emit("chat_event", evt)
                events.append(evt)
            user_login = g.get("authIdentifier") or ""
            documents_map: dict[str, dict] = {}
            try:
                documents_map = get_structured_documents(self.store, conv_id, user_login)
            except Exception:
                logger.exception("Failed to build structured prompt for conversation %s", conv_id)

            # prompt
            history = self._build_history(
                agent_ids,
                conv["messages"],
                documents_map=documents_map,
            )

            # assistant streaming
            final_reply, artifacts_meta = self._stream_reply(
                conv_id, sel_agents, history, emit, events, cancel_event, trace
            )
            # TODO maybe could optimize and return these directly instead of recomputing them
            actions = build_actions(events)
            used_ids = get_used_agent_ids(events)
            assistant_msg_dict = {
                "id": str(uuid.uuid4()),
                "role": "assistant",
                "content": final_reply,
                "eventLog": events,
                "actions": actions,
                "artifactsMetadata": artifacts_meta,
                "selectedAgentIds": sel_ids,
                "usedAgentIds": used_ids,
                "modeAgents": agents_enabled,
                "selectedLLM": selected_llm,
            }
            conv["messages"].append(assistant_msg_dict)
            # atomic write
            self.store.append_messages(conv_id, [user_msg_dict, assistant_msg_dict])
            # meta update if needed
            title_generated = False
            if conv["title"] is None or (conv["title"].strip() == "Untitled"):
                try:
                    conv["title"] = self._generate_title(conv["messages"], trace=trace)
                    title_generated = True
                except Exception as e:
                    # If title generation fails (e.g., broken LLM connection), keep existing title
                    logger.exception(f"Failed to generate title for conversation {conv_id}: {e}")
                    # Keep existing title (None or "Untitled"), don't set title_generated
                    title_generated = False
            if (
                title_generated
                or old_agent_ids != agent_ids
                or conv.get("agents_enabled") != agents_enabled
                or conv.get("llm_id") != selected_llm
            ):
                self.store.update_conversation_meta(
                    conv_id,
                    title=conv["title"],
                    agent_ids=agent_ids,
                    agents_enabled=agents_enabled,
                    llm_id=selected_llm,
                )
        finally:
            # End tracing
            trace.end(int(time.time() * 1000))

            # If we have an assistant message, update it with the trace info
            if assistant_msg_dict is not None:
                # Add the trace data to the assistant message dictionary.
                assistant_msg_dict["trace"] = trace.to_dict() if trace else {}
                # Update the message in the store using the update_message function.
                self.store.update_message(assistant_msg_dict["id"], {"trace": assistant_msg_dict["trace"]})

            # Emit chat_end to ensure frontend stops waiting, even if errors occurred
            # Only emit if we got past the early validation returns (conv_id exists)
            if conv_id and not cancel_event.is_set() and not draft_mode:
                # Reload conversation from database to get latest messages with extraction_mode (including quota_exceeded)
                # This ensures extraction_mode JSON is parsed and included in messages sent to frontend
                fresh_conv = self.store.get_conversation(conv_id) or conv
                title = fresh_conv.get("title") if fresh_conv else conv.get("title") if conv else 'New Conversation'
                messages = fresh_conv.get("messages", []) if fresh_conv else conv.get("messages", []) if conv else []
                
                emit(
                    "chat_end",
                    {
                        "agentIds": agent_ids or [],
                        "conversationId": conv_id,
                        "title": title,
                        "messages": messages,
                        "modeAgents": agents_enabled if agents_enabled is not None else True,
                        "selectedLLM": selected_llm or "",
                        "hasEventLog": bool(events),
                    },
                )

    # ------------------------------------------------------------------ #
    #  Draft-mode helper (legacy logic, no optimisation)
    # ------------------------------------------------------------------ #
    def _stream_draft(
        self, conv_id: str, agent_ids: list[str], user_msg: str, emit, cancel_event: threading.Event, trace
    ):
        """
        Keeps the original behaviour for draft chats: full conversation
        object rewritten on every turn.
        """
        logger.debug("_stream_draft: conv=%s agents=%s user_msg=%r", conv_id, agent_ids, user_msg)
        conv = self.store.get_draft_conversation(agent_ids[0]) or self._blank_conversation(conv_id, agent_ids)
        user_msg_id = str(uuid.uuid4())
        conv["messages"].append({"id": user_msg_id, "role": "user", "content": user_msg, "hasEventLog": False})

        history = self._build_history(agent_ids, conv["messages"])

        events: list[dict] = []
        final_reply, artifacts_meta = self._stream_reply(
            conv_id,
            [
                {"agentId": aid, "query": user_msg, "agentName": self._user_agents_dict.get(str(aid), {}).get("name")}
                for aid in agent_ids
            ],
            history,
            emit,
            events,
            cancel_event,
            trace,
        )
        # TODO maybe could optimize and return these directly instead of recomputing them
        actions = build_actions(events)
        used_ids = get_used_agent_ids(events)
        conv["messages"].append(
            {
                "id": str(uuid.uuid4()),
                "role": "assistant",
                "content": final_reply,
                "eventLog": events,
                "actions": actions,
                "artifactsMetadata": artifacts_meta,
                "selectedAgentIds": agent_ids,
                "usedAgentIds": used_ids,
                "trace": trace.to_dict() if trace else {},
            }
        )
        conv["lastUpdated"] = datetime.datetime.utcnow().isoformat()
        self.store.upsert_draft_conversation(agent_ids[0], conv)

        if not cancel_event.is_set():
            emit(
                "chat_end",
                {
                    "agentIds": agent_ids,
                    "conversationId": conv_id,
                    "title": conv["title"],
                    "messages": conv["messages"],
                    "hasEventLog": bool(events),
                },
            )

    # ---------- streaming low-level -----------------------------------
    def _stream_reply(self, convId, sel_agents, history, emit, events, cancel_event: threading.Event, trace):
        final = ""
        artifacts_meta = {}

        def token_cb(tok):
            nonlocal final
            final += tok
            emit("chat_token", {"token": tok, "conversationId": convId})

        def push_event(ev_dict: dict, store_event: bool = True):
            ev_dict.update({"conversationId": convId})
            if store_event:
                events.append(ev_dict)
            emit("chat_event", ev_dict)

        sel_ids = [a.get("agentId") for a in sel_agents]
        logger.info(f"_stream_reply  called conv: {convId}, sel_ids: {sel_ids}")
        if len(sel_ids) == 0:
            final = self._stream_base_llm(history, token_cb, push_event, cancel_event)
        elif len(sel_ids) == 1:
            final = self._stream_single(
                sel_agents[0], history, token_cb, push_event, artifacts_meta, cancel_event, convId, trace
            )
        else:
            if agents_as_tools():
                final = self._stream_agent_connect(
                    sel_ids, history, token_cb, push_event, artifacts_meta, cancel_event, convId, trace
                )
            else:
                context = get_agent_context(g.get("authIdentifier"), convId)
                final = OrchestratorService.stream_multiple_agents(
                    llm_id=get_user_base_llm(self.store),
                    sel_agents=sel_agents,
                    messages=history,
                    context=context,
                    tcb=token_cb,
                    pcb=push_event,
                    artifacts_meta=artifacts_meta,
                    cancel_event=cancel_event,
                    store=self.store,
                )
        return final, artifacts_meta

    # ---- concrete streamers -----------------------------------------
    def _stream_base_llm(self, msgs, tcb, pcb, cancel_event: threading.Event, emit=None, conv_id=None):
        generic_error = "Error during llm call"
        final = ""
        try:
            # Check if we should use Vision LLM for screenshots
            # Determine if any messages have documents with snapshots
            has_snapshots = False
            for message in msgs:
                docs = message.get("__documents__") or []
                for doc in docs:
                    snapshots = doc.get("snapshots") or []
                    if snapshots:
                        has_snapshots = True
                        break
                if has_snapshots:
                    break
            
            # Use Vision LLM if extraction mode is pagesScreenshots and there are snapshots
            extraction_mode = get_extraction_mode()
            if extraction_mode == ExtractionMode.PAGES_SCREENSHOTS and has_snapshots:
                llm_id = get_conversation_vision_llm()
            else:
                llm_id = get_user_base_llm(self.store)
            
            comp = self.project.get_llm(llm_id).new_completion()
            folder = dataiku.Folder(get_uploads_managedfolder_id())
            inline_cache: Dict[str, str] = {}

            def append_documents(documents: list[dict]) -> None:
                nonlocal comp
                for doc in documents:
                    doc_name = doc.get("name") or "Document"
                    text = (doc.get("text") or "").strip()
                    snapshots = doc.get("snapshots") or []
                    # If text content exists, prefer a single system message with the content
                    if text:
                        comp = comp.with_message(f"[Document: {doc_name}]\n{text}", role="user")
                        continue
                    # If a single image snapshot, add filename then the image once
                    if len(snapshots) == 1:
                        snap = snapshots[0]
                        path = snap.get("screenshot_path")
                        inline_data = None
                        if path:
                            inline_data = inline_cache.get(path)
                            if inline_data is None:
                                try:
                                    with folder.get_download_stream(path) as stream:
                                        image_bytes = stream.read()
                                    inline_data = base64.b64encode(image_bytes).decode("utf-8")
                                    inline_cache[path] = inline_data
                                except Exception as err:  # pragma: no cover
                                    logger.warning("Failed to retrieve snapshot %s: %s", path, err)
                        msg = comp.new_multipart_message()
                        msg.with_text(f"[Image: {doc_name}]")
                        if inline_data:
                            msg.with_inline_image(inline_data)
                        msg.add()
                        continue
                    # Default: multi-page (pdf/docx/pptx) → per-page captions and images
                    if snapshots:
                        msg = comp.new_multipart_message()
                        has_parts = False
                        for idx, snap in enumerate(snapshots, start=1):
                            caption = f"Document {doc_name}, page {snap.get('page') or idx}"
                            msg.with_text(caption)
                            path = snap.get("screenshot_path")
                            inline_data = None
                            if path:
                                inline_data = inline_cache.get(path)
                                if inline_data is None:
                                    try:
                                        with folder.get_download_stream(path) as stream:
                                            image_bytes = stream.read()
                                        inline_data = base64.b64encode(image_bytes).decode("utf-8")
                                        inline_cache[path] = inline_data
                                    except Exception as err:  # pragma: no cover
                                        logger.warning("Failed to retrieve snapshot %s: %s", path, err)
                            if inline_data:
                                msg.with_inline_image(inline_data)
                            has_parts = True
                        if has_parts:
                            msg.add()

            for message in msgs:
                comp = comp.with_message(message.get("content", ""), role=message.get("role"))
                docs = message.get("__documents__") or []
                if docs:
                    append_documents(docs)

            from backend.utils.logging_utils import sanitize_messages_for_log

            logger.info("Streaming using plain LLM - prompt messages: %s", sanitize_messages_for_log(msgs))
            # Emit thinking event before starting stream (this will be stored in events list via pcb)
            pcb({"eventKind": EventKind.AGENT_THINKING, "eventData": {"agentName": "Agent Hub"}})
            # Track if we've emitted the responding event (only on first text token)
            responding_emitted = False

            # wrap in closing(...) so that .close() (and thus HTTP teardown) happens on break
            with closing(comp.execute_streamed()) as stream:
                for chunk in stream:
                    if cancel_event.is_set():
                        break
                    data = chunk.data
                    if "text" in data:
                        # Emit responding event on first text token (replaces thinking)
                        # This must happen BEFORE token callback to ensure event is processed first
                        if not responding_emitted:
                            # Emit ANSWER_STREAM_START event (will be stored via pcb)
                            pcb({"eventKind": EventKind.ANSWER_STREAM_START, "eventData": {"agentName": "Agent Hub"}})
                            responding_emitted = True
                            logger.info("Emitted ANSWER_STREAM_START event for Agent Hub")
                        tcb(data["text"])
                    elif data.get("type") == "event":
                        pcb({"eventKind": data["eventKind"], "eventData": data.get("eventData", {})})
                    if "text" in data:
                        final += data["text"]
        except Exception as e:
            # Catch any exceptions during streaming and emit error message as text and error event
            error_message = f"{generic_error}: {extract_error_message(str(e))}"
            logger.exception(f"Exception in _stream_base_llm: {error_message}")
            # Emit error message as text so it appears in the assistant response
            tcb(generic_error)
            # Emit BASE_LLM_ERROR event for the event log (similar to how user_agent emits AGENT_ERROR)
            pcb(
                {
                    "eventKind": EventKind.BASE_LLM_ERROR,
                    "eventData": {
                        "message": error_message,
                        "agentName": "Agent Hub",
                    },
                }
            )
            final = generic_error
        return final

    def _stream_single(self, agent_info, msgs, tcb, pcb, artifacts_meta, cancel_event: threading.Event, conv_id, trace):
        final = ""
        aid = agent_info.get("agentId")
        aname = agent_info.get("agentName")
        ua = self._user_agents_dict.get(str(aid))
        user = g.get("authIdentifier")
        agent_query = agent_info.get("query", "")
        if ua:
            logger.info(f"Calling user agent {aid} {aname}")
            pcb({"eventKind": EventKind.CALLING_AGENT, "eventData": {"agentId": aid, "agentName": aname}})
            is_owner = ua.get("owner") == user

            # Same simple logic
            if self.draft_mode and is_owner:
                use_published = False
            else:
                if not ua.get("published_version"):
                    tcb("This agent has not been published yet.")
                    return "This agent has not been published yet."
                use_published = True

            ua_obj = UserAgent(ua, use_published=use_published)
            from dataiku.llm.tracing import new_trace

            async def _run():
                nonlocal final
                stream_id = str(uuid.uuid4())
                # mark usage up-front (we know exactly which agent is called)
                async for ev in ua_obj.aprocess_stream(
                    query={"messages": msgs, "context": get_agent_context(user, conv_id)},
                    settings={},
                    trace=trace or new_trace(aid),
                ):
                    if cancel_event.is_set():
                        break
                    normalise_stream_event(
                        ev=ev,
                        tcb=tcb,
                        pcb=pcb,
                        msgs=msgs,
                        aid=aid,
                        trace=trace,
                        aname=aname,
                        query=agent_query,
                        artifacts_meta=artifacts_meta,
                        stream_id=stream_id,
                    )
                    if "chunk" in ev and "text" in ev["chunk"]:
                        final += ev["chunk"]["text"]

            asyncio.run(_run())
        else:
            logger.info(f"Calling Enterprise Agent: {aid}")
            pcb({"eventKind": EventKind.CALLING_AGENT, "eventData": {"agentId": aid, "agentName": aname}})
            stream_id = str(uuid.uuid4())
            for ev in call_dss_agent_full_conversation(aid, msgs, user, True, conv_id=conv_id, trace=trace):
                if cancel_event.is_set():
                    break
                normalise_stream_event(
                    ev=ev,
                    tcb=tcb,
                    pcb=pcb,
                    msgs=msgs,
                    aid=aid,
                    trace=trace,
                    aname=aname,
                    query=agent_query,
                    artifacts_meta=artifacts_meta,
                    stream_id=stream_id,
                )
                if "chunk" in ev and "text" in ev["chunk"]:
                    final += ev["chunk"]["text"]
        return final

    def _stream_agent_connect(
        self, aids, msgs, tcb, pcb, artifacts_meta, cancel_event: threading.Event, conv_id, trace
    ):
        final = ""
        from dataiku.llm.tracing import new_trace

        agent_connect = build_agent_connect(
            self.store, aids, user_agents=self._user_agents_dict, draft_mode=self.draft_mode, conv_id=conv_id
        )

        async def _run():
            nonlocal final
            nonlocal trace
            user = g.get("authIdentifier")
            if not trace:
                trace = new_trace("DKU_AGENT_HUB_QUERY")
                trace.begin(int(time.time() * 1000))
            stream_id = str(uuid.uuid4())
            try:
                async for ev in agent_connect.aprocess_stream(
                    query={"messages": msgs, "context": get_agent_context(user, conv_id)},
                    settings={},
                    artifacts_meta=artifacts_meta,
                    trace=trace,
                    pcb=pcb,
                ):
                    if cancel_event.is_set():
                        break
                    normalise_stream_event(ev=ev, tcb=tcb, pcb=pcb, msgs=msgs, stream_id=stream_id)
                    if "chunk" in ev and "text" in ev["chunk"]:
                        final += ev["chunk"]["text"]
            finally:
                trace.end(int(time.time() * 1000))

        asyncio.run(_run())
        return final

    # ------------------------------------------------------------------ #
    #  Utility helpers
    # ------------------------------------------------------------------ #
    def _filter_active_agents(self, agent_ids: list[str], allow_draft: bool = False) -> list[str]:
        """Filter to only usable agents based on context."""
        current_user = g.get("authIdentifier")
        user_agents = self._user_agents_dict
        enterprise_ids = {e["id"] for e in get_enterprise_agents(current_user)}

        filtered = []
        for aid in agent_ids:
            if aid in enterprise_ids:
                # Enterprise agents always OK
                filtered.append(aid)
            elif aid in user_agents:
                agent = user_agents[aid]
                is_owner = agent.get("owner") == current_user

                if allow_draft and is_owner:
                    # Owner in draft mode - allow
                    filtered.append(aid)
                elif agent.get("published_version"):
                    # Has published version - allow
                    filtered.append(aid)
                # else: skip - no published version

        return filtered

    @staticmethod
    def _clean_title(raw: str) -> str:
        s = (raw or "").strip()
        s = re.sub(r"```[\w\-]*\s*([\s\S]*?)\s*```", r"\1", s).strip()

        for line in s.splitlines():
            if line.strip():
                s = line.strip()
                break
        else:
            s = ""

        s = re.sub(r"^\s*#{1,6}\s*", "", s)  # '#', '##', etc.
        s = re.sub(r"^\s*>+\s*", "", s)  # '>' or '>>'

        s = re.sub(
            r"^\s*(?:title|subject|chat|conversation)\s*[:：\-–]\s*",
            "",
            s,
            flags=re.IGNORECASE,
        )

        s = s.strip(" \"'`*_")

        s = re.sub(r"\s+", " ", s)

        s = re.sub(r"[.:;!?，。；！]+$", "", s)

        return s or "Untitled"
