import pandas as pd
from typing import List, Optional, Union

from dataiku.llm.evaluation.llm_eval_recipe_desc import LLMEvalRecipeDesc
from dataikuapi.dss.utils import Enum


class LLMMetricInputRole(Enum):
    INPUT = "Input"
    OUTPUT = "Output"
    GROUND_TRUTH = "Ground truth"
    CONTEXT = "Context"


class LLMMetricInput(object):
    """
    Holds the columns from the input dataframe as Panda Series
    """
    input: pd.Series
    output: pd.Series
    ground_truth: pd.Series
    context: pd.Series # could also be called the source

    def __init__(self, input_series: pd.Series, output_series: pd.Series, ground_truth_series: pd.Series, context_series: pd.Series):
        self.input = input_series
        self.output = output_series
        self.ground_truth = ground_truth_series
        self.context = context_series

    @staticmethod
    def from_series(input_series: pd.Series, output_series: pd.Series, ground_truth_series: pd.Series, context_series: pd.Series):
        return LLMMetricInput(input_series, output_series, ground_truth_series, context_series)

    @staticmethod
    def from_df(input_df: pd.DataFrame, recipe_desc: LLMEvalRecipeDesc):
        return LLMMetricInput(
            input_df[recipe_desc.input_column_name],
            input_df.get(recipe_desc.output_column_name),
            input_df.get(recipe_desc.ground_truth_column_name),
            input_df.get(recipe_desc.context_column_name))

    @staticmethod
    def from_single_entry(input: str, output: str, ground_truth: Optional[str], context: Union[str,List[str]]):
        return LLMMetricInput(
            pd.Series([input]),
            pd.Series([output]),
            pd.Series([ground_truth]),
            pd.Series([context]))

    def get(self, input_role: LLMMetricInputRole):
        if input_role == LLMMetricInputRole.INPUT:
            return self.input
        elif input_role == LLMMetricInputRole.OUTPUT:
            return self.output
        elif input_role == LLMMetricInputRole.GROUND_TRUTH:
            return self.ground_truth
        elif input_role == LLMMetricInputRole.CONTEXT:
            return self.context
