import copy
import logging
import time

from typing import Any

from langsmith.run_trees import RunTree

from dataiku.llm.tracing import SpanBuilder, SpanReader


def _dku_span_to_ls(sr: SpanReader, ls_run_tree: RunTree, d: int = 0) -> None:
    ls_run_tree.end(end_time=sr.end_ts)
    
    for (attr_k, attr_v) in sr.attributes.items():
        ls_run_tree.add_metadata({attr_k:attr_v})

    if ls_run_tree.inputs is None:
        ls_run_tree.inputs = {}
    ls_run_tree.inputs.update(sr.inputs)

    if ls_run_tree.outputs is not None:
        ls_run_tree.add_outputs(sr.outputs)
    
    if "usageMetadata" in sr.span_data:
        um_dict = {
            "input_tokens" : sr.span_data["usageMetadata"].get("promptTokens", None),
            "output_tokens" : sr.span_data["usageMetadata"].get("completionTokens", None),
            "total_tokens" : sr.span_data["usageMetadata"].get("totalTokens", None),
            "estimated_cost": sr.span_data["usageMetadata"].get("estimatedCost", None)
        }
        ls_run_tree.add_outputs({"usage_metadata": um_dict})
        ls_run_tree.run_type = "llm" # Need for LS to display the tokens

    if sr.span_data["name"] == "DKU_LLM_MESH_LLM_CALL":
        ls_run_tree.run_type = "llm"

    if "llmProvider" in sr.attributes:
        ls_run_tree.add_metadata({"ls_provider": sr.attributes["llmProvider"]})
    if "llmModel" in sr.attributes:
        ls_run_tree.add_metadata({"ls_model_name": sr.attributes["llmModel"]})

    for child_sr in sr.children:
        if child_sr.span_data["type"] == "event":
            continue
        child_run_tree = ls_run_tree.create_child(name=child_sr.name, start_time=child_sr.begin_ts, end_time=child_sr.end_ts)
        
        logging.debug("%sCreating child n=%s s=%s e=%s" % (" " * (d*2), child_run_tree.name, child_run_tree.start_time, child_run_tree.end_time))
        
        _dku_span_to_ls(child_sr, child_run_tree, d+1)
        
        child_run_tree.post()
        
def post_dku_trace_to_langsmith(trace: Any, name: str = "dss_trace") -> None:
    """Posts a DKU trace to the LangSmith service"""

    # The trace might not be closed/finished yet so we create a copy of it that we end
    if isinstance(trace, SpanBuilder) and trace.span.get("end") is None:
        trace = copy.deepcopy(trace)
        trace.end(int(time.time() * 1000))

    sr = SpanReader(trace)

    rt = RunTree(name=name, start_time=sr.begin_ts, end_time=sr.end_ts)
    _dku_span_to_ls(sr, rt)
    
    rt.end(end_time=sr.end_ts)
    
    rt.post()