from __future__ import unicode_literals

import logging
import os
from abc import ABCMeta
from abc import abstractmethod

import numpy as np
import pandas as pd
from six import add_metaclass

from dataiku import default_project_key
from dataiku import jek_or_backend_json_call
from dataiku.base.utils import safe_unicode_str
from dataiku.core import dkujson
from dataiku.core.dataframe_preparation import prepare_dataframe
from dataiku.doctor.prediction.decisions_and_cuts import DecisionsAndCuts
from dataiku.doctor.prediction.overrides.ml_overrides_results import DEFAULT_OVERRIDE_VALUE
from dataiku.doctor.prediction.overrides.ml_overrides_results import OverriddenClassificationPredictionResult
from dataiku.doctor.prediction.overrides.ml_overrides_results import OverriddenPredictionResults
from dataikuscoring.utils.prediction_result import AbstractPredictionResult
from dataikuscoring.utils.prediction_result import ClassificationPredictionResult
from dataikuscoring.utils.prediction_result import PREDICTION
from dataikuscoring.utils.prediction_result import PREDICTION_INTERVAL_LOWER
from dataikuscoring.utils.prediction_result import PREDICTION_INTERVAL_UPPER
from dataikuscoring.utils.prediction_result import PredictionResult

logger = logging.getLogger(__name__)
APPLICABLE_OVERRIDE_COL = "__DKU__APPLICABLE__OVERRIDE__"
UNCERTAINTY_COL = "prediction_uncertainty"
PREDICTION_INTERVAL_SIZE_COL = "prediction_interval_size"
PREDICTION_INTERVAL_RELATIVE_SIZE_COL = "prediction_interval_relative_size"


@add_metaclass(ABCMeta)
class MlOverridesApplier(object):

    def __init__(self, overrides_params):
        """
        :type overrides_params: MlOverridesParams
        """
        self._overrides_params = overrides_params
        self._overrides_map = {override.name: override for override in self._overrides_params.overrides}
        self._overrides_names = [o.name for o in self._overrides_params.overrides]

    def apply(self, input_df, prediction_result):
        """
        :type input_df: pd.DataFrame
        :type prediction_result: AbstractPredictionResult
        :rtype: dataiku.doctor.prediction.overrides.ml_overrides_results.OverriddenPredictionResults
        """
        if prediction_result.is_empty():
            return self._get_empty_overridden_result(prediction_result)
        prepared_df = self._prepare_df_for_overrides(input_df, prediction_result)
        override_flags_series = self._compute_overrides_flags(prepared_df)
        overridden_prediction_result = self._apply_overrides(prediction_result, override_flags_series)
        return overridden_prediction_result

    def _apply_single_outcome(self, preds_for_single_override, override_name):
        """
        Apply the override called `override_name` on the `preds_for_single_override`
        :type preds_for_single_override: np.ndarray
        :type override_name: str
        :rtype: np.ndarray or str
        """
        if override_name == DEFAULT_OVERRIDE_VALUE:
            return preds_for_single_override
        return self._overrides_map[override_name].outcome.apply(preds_for_single_override)

    def _override_preds(self, prediction_result, override_flags_series):
        """
        :type prediction_result: AbstractPredictionResult
        :type override_flags_series: pd.Series
        :return Tuple with the overriden predictions and mask with true on na values of the predictions
        :rtype: (np.ndarray, np.ndarray)
        """
        preds_series = pd.Series(prediction_result.preds, index=override_flags_series.index)
        overridden_preds = (preds_series.groupby(override_flags_series).
                            transform(lambda series: self._apply_single_outcome(series.values, series.name)))
        return overridden_preds.values, pd.isna(overridden_preds.values)

    def _compute_overrides_flags(self, prepared_df):
        """
        :type prepared_df: pd.DataFrame
        :rtype: pd.Series
        """
        # If we have two overrides A and B, we want A to be considered over B if they both apply to the same row.
        # Applying multiple filters and outputting on the same column will give priority to the LAST filter that
        # applied per row. Therefore, we want to reverse the overrides we pass so A will have priority over B
        request = {
            "outputColumn": APPLICABLE_OVERRIDE_COL,
            "defaultOverrideValue": DEFAULT_OVERRIDE_VALUE,
            "filters": [{
                "name": c.name,
                "filter": c.filter
            } for c in reversed(self._overrides_params.overrides)]
        }

        script_data = jek_or_backend_json_call("ml/generate-steps-from-filters/", data={
            "request": dkujson.dumps(request)
        })
        steps = script_data["steps"]
        output_schema = script_data["outputSchema"]

        context_project_key = _get_default_project_key_or_fail_apinode()
        overrides_flags = prepare_dataframe(prepared_df, steps, output_schema,
                                            context_project_key, infer_with_pandas=False)
        # Applying the same index to the newly prepared dataframe
        overrides_flags.set_index(prepared_df.index, inplace=True)
        return overrides_flags[APPLICABLE_OVERRIDE_COL]

    @staticmethod
    def _get_empty_overrides_flags():
        return pd.Series(dtype=str)

    @staticmethod
    def _concat_input_with_extra_df(input_df, extra_input_df):
        """
        :param input_df:
        :type extra_input_df: pd.DataFrame
        :return:
        """
        reserved_columns_in_input_df = set(extra_input_df.columns) & set(input_df.columns)
        if len(reserved_columns_in_input_df) > 0:
            logger.warning(u"The columns {} are found in the input dataset, but will be overwritten by the ones "
                           u"computed by the model to evaluate the "
                           u"overrides rules".format(reserved_columns_in_input_df))
            # This will make a copy (and preserve the original input_df)
            input_df = input_df.drop(reserved_columns_in_input_df, axis=1)
        return pd.concat([input_df, extra_input_df], axis=1)

    def _prepare_df_for_overrides(self, input_df, prediction_result):
        """
        :type input_df: pd.DataFrame
        :type prediction_result: AbstractPredictionResult
        :rtype: pd.DataFrame
        """
        preds_df = prediction_result.as_dataframe()
        preds_df.index = input_df.index
        return self._concat_input_with_extra_df(input_df, preds_df)

    @abstractmethod
    def _apply_overrides(self, prediction_result, override_flags_series):
        """
        :type prediction_result: AbstractPredictionResult
        :type override_flags_series: pd.Series
        :rtype: dataiku.doctor.prediction.overrides.ml_overrides_results.AbstractOverridesResults
        """

    @abstractmethod
    def _get_empty_overridden_result(self, prediction_result):
        """
        :type prediction_result: AbstractPredictionResult
        :rtype: dataiku.doctor.prediction.overrides.ml_overrides_results.AbstractOverridesResults
        """


class ProbabilisticClassificationMlOverridesApplier(MlOverridesApplier):

    def __init__(self, overrides_params, target_map, classes):
        """
        :type overrides_params: MlOverridesParams
        :type target_map: dict[str, int]
        :type classes: list[str]
        """
        super(ProbabilisticClassificationMlOverridesApplier, self).__init__(overrides_params)
        self._target_map = target_map
        self._classes = classes

    @staticmethod
    def _build_pred_and_proba_df(prediction_result, index):
        """
        :type prediction_result: ClassificationPredictionResult
        :param index: pandas index
        :return:
        """
        df = prediction_result.as_dataframe()
        # df comes with columns as ("probabilities", className) for probabilities,
        # need to remap them to "proba_className"
        # we also need the uncertainty column for future overrides computations
        proba_columns = [column for column in df.columns if isinstance(column, tuple) and column[0] == "probabilities"]
        df[UNCERTAINTY_COL] = 1 - df[proba_columns].max(axis=1)
        df.rename(columns={column: u"proba_{}".format(safe_unicode_str(column[1])) for column in proba_columns},
                  inplace=True)
        df.index = index
        return df

    def _prepare_df_for_overrides(self, input_df, prediction_result):
        """
        Note: adding prediction is useless for binary classification, as it cannot rely on prediction, but it is
        harmless.

        :type input_df: pd.DataFrame
        :type prediction_result: ClassificationPredictionResult
        :rtype: pd.DataFrame
        """
        pred_and_proba_df = self._build_pred_and_proba_df(prediction_result, input_df.index)
        return self._concat_input_with_extra_df(input_df, pred_and_proba_df)

    def _get_empty_overridden_result(self, prediction_result):
        """
        :type prediction_result: ClassificationPredictionResult
        :rtype: dataiku.doctor.prediction.overrides.ml_overrides_results.OverriddenClassificationPredictionResult
        """
        return OverriddenClassificationPredictionResult(prediction_result, self._get_empty_overrides_flags(),
                                                        self._overrides_names, prediction_result.preds,
                                                        probas=prediction_result.probas)

    def _apply_overrides(self, prediction_result, override_flags_series):
        """
        :type prediction_result: ClassificationPredictionResult
        :type override_flags_series: pd.Series
        :rtype: OverriddenClassificationPredictionResult
        """
        overridden_preds, declined_mask = self._override_preds(prediction_result, override_flags_series)
        overridden_probas = prediction_result.probas.copy()

        to_override = override_flags_series != DEFAULT_OVERRIDE_VALUE
        preds_to_override = overridden_preds[to_override]
        declined_mask_to_override = declined_mask[to_override]

        if preds_to_override.shape[0] > 0:
            unmapped_preds_to_override = pd.Series(preds_to_override).map(self._target_map).values
            # On the rows that are not declined we set the value of the unmapped pred,
            # And where they are overriden we don't care so we set just -1
            columns_to_override = np.where(~declined_mask_to_override, unmapped_preds_to_override.astype(int), -1)
            # For overridden rows, put 1.0 on the predicted column and 0.0 everywhere else
            probas_to_override = np.zeros((columns_to_override.shape[0], overridden_probas.shape[1]))

            probas_to_override[np.where(~declined_mask_to_override), columns_to_override[~declined_mask_to_override]] = 1.
            probas_to_override[declined_mask_to_override] = np.nan
            overridden_probas[to_override, :] = probas_to_override

        return OverriddenClassificationPredictionResult(prediction_result, override_flags_series, self._overrides_names,
                                                        overridden_preds, probas=overridden_probas,
                                                        declined_mask=declined_mask)


def _get_default_project_key_or_fail_apinode():
    """
    Retrieves current project key when it means something, and fail in the context of the API node (because there is no
    concept of Project on API Nodes)
    """
    # TODO @hippocrates: This condition is pretty hacky, you should have DKU_NODE_TYPE == 'api' for the dev
    #   server as well, but not working at the moment. Anyway, this code is meant to go away soon
    #   (please don't be there in 2023)
    if os.environ.get("DKU_NODE_TYPE") == "api" or os.environ.get("DKU_LAMBDA_DEVSERVER") is not None:
        raise ValueError("Using Model overrides with Python is not supported on the API node")
    return default_project_key()


class RegressionMlOverridesApplier(MlOverridesApplier):

    def _get_empty_overridden_result(self, prediction_result):
        """
        :type prediction_result: PredictionResult
        :rtype: dataiku.doctor.prediction.overrides.ml_overrides_results.OverriddenPredictionResults
        """
        return OverriddenPredictionResults(prediction_result, self._get_empty_overrides_flags(), self._overrides_names,
                                           prediction_result.preds, declined_mask=None)

    def _apply_overrides(self, prediction_result, override_flags_series):
        """
        :type prediction_result: PredictionResult
        :type override_flags_series: pd.Series
        :rtype: OverriddenPredictionResults
        """
        overridden_preds, declined_mask = self._override_preds(prediction_result, override_flags_series)
        return OverriddenPredictionResults(prediction_result, override_flags_series, self._overrides_names,
                                           overridden_preds, declined_mask,
                                           self._nullify_declined_values(declined_mask, prediction_result.prediction_intervals))

    def _prepare_df_for_overrides(self, input_df, prediction_result):
        """
        :type input_df: pd.DataFrame
        :type prediction_result: PredictionResult
        :rtype: pd.DataFrame
        """
        pred_df = prediction_result.as_dataframe()
        if prediction_result.has_prediction_intervals():
            interval_size = pred_df[PREDICTION_INTERVAL_UPPER] - pred_df[PREDICTION_INTERVAL_LOWER]
            prediction = pred_df[PREDICTION]
            pred_df[PREDICTION_INTERVAL_SIZE_COL] = interval_size
            pred_df[PREDICTION_INTERVAL_RELATIVE_SIZE_COL] = np.where(prediction == 0, np.nan, interval_size/prediction)
        pred_df.index = input_df.index
        return self._concat_input_with_extra_df(input_df, pred_df)

    @staticmethod
    def _nullify_declined_values(declined_mask, prediction_intervals):
        """
        :type declined_mask: np.ndarray or None
        :type prediction_intervals: np.ndarray or None
        :rtype: np.ndarray or None
        """
        if prediction_intervals is None:
            return None
        if declined_mask is None:
            return prediction_intervals
        return np.where(declined_mask[:, None], np.nan, prediction_intervals)


class NonProbabilisticClassificationMlOverridesApplier(MlOverridesApplier):

    def _get_empty_overridden_result(self, prediction_result):
        """
        :type prediction_result: ClassificationPredictionResult
        :rtype: dataiku.doctor.prediction.overrides.ml_overrides_results.OverriddenClassificationPredictionResult
        """
        return OverriddenClassificationPredictionResult(prediction_result, self._get_empty_overrides_flags(),
                                                        self._overrides_names, prediction_result.preds)

    def _apply_overrides(self, prediction_result, override_flags_series):
        """
        :type prediction_result: ClassificationPredictionResult
        :type override_flags_series: pd.Series
        :rtype: OverriddenClassificationPredictionResult
        """
        overridden_preds, declined_mask = self._override_preds(prediction_result, override_flags_series)
        return OverriddenClassificationPredictionResult(prediction_result, override_flags_series, self._overrides_names,
                                                        overridden_preds, declined_mask=declined_mask)


class BinaryProbabilisticClassificationMlOverridesApplier(ProbabilisticClassificationMlOverridesApplier):

    def apply_on_decisions_and_cuts(self, input_df, decisions_and_cuts):
        """
        :type input_df: pd.DataFrame
        :type decisions_and_cuts: DecisionsAndCuts
        :rtype: DecisionsAndCuts
        """
        if len(decisions_and_cuts) == 0:
            raise ValueError("Cannot override empty decision and cuts")

        first_prediction_result = decisions_and_cuts.get_prediction_results()[0]

        if first_prediction_result.is_empty():  # Assuming that they all will be empty then
            return DecisionsAndCuts(decisions_and_cuts.get_cuts(),
                                    [self._get_empty_overridden_result(pr) for
                                     pr in decisions_and_cuts.get_prediction_results()])

        # Arbitrarily computing flags with the first prediction result, any would yield the same flags
        prepared_df = self._prepare_df_for_overrides(input_df, first_prediction_result)
        override_flags_series = self._compute_overrides_flags(prepared_df)

        overridden_prediction_results = []
        for prediction_result in decisions_and_cuts.get_prediction_results():
            # Make a copy of the flag series, just to be extra safe
            overridden_prediction_result = self._apply_overrides(prediction_result, override_flags_series.copy())
            overridden_prediction_results.append(overridden_prediction_result)

        return DecisionsAndCuts(decisions_and_cuts.get_cuts(), overridden_prediction_results)
