"""Wrapper around Dataiku-mediated embedding LLMs"""
from typing import Any, Optional

import dataiku
from dataiku.llm.tracing import SpanBuilder
from dataikuapi.dss.langchain.embeddings import DKUEmbeddings as PublicDKUEmbeddings


class DKUEmbeddings(PublicDKUEmbeddings):
    """Wrapper around Dataiku-mediated embedding LLMs"""

    def __init__(self, project_key: Optional[str] = None, **data: Any):
        if project_key is None:
            project_handle = dataiku.api_client().get_default_project()
        else:
            project_handle = dataiku.api_client().get_project(project_key)
        llm_handle = project_handle.get_llm(data["llm_id"])
        super().__init__(llm_handle=llm_handle, **data)

class TraceableDKUEmbeddings(DKUEmbeddings):
    """Traceable wrapper around the DKUEmbeddings. Create a new instance for each new span."""
    def __init__(self, span: SpanBuilder, project_key: Optional[str] = None, **data: Any):
        super().__init__(project_key=project_key, **data)
        self._callbacks.append(lambda trace: span.append_trace(trace))
