from abc import ABCMeta
from abc import abstractmethod
import logging

import pandas as pd
from six import add_metaclass

from dataiku.core import doctor_constants
from dataiku.core.dku_logging import LogLevelContext
from dataiku.doctor.exploration.emu.generators import ActiveSphereCFGenerator
from dataiku.doctor.exploration.emu.generators import BaseGenerator
from dataiku.doctor.exploration.emu.generators import DiverseEvolutionaryOOGenerator
from dataiku.doctor.exploration.emu.generators import SpecialTarget
from dataiku.doctor.exploration.emu.feature_domains import FrozenFeatureDomain
from dataiku.doctor.exploration.emu.feature_domains import CategoricalFeatureDomain
from dataiku.doctor.exploration.emu.feature_domains import FeatureDomains
from dataiku.doctor.exploration.emu.feature_domains import FeatureType
from dataiku.doctor.exploration.emu.feature_domains import NumericalFeatureDomain
from dataiku.doctor.exploration.emu.metrics import PlausibilityScorer
from dataiku.doctor.exploration.predictor_adapter_for_emu import ClassifierAdapterForEmu
from dataiku.doctor.exploration.predictor_adapter_for_emu import PredictorAdapterForEmu
from dataiku.doctor.exploration.predictor_adapter_for_emu import RegressorAdapterForEmu

logger = logging.getLogger(__name__)


@add_metaclass(ABCMeta)
class EmuWrapper(object):
    """
    Models trained in the doctor have a `model_handler` that can be used to extract the
    info needed by EMU, but MLflow models don't have `model_handler`.

    This abstract class is a skeleton that wraps EMU, and it's model_handler-agnostic so
    that it can be used for both MLflow and other DSS models.

    The abstract methods are the parts of the algorithm for which the implementation
    depends on whether model_handler is available or not.
    """
    def __init__(self, computer):
        self.computer = computer
        self.model = None
        self.feature_domains = None
        self.generator = None
        self.scorer = None
        self.integer_features_map = self._get_integer_features_map()
        self.valid_X = None  # NOSONAR
        self.valid_df = None
        self.valid_y = None
        self._initial_computations_done = False

    def set_constraints(self, constraints, reference):
        self.feature_domains = self._init_feature_domains(reference, constraints)

    def update_generator(self, target, compute_plausibility=True):
        if not self._initial_computations_done:
            self._initial_computations(compute_plausibility)
        self.generator.target = self.computer.prepare_target(target)
        self.generator.fit(self._filter_columns(self.valid_df), self.valid_y, self.feature_domains)

    def _initial_computations(self, compute_plausibility):
        self._init_model()
        self.generator = self.computer.get_generator(self.model)
        dataset = self._init_dataset()
        self.valid_X, self.valid_df, self.valid_y = self. _get_valid_and_preprocessed_data(dataset)
        # If requested, disable computation of plausibility if any feature is non-numeric (For MLflow cases)
        if compute_plausibility == "ONLY_IF_FULLY_NUMERIC":
            compute_plausibility = self._is_valid_X_fully_numeric()
            logger.debug("only_if_fully_numeric, final compute_plausibility: %s" % compute_plausibility)

        if compute_plausibility:
            self.scorer = EmuWrapper._init_scorer(self.valid_X)
        self._initial_computations_done = True

    def _is_valid_X_fully_numeric(self):  # NOSONAR
        if not isinstance(self.valid_X, pd.DataFrame):
            # For non-mlflow models in DSS, valid_X will be either a sparse
            # matrix or a numpy array which was preprocessed by DSS. So, in
            # this case, it's safe to assume that valid_X is fully numeric.
            return True
        non_numerical_cols = [col for col in self.valid_X.columns if self.valid_X.dtypes[col] == "object"]
        return len(non_numerical_cols) == 0

    def _init_feature_domains(self, reference, constraints):
        feature_domains = FeatureDomains(len(constraints))
        for feature in constraints:
            feature_type = FeatureType(feature["type"])
            if feature_type is FeatureType.CATEGORICAL:
                categories = set(feature["categories"])
                feature_domains.append(CategoricalFeatureDomain(feature["name"], categories))
            elif feature_type == FeatureType.NUMERICAL:
                min_value, max_value = feature["minValue"], feature["maxValue"]
                is_integer = (self.integer_features_map is not None) and (self.integer_features_map[feature["name"]])
                feature_domains.append(NumericalFeatureDomain(feature["name"], min_value, max_value, is_integer))
            else:
                feature_domains.append(FrozenFeatureDomain(feature["name"], reference[feature["name"]].iloc[0]))
        return feature_domains

    @staticmethod
    def _init_scorer(valid_X):  # NOSONAR
        scorer = PlausibilityScorer()
        scorer.fit(valid_X)
        return scorer

    def _filter_columns(self, df):
        selected_columns = [feature_domain.feature_name for feature_domain in self.feature_domains]
        return df[[column for column in df.columns if column in selected_columns]]

    def compute(self, ref):
        ref = self._filter_columns(ref)
        with LogLevelContext(logging.CRITICAL, doctor_constants.PREPROCESSING_RELATED_LOGGER_NAMES):
            points_df = self.computer.compute(ref, self.generator)
        if not points_df.shape[0]:
            logger.warning("Did not find any counterfactual")
            return {"syntheticData": [], "syntheticMetadata": []}
        metadata_df = self._compute_metadata(points_df)
        return {"syntheticData": points_df.to_dict("records"), "syntheticMetadata": metadata_df.to_dict("records")}

    def _compute_metadata(self, df):
        transformed_df = self.model.preprocess(df)
        if self.scorer is not None:
            plausibilities = self.scorer.compute_plausibility(transformed_df)
        else:
            plausibilities = [None for _ in range(transformed_df.shape[0])]
        metadata_df = self._get_scoring_data_df(df)
        metadata_df["plausibility"] = plausibilities
        return metadata_df

    @abstractmethod
    def _get_integer_features_map(self):
        pass

    @abstractmethod
    def _init_model(self):
        pass

    @abstractmethod
    def _init_dataset(self):
        pass

    @abstractmethod
    def _get_valid_and_preprocessed_data(self, df):
        pass

    @abstractmethod
    def _get_scoring_data_df(self, df):
        pass

    @abstractmethod
    def _get_target_col(self):
        pass


class DoctorEmuWrapper(EmuWrapper):
    """Wrapper that works with model_handler - Used for all models except MLflow models"""
    def __init__(self, model_handler, computer):
        self.model_handler = model_handler
        super(DoctorEmuWrapper, self).__init__(computer)

    def _get_integer_features_map(self):
        return {
            e["name"]: bool(e["type"] in ["tinyint", "smallint", "int", "bigint"])
            for e in self.model_handler.get_schema()["columns"]
        }

    def _init_model(self):
        self.model = self.computer.get_model(self.model_handler)

    def _init_dataset(self):
        if self.model_handler.is_kfolding():
            dataset = self.model_handler.get_full_df()[0]
        else:
            dataset = self.model_handler.get_test_df()[0]
        # since the dataset is cached in the model_handler, and since the `process`
        # function modifies the input dataset, we need to copy the dataset
        return dataset.copy()

    def _get_valid_and_preprocessed_data(self, df):
        df_copy = df.copy()  # we will corrupt this one with the preprocessing
        preprocess = self.model_handler.get_predictor().preprocessing.preprocess
        with LogLevelContext(logging.CRITICAL, doctor_constants.PREPROCESSING_RELATED_LOGGER_NAMES):
            valid_X, valid_index, is_empty, valid_y = preprocess(df_copy, with_target=True)  # NOSONAR
        if is_empty:
            raise ValueError("All dataset dropped by preprocessing")
        valid_df = df.loc[valid_index, :]
        return valid_X, valid_df, valid_y

    def _get_scoring_data_df(self, df):
        return self.model_handler.get_predictor().predict(df)

    def _get_target_col(self):
        return self.model_handler.get_target_variable()


@add_metaclass(ABCMeta)
class EmuComputationStrategy(object):
    """
    EmuWrapper should not worry about whether we're searching for counterfactual examples
    or near-optima.
    This class hides the differences between OO and CF, so that EmuWrapper can be
    specialized without caring about that.

    Inheritance is avoided to avoid diamond inheritance issues with the MLflow variants.
    """
    @abstractmethod
    def get_model(self, model_handler):
        """
        :rtype: PredictorAdapterForEmu
        """
        pass

    @abstractmethod
    def get_generator(self, model):
        """
        :rtype: BaseGenerator
        """
        pass

    @abstractmethod
    def prepare_target(self, target):
        pass

    @staticmethod
    @abstractmethod
    def compute(ref, generator):
        pass


class CounterfactualsStrategy(EmuComputationStrategy):
    MAX_NB_COUNTERFACTUALS = 20

    def __init__(self, target_map):
        self.target_map = target_map

    def get_model(self, model_handler):
        return ClassifierAdapterForEmu(model_handler)

    def get_generator(self, model):
        return ActiveSphereCFGenerator(model)

    def prepare_target(self, target):
        # emu engine requires the actual classes from the model, i.e. classes mapped to indices through the target_map
        return self.target_map[target] if target is not None else None

    @staticmethod
    def compute(ref, generator):
        """
        Compute counterfactuals and their plausibilities
        :param pd.DataFrame ref: reference point (df with one single row)
        :param ActiveSphereCFGenerator generator: emu generator
        :return: (list of counterfactuals, list of plausibilities)
        :rtype: (list, list)
        """
        return generator.generate_counterfactuals(ref, CounterfactualsStrategy.MAX_NB_COUNTERFACTUALS)


class OutcomeOptimizationStrategy(EmuComputationStrategy):
    def get_model(self, model_handler):
        return RegressorAdapterForEmu(model_handler)

    def get_generator(self, model):
        return DiverseEvolutionaryOOGenerator(model)

    def prepare_target(self, target):
        return SpecialTarget(target) if target in [e.name for e in SpecialTarget] else target

    @staticmethod
    def compute(ref, generator):
        """
        Find near-optimal points and their plausibilities
        :param pd.DataFrame ref: reference point
        :param EvolutionaryOOGenerator generator: emu generator
        :return: (list of points, list of plausibilities)
        :rtype: (list, list)
        """
        return generator.optimize(ref)
