"""Wrapper around Dataiku-mediated LLM"""
from typing import Any, Optional

import dataiku
from dataikuapi.dss.langchain.llm import DKULLM as PublicDKULLM, DKUChatModel


class DKULLM(PublicDKULLM):
    """Wrapper around Dataiku-mediated LLMs"""

    def __init__(self, project_key:Optional[str] = None, **data: Any):
        if project_key is None:
            project_handle = dataiku.api_client().get_default_project()
        else:
            project_handle = dataiku.api_client().get_project(project_key)
        llm_handle = project_handle.get_llm(data["llm_id"])
        super().__init__(llm_handle=llm_handle, **data)


class DKUChatLLM(DKUChatModel):
    """Wrapper around Dataiku-mediated chat LLMs"""

    def __init__(self, project_key:Optional[str] = None, **data: Any):
        if project_key is None:
            project_handle = dataiku.api_client().get_default_project()
        else:
            project_handle = dataiku.api_client().get_project(project_key)
        llm_handle = project_handle.get_llm(data["llm_id"])
        super().__init__(llm_handle=llm_handle, **data)
