import logging

import numpy as np
import pandas as pd
from sklearn.model_selection import GroupKFold
from sklearn.model_selection import KFold
from sklearn.model_selection import StratifiedKFold

from dataiku.core import intercom
from dataiku.core.doctor_constants import PREDICTION_TYPE
from dataiku.core.doctor_constants import TARGET_VARIABLE
from dataiku.doctor import step_constants
from dataiku.doctor import utils
from dataiku.doctor.diagnostics import diagnostics
from dataiku.doctor.diagnostics.diagnostics import DiagnosticsScoringResults
from dataiku.doctor.posttraining.model_information_handler import PredictionModelInformationHandler
from dataiku.doctor.prediction.classification_fit import classification_fit
from dataiku.doctor.prediction.classification_fit import classification_fit_ensemble
from dataiku.doctor.prediction.classification_scoring import BinaryClassificationModelScorer
from dataiku.doctor.prediction.classification_scoring import CVBinaryClassificationModelScorer
from dataiku.doctor.prediction.classification_scoring import CVMulticlassModelScorer
from dataiku.doctor.prediction.classification_scoring import ClassificationModelIntrinsicScorer
from dataiku.doctor.prediction.classification_scoring import MulticlassModelScorer
from dataiku.doctor.prediction.classification_scoring import binary_classification_scorer_with_valid
from dataiku.doctor.prediction.classification_scoring import multiclass_scorer_with_valid
from dataiku.doctor.prediction.classification_scoring import save_classification_statistics
from dataiku.doctor.prediction.common import get_initial_intrinsic_perf_data
from dataiku.doctor.prediction.common import get_monotonic_cst
from dataiku.doctor.prediction.common import needs_hyperparameter_search
from dataiku.doctor.prediction.common import prepare_multiframe
from dataiku.doctor.prediction.common import PredictionAlgorithmNaNSupport
from dataiku.doctor.prediction.decisions_and_cuts import DecisionsAndCuts
from dataiku.doctor.prediction.prediction_model_serialization import ModelSerializer
from dataiku.doctor.prediction.regression_fit import regression_fit_ensemble
from dataiku.doctor.prediction.regression_fit import regression_fit_single
from dataiku.doctor.prediction.regression_scoring import CVRegressionModelScorer
from dataiku.doctor.prediction.regression_scoring import RegressionModelIntrinsicScorer
from dataiku.doctor.prediction.regression_scoring import RegressionModelScorer
from dataiku.doctor.prediction.regression_scoring import regression_scorer_with_valid
from dataiku.doctor.prediction.regression_scoring import save_regression_statistics
from dataiku.doctor.prediction.prediction_interval_model import train_prediction_interval_or_none
from dataiku.doctor.prediction.scorable_model import ScorableModel
from dataiku.doctor.prediction.scoring_base import PERF_FILENAME
from dataiku.doctor.prediction.scoring_base import PERF_WITHOUT_OVERRIDES_FILENAME
from dataiku.doctor.prediction.scoring_base import PREDICTED_FILENAME
from dataiku.doctor.preprocessing_collector import PredictionPreprocessingDataCollector
from dataiku.doctor.preprocessing_handler import PredictionPreprocessingHandler
from dataiku.doctor.utils import doctor_constants
from dataiku.doctor.utils.gpu_execution import get_gpu_config_from_core_params
from dataiku.doctor.utils.model_io import dump_model_to_folder
from dataiku.doctor.utils.skcompat import instantiate_stratified_group_kfold
from dataiku.doctor.utils.split import input_columns
from dataikuscoring.utils.prediction_result import ClassificationPredictionResult
from dataikuscoring.utils.prediction_result import PredictionResult

logger = logging.getLogger(__name__)

# The functions in this module are used both by the recipes and by the analyses
# for non-ensembles and when kfold disabled
def prediction_train_score_save(transformed_train,
                                transformed_test,
                                test_df_index,
                                core_params,
                                split_desc,
                                modeling_params,
                                model_folder_context,
                                preprocessing_folder_context,
                                split_folder_context,
                                listener,
                                target_map,
                                pipeline,
                                preprocessing_params,
                                ml_overrides_params):
    """
        Fit a CLF, save it, computes intrinsic scores, writes them,
        scores a test set it, write scores and extrinsinc perf
    """
    prediction_type = core_params["prediction_type"]
    model_type = core_params["taskType"]
    train_X = transformed_train["TRAIN"]
    train_y = transformed_train["target"]

    if needs_hyperparameter_search(modeling_params):
        previous_search_time = utils.get_hyperparams_search_time_traininfo(model_folder_context)
        initial_state = step_constants.ProcessingStep.STEP_HYPERPARAMETER_SEARCHING
        def gridsearch_done_fn():
            step = listener.pop_step()
            utils.write_hyperparam_search_time_traininfo(model_folder_context, step["time"])
            listener.push_step(step_constants.ProcessingStep.STEP_FITTING)
            listener.save_status()
    else:
        initial_state = step_constants.ProcessingStep.STEP_FITTING
        previous_search_time = None
        def gridsearch_done_fn():
            pass

    weight_method = core_params.get("weight", {}).get("weightMethod", None)
    with_sample_weight = weight_method in {"SAMPLE_WEIGHT", "CLASS_AND_SAMPLE_WEIGHT"}
    if with_sample_weight:
        assert transformed_train["weight"].values.min() > 0, "Sample weights must be positive"
        assert transformed_test["weight"].values.min() > 0, "Sample weights must be positive"

    monotonic_cst = get_monotonic_cst(preprocessing_params, train_X)

    if prediction_type in (doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS):
        assert(target_map != None)
        assert(len(target_map) >= 2)

        with_class_weight = weight_method in {"CLASS_WEIGHT", "CLASS_AND_SAMPLE_WEIGHT"}
        calibration_method = core_params.get("calibration", {}).get("calibrationMethod", None)
        calibrate_proba = calibration_method in [doctor_constants.SIGMOID, doctor_constants.ISOTONIC]
        calibrate_on_test = core_params.get("calibration", {}).get("calibrateOnTestSet", True)
        calibration_ratio = core_params.get("calibration", {}).get("calibrationDataRatio", doctor_constants.DEFAULT_CALIBRATION_DATA_RATIO)
        with listener.push_step(initial_state, previous_duration=previous_search_time):
            (clf, actual_params, prepared_X, iipd) = classification_fit(modeling_params, core_params, transformed_train,
                                                                        model_folder_context=model_folder_context,
                                                                        gridsearch_done_fn=gridsearch_done_fn,
                                                                        transformed_test=transformed_test,
                                                                        target_map=target_map,
                                                                        with_sample_weight=with_sample_weight,
                                                                        with_class_weight=with_class_weight,
                                                                        calibration_method=calibration_method,
                                                                        calibration_ratio=calibration_ratio,
                                                                        calibrate_on_test=calibrate_on_test,
                                                                        monotonic_cst=monotonic_cst)

            diagnostics.on_fitting_end(features=transformed_train["TRAIN"].columns(), clf=clf, prediction_type=prediction_type, train_target=transformed_train["target"])

            model = ScorableModel.build(clf, model_type, prediction_type, modeling_params['algorithm'],
                                        preprocessing_params, ml_overrides_params)
        with listener.push_step(step_constants.ProcessingStep.STEP_SAVING):
            ModelSerializer.build(model, model_folder_context, train_X.columns(), calibrate_proba).serialize()
            model_folder_context.write_json("actual_params.json", actual_params)

        with listener.push_step(step_constants.ProcessingStep.STEP_SCORING):
            ClassificationModelIntrinsicScorer(modeling_params, clf, train_X, train_y, target_map, pipeline,
                                               model_folder_context, prepared_X, iipd, with_sample_weight, calibrate_proba).score()

            if prediction_type == doctor_constants.BINARY_CLASSIFICATION:
                scorer = binary_classification_scorer_with_valid(modeling_params, model,
                                                                 transformed_test, model_folder_context, test_df_index, target_map=target_map, with_sample_weight=with_sample_weight)
            else:
                scorer = multiclass_scorer_with_valid(modeling_params, model,
                                                      transformed_test, model_folder_context, test_df_index, target_map=target_map, with_sample_weight=with_sample_weight)
            scorer.score()
            scorer.save()
            diagnostics.on_scoring_end(scoring_results=DiagnosticsScoringResults.build_from_scorer(prediction_type, scorer),
                                       transformed_test=transformed_test, transformed_train=transformed_train, with_sample_weight=with_sample_weight)

    elif prediction_type == doctor_constants.REGRESSION:
        with listener.push_step(initial_state, previous_duration=previous_search_time):
            (clf, actual_params, prepared_X, iipd) = regression_fit_single(modeling_params, core_params, transformed_train,
                                                                           model_folder_context=model_folder_context,
                                                                           gridsearch_done_fn=gridsearch_done_fn,
                                                                           with_sample_weight=with_sample_weight,
                                                                           monotonic_cst=monotonic_cst)
            diagnostics.on_fitting_end(features=transformed_train["TRAIN"].columns(), clf=clf, prediction_type=prediction_type, train_target=transformed_train["target"])

            pred_interval_model = train_prediction_interval_or_none(clf, core_params, modeling_params, transformed_test)
            model = ScorableModel.build(clf, model_type, prediction_type, modeling_params["algorithm"],
                                        overrides_params=ml_overrides_params, prediction_interval_model=pred_interval_model)
        with listener.push_step(step_constants.ProcessingStep.STEP_SAVING):
            ModelSerializer.build(model, model_folder_context, train_X.columns()).serialize()
            model_folder_context.write_json("actual_params.json", actual_params)

        with listener.push_step(step_constants.ProcessingStep.STEP_SCORING):
            RegressionModelIntrinsicScorer(modeling_params, clf, train_X, train_y, pipeline, model_folder_context, prepared_X, iipd, with_sample_weight).score()
            scorer = regression_scorer_with_valid(modeling_params, model, transformed_test, model_folder_context, test_df_index, with_sample_weight)
            scorer.score()
            scorer.save()
            diagnostics.on_scoring_end(scoring_results=DiagnosticsScoringResults.build_from_scorer(prediction_type, scorer),
                                       transformed_test=transformed_test, transformed_train=transformed_train, with_sample_weight=with_sample_weight)

    else:
        raise ValueError("Prediction type %s is not valid" % prediction_type)

    with listener.push_step(step_constants.ProcessingStep.STEP_POSTTRAINING):
        try:
            if modeling_params.get("skipExpensiveReports"):
                logger.info("Skipping background rows drawing, feature and column importance computation")
            elif prediction_type in {doctor_constants.MULTICLASS, doctor_constants.BINARY_CLASSIFICATION} and not scorer.use_probas:
                logger.info("Cannot draw background rows, compute column importance: model is not probabilistic")
            else:
                preliminary_compute_for_explanations(model_folder_context, split_desc, transformed_test,
                                                     preprocessing_folder_context, split_folder_context, core_params,
                                                     modeling_params, preprocessing_params,
                                                     scorer.test_prediction_result)
        except Exception as e:  # Catch all: see https://app.shortcut.com/dataiku/story/140973
            logger.exception("Exception running the post training global explanations: {}".format(e))


def prediction_train_score_save_ensemble(train,
                                         test,
                                         core_params,
                                         modeling_params,
                                         model_folder_context,
                                         listener, target_map, pipeline, with_sample_weight):
    """
        Fit a CLF, save it, computes intrinsic scores, writes them,
        scores a test set it, write scores and extrinsinc perf
    """
    prediction_type = core_params["prediction_type"]
    transformed_train = pipeline.process(train)
    train_y = transformed_train["target"]
    if with_sample_weight:
        sample_weight = transformed_train["weight"]
    else:
        sample_weight = None

    if prediction_type in (doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS):
        assert target_map is not None
        assert len(target_map) >= 2
        with listener.push_step(step_constants.ProcessingStep.STEP_FITTING):
            (clf, _, prepared_X, iipd) = classification_fit_ensemble(modeling_params, core_params,
                                                                     train, train_y, sample_weight)
        with listener.push_step(step_constants.ProcessingStep.STEP_SCORING):
            # Set the CLF in "pipelines with target" mode to be able to compute metrics
            clf.set_with_target_pipelines_mode(True)

            preds = clf.predict(test)
            probas = clf.predict_proba(test)

            transformed_test = pipeline.process(test)
            test_y = transformed_test["target"]
            if with_sample_weight:
                valid_sample_weight = transformed_test["weight"]
            else:
                valid_sample_weight = None

            if prediction_type == doctor_constants.BINARY_CLASSIFICATION:
                decisions_and_cuts = DecisionsAndCuts.from_probas(probas, target_map)
                scorer = BinaryClassificationModelScorer(modeling_params, model_folder_context, decisions_and_cuts, test_y, target_map,
                                                         test_unprocessed=transformed_test['UNPROCESSED'], test_X=transformed_test['TRAIN'],
                                                         test_df_index=test.index.copy(), test_sample_weight=valid_sample_weight)
            else:
                prediction_result = ClassificationPredictionResult(target_map, probas=probas, unmapped_preds=preds)
                scorer = MulticlassModelScorer(modeling_params, model_folder_context, prediction_result, test_y.astype(int), target_map,
                                               test_unprocessed=transformed_test['UNPROCESSED'], test_X=transformed_test['TRAIN'],
                                               test_df_index=test.index.copy(), test_sample_weight=valid_sample_weight)
            scorer.score()
            scorer.save()

    elif prediction_type == doctor_constants.REGRESSION:
        with listener.push_step(step_constants.ProcessingStep.STEP_FITTING):
            (clf, actual_params, prepared_X, iipd) = regression_fit_ensemble(modeling_params, core_params,
                                                                             train, train_y, sample_weight)

        with listener.push_step(step_constants.ProcessingStep.STEP_SCORING):
            # Set the CLF in "pipelines with target" mode to be able to compute metrics
            clf.set_with_target_pipelines_mode(True)

            p = clf.predict(test)
            transformed_test = pipeline.process(test)
            test_y = transformed_test["target"]
            if with_sample_weight:
                valid_sample_weight = transformed_test["weight"]
            else:
                valid_sample_weight = None
            prediction_result = PredictionResult(p)
            scorer = RegressionModelScorer(modeling_params, prediction_result, test_y, model_folder_context, test_unprocessed=transformed_test["UNPROCESSED"],
                                           test_X=transformed_test["TRAIN"], test_df_index=test.index.copy(), test_sample_weight=valid_sample_weight)
            scorer.score()
            scorer.save()
            model_folder_context.write_json("actual_params.json", actual_params)
    else:
        raise ValueError("Prediction type %s is not valid" % prediction_type)

    # Don't forget to put the CLF back in "scoring pipelines" mode for saving it
    clf.set_with_target_pipelines_mode(False)

    with listener.push_step(step_constants.ProcessingStep.STEP_SAVING):
        dump_model_to_folder(clf, model_folder_context)
        iperf = {
            "modelInputNRows" : train.shape[0], #todo : not the right count as may have dropped ...
            "modelInputNCols" : -1, # makes no sense for an ensemble as may have different preprocessings
            "modelInputIsSparse" : False
        }
        model_folder_context.write_json("actual_params.json", {"resolved": modeling_params})
        model_folder_context.write_json("iperf.json", iperf)


# Do the second part of KFold: params are already resolved
def prediction_train_model_kfold(full_df_clean,
                                 core_params, split_desc, preprocessing_params, modeling_params, optimized_params,
                                 preprocessing_folder_context, model_folder_context, split_folder_context,
                                 listener, with_sample_weight,
                                 with_class_weight, transformed_full,
                                 assertions_metrics=None, ml_overrides_params=None, overrides_metrics=None,
                                 monotonic_cst=None):

    split_params = split_desc["params"]
    if split_params is not None and split_params["ssdSeed"] is not None:
        seed = int(split_params["ssdSeed"])
    else:
        seed = 1337
    nan_support = PredictionAlgorithmNaNSupport(modeling_params, preprocessing_params)

    # Check error case for stratified k-fold, falling back to simpler splitter if necessary
    use_stratified = split_params["ssdStratified"]
    is_classification = core_params[PREDICTION_TYPE] in {doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS}
    if use_stratified and not is_classification:
        logger.warning("Stratified k-fold can only be used with classification models. Falling back to non-stratified k-fold.")
        use_stratified = False

    if use_stratified and full_df_clean[core_params[TARGET_VARIABLE]].hasnans:
        raise ValueError("Empty target values not supported for stratified k-fold train/test split")

    # Create the k-fold iterator
    if use_stratified and split_params["ssdGrouped"]:
        logger.info("Using stratified group k-fold train/test split")
        # NB: Don't use shuffle=True with StratifiedGroupKFold, as there is a bug in scikit-learn. See sc-131099 for details.
        logger.info('Setting shuffle=False for StratifiedGroupKFold splitter, ignoring random seed')
        kf = instantiate_stratified_group_kfold(n_splits=split_params["nFolds"], shuffle=False)
    elif use_stratified:
        logger.info('Using stratified k-fold train/test split')
        kf = StratifiedKFold(n_splits=split_params["nFolds"], shuffle=True, random_state=seed)
    elif split_params["ssdGrouped"]:
        logger.info('Using group k-fold train/test split')
        kf = GroupKFold(n_splits=split_params["nFolds"])
    else:
        logger.info('Using k-fold train/test split')
        kf = KFold(n_splits=split_params["nFolds"], shuffle=True, random_state=seed)

    if split_params["ssdGrouped"]:
        if "ssdGroupColumnName" not in split_params:
            raise ValueError("Group k-fold train/test split requires a group column to be set")

        group_labels = full_df_clean[split_params["ssdGroupColumnName"]]

        if group_labels.hasnans:
            if pd.api.types.is_numeric_dtype(group_labels):
                # When the group column is numerical, fill empty/NaN cells with a value that is not already in the column, i.e. max(column values) + 1
                group_labels_without_na = group_labels.dropna()
                if group_labels_without_na.empty:
                    raise ValueError("Group k-fold train/test split column contains no values")
                na_group_label = group_labels_without_na.max() + 1
            else:
                na_group_label = doctor_constants.FILL_NA_VALUE
            group_labels = group_labels.fillna(na_group_label)
            logger.info("Empty values found in group column for group k-fold train/test split, replacing with '{new_group_label}'".format(new_group_label=na_group_label))
        if group_labels.nunique() < split_params["nFolds"]:
            raise ValueError("Cannot have more folds ({numFolds}) than groups ({numGroups}) for group k-fold train/test split".format(
                numFolds=split_params["nFolds"], numGroups=group_labels.nunique()))
    else:
        group_labels = None

    folds = []
    prediction_type = core_params["prediction_type"]
    model_type = core_params["taskType"]

    target_map = None
    with listener.push_step(step_constants.ProcessingStep.KFOLD_STEP_PROCESSING_FOLD):
        for split, (train_idx, test_idx) in enumerate(kf.split(X=full_df_clean,  # training data
                                                               y=full_df_clean[core_params[TARGET_VARIABLE]],  # target variable; ignored for non-stratified splits
                                                               groups=group_labels)):  # group labels; ignored when not group k-fold
            with listener.push_step("[%s/%s]" % (split+1, split_params["nFolds"])):
                logger.info("Processing a fold")

                fold_ppfolder_context = preprocessing_folder_context.get_subfolder_context("fold_%s" % split)
                fold_mfolder_context = model_folder_context.get_subfolder_context("fold_%s" % split)
                # DO NOT CREATE IT. It should not be required.
                #os.makedirs(fold_ppfolder)
                fold_mfolder_context.create_if_not_exist()

                train_df, test_df = full_df_clean.loc[train_idx].reset_index().copy(), full_df_clean.loc[test_idx].reset_index().copy()

                # We rebuild the collector and preprocessing handler for each fold
                with listener.push_step(step_constants.ProcessingStep.STEP_COLLECTING):
                    collector = PredictionPreprocessingDataCollector(train_df, preprocessing_params)
                    collector_data = collector.build()

                preproc_handler = PredictionPreprocessingHandler.build(core_params, preprocessing_params, fold_ppfolder_context, nan_support=nan_support)
                preproc_handler.collector_data = collector_data

                pipeline = preproc_handler.build_preprocessing_pipeline(with_target=True)

                with listener.push_step(step_constants.ProcessingStep.STEP_PREPROCESS_TRAIN):
                    transformed_train = pipeline.fit_and_process(train_df)

                with listener.push_step(step_constants.ProcessingStep.STEP_PREPROCESS_TEST):
                    test_df_index = test_df.index.copy()
                    transformed_test = pipeline.process(test_df)
                    logger.info("Transformed valid: %s" % transformed_test["TRAIN"].stats())

                if prediction_type in (doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS):
                    target_map = preproc_handler.target_map
                    calibration_method = core_params.get("calibration", {}).get("calibrationMethod")
                    calibrate_on_test = core_params.get("calibration", {}).get("calibrateOnTestSet", True)
                    calibration_ratio = core_params.get("calibration", {}).get("calibrationDataRatio", doctor_constants.DEFAULT_CALIBRATION_DATA_RATIO)
                    with listener.push_step(step_constants.ProcessingStep.STEP_FITTING):
                        (clf, actual_params, prepared_X, iipd) = classification_fit(optimized_params,
                                                                                    core_params,
                                                                                    transformed_train,
                                                                                    model_folder_context=None,
                                                                                    transformed_test=transformed_test,
                                                                                    target_map=target_map,
                                                                                    with_sample_weight=with_sample_weight,
                                                                                    with_class_weight=with_class_weight,
                                                                                    calibration_method=calibration_method,
                                                                                    calibration_ratio=calibration_ratio,
                                                                                    calibrate_on_test=calibrate_on_test,
                                                                                    monotonic_cst=monotonic_cst)

                    model = ScorableModel.build(clf, model_type, prediction_type, modeling_params["algorithm"],
                                                preprocessing_params, ml_overrides_params)

                    if prediction_type == doctor_constants.BINARY_CLASSIFICATION:
                        scorer = binary_classification_scorer_with_valid(optimized_params, model,
                                                                         transformed_test, fold_mfolder_context, test_df_index, target_map=target_map, with_sample_weight=with_sample_weight)
                    else:
                        scorer = multiclass_scorer_with_valid(optimized_params, model,
                                                              transformed_test, fold_mfolder_context, test_df_index, target_map=target_map, with_sample_weight=with_sample_weight)
                else:
                    with listener.push_step(step_constants.ProcessingStep.STEP_FITTING):
                        (clf, actual_params, prepared_X, iipd) = regression_fit_single(optimized_params, core_params, transformed_train,
                                                                                       model_folder_context=None,
                                                                                       with_sample_weight=with_sample_weight,
                                                                                       monotonic_cst=monotonic_cst)

                    model = ScorableModel.build(clf, model_type, prediction_type, modeling_params["algorithm"],
                                                preprocessing_params, ml_overrides_params)
                    scorer = regression_scorer_with_valid(optimized_params, model, transformed_test, fold_mfolder_context, test_df_index,
                                                          with_sample_weight)

                scorer.score(with_assertions=False)
                scorer.save(dump_predicted=False)
                folds.append({
                    "test_idx": test_idx,
                    "scorer": scorer,
                    "transformed_train": transformed_train,
                    "transformed_test": transformed_test,
                    "fold_id": split,
                })
        diagnostics_folds = []
        for fold in folds:
            d = {}
            for k in ("scorer", "transformed_train", "transformed_test"):
                d[k] = fold[k]
            diagnostics_folds.append(d)

        logger.info("Folds done")
        arr = []
        preds = []
        if prediction_type in {doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS}:
            classes = folds[0]["scorer"].classes

        for fold in folds:
            predicted_df = fold["scorer"].predicted_df
            predicted_df.index = fold["test_idx"]
            predicted_df["fold_id"] = fold["fold_id"]
            arr.append(predicted_df)
            preds.append(pd.Series(fold["scorer"].test_predictions))

        global_predicted_df = pd.concat(arr, axis=0).sort_index()
        with model_folder_context.get_file_path_to_write(PREDICTED_FILENAME) as pred_csv_file_path:
            global_predicted_df.to_csv(pred_csv_file_path, sep="\t", header=True, index=False, encoding='utf-8')

        scorers = [f["scorer"] for f in folds]
        if prediction_type == doctor_constants.BINARY_CLASSIFICATION:
            scorer = CVBinaryClassificationModelScorer(scorers)
        elif prediction_type == doctor_constants.MULTICLASS:
            scorer = CVMulticlassModelScorer(scorers)
        elif prediction_type == doctor_constants.REGRESSION:
            scorer = CVRegressionModelScorer(scorers)
        gperf = scorer.score()

        gperf_without_overrides = scorer.score_without_overrides()
        if gperf_without_overrides is not None:
            model_folder_context.write_json(PERF_WITHOUT_OVERRIDES_FILENAME, gperf_without_overrides)

        # Save scorer predictions
        if prediction_type in {doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS}:
            predicted_class_series = pd.concat(preds, axis=0).sort_index().map(lambda x: classes[x])
            save_classification_statistics(predicted_class_series,
                                           model_folder_context,
                                           probas=(global_predicted_df.values if scorer.use_probas else None),
                                           sample_weight=transformed_full["weight"] if with_sample_weight else None,
                                           target_map=target_map)
        elif prediction_type == doctor_constants.REGRESSION:
            predicted_class_series = pd.concat(preds, axis=0).sort_index()
            save_regression_statistics(predicted_class_series, model_folder_context)

        if assertions_metrics is not None:
            if core_params["prediction_type"] == doctor_constants.BINARY_CLASSIFICATION:
                gperf["perCutData"]["assertionsMetrics"] = [assertion_metrics.to_dict() for assertion_metrics in assertions_metrics]
            elif core_params["prediction_type"] in {doctor_constants.MULTICLASS, doctor_constants.REGRESSION}:
                gperf["metrics"]["assertionsMetrics"] = assertions_metrics.to_dict()

        if overrides_metrics:
            if core_params["prediction_type"] == doctor_constants.BINARY_CLASSIFICATION:
                assert len(overrides_metrics) == len(gperf["perCutData"]["cut"]), \
                    "We should have the same number of override metrics and cuts"
                gperf["perCutData"]["overridesMetrics"] = [om.to_dict() for om in overrides_metrics]
            elif core_params["prediction_type"] in {doctor_constants.MULTICLASS, doctor_constants.REGRESSION}:
                gperf["metrics"]["overridesMetrics"] = overrides_metrics.to_dict()

        gperf["processed_feature_names"] = transformed_full["TRAIN"].columns()

        diagnostics.on_processing_all_kfold_end(prediction_type=prediction_type, folds=diagnostics_folds, with_sample_weight=with_sample_weight, perf_data=gperf)
        logger.info("gperf %s" % gperf)
        model_folder_context.write_json(PERF_FILENAME, gperf)

    with listener.push_step(step_constants.ProcessingStep.STEP_POSTTRAINING):
        try:
            # We don't want to fail the whole training if the training fails.
            if optimized_params.get("skipExpensiveReports"):
                logger.info("Skipping background rows drawing, feature and column importance computation")
            elif prediction_type in {doctor_constants.MULTICLASS, doctor_constants.BINARY_CLASSIFICATION} and not scorer.use_probas:
                logger.info("Cannot draw background rows, compute feature and column importance: model is not probabilistic")
            else:
                # for kfold cross-test models, the final model is trained on the whole data (transformed_full) and doesn't have a proper testset
                # (only surrogate models trained on each fold have test sets),
                # we will compute the explanations on the trainingset (transformed_full):
                preliminary_compute_for_explanations(model_folder_context, split_desc, transformed_full,
                                                     preprocessing_folder_context, split_folder_context, core_params,
                                                     modeling_params, preprocessing_params,
                                                     scorer.test_prediction_result)
        except Exception as e:
            logger.exception("Exception running the post training global explanations: {}".format(e))


def preliminary_compute_for_explanations(model_folder_context, split_desc, transformed_test,
                                         preprocessing_folder_context, split_folder_context, core_params,
                                         modeling_params, preprocessing_params, prediction_result):
    logger.info("Starting to compute global explanations")

    logger.info("Building model handler")
    model_handler = PredictionModelInformationHandler(
        split_desc, core_params, preprocessing_folder_context, model_folder_context, split_folder_context
    )
    model_handler.get_explainer().make_ready(save=True)

    def is_input_sentence_embedding_feature(feature):
        return (feature.get('type') == 'TEXT'
                and feature.get('text_handling') == 'SENTENCE_EMBEDDING'
                and feature.get('role') == 'INPUT')

    def is_input_image_embedding_feature(feature):
        return (feature.get('type') == 'IMAGE'
                and feature.get('image_handling') == 'EMBEDDING_EXTRACTION'
                and feature.get('role') == 'INPUT')

    reasons_for_skipping = []

    if modeling_params.get("algorithm") in {"KNN", "SVC_CLASSIFICATION", "SVM_REGRESSION"}:
        reasons_for_skipping.append(u"long computation time for {} algorithm ".format(modeling_params.get("algorithm")))

    sentence_embedding_features = [feature_name
                                   for (feature_name, feature) in preprocessing_params.get("per_feature", {}).items()
                                   if is_input_sentence_embedding_feature(feature)]
    if len(sentence_embedding_features) > 0:
        reasons_for_skipping.append(u"text embedding preprocessing of "
                                    u"the following features: {}".format(", ".join(sentence_embedding_features)))

    image_embedding_features = [feature_name
                                for (feature_name, feature) in preprocessing_params.get("per_feature", {}).items()
                                if is_input_image_embedding_feature(feature)]
    if len(image_embedding_features) > 0:
        reasons_for_skipping.append(u"image embedding preprocessing of "
                                    u"the following features: {}".format(", ".join(image_embedding_features)))

    max_columns_to_explain = intercom.jek_or_backend_get_call("ml/prediction/get-max-columns-to-explain")
    number_of_input_features = len(input_columns(preprocessing_params.get("per_feature", {})))
    if model_handler.get_explainer().column_importance_compute_has_failed and number_of_input_features > max_columns_to_explain:
        reasons_for_skipping.append(
            u"we couldn't compute surrogate feature importance and the training dataset has more than {} columns to explain ({})".format(
                max_columns_to_explain, number_of_input_features))

    if len(reasons_for_skipping) == 1:
        logger.info(u"Skipped feature importance computations due to {}. They can be done post-training.".format(
            reasons_for_skipping[0]))
        return
    elif len(reasons_for_skipping) > 1:
        logger.info(u"Skipped feature importance computations due to :\n- {}\n  They can be done post-training.".format(
            "\n- ".join(reasons_for_skipping)))
        return

    model_handler.compute_global_explanations_on_non_droppable_data(
        prediction_result.align_with_not_declined(transformed_test["UNPROCESSED"])
    )


def prediction_train_model_keras(transformed_normal, train_df, test_df, pipeline, modeling_params, core_params,
                                 per_feature, model_folder_context, listener, target_map, generated_features_mapping):
    """
        Fit a CLF on Keras, save it, computes intrinsic scores, writes them,
        scores a test set it, write scores and extrinsinc perf
    """
    from dataiku.doctor.deep_learning.keras_support import build_scored_validation_data, get_best_model

    prediction_type = core_params["prediction_type"]

    # Building necessary vars to be used in model
    # For the "normal" features, the preprocessing was performed on a subsample (that can be 100%) of the data, so
    # we can retrieve the shape of each normal input, but the data will be processed again on each batch, in order
    # to also preprocess special features
    train_normal_X = transformed_normal["TRAIN"]
    train_normal_y = transformed_normal["target"]

    gpu_config = get_gpu_config_from_core_params(core_params)

    # Execute user-written code
    with listener.push_step(step_constants.ProcessingStep.STEP_FITTING):
        keras_model, validation_sequence = get_best_model(train_normal_X, train_df, pipeline, test_df, per_feature,
                                                          modeling_params, model_folder_context, prediction_type, target_map,
                                                          generated_features_mapping, gpu_config)
        prepared_X, is_sparse = prepare_multiframe(train_normal_X,modeling_params) \
            if len(train_normal_X.columns()) > 0 else (np.empty((0,0)), False)

        iipd = get_initial_intrinsic_perf_data(prepared_X, is_sparse)

        iipd['modelInputNSpecialFeatures'] = len([f for f in per_feature.items() if f[1]['isSpecialFeature']])
        diagnostics.on_fitting_end(prediction_type=prediction_type, clf=keras_model, train_target=train_normal_y, features=train_normal_X.columns())

    with listener.push_step(step_constants.ProcessingStep.STEP_SAVING):
        # No need to save model here, already done in callbacks
        model_folder_context.write_json("actual_params.json", {"resolved": modeling_params})

    if len(test_df) > 0:
        with listener.push_step(step_constants.ProcessingStep.STEP_SCORING):
            preds, probas, valid_y, _ = build_scored_validation_data(keras_model, prediction_type, modeling_params,
                                                                     validation_sequence)
            # Then using PY_MEMORY scorers to compute the score, depending on prediction_type
            # Not providing:
            #  - sample_weights because not supported
            #  - valid because not displaying predicted data in the model
            if prediction_type == doctor_constants.REGRESSION:
                prediction_result = PredictionResult(preds)
                RegressionModelIntrinsicScorer(modeling_params, keras_model, train_normal_X, train_normal_y, pipeline,
                                               model_folder_context,
                                               prepared_X,
                                               iipd, False).score()
                scorer = RegressionModelScorer(modeling_params, prediction_result, valid_y, model_folder_context, test_unprocessed=None,
                                               test_X=None, test_df_index=None, test_sample_weight=None)
            elif prediction_type in (doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS):
                ClassificationModelIntrinsicScorer(modeling_params, keras_model,
                                                   train_normal_X, train_normal_y, target_map,
                                                   pipeline, model_folder_context, prepared_X,
                                                   iipd, False, False).score()
                if prediction_type == doctor_constants.BINARY_CLASSIFICATION:
                    decisions_and_cuts = DecisionsAndCuts.from_probas(probas, target_map)
                    scorer = BinaryClassificationModelScorer(modeling_params, model_folder_context, decisions_and_cuts, valid_y, target_map, test_unprocessed=None,
                                                             test_X=None, test_df_index=None, test_sample_weight=None)
                else:
                    prediction_result = ClassificationPredictionResult(target_map, probas=probas, unmapped_preds=preds)
                    scorer = MulticlassModelScorer(modeling_params, model_folder_context, prediction_result, valid_y, target_map, test_unprocessed=None,
                                                   test_X=None, test_df_index=None, test_sample_weight=None)
            scorer.score()
            scorer.save()
            diagnostics.on_scoring_end(scoring_results=DiagnosticsScoringResults.build_from_scorer(prediction_type, scorer))
