try:
    from langchain_core.callbacks.manager import CallbackManagerForLLMRun
except ModuleNotFoundError:
    from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import BaseMessage
from langchain_core.outputs import ChatGenerationChunk, ChatResult
from typing import List, Optional, Any, Iterator

from dataiku.langchain.dku_llm import DKUChatLLM

# Ragas have a [fallback](https://github.com/explodinggradients/ragas/blob/e97886ac976465efb60e5949c5d69baf30cc811d/src/ragas/llms/base.py#L63) when [temperature is None](https://github.com/explodinggradients/ragas/blob/e97886ac976465efb60e5949c5d69baf30cc811d/src/ragas/llms/base.py#L103)
# We want to avoid it, since some models (o1, o3-mini at least) don't support setting a temperature
# So this is just a wrapper around DKUChatLLM, which resets the temperature if it was None just before doing the completion
# Also stores eventual errors, in case we want to look at them despite the use of fail_on_row_level_errors
class RagasCompatibleLLM(DKUChatLLM):
    has_temperature: bool = True
    errors: List[str] = []

    def __init__(self, project_key:Optional[str] = None, **data: Any):
        super().__init__(project_key=project_key, **data)
        self.has_temperature = self.temperature is not None

    def _stream(
            self,
            messages: List[BaseMessage],
            stop: Optional[List[str]] = None,
            run_manager: Optional[CallbackManagerForLLMRun] = None,
            **kwargs: Any,
    ) -> Iterator[ChatGenerationChunk]:
        if not self.has_temperature:
            self.temperature = None
        try:
            return super()._stream(messages, stop, run_manager, **kwargs)
        except Exception as e:
            self.errors.extend(e.args)
            raise e

    def _generate(
            self,
            messages: List[BaseMessage],
            stop: Optional[List[str]] = None,
            run_manager: Optional[CallbackManagerForLLMRun] = None,
            **kwargs: Any,
    ) -> ChatResult:
        if not self.has_temperature:
            self.temperature = None
        try:
            return super()._generate(messages, stop, run_manager, **kwargs)
        except Exception as e:
            self.errors.extend(e.args)
            raise e