import pickle
import re

from backend.models.artifacts import Artifact, ArtifactsMetadata
from backend.models.events import EventKind
from backend.utils.logging_utils import get_logger

logger = get_logger(__name__)

EVENT_DATA = "eventData"
EVENT_KIND = "eventKind"


def get_references(events: list[dict]) -> list:
    references = []
    for event in events or []:
        if event.get(EVENT_KIND) == "references":
            references.append(event[EVENT_DATA])
    return references


def get_used_tables(reference: dict) -> list[str]:
    TABLE_NAME_PATTERN = re.compile(r"""['"]table_name['"]\s*:\s*['"]([^'"]+)['"]""")
    tables: set[str] = set()
    sources = reference.get("sources") or []
    for source in sources:
        for item in source.get("items") or []:
            if (
                item.get("type") == "INFO"
                and isinstance(item.get("textSnippet"), str)
                and item["textSnippet"].startswith("Decided to use [")
            ):
                for match in TABLE_NAME_PATTERN.finditer(item["textSnippet"]):
                    tables.add(match.group(1))

    return list(tables)


def get_used_agent_ids(events: list[dict]) -> list[str]:
    """
    Extract agent IDs from multiple event types:
    - agent_started: directly extract agentId from eventData
    - AGENT_CALLING_AGENT: map agent names to IDs from AGENT_SELECTION event
    Handles duplicate agent names by including all matching agent IDs.
    When interacting with a single agent, returns that agent's ID directly.
    """
    used: set[str] = set()
    
    # First, build a mapping from agent name to list of agent IDs from AGENT_SELECTION event
    # Using a list to handle duplicate names
    name_to_ids: dict[str, list[str]] = {}
    selected_agents = []
    for ev in events or []:
        event_kind = ev.get(EVENT_KIND)
        # Handle enum values
        if isinstance(event_kind, EventKind):
            event_kind = event_kind.value
        if event_kind == "AGENT_SELECTION":
            selection = (ev.get(EVENT_DATA) or {}).get("selection", [])
            selected_agents = selection
            for agent in selection:
                agent_name = agent.get("agentName")
                agent_id = agent.get("agentId")
                if isinstance(agent_name, str) and isinstance(agent_id, str):
                    if agent_name not in name_to_ids:
                        name_to_ids[agent_name] = []
                    name_to_ids[agent_name].append(agent_id)
    
    # If only one agent is selected, return its ID(s) directly
    if len(selected_agents) == 1:
        agent_id = selected_agents[0].get("agentId")
        if isinstance(agent_id, str):
            return [agent_id]
    
    # Extract agent IDs from agent_started events (direct agentId in eventData)
    for ev in events or []:
        event_kind = ev.get(EVENT_KIND)
        # Handle enum values
        if isinstance(event_kind, EventKind):
            event_kind = event_kind.value
        if event_kind == "agent_started":
            agent_id = (ev.get(EVENT_DATA) or {}).get("agentId")
            if isinstance(agent_id, str):
                used.add(agent_id)
    
    # Extract agent IDs from AGENT_CALLING_AGENT events (map names to IDs)
    for ev in events or []:
        event_kind = ev.get(EVENT_KIND)
        # Handle enum values
        if isinstance(event_kind, EventKind):
            event_kind = event_kind.value
        if event_kind == "AGENT_CALLING_AGENT":
            agent_name = (ev.get(EVENT_DATA) or {}).get("agentAsToolName")
            if isinstance(agent_name, str) and agent_name in name_to_ids:
                # Add all agent IDs with this name (handles duplicates)
                used.update(name_to_ids[agent_name])
    
    return list(used)


def get_chart_plans(events: list[dict]) -> list[dict]:
    plans = []
    # look for chart data if there is any
    for event in events or []:
        if event[EVENT_KIND] == "chart_plan" and "chart_plan" in event[EVENT_DATA]:
            plans.append(event[EVENT_DATA])
    return plans


def extract_artifacts_preview(artifacts: list[Artifact]):
    preview = []
    for artifact in artifacts or []:
        artifact_preview = {
            "name": artifact.get("name"),
            "type": artifact.get("type"),
            "parts": [],
            "description": artifact.get("description"),
            "preview": True,
        }
        for item in artifact.get("parts") or []:
            if item.get("type") != "RECORDS":
                artifact_preview["parts"].append(item)
            else:
                artifact_preview["parts"].append(
                    {
                        "type": "RECORDS",
                        "records": {
                            "columns": item["records"]["columns"],
                            "data": item["records"]["data"][:50],  # Preview only first 5 rows
                        },
                    }
                )
        preview.append(artifact_preview)
    return preview


def get_artifacts_metadata(artifacts: dict, max_size_mb: float) -> dict:
    if not artifacts:
        return {}
    logger.info(f"Getting artifacts metadata for artifacts, max size {max_size_mb} MB")
    meta = {}
    for k, v in artifacts.items():
        size = v.get("size_mb", get_artifacts_size_mb(v.get("artifacts")))
        meta[k] = ArtifactsMetadata(
            size_mb=v.get("size_mb", 0),
            artifacts_id=v.get("artifacts_id", ""),
            agentName=v.get("agentName", ""),
            agentId=v.get("agentId", ""),
            query=v.get("query", ""),
            has_records=v.get("has_records", False),
            artifacts=v.get("artifacts") if size <= max_size_mb else extract_artifacts_preview(v.get("artifacts")),
            preview=True if size > max_size_mb else False,
        )
    return meta


def get_selected_agents(events: list[dict]) -> list[any]:
    if not events:
        return []
    for event in events:
        if event.get(EVENT_KIND) == "AGENT_SELECTION":
            return event.get(EVENT_DATA, {}).get("selection")
    return []


def get_artifacts_size_mb(artifacts: dict) -> float:
    size_bytes = len(pickle.dumps(artifacts))
    return size_bytes / (1024 * 1024)


def has_records(artifacts: dict) -> bool:
    for a in artifacts or []:
        for item in a.get("parts") or []:
            if item.get("type") == "RECORDS":
                return True
    return False
