import logging

import enum
import pandas as pd

from dataiku.core.doctor_constants import NUMERIC, TEXT, IMAGE, VECTOR
from dataiku.doctor.exception import DriftException
from dataiku.doctor.preprocessing.multimodal_preprocessings.sentence_embedding_extraction import LLMApiSentenceEmbeddingExtractor
from dataiku.modelevaluation.data_types import cast_as_numeric, cast_as_string

logger = logging.getLogger(__name__)


class ResolvedColumnHandling(enum.Enum):
    NUMERICAL = 1
    CATEGORICAL = 2
    TEXT = 3
    IGNORED = 4
    UNSUPPORTED = 5


class DriftPreparator(object):
    """
    Prepare reference & current dataframes by applying (the same) drift column handling parameters
    => Ensure the two dataframes have *exactly* the same schema after preparation
    """

    def __init__(self, original_ref_me, original_cur_me, data_drift_params, can_text_drift=False, should_text_drift=False, text_drift_params=None):
        self.original_ref_df = original_ref_me.sample_df
        self.original_cur_df = original_cur_me.sample_df
        self.ref_preprocessing = original_ref_me.preprocessing_params
        self.cur_preprocessing = original_cur_me.preprocessing_params
        self.data_drift_params = data_drift_params
        self.can_text_drift = can_text_drift  # Are we able to do text drift (ie, not possible in interactive mode)
        self.should_text_drift = should_text_drift  # Did the user asked for text drift
        self.text_drift_params = text_drift_params

    def _infer_column_handling(self, column):
        """
        Determine the type of a column for drift analysis from multiple sources:
        - Drift column params (if they are defined for this column)
        - MEs (or ME-like)'s preprocessings
        - Pandas type
        """

        ref_feature_handling = self.ref_preprocessing["per_feature"].get(column)
        cur_feature_handling = self.cur_preprocessing["per_feature"].get(column)

        if ref_feature_handling["type"] == NUMERIC and cur_feature_handling["type"] == NUMERIC:
            default_handling = ResolvedColumnHandling.NUMERICAL
        elif ref_feature_handling["type"] == IMAGE and cur_feature_handling["type"] == IMAGE:
            default_handling = ResolvedColumnHandling.UNSUPPORTED
        elif ref_feature_handling["type"] == VECTOR and cur_feature_handling["type"] == VECTOR:
            default_handling = ResolvedColumnHandling.UNSUPPORTED
        elif ref_feature_handling["type"] == TEXT and cur_feature_handling["type"] == TEXT:
            if not self.should_text_drift:
                logger.info("Column %s is detected as Text but the Text Drift is disabled : ignored" % column)
                default_handling = ResolvedColumnHandling.UNSUPPORTED
            else:
                default_handling = ResolvedColumnHandling.TEXT
        else:
            default_handling = ResolvedColumnHandling.CATEGORICAL

        drift_col_params = self.data_drift_params.columns.get(column)
        if drift_col_params:
            if not drift_col_params.get("enabled", False):
                actual_handling = ResolvedColumnHandling.IGNORED if default_handling != ResolvedColumnHandling.UNSUPPORTED else ResolvedColumnHandling.UNSUPPORTED
            elif "handling" not in drift_col_params or drift_col_params["handling"] == "AUTO":
                actual_handling = default_handling
            elif drift_col_params["handling"] == "NUMERICAL":
                actual_handling = ResolvedColumnHandling.NUMERICAL
            elif drift_col_params["handling"] == "TEXT":
                actual_handling = ResolvedColumnHandling.TEXT
            elif drift_col_params["handling"] == "IMAGE":
                actual_handling = ResolvedColumnHandling.UNSUPPORTED
            elif drift_col_params["handling"] == "VECTOR":
                actual_handling = ResolvedColumnHandling.UNSUPPORTED
            else:
                actual_handling = ResolvedColumnHandling.CATEGORICAL
        else:
            actual_handling = default_handling

        return actual_handling, default_handling

    def prepare(self):
        tabular_ref_series = {}
        tabular_cur_series = {}
        embedding_ref = {}
        embedding_cur = {}
        per_column_settings = []


        for column in self.list_available_columns():
            actual_handling, default_handling = self._infer_column_handling(column)
            logger.info(u"Treating {} as {} for drift analysis".format(column, actual_handling))

            settings = {
                "name": column,
                "actualHandling": actual_handling.name,
                "defaultHandling": default_handling.name
            }

            if actual_handling == ResolvedColumnHandling.NUMERICAL:
                try:
                    tabular_ref_series[column] = cast_as_numeric(self.original_ref_df[column])
                    tabular_cur_series[column] = cast_as_numeric(self.original_cur_df[column])
                except ValueError:
                    msg = u"Failed to cast {} as {} for drift analysis".format(column, actual_handling.name)
                    logger.info(msg)
                    settings["errorMessage"] = msg
                    tabular_ref_series.pop(column, None)
                    tabular_cur_series.pop(column, None)

            elif actual_handling == ResolvedColumnHandling.CATEGORICAL:
                # TODO: py2/p3 ok?
                tabular_ref_series[column] = cast_as_string(self.original_ref_df[column])
                tabular_cur_series[column] = cast_as_string(self.original_cur_df[column])

            elif actual_handling == ResolvedColumnHandling.TEXT and self.should_text_drift:
                if not self.can_text_drift:
                    raise DriftException("The column %s is to be handled as Text for drift computation but the Text Drift section is disabled. Please review your recipe configuration" % column)
                embedding_extractor = LLMApiSentenceEmbeddingExtractor(column, self.text_drift_params.get("embeddingModelId"))
                embedding_cur[column] = embedding_extractor.extract_embeddings(self.original_cur_df[column].values)
                embedding_ref[column] = embedding_extractor.extract_embeddings(self.original_ref_df[column].values)

            per_column_settings.append(settings)

        if all([(column["actualHandling"] == ResolvedColumnHandling.IGNORED.name or column["actualHandling"] == ResolvedColumnHandling.UNSUPPORTED.name) for column in per_column_settings]):
            if len(self.data_drift_params.columns) == 0:
                # The default case, without manual selection by the user : we just don't compute
                logger.warning("All the input features of the model are either ignored or unsupported for this input data drift computation. Skipping input data drift.")
                return None, None, None, None, per_column_settings, False
            else:
                raise DriftException("All the input features of the model are either ignored or unsupported for this input data drift computation. Skipping input data drift.")

        return pd.DataFrame(tabular_ref_series), pd.DataFrame(tabular_cur_series), embedding_cur, embedding_ref, per_column_settings, True

    def list_available_columns(self):
        columns = self.original_ref_df.columns.intersection(self.original_cur_df.columns)
        for processing in [self.ref_preprocessing["per_feature"], self.cur_preprocessing["per_feature"]]:
            columns = columns.intersection(set(feature for feature, feature_processing in processing.items() if feature_processing.get("role") == "INPUT"))
        return columns
