from abc import ABCMeta
from abc import abstractmethod

import numpy as np
import pandas as pd
from six import add_metaclass

from dataiku.core import doctor_constants
from dataiku.doctor.prediction.decisions_and_cuts import DecisionsAndCuts
from dataiku.doctor.prediction.overrides.ml_overrides_applier import BinaryProbabilisticClassificationMlOverridesApplier
from dataiku.doctor.prediction.overrides.ml_overrides_applier import MlOverridesApplier
from dataiku.doctor.prediction.overrides.ml_overrides_applier import NonProbabilisticClassificationMlOverridesApplier
from dataiku.doctor.prediction.overrides.ml_overrides_applier import ProbabilisticClassificationMlOverridesApplier
from dataiku.doctor.prediction.overrides.ml_overrides_applier import RegressionMlOverridesApplier
from dataiku.doctor.prediction.overrides.ml_overrides_params import MlOverridesParams
from dataikuscoring.utils.prediction_result import AbstractPredictionResult
from dataikuscoring.utils.prediction_result import ClassificationPredictionResult
from dataikuscoring.utils.prediction_result import PredictionResult


def is_proba_aware(clf, algorithm):
    if algorithm not in {'SCIKIT_MODEL', 'CUSTOM_PLUGIN', 'EVALUATED'}:
        return True
    return hasattr(clf, "predict_proba") and callable(clf.predict_proba)


@add_metaclass(ABCMeta)
class SerializableMixin(object):
    def __init__(self, clf, algorithm):
        """
        :type clf: sklearn.base.BaseEstimator
        :type algorithm: str
        """
        self.clf = clf
        self.algorithm = algorithm


@add_metaclass(ABCMeta)
class ScorableModel(SerializableMixin):

    def __init__(self, clf, algorithm, overrides_params=None):
        """
        :type clf: sklearn.base.BaseEstimator
        :type algorithm: str
        :type overrides_params: MlOverridesParams or None
        """
        super(ScorableModel, self).__init__(clf, algorithm)
        self.overrides_params = overrides_params

    def has_overrides(self):
        """
        :rtype: bool
        """
        return self.overrides_params is not None and len(self.overrides_params.overrides) > 0

    @abstractmethod
    def compute_predictions(self, X, unprocessed_df):
        """
        :param np.ndarray X:
        :param pd.DataFrame unprocessed_df: for model overrides
        :rtype: AbstractPredictionResult
        """
        pass

    @abstractmethod
    def _get_overrides_applier(self):
        """
        :rtype: MlOverridesApplier
        """
        pass

    @staticmethod
    def build(clf, model_type, prediction_type, algorithm, preprocessing_params=None, overrides_params=None,
              prediction_interval_model=None):
        """
        :param sklearn.base.BaseEstimator clf: classifier with predict method, and predict_proba for classification tasks
        :param str model_type: model type, such as PREDICTION or CLUSTERING
        :param str prediction_type:
        :param str algorithm:
        :param dict or None preprocessing_params:
        :param MlOverridesParams or None overrides_params:
        :param dataiku.doctor.prediction.prediction_interval_model.PredictionIntervalsModel prediction_interval_model:
        classifier with predict method, for regression task only
        :return: instance of a scorable model either for regression or classification tasks
        """
        if model_type == "CLUSTERING":
            return ScorableModelClustering(clf, algorithm)
        if algorithm == "KERAS_CODE":
            if prediction_type == doctor_constants.REGRESSION:
                return KerasScorableModelRegression(clf, algorithm)
            elif prediction_type == doctor_constants.MULTICLASS:
                return KerasScorableModelMulticlass(clf, algorithm, preprocessing_params)
            elif prediction_type == doctor_constants.BINARY_CLASSIFICATION:
                return KerasScorableModelBinary(clf, algorithm, preprocessing_params)

        if prediction_type == doctor_constants.REGRESSION:
            if prediction_interval_model is not None:
                return ScorableModelRegressionWithPredictionIntervals(clf, algorithm, prediction_interval_model, overrides_params)
            return ScorableModelRegression(clf, algorithm, overrides_params)
        if prediction_type == doctor_constants.MULTICLASS:
            return ScorableModelMulticlass(clf, algorithm, preprocessing_params, overrides_params)
        if prediction_type == doctor_constants.BINARY_CLASSIFICATION:
            if is_proba_aware(clf, algorithm):
                return ProbabilisticScorableModelBinary(clf, algorithm, preprocessing_params, overrides_params)
            return NonProbabilisticScorableModelBinary(clf, algorithm, preprocessing_params, overrides_params)
        raise ValueError("Prediction type not supported: " + prediction_type)

    def _apply_post_predict(self, prediction_result, unprocessed_df):
        """
        :type prediction_result: AbstractPredictionResult
        :param pd.DataFrame unprocessed_df: for model overrides
        :rtype: AbstractPredictionResult
        """
        if not self.has_overrides():
            return prediction_result
        return self._get_overrides_applier().apply(unprocessed_df, prediction_result)

    def is_proba_aware(self):
        """
        This method is also used during explanations computation for regression models
        :rtype: bool
        """
        return is_proba_aware(self.clf, self.algorithm)

    def requires_unprocessed_df_for_prediction(self):
        """
        Returns whether unprocessed_df is actually used for prediction or not. For now, only overridden models require
        it.
        :rtype: bool
        """
        return self.has_overrides()

    @staticmethod
    @abstractmethod
    def is_classification():
        pass


class ScorableModelRegression(ScorableModel):

    def compute_predictions(self, X, unprocessed_df):
        """
        :param np.ndarray X:
        :param pd.DataFrame unprocessed_df: for model overrides
        :rtype: PredictionResult
        """
        if X.shape[0] <= 0:
            raw_preds = self._get_empty_preds()
        else:
            raw_preds = self.clf.predict(X)
        raw_prediction_result = PredictionResult(raw_preds)
        return self._apply_post_predict(raw_prediction_result, unprocessed_df)

    def _get_overrides_applier(self):
        return RegressionMlOverridesApplier(self.overrides_params)

    @staticmethod
    def _get_empty_preds():
        return np.empty((0,)).astype(float)

    @staticmethod
    def is_classification():
        return False


class ScorableModelRegressionWithPredictionIntervals(ScorableModelRegression):
    def __init__(self, clf, algorithm, prediction_intervals_model, overrides_params=None):
        """
        :type clf: sklearn.base.BaseEstimator
        :type algorithm: str
        :type prediction_intervals_model: PredictionIntervalsModel
        :type overrides_params: dataiku.doctor.prediction.overrides.ml_overrides_params.MlOverridesParams
        """
        super(ScorableModelRegressionWithPredictionIntervals, self).__init__(clf, algorithm, overrides_params)
        self.prediction_intervals_model = prediction_intervals_model

    def compute_predictions(self, X, unprocessed_df):
        """
        :type X: np.ndarray
        :param pd.DataFrame unprocessed_df: for model overrides
        :rtype: PredictionResult
        """
        if X.shape[0] <= 0:
            raw_preds = self._get_empty_preds()
        else:
            raw_preds = self.clf.predict(X)
        intervals_sizes = self.prediction_intervals_model.predict(X)
        prediction_intervals = np.column_stack((raw_preds - intervals_sizes, raw_preds + intervals_sizes))
        prediction_result_with_intervals = PredictionResult(raw_preds, prediction_intervals)
        return self._apply_post_predict(prediction_result_with_intervals, unprocessed_df)


class ScorableModelClustering(ScorableModel):

    def _get_overrides_applier(self):
        raise NotImplementedError("Overrides are not supported for Clustering model")

    def compute_predictions(self, X, unprocessed_df):
        """
        :param np.ndarray X:
        :param pd.DataFrame unprocessed_df: for model overrides
        :rtype: PredictionResult
        """
        if X.shape[0] == 0:
            raw_preds = np.empty((0,)).astype(int)
        else:
            raw_preds = self.clf.predict(X)
        raw_prediction_result = PredictionResult(raw_preds)
        return self._apply_post_predict(raw_prediction_result, unprocessed_df)

    @staticmethod
    def is_classification():
        return False


@add_metaclass(ABCMeta)
class ScorableModelClassification(ScorableModel):

    def __init__(self, clf, algorithm, preprocessing_params, override_params=None):
        super(ScorableModelClassification, self).__init__(clf, algorithm, override_params)
        if preprocessing_params is None:
            raise ValueError("Scorable models need `preprocessing_params` for classification tasks")
        self.target_map = {tv["sourceValue"]: tv["mappedValue"] for tv in preprocessing_params["target_remapping"]}
        self._inv_map = {v: k for k, v in self.target_map.items()}
        self._classes = [label for (_, label) in sorted(self._inv_map.items(), key=lambda t: t[0])]

    @staticmethod
    def is_classification():
        return True

    def get_classes(self):
        return self._classes

    def compute_predictions(self, X, unprocessed_df):
        """
        :param np.ndarray X:
        :param pd.DataFrame unprocessed_df: for model overrides
        :rtype: ClassificationPredictionResult
        """
        raw_prediction_result = self._compute_raw_prediction_result(X)
        return self._apply_post_predict(raw_prediction_result, unprocessed_df)

    def _get_overrides_applier(self):
        if self.is_proba_aware():
            return ProbabilisticClassificationMlOverridesApplier(self.overrides_params, self.target_map, self._classes)
        else:
            return NonProbabilisticClassificationMlOverridesApplier(self.overrides_params)

    def _compute_raw_prediction_result(self, X):
        """
        :type X: np.ndarray
        :rtype: ClassificationPredictionResult
        """
        unmapped_preds = self._compute_raw_predictions(X)
        probas = self._compute_raw_probas(X)
        return ClassificationPredictionResult(self.target_map, probas=probas, unmapped_preds=unmapped_preds)

    def _get_empty_probas(self):
        return np.empty((0, len(self._classes)))

    @staticmethod
    def _get_empty_predictions():
        return np.empty((0,)).astype(int)

    def _compute_raw_predictions(self, X):
        if X.shape[0] <= 0:
            return self._get_empty_predictions()
        # Since targets have been remapped to integers before, the `clf.predict` method should
        # return integers. But since custom models may store these integers in an array with
        # dtype=float, we add the `astype(int)`.
        return self.clf.predict(X).astype(int)

    def _compute_raw_probas(self, X):
        """
        :type X: np.ndarray
        :return: a numpy array of shape (X.shape[0], len(self._classes))
        :rtype: np.ndarray
        """
        if not self.is_proba_aware():
            return None

        if X.shape[0] <= 0:
            return self._get_empty_probas()

        raw_probas = self.clf.predict_proba(X)

        (nb_rows, nb_present_classes) = raw_probas.shape
        probas = np.zeros((nb_rows, len(self.target_map)))
        for j in range(nb_present_classes):
            actual_class_id = self.clf.classes_[j]
            probas[:, actual_class_id] = raw_probas[:, j]
        return probas


class ScorableModelMulticlass(ScorableModelClassification):
    pass


@add_metaclass(ABCMeta)
class ScorableModelBinary(ScorableModelClassification):

    def __init__(self, clf, algorithm, preprocessing_params, overrides_params=None, threshold=0.5):
        super(ScorableModelBinary, self).__init__(clf, algorithm, preprocessing_params, overrides_params)
        self.threshold = threshold

    def set_threshold(self, threshold):
        self.threshold = threshold

    def get_threshold(self):
        return self.threshold

    @abstractmethod
    def compute_decisions_and_cuts(self, X, unprocessed_df):
        """
        For each cut, computes the (unmapped) predicted binary classes. No need to check on emptiness of data, as
        decisions and cuts are only ever used to compute performance of the model on some data, so you necessarily have
        data.

        :type X: np.ndarray
        :param pd.DataFrame unprocessed_df: for model overrides
        :rtype: DecisionsAndCuts
        """
        pass


class NonProbabilisticScorableModelBinary(ScorableModelBinary):

    def compute_decisions_and_cuts(self, X, unprocessed_df):
        cuts = DecisionsAndCuts.get_default_cuts()
        # in non-probabilistic models, preds is the same for any cut
        prediction_result = self.compute_predictions(X, unprocessed_df)
        # not doing a copy of `prediction_result` because we like to live dangerously
        preds_list = [prediction_result for _ in cuts]
        return DecisionsAndCuts(cuts, preds_list)


class ProbabilisticScorableModelBinary(ScorableModelBinary):

    def _get_overrides_applier(self):
        """
        :rtype: BinaryProbabilisticClassificationMlOverridesApplier
        """
        return BinaryProbabilisticClassificationMlOverridesApplier(self.overrides_params, self.target_map,
                                                                   self._classes)

    def _compute_raw_prediction_result_from_raw_probas(self, raw_probas):
        """
        :type raw_probas: np.ndarray
        :rtype: ClassificationPredictionResult
        """
        if raw_probas.shape[0] == 0:
            raw_unmapped_preds = self._get_empty_predictions()
        else:
            raw_unmapped_preds = (raw_probas[:, 1] > self.threshold).astype(int)
        return ClassificationPredictionResult(self.target_map, probas=raw_probas, unmapped_preds=raw_unmapped_preds)

    def compute_predictions(self, X, unprocessed_df):
        """
        :param np.ndarray X:
        :param pd.DataFrame unprocessed_df: for model overrides
        :rtype: ClassificationPredictionResult
        """
        raw_probas = self._compute_raw_probas(X)
        raw_prediction_result = self._compute_raw_prediction_result_from_raw_probas(raw_probas)
        return self._apply_post_predict(raw_prediction_result, unprocessed_df)

    def compute_decisions_and_cuts(self, X, unprocessed_df):
        raw_probas = self._compute_raw_probas(X)
        original_threshold = self.threshold
        raw_prediction_results = []
        cuts = DecisionsAndCuts.get_default_cuts()
        for cut in cuts:
            self.threshold = cut
            raw_prediction_results.append(self._compute_raw_prediction_result_from_raw_probas(raw_probas))

        self.threshold = original_threshold
        raw_decisions_and_cuts = DecisionsAndCuts(cuts, raw_prediction_results)

        if not self.has_overrides():
            return raw_decisions_and_cuts
        else:
            return self._get_overrides_applier().apply_on_decisions_and_cuts(unprocessed_df, raw_decisions_and_cuts)


def is_empty_keras_model_input(X):
    """
    :type X: dict[str, np.ndarray]
    :return: boole
    """
    return list(X.values())[0].shape == 0


class KerasScorableModelRegression(ScorableModelRegression):

    def compute_predictions(self, X, unprocessed_df):
        """
        :param X: for Keras model, `X` represents the data sent to the model, i.e. a dict of named inputs. This does
                   not respect the API of the ScorableModel, and we should probably change that at some point
        :type X: dict[str, np.ndarray]
        :param pd.DataFrame unprocessed_df: for model overrides
        :rtype: PredictionResult
        """
        if is_empty_keras_model_input(X):
            raw_preds = self._get_empty_preds()
        else:
            raw_preds = np.squeeze(self.clf.predict(X), axis=1)
        raw_prediction_result = PredictionResult(raw_preds)
        return self._apply_post_predict(raw_prediction_result, unprocessed_df)


class KerasScorableModelMulticlass(ScorableModelMulticlass):

    def _compute_raw_prediction_result(self, X):
        """
        :type X: dict[str, np.ndarray]
        :param X: for Keras model, `X` represents the data sent to the model, i.e. a dict of named inputs. This does
                   not respect the API of the ScorableModel, and we should probably change that at some point
        :rtype: ClassificationPredictionResult
        """
        if is_empty_keras_model_input(X):
            probas = self._get_empty_probas()
            unmapped_preds = self._get_empty_predictions()
        else:
            probas = self.clf.predict(X)
            unmapped_preds = np.argmax(probas, axis=1)
        return ClassificationPredictionResult(self.target_map, probas=probas, unmapped_preds=unmapped_preds)


class KerasScorableModelBinary(ProbabilisticScorableModelBinary):
    def _compute_raw_probas(self, X):
        """
        :type X: dict[str, np.ndarray]
        :param X: for Keras model, `X` represents the data sent to the model, i.e. a dict of named inputs. This does
                   not respect the API of the ScorableModel, and we should probably change that at some point
        """
        if is_empty_keras_model_input(X):
            return self._get_empty_probas()
        raw_probas = self.clf.predict(X)
        if raw_probas.shape[1] == 1:  # one-dimensional output case: we need to create a second column
            probas_one = np.squeeze(raw_probas, axis=1)
            probas = np.zeros((probas_one.shape[0], 2))
            probas[:, 0] = 1 - probas_one
            probas[:, 1] = probas_one
            return probas
        return raw_probas
