from typing import List, Optional

from dataiku.llm.evaluation.custom_llm_evaluation_metric import CustomLLMEvaluationMetric

NO_COLUMN_SELECTED = "None - no column selected"
NO_LLM_MODEL_SELECTED_ID = "None"

class LLMEvalRecipeDesc(object):
    """
    This class is to be kept as much as possible in sync with LLMEvaluationRecipePayloadParams.java
    """
    input_format: str
    llm_task_type: str
    input_column_name: str
    output_column_name: str
    ground_truth_column_name: str
    context_column_name: str
    sampling: dict
    embedding_llm_id: str
    embedding_settings: dict
    completion_llm_id: str
    completion_settings: dict
    env_selection: dict
    metrics: List[str]
    custom_metrics: List[CustomLLMEvaluationMetric]
    fail_on_errors: bool
    labels: List[dict]
    evaluation_id: str
    evaluation_name: str
    bert_score_model_type: Optional[str]

    _recipe_desc: dict
    def get_raw(self) -> dict:
        return self._recipe_desc

    def __init__(self, recipe_desc: dict):
        self.input_format = recipe_desc.get("inputFormat")
        self.llm_task_type = recipe_desc.get("llmTaskType")
        self.input_column_name = recipe_desc.get("inputColumnName") if recipe_desc.get("inputColumnName") != NO_COLUMN_SELECTED else None
        self.output_column_name = recipe_desc.get("outputColumnName") if recipe_desc.get("outputColumnName") != NO_COLUMN_SELECTED else None
        self.ground_truth_column_name = recipe_desc.get("groundTruthColumnName") if recipe_desc.get("groundTruthColumnName") != NO_COLUMN_SELECTED else None
        self.context_column_name = recipe_desc.get("contextColumnName") if recipe_desc.get("contextColumnName") != NO_COLUMN_SELECTED else None
        self.sampling = recipe_desc.get("selection", {"samplingMethod": "FULL"})
        self.embedding_llm_id = recipe_desc.get("embeddingLLMId") if recipe_desc.get("embeddingLLMId") != NO_LLM_MODEL_SELECTED_ID else None
        self.embedding_settings = recipe_desc.get("embeddingSettings")
        self.completion_llm_id = recipe_desc.get("completionLLMId") if recipe_desc.get("completionLLMId") != NO_LLM_MODEL_SELECTED_ID else None
        self.completion_settings = recipe_desc.get("completionSettings")
        self.env_selection = recipe_desc.get("envSelection")
        self.metrics = recipe_desc.get("metrics")
        self.custom_metrics = [CustomLLMEvaluationMetric(cm) for cm in recipe_desc.get("customMetrics", [])]
        self.fail_on_errors = recipe_desc.get("failOnErrors", True)
        self.labels = recipe_desc.get("labels")
        self.evaluation_id = recipe_desc.get("evaluationId")
        self.evaluation_name = recipe_desc.get("evaluationName")
        self.bert_score_model_type = recipe_desc["bertScoreModelType"] \
            if recipe_desc.get("bertScoreModelType") is not None and len(recipe_desc["bertScoreModelType"].strip()) > 0 else "bert-base-uncased"  # Kept in sync with its serialization in the _evaluation.json in the java runner
        self.bleu_tokenizer = recipe_desc["bleuTokenizer"] \
            if recipe_desc.get("bleuTokenizer") is not None and len(recipe_desc["bleuTokenizer"].strip()) > 0 else "13a"  # Kept in sync with its serialization in the _evaluation.json in the java runner
        self._recipe_desc = recipe_desc
