"""A Langchain Tracer implementation that writes into a Dataiku LLM Mesh trace"""

import logging
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Set
from uuid import UUID

from langchain_core.env import get_runtime_environment
from langchain_core.load import dumpd
from langchain_core.messages import BaseMessage
from langchain_core.outputs import LLMResult
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import Run

from dataiku.llm.tracing import SpanBuilder

logger = logging.getLogger(__name__)

# Transforms the run object into a proper dict
# This handles serializing inputs and outputs, while properly
# transforming pydantic objects.
#
# The problem is that:
#   in pydantic v1, pydantic.BaseModel.dict() properly recurses, so that you get a pure python dict that is serializable
#   in pydantic v2, pydantic.BaseModel.dict() properly recurses, so that you get a pure python dict that is serializable
#   But Pydantic v2 includes a compatibility layer, pydantic.v1.BaseModel, that is buggy. pydantic.v1.BaseModel.dict() 
#            does not recurse, and leaves BaseModels, and it's not JSON-serializable
#   AND Langchain prefers to use pydantic.v1.BaseModel rather than pydantic.BaseModel
#   So we must workaround the bug by recursing ourselves to re-serialize each sub-object. FML.
def _run_to_dict(run: Run) -> dict:
    run_dict =  run.dict()

    def fixup_pydantic_crap_rec(pydantic_obj: Any, serialized_obj: Any) -> Any:
        #print("Fixup Pydantic crap obj type= %s " % type(pydantic_obj))
        import pydantic

        try:
            from pydantic.v1 import BaseModel as BMV1

            if isinstance(pydantic_obj, BMV1):
                #print("XX Replacing %s by its dict" % pydantic_obj)
                return pydantic_obj.dict()

        except Exception as e:
            pass

        if isinstance(pydantic_obj, pydantic.BaseModel):
            #print("XX Replacing %s by its dict (V2)" % pydantic_obj)
            return pydantic_obj.dict()

        if isinstance(serialized_obj, dict):
            for (k, v) in serialized_obj.items():
                #print("  Pydantic crap k=%s v=%s" % (k, type(v)))
                serialized_obj[k] = fixup_pydantic_crap_rec(v, serialized_obj[k])

        elif isinstance(serialized_obj, list):
            #print("  Pydantic crap recurse on list items: %s" % len(serialized_obj))
            for i in range(len(serialized_obj)):
                serialized_obj[i] = fixup_pydantic_crap_rec(pydantic_obj[i], serialized_obj[i])

        else:
            return serialized_obj

    if isinstance(run_dict.get("inputs", {}), dict):
       fixup_pydantic_crap_rec(run.inputs, run_dict["inputs"])
    if isinstance(run_dict.get("outputs", {}), dict):
        fixup_pydantic_crap_rec(run.outputs, run_dict["outputs"])

    if run.inputs is None and "inputs" in run_dict:
        del run_dict["inputs"]
    if run.outputs is None and "outputs" in run_dict:
        del run_dict["outputs"]

    if isinstance(run_dict.get("outputs", {}), dict) \
        and isinstance(run_dict.get("outputs", {}).get("output", {}), dict) \
        and run_dict.get("outputs", {}).get("output", {}).get("response_metadata", False):
        del run_dict["outputs"]["output"]["response_metadata"]

    return run_dict

def dku_span_builder_for_callbacks(callbacks: Any, ignore_missing: bool = False) -> SpanBuilder:
    """
    Returns a DKU Trace span builder corresponding to a LangChain Callbacks object

    This is useful to use in tool methods:

    @tool
    async def where_cat_is_hiding(cat_name: str, callbacks: Callbacks = None) -> str:
        cat_location = None
        with dku_span_builder_for_callbacks(callbacks).subspan("Searching for cat...") as s:
            cat_location = random.choice(["under the bed", "on the shelf"])

        with dku_span_builder_for_callbacks(callbacks).subspan("Verifying cat location...") as s:
            verify_location(cat_location)

        return cat_location


    For this to work, the Callbacks must have been instantiated with a LangchainToDKUTracer
    """
    if hasattr(callbacks, "handlers"):
        for handler in callbacks.handlers:
            if isinstance(handler, LangchainToDKUTracer):
                if hasattr(callbacks, "parent_run_id"):
                    return handler.run_id_to_span_map[str(callbacks.parent_run_id)]
                else:
                    if ignore_missing:
                        return SpanBuilder("ignored")
                    else:
                        raise Exception("Callbacks %s don't have a parent_run_id" % callbacks)
        if ignore_missing:
            return SpanBuilder("ignored")
        else:
            raise Exception("Callbacks %s don't have a LangchainToDKUTracer handler" % callbacks)
    else:
        if ignore_missing:
            return SpanBuilder("ignored")
        else:
            raise Exception("Callbacks %s don't have handlers" % callbacks)


# This is strongly inspired / copied from the LangChain BaseTracer itself
class LangchainToDKUTracer(BaseTracer):
    """

    A LangChain-compatible tracer that logs to the Dataiku LLM Mesh trace object

    Typical usage in a Dataiku Code Agent:


    Sync version:

    def process(self, query, settings, trace):
        tracer = LangchainToDKUTracer(dku_trace=trace)

        agent_executor.invoke({"input": prompt}, config={ "callbacks": [tracer] })


    Async version:

    async def aprocess_stream(self, query, settings, trace):
        tracer = LangchainToDKUTracer(dku_trace=trace)
        
        async for event in self.agent_executor.astream_events({"input": prompt}, config={ "callbacks": [tracer] }):
            # process events
    """

    def __init__(
        self,
        dku_trace: SpanBuilder,
        tags: Optional[List[str]] = None,
        **kwargs: Any,
    ) -> None:

        super().__init__(**kwargs)

        self.dku_trace = dku_trace
        self.known_runs = 0
        self.tags = tags or []
        self.latest_run: Optional[Run] = None
        self.run_id_to_span_map: Dict = {}
        self.attached_run_ids: Set = set()

    def on_chat_model_start(
        self,
        serialized: Dict[str, Any],
        messages: List[List[BaseMessage]],
        *,
        run_id: UUID,
        tags: Optional[List[str]] = None,
        parent_run_id: Optional[UUID] = None,
        metadata: Optional[Dict[str, Any]] = None,
        name: Optional[str] = None,
        **kwargs: Any,
    ) -> Run:

        start_time = datetime.now(timezone.utc)
        if metadata:
            kwargs.update({"metadata": metadata})
        chat_model_run = Run(
            id=run_id,
            parent_run_id=parent_run_id,
            serialized=serialized,
            inputs={"messages": [[dumpd(msg) for msg in batch] for batch in messages]},
            extra=kwargs,
            events=[{"name": "start", "time": start_time}],
            start_time=start_time,
            run_type="llm",
            tags=tags,
            name=name,  # type: ignore[arg-type]
        )
        self._start_trace(chat_model_run)
        self._on_chat_model_start(chat_model_run)
        return chat_model_run

    def on_llm_end(self, response: LLMResult, run_id: UUID, **kwargs: Any) -> Run:
        llm_run = self._complete_llm_run(
            response=response,
            run_id=run_id,
        )

        llm_run.run_type = "llm"

        for generation in response.generations:
            for generation_chunk in generation:
                if generation_chunk.generation_info is not None:
                    if "trace" in generation_chunk.generation_info:
                        span_builder = self._get_span_builder_for_run(llm_run)
                        span_builder.span["children"].append(generation_chunk.generation_info["trace"])

        self._end_trace(llm_run)
        self._on_llm_end(llm_run)
        return llm_run

    def _persist_run(self, run: Run) -> None:
        # Not too sure what this does, but required by the LangChain API
        run_ = run.copy()
        self.latest_run = run_

    def _get_tags(self, run: Run) -> List[str]:
        """Get combined tags for a run."""
        tags = set(run.tags or [])
        tags.update(self.tags or [])
        return list(tags)

    def _get_span_builder_for_run(self, run:Run) -> SpanBuilder:
        # Gets the DKU trace span builder corresponding to a given LangChain run,
        # creating it if needed. This method also handles attaching it to the hierarchy

        run_id_str = str(run.id)
        parent_run_id_str = str(run.parent_run_id)

        if not run_id_str in self.run_id_to_span_map:
            logger.debug("Creating span builder for run %s (name=%s)" % (run.id, run.name))
            span_builder = SpanBuilder(name=run.name)
            self.known_runs += 1
            self.run_id_to_span_map[run_id_str] = span_builder

        run_span_builder = self.run_id_to_span_map[run_id_str]
        assert run_span_builder is not None

        #print("Is already attached: %s" % (run_id_str in self.attached_run_ids))

        if run.parent_run_id is None and not run_id_str in self.attached_run_ids:
            logger.debug("No parent, adding to root trace: %s" % self.dku_trace)
            #print("No parent, adding to root trace span: %s" % self.dku_trace.span)
            #print("No parent, adding to root trace span: children %s" % self.dku_trace.span["children"])
            self.dku_trace.span["children"].append(run_span_builder.span)
            self.attached_run_ids.add(run_id_str)

        if not run_id_str in self.attached_run_ids:
            #print("Trying to attach to parent")

            if run.parent_run_id is not None:
                parent_run_id_str = str(run.parent_run_id)

                if parent_run_id_str in self.run_id_to_span_map:
                    
                    parent_run_span_builder = self.run_id_to_span_map[parent_run_id_str]

                    logger.debug("  Attaching to parent span builder: %s" % parent_run_span_builder.span["name"])

                    parent_run_span_builder.span["children"].append(run_span_builder.span)
                    #print("--> ")
                    self.attached_run_ids.add(run_id_str)
                else:
                    logger.info("  Parent not yet known, can't attach")
            else:
                logger.info("No parent, can't attach")

        return run_span_builder

    def _persist_run_single(self, run: Run) -> None:
        """Persist a run."""

        try:
            run_dict = _run_to_dict(run)
            run_dict["tags"] = self._get_tags(run)
            extra = run_dict.get("extra", {})
            extra["runtime"] = get_runtime_environment()
            run_dict["extra"] = extra

            logger.debug("Creating Langchain run run_id=%s parent=%s name=%s" % (run.id, run.parent_run_id, run.name))
            #print(" --> RUN OUTPUTS DICT: %s" % run_dict["outputs"])

            span_builder = self._get_span_builder_for_run(run)

            if run.start_time is not None and run.end_time is not None:
                #print("Run start time is %s " % run.start_time)
                #print("Associated timetamp is %s " % run.start_time.timestamp())
                span_builder.begin(run.start_time.timestamp() * 1000)
                span_builder.end(run.end_time.timestamp() * 1000)

            if run.inputs is not None:
                span_builder.span["inputs"] = run_dict["inputs"]
            if run.outputs is not None:
                span_builder.span["outputs"] = run_dict["outputs"]

            span_builder.span["attributes"]["ls_run_id"] = str(run.id)
        except Exception as e:
            # Errors are swallowed by the thread executor so we need to log them here
            logger.exception("Failed persisting LangChain run", e)
            raise

    def _update_run_single(self, run: Run) -> None:
        """Update a run."""
        try:
            run_dict = _run_to_dict(run)
            run_dict["tags"] = self._get_tags(run)

            logger.debug("Updating Langchain run run_id=%s parent=%s name=%s" % (run.id, run.parent_run_id, run.name))
            #print(" --> RUN OUTPUTS DICT: %s" % run_dict["outputs"])

            span_builder = self._get_span_builder_for_run(run)

            if run.start_time is not None and run.end_time is not None:
                #print("Run start time is %s " % run.start_time)
                #print("Associated timetamp is %s " % run.start_time.timestamp())
                span_builder.begin(run.start_time.timestamp() * 1000)
                span_builder.end(run.end_time.timestamp() * 1000)

            if run.inputs is not None:
                span_builder.span["inputs"] = run_dict["inputs"]
            if run.outputs is not None:
                span_builder.span["outputs"] = run_dict["outputs"]
            if run.error is not None:
                span_builder.span["outputs"] = {"error": run_dict["error"]}

        except Exception as e:
            # Errors are swallowed by the thread executor so we need to log them here
            logger.exception("Failed updating LangChain run", e)
            raise

    def _on_llm_start(self, run: Run) -> None:
        """Persist an LLM run."""
        self._persist_run_single(run)

    def _on_chat_model_start(self, run: Run) -> None:
        """Persist an LLM run."""
        self._persist_run_single(run)

    def _on_llm_end(self, run: Run) -> None:
        """Process the LLM Run."""
        self._update_run_single(run)

    def _on_llm_error(self, run: Run) -> None:
        """Process the LLM Run upon error."""
        self._update_run_single(run)

    def _on_chain_start(self, run: Run) -> None:
        """Process the Chain Run upon start."""
        self._persist_run_single(run)

    def _on_chain_end(self, run: Run) -> None:
        """Process the Chain Run."""
        self._update_run_single(run)

    def _on_chain_error(self, run: Run) -> None:
        """Process the Chain Run upon error."""
        self._update_run_single(run)

    def _on_tool_start(self, run: Run) -> None:
        """Process the Tool Run upon start."""
        self._persist_run_single(run)

    def _on_tool_end(self, run: Run) -> None:
        """Process the Tool Run."""
        self._update_run_single(run)

    def _on_tool_error(self, run: Run) -> None:
        """Process the Tool Run upon error."""
        self._update_run_single(run)

    def _on_retriever_start(self, run: Run) -> None:
        """Process the Retriever Run upon start."""
        self._persist_run_single(run)

    def _on_retriever_end(self, run: Run) -> None:
        """Process the Retriever Run."""
        self._update_run_single(run)

    def _on_retriever_error(self, run: Run) -> None:
        """Process the Retriever Run upon error."""
        self._update_run_single(run)

    #def wait_for_futures(self) -> None:
    #    pass