import logging
import math
from abc import ABCMeta
from abc import abstractmethod

import pandas as pd
from six import add_metaclass

from dataiku.base.folder_context import FmiReadonlyFolderContexts
from dataiku.base.folder_context import build_folder_context
from dataiku.base.utils import RaiseWithTraceback
from dataiku.core import doctor_constants
from dataiku.core.doctor_constants import PREPROC_ONEVALUE
from dataiku.core.intercom import backend_void_call
from dataiku.core.saved_model import build_predictor
from dataiku.doctor.prediction.classification_scoring import BinaryClassificationModelScorer
from dataiku.doctor.prediction.decisions_and_cuts import DecisionsAndCuts
from dataiku.doctor.prediction.overrides.ml_overrides_params import ml_overrides_params_from_model_folder
from dataikuscoring.utils.prediction_result import AbstractPredictionResult
from dataiku.doctor.prediction.regression_scoring import RegressionModelScorer
from dataiku.doctor.preprocessing.dataframe_preprocessing import DkuDroppedMultiframeException
from dataiku.doctor.preprocessing_handler import PredictionPreprocessingHandler
from dataiku.doctor.utils.split import df_from_split_desc
from dataikuscoring.utils.scoring_data import ScoringData

logger = logging.getLogger(__name__)


@add_metaclass(ABCMeta)
class ModelInformationHandlerBase(object):
    @abstractmethod
    def use_full_df(self):
        pass

    @abstractmethod
    def get_sample_weight_variable(self):
        pass

    @abstractmethod
    def get_output_folder_context(self):
        """
        :rtype: dataiku.base.folder_context.FolderContext
        """
        pass

    @abstractmethod
    def get_model_folder_context(self):
        """
        :rtype: dataiku.base.folder_context.FolderContext
        """
        pass

    @abstractmethod
    def get_schema(self):
        pass

    @abstractmethod
    def get_target_variable(self):
        pass

    @abstractmethod
    def get_prediction_type(self):
        pass

    @abstractmethod
    def get_type_of_column(self, col_name):
        pass

    @abstractmethod
    def run_binary_scoring(self, df):
        pass

    @abstractmethod
    def run_regression_scoring(self, df):
        pass

    @abstractmethod
    def get_full_df(self):
        pass

    def get_per_feature(self):
        return self.get_preprocessing_params().get("per_feature")

    def get_per_feature_col(self, col_name):
        per_feature = self.get_per_feature()
        if col_name not in per_feature.keys():
            raise ValueError("Column '{}' not found".format(col_name))
        return per_feature[col_name]

    def is_column_dummified(self, col_name):
        return self.get_per_feature()[col_name].get("category_handling") == "DUMMIFY"

    def category_possible_values(self, col_name):
        """
        Get the list of modalities which are dummified by the preprocessing for the given column
        :param col_name: the name of the column
        :return: None if the column is not dummified else it returns the list of modalities that are dummified
        """
        if not self.is_column_dummified(col_name):
            return None
        possible_values = self.get_collector_data()["per_feature"][col_name].get("category_possible_values")
        missing_handling = self.get_per_feature_col(col_name).get("missing_handling")
        if missing_handling == "NONE":  # Treat as regular value
            possible_values.append(doctor_constants.FILL_NA_VALUE)
        return possible_values

    def compute_global_explanations(
            self,
            df,
            n_explanations=4,
            progress=None,
            mc_steps=100,
            num_samples=150
    ):
        if not self._predictor._individual_explainer.is_ready():
            self._predictor.ready_explainer()
        if not self._predictor._individual_explainer.is_ready():
            logger.warning("Could not prepare explainer")
        else:
            df = df.sample(min(num_samples, len(df)), random_state=1337)

            abs_avg_explanations, explanations_dict = self._predictor._individual_explainer.explain_global(
                df,
                n_explanations,
                progress=progress,
                shapley_background_size=mc_steps,
            )

            observations = df.fillna("").to_dict(orient="list")
            model_folder_context = self.get_model_folder_context()

            absolute_importance = {
                "absoluteImportance": abs_avg_explanations
            }

            explanations_and_observations = {
                "explanations": explanations_dict,
                "observations": observations,
            }

            model_folder_context.write_json("global_explanations_absolute_importance.json", absolute_importance)
            model_folder_context.write_json("global_explanations_observations.json", explanations_and_observations)

            per_class_facts = self._predictor._individual_explainer.get_model_per_class_facts(explanations_dict, observations)

            facts = {
                "perClassFacts": per_class_facts
            }

            model_folder_context.write_json("global_explanations_facts.json", facts)

@add_metaclass(ABCMeta)
class PredictionModelInformationHandlerBase(ModelInformationHandlerBase):
    @abstractmethod
    def get_explainer(self):
        pass

    @abstractmethod
    def get_test_df(self):
        pass

    @abstractmethod
    def predict(self, df, output_probas=True):
        pass


# WARNING: this class is used for plugin development (trained and saved models views).
# Beware not to make breaking changes
class PredictionModelInformationHandler(PredictionModelInformationHandlerBase):

    def __init__(self, split_desc, core_params, preprocessing_folder_context, model_folder_context,
                 split_folder_context, postcompute_folder_context=None, train_split_desc=None,
                 train_split_folder_context=None, fmi=None):
        self._split_desc = split_desc
        self._train_split_desc = train_split_desc if train_split_desc else split_desc
        self._core_params = core_params
        self._preprocessing_folder_context = preprocessing_folder_context
        self._model_folder_context = model_folder_context
        self._split_folder_context = split_folder_context
        self._preprocessing_params = self._preprocessing_folder_context.read_json("rpreprocessing_params.json")
        self._modeling_params = self._model_folder_context.read_json("rmodeling_params.json")
        self._ml_overrides_params = ml_overrides_params_from_model_folder(self._model_folder_context)
        self._keras_scoring_batch_size = 100

        self._predictor = build_predictor(
            "PREDICTION",
            self._model_folder_context,
            self._preprocessing_folder_context,
            self._split_folder_context,
            [],  # no need for conditional outputs in this case
            self._core_params,
            self._split_desc,
            self._train_split_desc,
            train_split_folder_context,
            fmi=fmi
        )

        self._collector_data = None
        self._preproc_handler = None
        self._pipeline = None
        self._train_df = None
        self._test_df = None
        self._full_df = None
        if not postcompute_folder_context:
            self._postcompute_folder_context = self._model_folder_context.get_subfolder_context("posttrain")
        else:
            self._postcompute_folder_context = postcompute_folder_context

    @staticmethod
    def from_full_model_id(fmi):
        backend_void_call("savedmodels/grant-fs-read-acls", data={"fmi": fmi})
        fmi_folder_contexts = FmiReadonlyFolderContexts.build(fmi)
        session_folder_context = fmi_folder_contexts.session_folder_context
        model_folder_context = fmi_folder_contexts.model_folder_context
        preprocessing_folder_context = fmi_folder_contexts.preprocessing_folder_context
        split_folder_context = fmi_folder_contexts.split_folder_context
        split_desc = split_folder_context.read_json(fmi_folder_contexts.split_desc_filename)
        core_params = session_folder_context.read_json("core_params.json")

        return PredictionModelInformationHandler(
            split_desc, core_params, preprocessing_folder_context, model_folder_context, split_folder_context, fmi=fmi
        )

    def get_predictor(self):
        return self._predictor

    def get_explainer(self):
        return self._predictor._individual_explainer

    def get_algorithm(self):
        return self._modeling_params.get("algorithm", None)

    def is_ensemble(self):
        return self.get_algorithm() == "PYTHON_ENSEMBLE"

    def is_kfolding(self):
        return self._split_desc["params"].get("kfold", False)

    def use_full_df(self):
        return self.is_kfolding()

    def is_keras_backend(self):
        return self.get_algorithm() == "KERAS_CODE"

    def get_weight_method(self):
        return self._core_params.get("weight", {}).get("weightMethod", None)

    def get_sample_weight_variable(self):
        return self._core_params.get("weight", {}).get("sampleWeightVariable", None)

    def with_sample_weights(self):
        return self.get_weight_method() in {"SAMPLE_WEIGHT", "CLASS_AND_SAMPLE_WEIGHT"}

    def set_keras_scoring_batch_size(self, new_value):
        self._keras_scoring_batch_size = new_value

    def supports_exploration(self):
        return self._core_params["backendType"] == "PY_MEMORY"

    def get_clf(self):  # This method must be kept because it is used by some plugins
        return self._predictor._clf

    def get_model(self):
        return self._predictor._model

    def get_output_folder_context(self):
        return self._postcompute_folder_context

    def get_model_folder_context(self):
        return self._model_folder_context

    def get_preprocessing_params(self):
        return self._preprocessing_params

    def get_modeling_params(self):
        return self._modeling_params

    def get_split_desc(self):
        return self._split_desc

    def get_schema(self):
        return self._split_desc.get("schema", {})

    def get_target_variable(self):
        return self._core_params["target_variable"]

    def get_model_type(self):
        return self._core_params["taskType"]

    def get_prediction_type(self):
        return self._core_params["prediction_type"]

    def get_target_map(self):
        return self.get_preproc_handler().target_map

    def get_inv_map(self):
        return {int(v): k for (k, v) in self.get_target_map().items()}

    def get_per_feature(self):
        return self.get_preprocessing_params().get("per_feature")

    def get_per_feature_col(self, col_name):
        per_feature = self.get_per_feature()
        if col_name not in per_feature.keys():
            raise ValueError("Column '{}' not found".format(col_name))
        return per_feature[col_name]

    def get_type_of_column(self, col_name):
        return self.get_per_feature_col(col_name)["type"]

    def get_role_of_column(self, col_name):
        return self.get_per_feature_col(col_name)["role"]

    def predict_and_concatenate(self, df, output_probas=True):
        orig_df = df.copy()
        pred_df = self.predict(df, output_probas)
        return pd.concat([orig_df, pred_df], axis=1)

    # Behaves as scoring recipe, i.e. returns:
    #  - For binary : ["prediction", "proba_{class1}", "proba_{class2}"] with "proba_..." only if "output_probas"
    #  - For multiclass : ["prediction", "proba_{class1}", "proba_{class2}", ..., "proba_{classN}"]  with "proba_..."
    #    only if "output_probas"
    #  - For regression: ["prediction"]
    #
    #  Note that the prediction may alter the input dataframe
    def predict(self, df, output_probas=True):
        if self.is_ensemble():
            return self._predictor.get_prediction_dataframe(df, True, output_probas, False, False)
        else:
            return self._predictor._get_prediction_dataframe(df, True, output_probas, False, False)

    def prepare_for_scoring(self, df):
        # Preprocess data

        # Ensemble models embeds the preprocessing inside the model so we don't need to preprocess them before calling
        # predict. However we need to preprocess them to get valid_y
        if self.is_ensemble():
            transform = self.get_pipeline().process(df)
            valid_y = transform["target"]
            valid_unprocessed_df = transform["UNPROCESSED"]
            transformed_X = valid_unprocessed_df.copy()

            if self.with_sample_weights():
                valid_sample_weights = transform["weight"]
            else:
                valid_sample_weights = None

        else:

            try:
                transformed_X, _, is_empty, valid_y, valid_sample_weights, valid_unprocessed_df, transform \
                    = self._predictor.preprocessing.preprocess(df, with_target=True, with_sample_weights=True,
                                                               with_unprocessed=True, with_transformed=True)
            except DkuDroppedMultiframeException:
                # preprocessing failed because all targets (or weights) are NaN
                return ScoringData(is_empty=True, reason=doctor_constants.PREPROC_NOTARGET)

            # No need to predict if all the rows are dropped
            if is_empty:
                return ScoringData(is_empty=True, reason=doctor_constants.PREPROC_DROPPED)

        # Run prediction
        model = self.get_model()
        decisions_and_cuts = None  # should only be defined for binary classif.
        if self.get_prediction_type() == doctor_constants.BINARY_CLASSIFICATION:
            decisions_and_cuts = model.compute_decisions_and_cuts(transformed_X, valid_unprocessed_df)
            prediction_result = decisions_and_cuts.get_prediction_result_for_nearest_cut(model.get_threshold())
        elif self.get_prediction_type() == doctor_constants.MULTICLASS:
            prediction_result = model.compute_predictions(transformed_X, valid_unprocessed_df)
        else:  # regression case
            prediction_result = model.compute_predictions(transformed_X, valid_unprocessed_df)

        return ScoringData(prediction_result=prediction_result, valid_y=valid_y,
                           valid_sample_weights=valid_sample_weights, 
                           decisions_and_cuts=decisions_and_cuts, valid_unprocessed=valid_unprocessed_df)

    # For KERAS algorithm, cannot preprocess full data directly, must work with batches
    def prepare_for_scoring_full(self, df):
        if not self.is_keras_backend():
            return self.prepare_for_scoring(df)
        else:
            scoring_data_batches = ScoringDataConcatenator()
            num_rows = df.shape[0]
            nb_batches = int(math.ceil(num_rows * 1.0 / self._keras_scoring_batch_size))

            for num_batch in range(nb_batches):
                input_df_batch = df.iloc[num_batch * self._keras_scoring_batch_size: (num_batch + 1) * self._keras_scoring_batch_size, :]
                scoring_data = self.prepare_for_scoring(input_df_batch)
                scoring_data_batches.add_scoring_data(scoring_data)

            full_scoring_data = scoring_data_batches.get_concatenated_scoring_data()
            return full_scoring_data

    def run_binary_scoring(self, df):
        if self.get_prediction_type() == doctor_constants.BINARY_CLASSIFICATION:
            # warning: scoring_data.prediction_result.preds are computed with the default threshold, but are
            # not used anyway
            scoring_data = self.prepare_for_scoring_full(df)
            if scoring_data.is_empty:
                return False, scoring_data.reason, None

            binary_classif_scorer = BinaryClassificationModelScorer(
                self._modeling_params,
                None,
                scoring_data.decisions_and_cuts,
                scoring_data.valid_y,
                self.get_target_map(),
                test_unprocessed=scoring_data.valid_unprocessed,
                test_X=None,  # Not dumping on disk predicted_df
                test_df_index=None,  # Not dumping on disk predicted_df
                test_sample_weight=scoring_data.valid_sample_weights)

            can_score, reason = binary_classif_scorer.can_score()
            if not can_score:
                return False, reason, None

            perf = binary_classif_scorer.score()
            return True, None, perf

        else:
            raise ValueError("Cannot compute binary scoring on '{}' model".format(self.get_prediction_type().lower()))

    def run_regression_scoring(self, df):
        if self.get_prediction_type() != doctor_constants.REGRESSION:
            raise ValueError(
                "Cannot compute regression scoring on '{}' model".format(self.get_prediction_type().lower()))

        scoring_data = self.prepare_for_scoring_full(df)
        if scoring_data.is_empty:
            return False, scoring_data.reason, None

        if df.shape[0] == 1:
            return False, PREPROC_ONEVALUE, None

        regression_scorer = RegressionModelScorer(self._modeling_params,
                                                  scoring_data.prediction_result,
                                                  scoring_data.valid_y,
                                                  None,
                                                  test_unprocessed=scoring_data.valid_unprocessed,
                                                  test_X=None,  # Not dumping on disk predicted_df
                                                  test_df_index=None,  # Not dumping on disk predicted_df
                                                  test_sample_weight=scoring_data.valid_sample_weights)
        perf = regression_scorer.score()
        return True, None, perf

    def get_collector_data(self):
        if self._collector_data is None:
            self._collector_data = self._preprocessing_folder_context.read_json("collector_data.json")
        return self._collector_data

    def get_preproc_handler(self):
        if self._preproc_handler is None:
            from dataiku.core.saved_model import get_source_dss_version
            from dataiku.doctor.prediction.common import PredictionAlgorithmNaNSupport
            nan_support = PredictionAlgorithmNaNSupport(self._modeling_params, self._preprocessing_params, source_dss_version=get_source_dss_version(self._model_folder_context))
            self._preproc_handler = PredictionPreprocessingHandler.build(self._core_params, self._preprocessing_params,
                                                                         self._preprocessing_folder_context, nan_support=nan_support)
            self._preproc_handler.collector_data = self.get_collector_data()
        return self._preproc_handler

    def get_pipeline(self, with_target=True):
        if self._pipeline is None:
            self._pipeline = self.get_preproc_handler().build_preprocessing_pipeline(with_target=with_target)
        return self._pipeline

    def category_possible_values(self, col_name):
        """
        Get the list of modalities which are dummified by the preprocessing for the given column
        :param col_name: the name of the column
        :return: None if the column is not dummified else it returns the list of modalities that are dummified
        """
        if not self.is_column_dummified(col_name):
            return None
        possible_values = self.get_collector_data()["per_feature"][col_name].get("category_possible_values")
        missing_handling = self.get_per_feature_col(col_name).get("missing_handling")
        if missing_handling == "NONE":  # Treat as regular value
            possible_values.append(doctor_constants.FILL_NA_VALUE)
        return possible_values

    def is_column_dummified(self, col_name):
        return self.get_per_feature()[col_name].get("category_handling") == "DUMMIFY"

    def _get_df(self, split):
        with RaiseWithTraceback("Failed to properly open the {} dataset. "
                                "Was it modified during a clean-up routine ?".format(split)):
            df = df_from_split_desc(self._split_desc,
                                    split,
                                    self._split_folder_context,
                                    self._preprocessing_params['per_feature'],
                                    self._core_params["prediction_type"])

        expected_df_length = self._split_desc.get("{}Rows".format(split), None)
        return df, (not expected_df_length) or (expected_df_length == df.shape[0])

    def get_train_df(self):
        if self._train_df is None:
            self._train_df = self._get_df("train")
        return self._train_df

    def get_test_df(self):
        if self._test_df is None:
            self._test_df = self._get_df("test")
        return self._test_df

    def get_full_df(self):
        if self._full_df is None:
            self._full_df = self._get_df("full")
        return self._full_df

    def compute_global_explanations_on_non_droppable_data(self, preprocessed_df,
                                                          n_explanations=4,
                                                          progress=None,
                                                          mc_steps=100,
                                                          num_samples=150):
        return super(PredictionModelInformationHandler, self). \
            compute_global_explanations(preprocessed_df, n_explanations, progress, mc_steps, num_samples)

    def compute_global_explanations(
            self,
            testset,
            n_explanations=4,
            progress=None,
            mc_steps=100,
            num_samples=150,
            output_folder=None,
    ):
        # Process the testset and compute prediction result in order to make sure we only use rows not dropped by
        # the preprocessing or declined by the model overrides in our montecarlo process for shap computations.
        # This means we'll reprocess the whole testset, which can be arbitrarily slow.
        scoring_data = self.prepare_for_scoring(testset)
        aligned_preprocessed = scoring_data.prediction_result.align_with_not_declined(scoring_data.valid_unprocessed)
        return self.compute_global_explanations_on_non_droppable_data(aligned_preprocessed,
                                                                      n_explanations,
                                                                      progress,
                                                                      mc_steps,
                                                                      num_samples)


class ScoringDataConcatenator:

    def __init__(self):
        self.valid_y_list = []
        self.valid_sample_weights_list = []
        self.decisions_and_cuts_list = []
        self.preds_df_list = []
        self.probas_df_list = []
        self.prediction_result_list = []
        self.valid_unprocessed_list = []

    def add_scoring_data(self, scoring_data):
        if scoring_data.is_empty:
            return

        self.prediction_result_list.append(scoring_data.prediction_result)
        self.valid_y_list.append(scoring_data.valid_y)
        if scoring_data.valid_sample_weights is not None:
            self.valid_sample_weights_list.append(scoring_data.valid_sample_weights)
        if scoring_data.decisions_and_cuts is not None:
            self.decisions_and_cuts_list.append(scoring_data.decisions_and_cuts)
        if scoring_data.preds_df is not None:
            self.preds_df_list.append(scoring_data.preds_df)  # never actually used as of Nov. 2022
        if scoring_data.probas_df is not None:
            self.probas_df_list.append(scoring_data.probas_df)  # never actually used as of Nov. 2022
        if scoring_data.valid_unprocessed is not None:
            self.valid_unprocessed_list.append(scoring_data.valid_unprocessed)

    def get_concatenated_scoring_data(self):
        if len(self.prediction_result_list) == 0:
            return ScoringData(is_empty=True)
        else:
            valid_y = pd.concat(self.valid_y_list)
            valid_sample_weights = pd.concat(self.valid_sample_weights_list) \
                                   if len(self.valid_sample_weights_list) > 0 \
                                   else None
            preds_df = pd.concat(self.preds_df_list) if len(self.preds_df_list) > 0 else None
            probas_df = pd.concat(self.probas_df_list) if len(self.probas_df_list) > 0 else None
            decisions_and_cuts = (DecisionsAndCuts.concat(self.decisions_and_cuts_list)
                                  if len(self.decisions_and_cuts_list) > 0 else None)
            prediction_result = (AbstractPredictionResult.concat(self.prediction_result_list)
                                 if len(self.prediction_result_list) > 0 else None)
            valid_unprocessed = (pd.concat(self.valid_unprocessed_list) 
                                 if len(self.valid_unprocessed_list) > 0 else None)

            return ScoringData(prediction_result=prediction_result, valid_y=valid_y,
                               valid_sample_weights=valid_sample_weights, preds_df=preds_df, probas_df=probas_df,
                               decisions_and_cuts=decisions_and_cuts, valid_unprocessed=valid_unprocessed)


def build_model_handler(split_desc, core_params, preprocessing_folder, model_folder,
                        split_folder, fmi, postcompute_folder=None, train_split_desc=None,
                        train_split_folder=None):
    from dataiku.external_ml import is_mlflow_model

    model_folder_context = build_folder_context(model_folder)
    postcompute_folder_context = build_folder_context(postcompute_folder) if postcompute_folder else None
    if is_mlflow_model(model_folder_context):
        from dataiku.external_ml.mlflow.model_information_handler import MLflowModelInformationHandler
        model_handler = MLflowModelInformationHandler(model_folder_context, postcompute_folder_context)
    else:
        preprocessing_folder_context = build_folder_context(preprocessing_folder)
        split_folder_context = build_folder_context(split_folder)
        train_split_folder_context = build_folder_context(train_split_folder) if train_split_folder is not None else None
        model_handler = PredictionModelInformationHandler(split_desc, core_params, preprocessing_folder_context,
                                                          model_folder_context, split_folder_context,
                                                          postcompute_folder_context, train_split_desc,
                                                          train_split_folder_context, fmi=fmi)
    return model_handler
