from dataclasses import dataclass

from dataiku.llm.types import BelowThresholdHandling, FaithfulnessSettings, RelevancySettings


@dataclass
class Guardrail(object):
    metric_key_name: str
    threshold: float
    error_message: str
    below_threshold_handling: BelowThresholdHandling
    answer_overwrite: str
    is_multimodal: bool

    def get_error_message(self, metric_value: float) -> str:
        return self.error_message + (f"Got '{metric_value}', expected at least '{self.threshold}'." if not self.is_multimodal else "")

class Faithfulness(Guardrail):
    def __init__(self, faithfulness_settings: FaithfulnessSettings):
        super().__init__(metric_key_name="faithfulness",
                         threshold=faithfulness_settings["threshold"],
                         error_message="The answer does not meet the required faithfulness threshold.",
                         below_threshold_handling=faithfulness_settings["handling"],
                         answer_overwrite=faithfulness_settings["answerOverwrite"],
                         is_multimodal=False)

class AnswerRelevancy(Guardrail):
    def __init__(self, relevancy_settings: RelevancySettings):
        super().__init__(metric_key_name="answerRelevancy",
                         threshold=relevancy_settings["threshold"],
                         error_message="The answer does not meet the required relevancy threshold.",
                         below_threshold_handling=relevancy_settings["handling"],
                         answer_overwrite=relevancy_settings["answerOverwrite"],
                         is_multimodal=False)

class MultimodalFaithfulness(Guardrail):
    def __init__(self, faithfulness_settings: FaithfulnessSettings):
        super().__init__(metric_key_name="multimodalFaithfulness",
                         threshold=faithfulness_settings["threshold"],
                         error_message="The answer is not faithful to the retrieved source(s).",
                         below_threshold_handling=faithfulness_settings["handling"],
                         answer_overwrite=faithfulness_settings["answerOverwrite"],
                         is_multimodal=True)

class MultimodalAnswerRelevancy(Guardrail):
    def __init__(self, relevancy_settings: RelevancySettings):
        super().__init__(metric_key_name="multimodalRelevancy",
                         threshold=relevancy_settings["threshold"],
                         error_message="The answer is not relevant to the question and retrieved source(s).",
                         below_threshold_handling=relevancy_settings["handling"],
                         answer_overwrite=relevancy_settings["answerOverwrite"],
                         is_multimodal=True)
