import logging
import math
from collections import Counter
from collections import defaultdict
from functools import reduce

import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
from sklearn.metrics import auc
from sklearn.metrics import average_precision_score
from sklearn.metrics import hamming_loss
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import precision_recall_fscore_support
from sklearn.model_selection import learning_curve
from sklearn.neighbors import KernelDensity

from dataiku.base.utils import safe_unicode_str
from dataiku.core import doctor_constants
from dataiku.doctor.prediction.common import compute_cost_matrix_score
from dataiku.doctor.prediction.common import get_threshold_optim_function
from dataiku.doctor.prediction.common import get_multiclass_metrics_averaging_method
from dataiku.doctor.prediction.common import make_lift_score
from dataiku.doctor.prediction.common import prepare_multiframe
from dataiku.doctor.prediction.common import weighted_quantile
from dataiku.doctor.prediction.custom_scoring import aggregate_custom_metrics_for_cross_val_model
from dataiku.doctor.prediction.custom_scoring import build_cv_per_cut_custom_metrics
from dataiku.doctor.prediction.custom_scoring import calculate_overall_classification_custom_metrics
from dataiku.doctor.prediction.custom_scoring import compute_custom_metrics_for_cut
from dataiku.doctor.prediction.custom_scoring import execute_parsed_custom_metric_function
from dataiku.doctor.prediction.custom_scoring import get_custom_evaluation_metric
from dataiku.doctor.prediction.custom_scoring import get_custom_metric_functions_binary_classif
from dataiku.doctor.prediction.custom_scoring import get_custom_score_from_custom_metrics_results
from dataiku.doctor.prediction.decisions_and_cuts import DecisionsAndCuts
from dataiku.doctor.prediction.linear_coefficients_computation import compute_coefs_if_available
from dataiku.doctor.prediction.overrides.ml_overrides_params import OVERRIDE_INFO_COL
from dataiku.doctor.prediction.overrides.ml_overrides_results import OverriddenClassificationPredictionResult
from dataiku.doctor.prediction.overrides.ml_overrides_results import OverridesResultsMixin
from dataiku.doctor.prediction.regression_scoring import GradientBoostingSummaryBuilder
from dataiku.doctor.prediction.regression_scoring import RandomForestSummaryBuilder
from dataiku.doctor.prediction.regression_scoring import TreeSummaryBuilder
from dataiku.doctor.prediction.scorable_model import ScorableModel
from dataiku.doctor.prediction.scorable_model import is_proba_aware
from dataiku.doctor.prediction.scoring_base import BaseCVModelScorer
from dataiku.doctor.prediction.scoring_base import ClassicalPredictionModelScorer
from dataiku.doctor.prediction.scoring_base import PredictionModelIntrinsicScorer
from dataiku.doctor.prediction.scoring_base import PredictionModelScorer
from dataiku.doctor.prediction.scoring_base import align_assertions_masks_with_not_declined
from dataiku.doctor.prediction.scoring_base import build_partial_dependence_plot
from dataiku.doctor.prediction.scoring_base import trim_curve
from dataiku.doctor.preprocessing.assertions import MLAssertionMetrics
from dataiku.doctor.preprocessing.assertions import MLAssertionsMetrics
from dataiku.doctor.preprocessing.dataframe_preprocessing import RescalingProcessor2
from dataiku.doctor.utils import dku_isnan
from dataiku.doctor.utils import dku_nonan
from dataiku.doctor.utils import remove_all_nan
from dataiku.doctor.utils.calibration import dku_calibration_curve
from dataiku.doctor.utils.calibration import dku_calibration_loss
from dataiku.doctor.utils.lift_curve import LiftBuilder
from dataiku.doctor.utils.metrics import check_test_set_ok_for_classification
from dataiku.doctor.utils.metrics import confusion_matrix_derived_metrics
from dataiku.doctor.utils.metrics import handle_failure
from dataiku.doctor.utils.metrics import log_loss
from dataiku.doctor.utils.metrics import m_average_precision_score
from dataiku.doctor.utils.metrics import mroc_auc_score
from dataiku.doctor.utils.skcompat import get_base_estimator
from dataiku.doctor.utils.skcompat import roc_curve
from dataikuscoring.utils.prediction_result import ClassificationPredictionResult
from dataikuscoring.utils.scoring_data import ScoringData

logger = logging.getLogger(__name__)


class ClassificationModelIntrinsicScorer(PredictionModelIntrinsicScorer):

    def __init__(self, modeling_params, clf, train_X, train_y, target_map, pipeline, out_folder_context, prepared_X, initial_intrinsic_perf_data, with_sample_weight, calibrate_proba):
        super(ClassificationModelIntrinsicScorer, self).__init__(modeling_params, clf, train_X, train_y, out_folder_context, prepared_X, with_sample_weight)
        self.initial_intrinsic_perf_data = initial_intrinsic_perf_data
        self.pipeline = pipeline
        self.calibrate_proba = calibrate_proba
        self.target_map = target_map

    def _extract_rescalers(self):
        return list(filter(lambda u: isinstance(u, RescalingProcessor2), self.pipeline.steps))

    def _get_classes(self):
        inverse_target_map = {i: k for k, i in self.target_map.items()}
        return [inverse_target_map[i] for i in self.clf.classes_]

    def score(self):
        ret = self.initial_intrinsic_perf_data
        logger.info("Intrinsic scoring")

        if self.calibrate_proba:
            uncalibrated_clf = get_base_estimator(self.clf)
        else:
            uncalibrated_clf = self.clf

        self.add_raw_feature_importance_if_exists(uncalibrated_clf, ret)

        # Linear coefficients (binary only)
        compute_coefs_if_available(uncalibrated_clf, self.train_X, self.prepared_X, self.train_y, self._extract_rescalers(), ret, True)

        if self.modeling_params.get("skipExpensiveReports"):
            logger.info("Skipping potentially expensive reports")  # tree(s) summary, PDP

        else:
            logger.info("Extracting rescalers")
            rescalers = self._extract_rescalers()

            if self.modeling_params['algorithm'] == 'DECISION_TREE_CLASSIFICATION':
                logger.info("Creating decision tree summary")
                tree_summary = TreeSummaryBuilder(uncalibrated_clf, self.train_X.columns(), rescalers, False,
                                                  self.with_sample_weight, classes=self._get_classes()).build()
                self.out_folder_context.write_json("tree.json", tree_summary)

            if self.modeling_params['algorithm'] == 'GBT_CLASSIFICATION':
                logger.info("Creating gradient boosting trees summary")
                summary = GradientBoostingSummaryBuilder(uncalibrated_clf, self.train_X.columns(), rescalers, False,
                                                         self.modeling_params["max_ensemble_nodes_serialized"],
                                                         self.with_sample_weight, classes=self._get_classes()).build()
                self.out_folder_context.write_json( "trees.json", summary)
                logger.info("Creating GBT PDP")
                ret["partialDependencies"] = build_partial_dependence_plot(uncalibrated_clf, self.train_X, self.train_y, rescalers)

            if self.modeling_params['algorithm'] == 'RANDOM_FOREST_CLASSIFICATION':
                logger.info("Creating random forest trees summary")
                summary = RandomForestSummaryBuilder(uncalibrated_clf, self.train_X.columns(), rescalers, False,
                                                     self.modeling_params["max_ensemble_nodes_serialized"],
                                                     self.with_sample_weight, classes=self._get_classes()).build()
                self.out_folder_context.write_json( "trees.json", summary)

        if self.modeling_params['algorithm'] == 'LARS':
            self.out_folder_context.write_json("coef_path.json", {
                "path": [[[t for t in x] for x in c] for c in uncalibrated_clf.coef_path_],
                "features": self.train_X.columns(),
                "currentIndex": uncalibrated_clf.current_index
            })

        # Learning curve if requested
        if self.modeling_params["computeLearningCurves"]:
            logger.info("Computing learning curves")
            train_X, _ = prepare_multiframe(self.train_X, self.modeling_params)
            train_y = self.train_y.astype(int)

            train_sizes, train_scores, valid_scores = learning_curve(uncalibrated_clf, train_X, train_y)
            ret["learningCurve"] = {
                "samples": train_sizes,
                "trainScoreMean": np.mean(train_scores, axis=1),
                "trainScoreStd": np.std(train_scores, axis=1),
                "cvScoreMean": np.mean(valid_scores, axis=1),
                "cvScoreStd": np.std(valid_scores, axis=1)
            }
        ret["probaAware"] = is_proba_aware(uncalibrated_clf, self.modeling_params['algorithm'])

        # Dump the perf
        self.out_folder_context.write_json("iperf.json", ret)


def binary_classification_predict_ensemble(clf, target_map, threshold, data, output_probas=True, has_target=False,
                                           for_all_cuts=True):
    """returns (prediction df - one column, probas df)"""
    if has_target:
        clf.set_with_target_pipelines_mode(True)

    logger.info("Start actual predict")
    proba_df = clf.predict_proba_as_dataframe(data)
    logger.info("Done actual predict")

    probas_one = proba_df.values[:, 1]
    preds = (probas_one > threshold).astype(int)
    probas = proba_df.values
    prediction_result = ClassificationPredictionResult(target_map, probas=probas, unmapped_preds=preds)

    pred_df = pd.DataFrame({"prediction": prediction_result.preds})
    pred_df.index = proba_df.index

    decisions_and_cuts = DecisionsAndCuts.from_probas(probas, target_map) if for_all_cuts else None
    return ScoringData(prediction_result=prediction_result, preds_df=pred_df,
                       probas_df=(proba_df if output_probas else None), decisions_and_cuts=decisions_and_cuts)


def binary_classification_predict_single(model, pipeline, modeling_params, data, output_probas=True, for_all_cuts=True):
    """returns (prediction df - one column, probas df)"""
    logger.info("Prepare to predict ...")
    transformed = pipeline.process(data)
    features_X_orig = features_X = transformed["TRAIN"]
    unprocessed_df = transformed["UNPROCESSED"]
    features_X, _ = prepare_multiframe(features_X, modeling_params)
    return binary_classification_predict_single_from_prepared(model, features_X, unprocessed_df, features_X_orig.index, output_probas, for_all_cuts=for_all_cuts)


def binary_classification_predict_single_from_prepared(model, X, unprocessed_df, orig_index, output_probas, for_all_cuts=True):
    if for_all_cuts:
        decisions_and_cuts = model.compute_decisions_and_cuts(X, unprocessed_df)
        prediction_result = decisions_and_cuts.get_prediction_result_for_nearest_cut(model.get_threshold())
    else:
        decisions_and_cuts = None
        prediction_result = model.compute_predictions(X, unprocessed_df)

    pred_df = pd.DataFrame({"prediction": prediction_result.preds})
    pred_df.index = orig_index

    if prediction_result.has_probas() and output_probas:
        proba_df = pd.DataFrame(prediction_result.probas, columns=["proba_%s" % x for x in model.get_classes()])
        proba_df.index = orig_index
    else:
        proba_df = None

    return ScoringData(prediction_result=prediction_result, preds_df=pred_df, probas_df=proba_df,
                       decisions_and_cuts=decisions_and_cuts)


def binary_classification_predict(model, pipeline, modeling_params, target_map, threshold, data, output_probas=True,
                                  ensemble_has_target=False, for_all_cuts=True):
    """returns the predicted dataframe"""

    if modeling_params["algorithm"] == "PYTHON_ENSEMBLE":
        scoring_data = binary_classification_predict_ensemble(model.clf, target_map, threshold, data, output_probas,
                                                              has_target=ensemble_has_target, for_all_cuts=for_all_cuts)
    else:
        model.set_threshold(threshold)
        scoring_data = binary_classification_predict_single(model, pipeline, modeling_params, data, output_probas,
                                                            for_all_cuts=for_all_cuts)

    return scoring_data


def binary_classif_scoring_add_percentile_and_cond_outputs(pred_df, recipe_desc, model_folder_context, cond_outputs,
                                                           target_map):
    inv_map = {
        int(class_id): label
        for label, class_id in target_map.items()
    }
    classes = [class_label for (_, class_label) in sorted(inv_map.items())]
    proba_cols = [u"proba_{}".format(safe_unicode_str(c)) for c in classes]
    has_probas = recipe_desc["outputProbabilities"] or (cond_outputs and
                                                        len([co for co in cond_outputs
                                                             if co["input"] in proba_cols]))
    has_percentiles = recipe_desc["outputProbaPercentiles"] or (cond_outputs and
                                                                len([co for co in cond_outputs if
                                                                     co["input"] == "proba_percentile"]))
    if has_percentiles:
        model_perf = model_folder_context.read_json("perf.json")
        if "probaPercentiles" in model_perf and model_perf["probaPercentiles"]:
            percentile = pd.Series(model_perf["probaPercentiles"])
            proba_1 = u"proba_{}".format(safe_unicode_str(inv_map[1]))
            pred_df["proba_percentile"] = pred_df[proba_1].apply(
                lambda p: percentile.where(percentile <= p).count() + 1)
        else:
            raise Exception("Probability percentiles are missing from model.")
    if cond_outputs:
        for co in cond_outputs:
            inp = pred_df[co["input"]]
            acc = inp.notnull()  # condition accumulator
            for r in co["rules"]:
                if r["operation"] == 'GT':
                    cond = inp > r["operand"]
                elif r["operation"] == 'GE':
                    cond = inp >= r["operand"]
                elif r["operation"] == 'LT':
                    cond = inp < r["operand"]
                elif r["operation"] == 'LE':
                    cond = inp <= r["operand"]
                pred_df.loc[acc & cond, co["name"]] = r["output"]
                acc = acc & (~cond)
            pred_df.loc[acc, co["name"]] = co.get("defaultOutput", "")
    if has_percentiles and not recipe_desc["outputProbaPercentiles"]:  # was only for conditional outputs
        pred_df.drop("proba_percentile", axis=1, inplace=True)
    if has_probas and not recipe_desc["outputProbabilities"]:  # was only for conditional outputs
        pred_df.drop(proba_cols, axis=1, inplace=True)

    return pred_df


def multiclass_predict_ensemble(clf, target_map, data, output_probas, has_target=False):
    if has_target:
        clf.set_with_target_pipelines_mode(True)
    preds_df_unmapped = clf.predict_as_dataframe(data).astype(int)
    preds = preds_df_unmapped["prediction"].values
    preds_remapped = np.zeros(preds.shape, dtype="object")
    inv_map = {
        int(class_id): label
        for label, class_id in target_map.items()
    }
    for (mapped_value, original_value) in inv_map.items():
        idx = preds == mapped_value
        preds_remapped[idx] = original_value
    pred_df = pd.DataFrame({"prediction": preds_remapped})
    pred_df.index = preds_df_unmapped.index

    if output_probas:
        proba_df = clf.predict_proba_as_dataframe(data)
        probas = proba_df.values
    else:
        proba_df = None
        probas = None

    prediction_result = ClassificationPredictionResult(target_map, probas=probas, unmapped_preds=preds)
    return ScoringData(prediction_result=prediction_result, preds_df=pred_df, probas_df=proba_df)


def multiclass_predict_single(model, pipeline, modeling_params, data, output_probas):
    logger.info("Prepare to predict ...")

    transformed = pipeline.process(data)
    features_X_orig = features_X = transformed["TRAIN"]
    features_X, _ = prepare_multiframe(features_X, modeling_params)
    unprocessed_df = transformed["UNPROCESSED"]
    return multiclass_predict_single_from_prepared(model, features_X, unprocessed_df, features_X_orig.index,
                                                   output_probas)


def multiclass_predict_single_from_prepared(model, X, unprocessed_df, orig_index, output_probas):
    prediction_result = model.compute_predictions(X, unprocessed_df)

    proba_df = None
    if output_probas and prediction_result.has_probas():
        proba_df = pd.DataFrame(prediction_result.probas, columns=["proba_%s" % x for x in model.get_classes()])
        proba_df.index = orig_index

    pred_df = pd.DataFrame({"prediction": prediction_result.preds})
    pred_df.index = orig_index

    return ScoringData(prediction_result=prediction_result, preds_df=pred_df, probas_df=proba_df)


def multiclass_predict(model, pipeline, modeling_params, target_map, data, output_probas=True, ensemble_has_target=False):
    """
    :rtype: ScoringData
    """
    if modeling_params["algorithm"] == "PYTHON_ENSEMBLE":
        scoring_data = multiclass_predict_ensemble(model.clf, target_map, data, output_probas, has_target=ensemble_has_target)
    else:
        scoring_data = multiclass_predict_single(model, pipeline, modeling_params, data, output_probas)
    return scoring_data


def format_proba_density(data, sample_weight=None):
    if len(data) == 0:
        return []
    # Popular bandwidth estimation by Scott (1992), also used in scikit-learn 1.0
    h = 1.06 * np.std(data) * math.pow(len(data), -.2)
    if h <= 0:
        h = 0.06
    if sample_weight is not None and len(np.unique(data)) == 1:
        sample_weight = None

    # Quantize data by computing a histogram => O(N)
    n_bins = 10000
    histogram_weights, histogram_bins = np.histogram(data, range=(0,1), bins=n_bins, weights=sample_weight)
    histogram_midpoints = histogram_bins[0:-1] + 0.5/n_bins
    # Remove empty bins (sample_weight muste be > 0 in KernelDensity)
    positive_weigth_mask = histogram_weights > 0
    positive_weigth_index = np.arange(0, n_bins).astype(int)[positive_weigth_mask]
    # Shuffle data to prevent sorted-data worst-case in scikit-learn pre-1.0
    np.random.seed(12345)
    np.random.shuffle(positive_weigth_index)
    histogram_weights = histogram_weights[positive_weigth_index]
    histogram_midpoints = histogram_midpoints[positive_weigth_index]
    # Compute the KDTree on histogram => build in O(n_bins log(n_bins))
    kde = KernelDensity(kernel='gaussian', bandwidth=h).fit(histogram_midpoints.reshape(-1, 1), sample_weight=histogram_weights)
    X_plot = np.linspace(0.0, 1.0, num=100)[:, np.newaxis]
    # Binary tree is built on n_bins weighted samples => query in O(log(n_bins))
    Y_plot = np.exp(kde.score_samples(X_plot))
    return [v if not dku_isnan(v) else 0 for v in Y_plot]


def format_all_proba_densities(probas, target_map, sample_weight=None):
    ret = {}
    for clazz in target_map.keys():
        idx = int(target_map[clazz])
        ret[clazz] = _format_proba_densities_for_class(probas, idx, sample_weight)
    return ret


def _format_proba_densities_for_class(probas, idx, sample_weight=None):
    pdd_for_class = {}
    proba_array = probas[:, idx]
    proba_array_no_nan = proba_array[~pd.isna(proba_array)]
    if sample_weight is None:
        pdd_for_class['density'] = format_proba_density(proba_array_no_nan)
        pdd_for_class['median'] = get_np_median_or_none(proba_array_no_nan)
    else:
        pdd_for_class['density'] = format_proba_density(proba_array_no_nan, sample_weight.values)
        pdd_for_class['median'] = weighted_quantile(np.sort(proba_array_no_nan), sample_weight.values, 0.5)
    return pdd_for_class


def format_all_conditional_proba_densities(classes, target_map, probas, valid_y, sample_weight=None):

    """
    The result of this function represents, for each possible class c:
        - the distribution of the predicted probabilities of being class c, conditionally on being actually class c (double[] actualIsThisClass)
        - the distribution of the predicted probabilities of being class c,
        conditionally on being actually a class different from c (double[] actualIsNotThisClass)
    """

    ret = {}
    for class_actual in classes:
        class_actual_id = int(target_map[class_actual])
        logger.info("Density for %s (id %s)" % (class_actual, class_actual_id))
        logger.info("valid_y shape = %s" % str(valid_y.shape))

        class_proba = probas[:, class_actual_id]
        logger.info("CP: %s (%s)" % (class_proba.__class__, str(class_proba.shape)))
        actual_is_this_class_mask = (valid_y.values == class_actual_id)
        actual_is_not_this_class_mask = (valid_y.values != class_actual_id)

        logger.info("Actual shape = %s" % str(actual_is_this_class_mask.shape))
        logger.info("MASK is %s " % actual_is_this_class_mask.__class__)

        class_proba_isactual = class_proba[actual_is_this_class_mask]
        class_proba_isnotactual = class_proba[actual_is_not_this_class_mask]
        # logger.info("Class proba %s" % class_proba_isactual)
        logger.info("Class proba shape %s" % str(class_proba_isactual.shape))

        if sample_weight is None:
            ret[class_actual] = {
                "actualIsThisClass": format_proba_density(class_proba_isactual),
                "actualIsNotThisClass": format_proba_density(class_proba_isnotactual),
                "actualIsThisClassMedian": get_np_median_or_none(class_proba_isactual),
                "actualIsNotThisClassMedian": get_np_median_or_none(class_proba_isnotactual)
            }
        else:
            ret[class_actual] = {
                "actualIsThisClass": format_proba_density(class_proba_isactual, sample_weight[actual_is_this_class_mask].values),
                "actualIsNotThisClass": format_proba_density(class_proba_isnotactual, sample_weight[actual_is_not_this_class_mask].values),
                "actualIsThisClassMedian": weighted_quantile(np.sort(class_proba_isactual), sample_weight[actual_is_this_class_mask].values, 0.5),
                "actualIsNotThisClassMedian": weighted_quantile(np.sort(class_proba_isnotactual), sample_weight[actual_is_not_this_class_mask].values, 0.5)
            }
    return ret


def get_np_median_or_none(series):
    if len(series) == 0:
        return None
    return np.median(series)


def compute_proba_distribution(probas, valid_y, sample_weights):

    bins = [(i * 1.0 / 10) for i in range(11)]

    ret = {
        "bins": bins,
        "probaDistribs": np.zeros((probas.shape[1], len(bins) - 1))
    }

    for class_id in range(probas.shape[1]):
        class_id_mask = valid_y.values == class_id
        sample_weights_masked = None if sample_weights is None else sample_weights[class_id_mask]
        ret["probaDistribs"][class_id, :], _ = np.histogram(probas[class_id_mask, 1], bins=bins,
                                                            weights=sample_weights_masked)
    ret["probaDistribs"] = ret["probaDistribs"].tolist()
    return ret


def binary_classification_scorer_with_valid(modeling_params, model, valid, out_folder_context, test_df_index, target_map, with_sample_weight=False):
    valid_y = valid["target"].astype(int)
    if with_sample_weight:
        valid_w = valid["weight"]
    else:
        valid_w = None

    check_test_set_ok_for_classification(valid_y)
    valid_X = valid["TRAIN"]
    valid_X_prepared, _ = prepare_multiframe(valid_X, modeling_params)

    decisions_and_cuts = model.compute_decisions_and_cuts(valid_X_prepared, valid["UNPROCESSED"])
    decisions_and_cuts.assert_not_all_declined()
    assertions = valid.get("assertions", None)
    return BinaryClassificationModelScorer(modeling_params, out_folder_context, decisions_and_cuts, valid_y, target_map, test_unprocessed=valid["UNPROCESSED"],
                                           test_X=valid_X, test_df_index=test_df_index, test_sample_weight=valid_w, assertions=assertions)


def search_optimized_threshold(scores, cuts, greater_is_better):
    logger.info("Starting optimal threshold search")

    assert len(cuts) == len(scores), "The length of scores (length = %d) and cuts (length = %d) must be the same" % (len(cuts), len(scores))

    np_scores = np.array(scores)
    if greater_is_better:
        best_cut_index = np.argmax(np_scores)
    else:
        best_cut_index = np.argmin(np_scores)
    
    best_cut = cuts[best_cut_index]
    logger.info("Found threshold %s " % best_cut)
    return best_cut


def compute_optimized_threshold(valid_y, decisions_and_cuts, metric_params, sample_weight=None):
    logger.info("Starting threshold optim")
    (func, greater_is_better) = get_threshold_optim_function(metric_params)
    scores = []
    for prediction_result, cut in decisions_and_cuts:
        decision = prediction_result.unmapped_preds_not_declined
        score = func(valid_y.astype(int), decision, sample_weight=sample_weight)
        logger.info("AT CUT %f score %f (pred_true=%d)" % (cut, score, np.count_nonzero(decision)))
        scores.append(score)

    best_cut = search_optimized_threshold(scores, decisions_and_cuts.get_cuts(), greater_is_better)

    logger.info("Selected threshold %s " % best_cut)
    return best_cut


def compute_assertions_for_decision(decision, assertions, target_map):
    assertions_metrics = MLAssertionsMetrics()
    for assertion in assertions:
        mask = assertion.mask.values  # mask and decision are aligned, we can work with np arrays
        condition = assertion.params["assertionCondition"]
        nb_rows_in_mask = np.sum(mask)
        nb_dropped_rows = assertion.nb_initial_rows - nb_rows_in_mask
        if nb_rows_in_mask > 0:
            expected_class_index = target_map[condition["expectedClass"]]
            valid_ratio = np.sum(mask & (decision == expected_class_index)) / (1.0 * nb_rows_in_mask)
            result = bool(valid_ratio >= condition["expectedValidRatio"])
        else:
            valid_ratio = None
            result = None

        new_assertion_metrics = MLAssertionMetrics(result, assertion.nb_initial_rows, nb_dropped_rows,
                                                   valid_ratio, assertion.params["name"])
        assertions_metrics.add_assertion_metrics(new_assertion_metrics)
    return assertions_metrics


def compute_assertions_for_binary_classification(decision_candidates, assertions, target_map):
    logger.info(u"Computing assertions metrics for assertions {}".format(assertions.printable_names()))
    assertions_metrics = [compute_assertions_for_decision(decision, assertions, target_map)
                          for decision in decision_candidates]
    logger.info("Finished computing assertions metrics")
    return assertions_metrics


def compute_assertions_for_multiclass_classification(preds, assertions, target_map):
    logger.info(u"Computing assertions metrics for assertions {}".format(assertions.printable_names()))
    assertions_metrics = compute_assertions_for_decision(preds, assertions, target_map)
    logger.info("Finished computing assertions metrics")
    return assertions_metrics


def compute_assertions_and_overrides_for_classification_from_clf(clf, model_type, modeling_params,
                                                                 prediction_type, preprocessing_params,
                                                                 target_map, transformed, overrides_params):
    logger.info("Computing assertions and overrides from model")
    model = ScorableModel.build(clf, model_type, prediction_type,
                                modeling_params['algorithm'],
                                preprocessing_params, overrides_params)
    has_assertions = "assertions" in transformed
    overrides_metrics = None
    if prediction_type == doctor_constants.BINARY_CLASSIFICATION:
        # no need to set threshold in model, because assertions will be computed for all cuts
        valid_X = transformed["TRAIN"]
        valid_X, _ = prepare_multiframe(valid_X, modeling_params)

        decisions_and_cuts = model.compute_decisions_and_cuts(valid_X, transformed["UNPROCESSED"])
        decisions_candidates = decisions_and_cuts.get_unmapped_preds_not_declined_list()

        if model.has_overrides():
            overrides_metrics = []
            for prediction_result, _ in decisions_and_cuts:
                assert isinstance(prediction_result, OverridesResultsMixin)
                overrides_metrics.append(prediction_result.compute_and_return_overrides_metrics())
        if has_assertions:

            accepted_assertions = align_assertions_masks_with_not_declined(transformed["assertions"],
                                                                           decisions_and_cuts.align_with_not_declined)
            assertions_metrics = compute_assertions_for_binary_classification(decisions_candidates,
                                                                              accepted_assertions,
                                                                              target_map)
        else:
            assertions_metrics = None
    else:
        valid_X = transformed["TRAIN"]
        valid_X, _ = prepare_multiframe(valid_X, modeling_params)
        prediction_result = model.compute_predictions(valid_X, transformed["UNPROCESSED"])
        assertions_metrics = compute_assertions_for_multiclass_classification(prediction_result.unmapped_preds_not_declined,
                                                                              transformed["assertions"],
                                                                              target_map) if has_assertions else None
        if model.has_overrides():
            assert isinstance(prediction_result, OverridesResultsMixin)
            overrides_metrics = prediction_result.compute_and_return_overrides_metrics()
    return assertions_metrics, overrides_metrics


class BinaryClassificationModelScorer(ClassicalPredictionModelScorer):

    def __init__(self, modeling_params, out_folder_context, decisions_and_cuts, test_y, target_map, test_unprocessed=None, test_X=None, test_df_index=None,
                 test_sample_weight=None, assertions=None):
        """
        :param dict modeling_params: modeling choices of the current ML task (see PredictionModelingParams.java in backend)
        :param dataiku.base.folder_context.FolderContext | None out_folder_context: directory where predicted data and perf.json will be written
        :param DecisionsAndCuts decisions_and_cuts: preds and probas for all cuts
        :param pandas.Series test_y: 1-dimensional array representing the ground truth target on the test set
        :param dict target_map: map of named class (label) to class id in range(len(target_map))
        :param pandas.DataFrame | None test_unprocessed: The "UNPROCESSED" value returned from processing the test dataset via pipeline.process().
        Required for the custom metric x_valid parameter.
        :param dataiku.doctor.multiframe.MultiFrame | None test_X: The "TRAIN" value returned from processing the test dataset via pipeline.process().
        If None, no data will be written on disk (e.g. training recipes).
        :param test_df_index: Pandas index of the input dataframe of the original test set, prior to any processing
        :param Series test_sample_weight: 1-dimensional array representing sample weights on the test set
        :param MLAssertions assertions: collection of assertions based on ML performance metrics
        """
        super(BinaryClassificationModelScorer, self).__init__(modeling_params, out_folder_context,
                                                              decisions_and_cuts.align_with_not_declined, test_y,
                                                              test_unprocessed, test_X, test_df_index,
                                                              test_sample_weight, assertions)
        # the raw_test_X_index contains the index before aligning with the accepted (i.e. not declined) predictions.
        self.raw_test_X_index = test_X.index if test_X is not None else None
        self.target_map = target_map
        self.inv_map = {
            int(class_id): label
            for label, class_id in self.target_map.items()
        }
        self.classes = [class_label for (_, class_label) in sorted(self.inv_map.items())]
        self.decisions_and_cuts = decisions_and_cuts

        self.test_predictions = None
        self.use_probas = decisions_and_cuts.has_probas()
        self.scorer_without_overrides = self._instantiate_scorer_without_overrides(
            modeling_params, out_folder_context, decisions_and_cuts, test_y, target_map, test_unprocessed, test_X,
            test_df_index, test_sample_weight, assertions)

    @staticmethod
    def _instantiate_scorer_without_overrides(modeling_params, out_folder_context, decisions_and_cuts, test_y,
                                              target_map, test_unprocessed, test_X, test_df_index, test_sample_weight,
                                              assertions):
        prediction_results = decisions_and_cuts.get_prediction_results()
        if any([not isinstance(pr, OverriddenClassificationPredictionResult) for pr in prediction_results]):
            return None
        # Since pr.raw_prediction_result will never be of class
        # OverriddenClassificationPredictionResult we won't be entering into an instantiation loop for the Scorer
        cuts = decisions_and_cuts.get_cuts()
        raw_prediction_results = [pr.raw_prediction_result for pr in prediction_results]
        raw_decisions_and_cuts = DecisionsAndCuts(cuts, raw_prediction_results)
        return BinaryClassificationModelScorer(
            modeling_params, out_folder_context, raw_decisions_and_cuts, test_y, target_map,
            test_unprocessed, test_X, test_df_index, test_sample_weight, assertions)

    def can_score(self):
        n_classes_valid = np.unique(self.test_y).shape[0]
        if n_classes_valid < 2:
            logger.error("Both classes must be present")
            return False, doctor_constants.PREPROC_ONECLASS
        return True, None

    def save(self, dump_predicted=True):
        PredictionModelScorer.save(self, dump_predicted)
        # Dump the preds
        predicted_class_series = pd.Series(self.test_predictions).map(lambda x: self.classes[x])
        if self.decisions_and_cuts.has_probas() and "usedThreshold" in self.ret:
            test_probas = self.decisions_and_cuts.get_prediction_result_for_nearest_cut(self.ret["usedThreshold"]).probas
            predicted_class_series_per_cut = []
            cuts = []
            for prediction_result, cut in self.decisions_and_cuts:
                predicted_class_series_per_cut.append(pd.Series(prediction_result.unmapped_preds_not_declined).map(lambda x: self.classes[x]))
                cuts.append(cut)
        else:
            test_probas = None
            predicted_class_series_per_cut = None
            cuts = None
        save_classification_statistics(predicted_class_series,
                                       self.out_folder_context,
                                       test_probas,
                                       self.test_sample_weight,
                                       self.target_map,
                                       predicted_class_series_per_cut=predicted_class_series_per_cut,
                                       cuts=cuts)

    @staticmethod
    def _get_detailed_prediction_results(decisions_and_cuts, threshold):
        """
        :type decisions_and_cuts: DecisionsAndCuts
        :type threshold: float
        :rtype: (np.ndarray, np.ndarray | None, np.ndarray | None, ClassificationPredictionResult)
        """
        prediction_result = decisions_and_cuts.get_prediction_result_for_nearest_cut(threshold)
        unmapped_preds = prediction_result.unmapped_preds_not_declined

        if not prediction_result.has_probas():
            return unmapped_preds, None, None, prediction_result

        probas = prediction_result.probas_not_declined
        if isinstance(prediction_result, OverriddenClassificationPredictionResult):
            raw_probas = prediction_result.raw_prediction_result.probas
        else:
            raw_probas = probas
        return unmapped_preds, probas, raw_probas, prediction_result

    def _do_score(self, with_assertions, treat_metrics_failure_as_error=True):
        custom_metrics_threshold_dependent, custom_metrics_threshold_independent = get_custom_metric_functions_binary_classif(self.modeling_params['metrics'],
                                                                                                                              self.test_unprocessed)
        if self.use_probas:
            if self.modeling_params["autoOptimizeThreshold"]:
                best_cut = compute_optimized_threshold(self.test_y, self.decisions_and_cuts, self.modeling_params["metrics"],
                                                       self.test_sample_weight)
                self.ret["optimalThreshold"] = best_cut
                self.ret["usedThreshold"] = best_cut
            else:
                self.ret["usedThreshold"] = self.modeling_params["forcedClassifierThreshold"]
        else:
            # the threshold doesn't matter, really, but you'll see it in the UI
            self.ret["usedThreshold"] = self.modeling_params.get("forcedClassifierThreshold", 0.5)

        # Set predictions as the ones from the used threshold.
        # Note, we retrieve the probas of given threshold, but it would be the same value for all
        # thresholds, as it is threshold independent.
        test_predictions, test_probas, raw_probas, prediction_result = (
            self._get_detailed_prediction_results(self.decisions_and_cuts, self.ret["usedThreshold"]))

        pcd = self.compute_per_cut_data(self.decisions_and_cuts,
                                        with_assertions,
                                        custom_metrics_threshold_dependent)

        self.ret["perCutData"] = pcd

        if self.use_probas:
            self.ret["tiMetrics"] = {}
            # np.sort shouldn't be necessary but works around a microbug leading to non-monotonous percentiles.
            # See https://github.com/numpy/numpy/issues/10373
            # Percentiles could include [..., a, b, a, ...] with b < a at the 15 or 16th decimal place,
            # which could lead to different probaPercentile results at prediction time.
            self.ret["probaPercentiles"] = np.sort(np.quantile(test_probas[:, 1], [float(x + 1) / 100 for x in range(99)]))

        mapped_preds = np.zeros(test_predictions.shape, object)

        logger.info("preds %s" % test_predictions)
        logger.info("MAPPED SHAPE %s" % mapped_preds.shape)

        for k, v in self.target_map.items():
            v = int(v)
            logger.info("k=%s v=%s" % (k, v))
            mask = test_predictions == v
            logger.info("Mask data %s", mask)
            logger.info("mapped pred %s" % mapped_preds.__class__)
            mapped_preds[mask] = k

        logger.info("MAPPED PREDS %s" % mapped_preds)

        if custom_metrics_threshold_independent:
            if "tiMetrics" not in self.ret:
                # calc of all metrics here will fail, as probabilities are not available
                # the tiMetrics dict itself is only initialised above if probabilities are available
                # in this case, we init tiMetrics, so that it is obvious in the frontend that the ti specific metrics have failed
                # this also saves us from adjusting the frontend metrics display code for this specific case
                self.ret["tiMetrics"] = {}

            custom_metrics_results_ti = []

            for custom_metric_function_container in custom_metrics_threshold_independent:
                metric = custom_metric_function_container["metric"]
                has_failure = custom_metric_function_container['hasFailure']

                custom_metric_result = {
                    'metric': metric
                }

                if has_failure:
                    custom_metric_result['didSucceed'] = False
                    custom_metric_result['error'] = custom_metric_function_container['error']
                else:
                    if not self.use_probas:
                        logger.warning("Cannot compute custom metric with probabilities on model without probabilities")
                        custom_metric_result['didSucceed'] = False
                        custom_metric_result['error'] = "Cannot compute custom metric with probabilities on model without probabilities"
                    else:
                        custom_metric_function = custom_metric_function_container["function"]
                        custom_metric_result = execute_parsed_custom_metric_function(custom_metric_function=custom_metric_function,
                                                                                     custom_metric_result=custom_metric_result,
                                                                                     valid_y=self.test_y,
                                                                                     preds_or_probas=test_probas,
                                                                                     sample_weight=self.test_sample_weight)
                custom_metrics_results_ti.append(custom_metric_result)

            self.ret["tiMetrics"]['customMetricsResults'] = custom_metrics_results_ti
            # "customScore" is deprecated in favour of custom metrics. We set it for backwards compatibility reasons
            if self.modeling_params["metrics"]["evaluationMetric"] == "CUSTOM" and self.use_probas and get_custom_evaluation_metric(self.modeling_params["metrics"])["needsProbability"]:
                self.ret["tiMetrics"]["customScore"] = get_custom_score_from_custom_metrics_results(
                    self.ret["tiMetrics"]["customMetricsResults"],
                    self.modeling_params["metrics"]["customEvaluationMetricName"]
                )

        if self.use_probas:
            # Threshold-independent metrics
            self.ret["tiMetrics"]["auc"] = handle_failure(lambda: mroc_auc_score(self.test_y, test_probas, sample_weight=self.test_sample_weight),
                                                           treat_metrics_failure_as_error)
            self.ret["tiMetrics"]["logLoss"] = handle_failure(lambda: log_loss(self.test_y, test_probas, sample_weight=self.test_sample_weight),
                                                               treat_metrics_failure_as_error)
            self.ret["tiMetrics"]["lift"] = handle_failure(lambda: make_lift_score(self.modeling_params["metrics"])(self.test_y, test_probas, sample_weight=self.test_sample_weight),
                                                            treat_metrics_failure_as_error)
            self.ret["tiMetrics"]["averagePrecision"] = handle_failure(lambda: m_average_precision_score(self.test_y, test_probas, sample_weight=self.test_sample_weight),
                                                                        treat_metrics_failure_as_error)

            # ROC and Lift for proba-aware classifiers
            false_positive_rates, true_positive_rates, thresholds = roc_curve(self.test_y, test_probas[:, 1],
                                                                              sample_weight=self.test_sample_weight)
            # full roc curve data
            roc_data = zip(false_positive_rates, true_positive_rates, thresholds)
            # trim the data as we don't need all points for visualization
            # in a single-element array for k-fold compatibility
            self.ret["rocVizData"] = [[{"x": x, "y": y, "p": p} for (x, y, p) in trim_curve(roc_data)]]

            predicted = pd.Series(data=test_probas[:, 1], name='predicted')
            with_weight = self.test_sample_weight is not None
            if with_weight:
                results = pd.DataFrame({"__target__": self.test_y, "sample_weight": self.test_sample_weight}).reset_index(drop=True).join(predicted)
            else:
                results = pd.DataFrame({"__target__": self.test_y}).reset_index(drop=True).join(predicted)

            lb = LiftBuilder(results, '__target__', 'predicted', with_weight)
            try:
                self.ret["liftVizData"] = lb.build()
            except:
                logger.exception("Cannot compute Lift curve")

            try:
                self.ret["prVizData"] = [self.compute_pr_curve_data(self.test_y, test_probas[:, 1], sample_weight=self.test_sample_weight)]
            except:
                logger.exception("Cannot compute precision recall curve")

            # Probability density per actual class
            self.ret["densityData"] = format_all_conditional_proba_densities(self.classes, self.target_map, test_probas, self.test_y,
                                                                             self.test_sample_weight)

            freqs, avg_preds, weights = dku_calibration_curve(self.test_y.values, test_probas[:, 1],
                                                              sample_weight=self.test_sample_weight, n_bins=10)
            zipped = [(t, p, n) for (t, p, n) in zip(freqs, avg_preds, weights) if not np.isnan(t + p + n)]
            self.ret["calibrationData"] = [{"y": 0, "x": 0, "n": 0}] + [{"y": t, "x": p, "n": n} for (t, p, n) in zipped ] + [{"y": 1, "x": 1, "n": 0}]
            self.ret["tiMetrics"]["calibrationLoss"] = dku_nonan(handle_failure(lambda: dku_calibration_loss([x[0] for x in zipped], [x[1] for x in zipped], [x[2] for x in zipped]),
                                                                                 treat_metrics_failure_as_error))

            # Proba distribution (only for subpopulation feature for the time being)
            self.ret["probaDistribData"] = compute_proba_distribution(test_probas, self.test_y, self.test_sample_weight)

        # Compute the predicted set
        if self.test_X_index is not None:
            if self.use_probas:
                # In probabilistic binary classification, prediction depends on the threshold, so needs to be recomputed
                # later on. Besides, probas might be impacted by the overrides. To make things simpler, we save here
                # the raw probas (before any override) and recompute everything (prediction + overrides if needed) when
                # actually needing it, depending on the threshold.
                # So, in this case, the `proba_df` is made from the raw probas,
                # and therefore we need to use the raw_text_X_index to realign.
                proba_df = pd.DataFrame(raw_probas, columns=["proba_%s" % x for x in self.classes])
                # Realign
                proba_df.index = self.raw_test_X_index
                full = pd.DataFrame(index=self.test_df_index)
                proba_df = full.join(proba_df, how="left")
                # We don't add the overrides info column here since for Probabilistic Binary Classification,
                # it is done on the Java side at the `PredictedDataService.java` where overrides are computed
                self.predicted_df = proba_df
            else:
                preds_remapped = np.zeros(test_predictions.shape, dtype="object")
                for (mapped_value, original_value) in self.inv_map.items():
                    idx = (test_predictions == mapped_value)
                    preds_remapped[idx] = original_value
                pred_df = pd.DataFrame({"prediction": preds_remapped})
                # Realign
                pred_df.index = self.test_X_index
                full = pd.DataFrame(index=self.test_df_index)
                pred_df = full.join(pred_df, how="left")
                if isinstance(prediction_result, OverridesResultsMixin):
                    pred_df[OVERRIDE_INFO_COL] = prediction_result.compute_and_return_info_column()

                self.predicted_df = pred_df

        # Global metrics
        global_metrics = {"testSize": self.test_y.shape[0]}
        if self.test_sample_weight is not None:
            test_weight = self.test_sample_weight.sum()
            target_avg = np.dot(self.test_y, self.test_sample_weight) / test_weight
            pred_avg = np.dot(test_predictions, self.test_sample_weight) / test_weight
            global_metrics["testWeight"] = test_weight
            global_metrics["targetAvg"] = [np.dot(self.test_y.values, self.test_sample_weight) / test_weight]
            global_metrics["targetStd"] = [np.sqrt(max(np.dot(self.test_y.values ** 2, self.test_sample_weight) / test_weight - target_avg ** 2, 0.))]
            global_metrics["predictionAvg"] = [np.dot(test_predictions, self.test_sample_weight) / test_weight]
            global_metrics["predictionStd"] = [np.sqrt(max(np.dot(test_predictions ** 2, self.test_sample_weight) / test_weight - pred_avg ** 2, 0.))]

        else:
            global_metrics["testWeight"] = self.test_y.shape[0]
            global_metrics["targetAvg"] = [self.test_y.mean()]
            global_metrics["targetStd"] = [self.test_y.std() if self.test_y.shape[0] > 1 else 0]
            global_metrics["predictionAvg"] = [test_predictions.mean()]
            global_metrics["predictionStd"] = [test_predictions.std() if test_predictions.shape[0] > 1 else 0]
        self.ret["globalMetrics"] = global_metrics

        self.ret = remove_all_nan(self.ret)
        self.perf_data = self.ret

        self.test_predictions = test_predictions
        self.test_prediction_result = prediction_result
        return self.ret

    @staticmethod
    def compute_pr_curve_data(y_true, test_probas, pos_label=None, sample_weight=None):

        precisions, recalls, thresholds = precision_recall_curve(y_true, test_probas, pos_label=pos_label, sample_weight=sample_weight)
        pr_curve_data = zip(recalls, precisions, thresholds)
        bins = [{"x": x, "y": y, "p": p} for (x, y, p) in trim_curve(pr_curve_data)]

        positive_value = 1 if pos_label is None else pos_label
        positive_rate = float(len(y_true[y_true == positive_value])) / float(len(y_true))

        return {
            "bins": bins,
            "positiveRate": positive_rate
        }

    def compute_per_cut_data(self, decisions_and_cuts, with_assertions, custom_metric_list):
        pcd = {"cut": [], "tp": [], "tn": [], "fp": [], "fn": [],
               "precision": [], "recall": [], "accuracy": [], "f1": [], "mcc": [], "hammingLoss": [],
               "assertionsMetrics": [], "overridesMetrics": []}
        for prediction_result, cut in decisions_and_cuts:
            decision = prediction_result.unmapped_preds_not_declined
            pcd["cut"].append(cut)
            derived_metrics = confusion_matrix_derived_metrics(self.test_y, decision, self.test_sample_weight)
            pcd["tp"].append(derived_metrics["tp"])
            pcd["tn"].append(derived_metrics["tn"])
            pcd["fp"].append(derived_metrics["fp"])
            pcd["fn"].append(derived_metrics["fn"])
            pcd["precision"].append(derived_metrics["precision"])
            pcd["recall"].append(derived_metrics["recall"])
            pcd["f1"].append(derived_metrics["f1"])
            pcd["accuracy"].append(derived_metrics["accuracy"])
            pcd["mcc"].append(derived_metrics["mcc"])
            pcd["hammingLoss"].append(derived_metrics["hamming"])

            if custom_metric_list:
                compute_custom_metrics_for_cut(custom_metric_list, self.test_y, decision.copy(), self.test_sample_weight, cut)

            if isinstance(prediction_result, OverridesResultsMixin):
                pcd["overridesMetrics"].append(prediction_result.compute_and_return_overrides_metrics().to_dict())

        if with_assertions and self.assertions:
            assertions_results = compute_assertions_for_binary_classification(
                decisions_and_cuts.get_unmapped_preds_not_declined_list(),
                self.assertions,
                self.target_map)
            pcd["assertionsMetrics"] = [result.to_dict() for result in assertions_results]

        if custom_metric_list:
            pcd["customMetricsResults"] = []
            for metric_result_object in custom_metric_list:
                has_failure = metric_result_object['hasFailure']

                custom_metric_result = {
                    "metric": metric_result_object["metric"],
                    "didSucceed": not has_failure
                }

                if has_failure:
                    custom_metric_result["error"] = metric_result_object["error"]
                else:
                    custom_metric_result["values"] = metric_result_object["values"]

                pcd["customMetricsResults"].append(custom_metric_result)

        if self.modeling_params["metrics"]["evaluationMetric"] == "CUSTOM":
            if not get_custom_evaluation_metric(self.modeling_params["metrics"])["needsProbability"]:
                pcd["customScore"] = get_custom_score_from_custom_metrics_results(
                    pcd["customMetricsResults"],
                    self.modeling_params["metrics"]["customEvaluationMetricName"],
                    True
                )
        return pcd


class CVBinaryClassificationModelScorer(BaseCVModelScorer):
    def __init__(self, scorers):
        super(CVBinaryClassificationModelScorer, self).__init__(scorers)
        self.modeling_params = self.scorers[0].modeling_params
        self.use_probas = scorers[0].use_probas

    def cost_matrix_averaged_scores(self, cuts, cost_matrix_weights):
        scores = []
        for index, threshold in enumerate(cuts):
            matrix_score_at_cut = []
            for perfdata in self.perfdatas:
                per_cut_data = perfdata["perCutData"]
                score = compute_cost_matrix_score(per_cut_data["tp"][index], per_cut_data["tn"][index], per_cut_data["fp"][index], per_cut_data["fn"][index], cost_matrix_weights)
                matrix_score_at_cut.append(score)
    
            np_matrix_score_at_cut = np.array(matrix_score_at_cut)
            scores.append(np.nanmean(np_matrix_score_at_cut))
        return scores

    def optimal_threshold_on_averaged_metrics(self, average_metrics=None):
        metrics = self.modeling_params["metrics"]
        (func, greater_is_better) = get_threshold_optim_function(metrics)

        threshold_optimization_metric = metrics["thresholdOptimizationMetric"]
        cuts = self.perfdatas[0]["perCutData"]["cut"]

        if threshold_optimization_metric == "F1":
            scores = average_metrics.get("f1")
        elif threshold_optimization_metric == "ACCURACY":
            scores = average_metrics.get("accuracy")
        elif threshold_optimization_metric == "COST_MATRIX":
            scores = self.cost_matrix_averaged_scores(cuts, metrics["costMatrixWeights"])
        else:
            raise KeyError("The metric {} doesn't exists".format(threshold_optimization_metric))

        return search_optimized_threshold(scores, cuts, greater_is_better)


    def score(self):
        super(CVBinaryClassificationModelScorer, self).score()

        self.r1 = self.perfdatas[0]
        if "perCutData" in self.r1:
            out = {"cut": self.r1["perCutData"]["cut"]}

            def build_standard_metrics():
                def append_one(key):
                    out[key] = []
                    out[key + "std"] = []
                    tozip= [x["perCutData"][key] for x in self.perfdatas]
                    logger.info("  for key: %s tozip=%s" % (key, tozip))
                    for vals in zip(*tozip):
                        logger.info("  for key: %s Vals=%s" % (key, vals))
                        data = np.array(vals)
                        out[key].append(np.nanmean(data))
                        out[key + "std"].append(np.nanstd(data))
                for key in ["f1", "precision", "accuracy", "recall", "mcc", "hammingLoss"]:
                    append_one(key)

            build_standard_metrics()

            if "customMetricsResults" in self.r1["perCutData"]:
                out["customMetricsResults"] = build_cv_per_cut_custom_metrics(self.r1, self.perfdatas)

            for key in ["fp", "tp", "fn", "tn"]:
                out[key] = self.r1["perCutData"][key]
            self.ret["perCutData"] = out

            for key in ["densityData", "liftVizData", "probaPercentiles"]:
                if key in self.r1:
                    self.ret[key] = self.r1[key]

            optimal_threshold = self.optimal_threshold_on_averaged_metrics(out)
            logger.info("Optimal Threshold for metrics averaged over all folds: %s" % optimal_threshold)

            self.ret["optimalThreshold"] = optimal_threshold
            self.ret["usedThreshold"] = optimal_threshold

            if self.use_probas:
                self.ret["rocVizData"] = [x["rocVizData"][0] for x in self.perfdatas]
                self.ret["prVizData"] = [x["prVizData"][0] for x in self.perfdatas]
                # TODO: average ? => fill holes...
                self.ret["calibrationData"] = self.perfdatas[0]["calibrationData"]
                all_folds_have_lift = True
                for x in self.perfdatas:
                    if not "liftVizData" in x:
                        all_folds_have_lift = False
                if all_folds_have_lift:
                    self.ret["liftVizData"]["folds"] = [[{ "cum_size": y["cum_size"], "cum_lift": y["cum_lift"] }
                        for y in x["liftVizData"]["bins"]] for x in self.perfdatas]
                    # cheat by making the steepest possible wizard
                    self.ret["liftVizData"]["wizard"] = {
                        "positives": min([x["liftVizData"]["wizard"]["positives"] for x in self.perfdatas]),
                        "total":     max([x["liftVizData"]["wizard"]["total"]     for x in self.perfdatas]) }

                self.ret["tiMetrics"] = {}
                for metric in self.r1["tiMetrics"].keys():
                    if metric == "customMetricsResults":
                        custom_metric_data_per_fold = [x["tiMetrics"]["customMetricsResults"] for x in self.perfdatas]
                        self.ret["tiMetrics"]["customMetricsResults"] = aggregate_custom_metrics_for_cross_val_model(custom_metric_data_per_fold)

                    else:
                        data = np.array([x["tiMetrics"][metric] for x in self.perfdatas])
                        self.ret["tiMetrics"][metric] = data.mean()
                        self.ret["tiMetrics"][metric + "std"] = data.std()
        return self.ret


def multiclass_scorer_with_valid(modeling_params, model, valid, out_folder_context, test_df_index, target_map=None, with_sample_weight=False):
    valid_y = valid["target"].astype(int)

    if with_sample_weight:
        valid_w = valid["weight"]
    else:
        valid_w = None
    assertions = valid.get("assertions", None)

    valid_X = valid["TRAIN"]
    valid_X_prepared, _ = prepare_multiframe(valid_X, modeling_params)
    prediction_result = model.compute_predictions(valid_X_prepared, valid["UNPROCESSED"])
    prediction_result.assert_not_all_declined()
    return MulticlassModelScorer(modeling_params,
                                 out_folder_context,
                                 prediction_result,
                                 valid_y,
                                 target_map,
                                 test_unprocessed=valid["UNPROCESSED"],
                                 test_X=valid_X,
                                 test_df_index=test_df_index,
                                 test_sample_weight=valid_w,
                                 assertions=assertions)


class MulticlassModelScorer(ClassicalPredictionModelScorer):
    def __init__(self, modeling_params, out_folder_context, test_prediction_result, test_y, target_map=None,
                 test_unprocessed=None, test_X=None,
                 test_df_index=None, test_sample_weight=None, assertions=None):
        """
        :param dict modeling_params: modeling choices of the current ML task (see PredictionModelingParams.java in backend)
        :param dataiku.base.folder_context.FolderContext out_folder_context: directory where predicted data and perf.json will be written
        :param ClassificationPredictionResult test_prediction_result: output of the model
        :param Series test_y: 1-dimensional array representing the ground truth target on the test set
        :param dict target_map: map of named class (label) to class id in range(len(target_map))
        :param pandas.DataFrame | None test_unprocessed: The "UNPROCESSED" value returned from processing the test dataset via pipeline.process().
        Required for the custom metric x_valid parameter.
        :param dataiku.doctor.multiframe.MultiFrame | None test_X: The "TRAIN" value returned from processing the test dataset via pipeline.process().
        If None, no data will be written on disk (e.g. training recipes).
        :param test_df_index: Pandas index of the input dataframe of the original test set, prior to any processing
        :param Series test_sample_weight: 1-dimensional array representing sample weights on the test set
        :param MLAssertions assertions: collection of assertions based on ML performance metrics
        """
        super(MulticlassModelScorer, self).__init__(modeling_params, out_folder_context,
                                                    test_prediction_result.align_with_not_declined, test_y,
                                                    test_unprocessed, test_X, test_df_index, test_sample_weight,
                                                    assertions)
        self.target_map = target_map
        self.test_prediction_result = test_prediction_result
        self.test_predictions = test_prediction_result.unmapped_preds_not_declined
        self.test_probas = test_prediction_result.probas_not_declined if test_prediction_result.has_probas() else None
        self.inv_map = {
            int(class_id): label
            for label, class_id in self.target_map.items()
        }
        self.classes = [class_label for (_, class_label) in sorted(self.inv_map.items())]

        self.use_probas = test_prediction_result.has_probas()
        if self.use_probas:
            self.columns = ["proba_%s" % x for x in self.classes]


        self.scorer_without_overrides = self._instantiate_scorer_without_overrides(
            modeling_params, out_folder_context, test_prediction_result, test_y, target_map, test_unprocessed, test_X,
            test_df_index,  test_sample_weight, assertions)

    @staticmethod
    def _instantiate_scorer_without_overrides(modeling_params, out_folder_context, test_prediction_result, test_y,
                                              target_map, test_unprocessed, test_X, test_df_index,  test_sample_weight,
                                              assertions):
        if not isinstance(test_prediction_result, OverriddenClassificationPredictionResult):
            return None
        # Since test_prediction_result.raw_prediction_result will never be of class
        # OverriddenClassificationPredictionResult we won't be entering into an instantiation loop for the Scorer
        raw_test_prediction_result = test_prediction_result.raw_prediction_result
        return MulticlassModelScorer(
            modeling_params, out_folder_context, raw_test_prediction_result, test_y, target_map,
            test_unprocessed, test_X, test_df_index, test_sample_weight, assertions)

    def save(self, dump_predicted=True):
        PredictionModelScorer.save(self, dump_predicted)
        # Dump the preds
        predicted_class_series = pd.Series(self.test_predictions).map(lambda x: self.classes[x])
        save_classification_statistics(predicted_class_series, self.out_folder_context, self.test_probas,
                                       self.test_sample_weight, self.target_map)

    def can_score(self):
        n_classes_valid = np.unique(self.test_y).shape[0]
        if n_classes_valid < 2:
            logger.error("At least two classes must be present")
            return False, doctor_constants.PREPROC_ONECLASS
        return True, None

    def _do_score(self, with_assertions, treat_metrics_failure_as_error=True):
        logger.info("Will use probas : %s" % self.use_probas)

        check_test_set_ok_for_classification(self.test_y, treat_failure_as_error=treat_metrics_failure_as_error)

        # Not clear whether this is good or not ...
        # all_classes_in_test_set = np.unique(self.valid_y)
        # all_classes_in_pred = np.unique(self.preds)
        # logger.info("  IN TEST: %s" % all_classes_in_test_set)
        # logger.info("  IN PRED: %s" % all_classes_in_pred)
        # for cls in all_classes_in_pred:
        #     if not cls in all_classes_in_test_set:
        #         raise Exception("One of the classes predicted by the model (%s) is not in the test set. Cannot proceed." % (cls))

        # Compute unmapped preds
        mapped_preds = np.zeros(self.test_predictions.shape, object)
        for k, v in self.target_map.items():
            mapped_preds[self.test_predictions == v] = k

        # Confusion matrix
        self.ret["classes"] = self.classes
        self.ret["confusion"] = self.get_multiclass_confusion_matrix(self.test_y, self.test_predictions, self.inv_map, self.test_sample_weight)

        # Aggregated metrics + 1-vs-all ROC & Calibration for proba-aware classifiers
        metrics = self.compute_multiclass_metrics(self.test_y, self.test_predictions, self.target_map, self.test_probas,
                                                  self.test_sample_weight, self.test_unprocessed, self.modeling_params["metrics"], treat_metrics_failure_as_error)
        self.ret["metrics"] = metrics["metrics"]

        if self.use_probas:
            self.ret["oneVsAllCalibrationLoss"] = metrics["oneVsAllCalibrationLoss"]
            self.ret["oneVsAllCalibrationCurves"] = metrics["oneVsAllCalibrationCurves"]

            self.ret["oneVsAllRocAUC"], self.ret["oneVsAllRocCurves"] = \
                MulticlassModelScorer.get_roc_metrics_and_curves(self.test_y, self.test_probas, self.target_map, self.test_sample_weight)

            self.ret["oneVsAllAveragePrecision"], self.ret["oneVsAllPrCurves"] = \
                MulticlassModelScorer.get_pr_metrics_and_curves(self.test_y, self.test_probas, self.target_map, self.test_sample_weight)

            self.ret["densityData"] = format_all_conditional_proba_densities(self.classes, self.target_map, self.test_probas, self.test_y,
                                                                             self.test_sample_weight)

        if with_assertions and self.assertions:
            self.ret["metrics"]["assertionsMetrics"] = \
                compute_assertions_for_multiclass_classification(self.test_predictions, self.assertions, self.target_map).to_dict()

        if isinstance(self.test_prediction_result, OverridesResultsMixin):
            self.ret["metrics"]["overridesMetrics"] = self.test_prediction_result.compute_and_return_overrides_metrics().to_dict()

        # Global metrics
        global_metrics = {"testSize": self.test_y.shape[0]}
        if self.test_sample_weight is not None:
            test_weight = self.test_sample_weight.sum()
            target_avg = [np.dot(self.test_y == int(self.target_map[c]), self.test_sample_weight) / test_weight for c in self.classes]
            global_metrics["testWeight"] = test_weight
            global_metrics["targetAvg"] = [np.dot(self.test_y.values == int(self.target_map[c]), self.test_sample_weight) / test_weight for c in self.classes]
            global_metrics["targetStd"] = [np.sqrt(np.dot((self.test_y.values == int(self.target_map[c])) ** 2, self.test_sample_weight) / test_weight - target_avg[i] ** 2) for i, c in enumerate(self.classes)]
            if self.use_probas:
                pred_avg = [np.dot(self.test_probas[:, int(self.target_map[c])], self.test_sample_weight) / test_weight for c in self.classes]
                global_metrics["predictionAvg"] = [np.dot(self.test_probas[:, int(self.target_map[c])], self.test_sample_weight) / test_weight for c in self.classes]
                global_metrics["predictionStd"] = [np.sqrt(max(np.dot(self.test_probas[:, int(self.target_map[c])] ** 2, self.test_sample_weight) / test_weight - pred_avg[i] ** 2, 0.)) for i, c in enumerate(self.classes)]

        else:
            global_metrics["testWeight"] = self.test_y.shape[0]
            global_metrics["targetAvg"] = [(self.test_y == int(self.target_map[c])).mean() for c in self.classes]
            global_metrics["targetStd"] = [(self.test_y == int(self.target_map[c])).std()
                                           if self.test_y.shape[0] > 1 else 0 for c in self.classes]
            if self.use_probas:
                global_metrics["predictionAvg"] = [(self.test_probas[:, int(self.target_map[c])]).mean() for c in self.classes]
                global_metrics["predictionStd"] = [(self.test_probas[:, int(self.target_map[c])]).std()
                                                   if self.test_probas.shape[0] > 1 else 0 for c in self.classes]

        self.ret["globalMetrics"] = global_metrics

        # Compute the predicted set
        if self.test_X_index is not None:
            if self.use_probas:
                proba_df = pd.DataFrame(self.test_probas, columns=self.columns)
                pred_df = pd.DataFrame({"prediction": mapped_preds})
                out_df = pd.concat([proba_df, pred_df], axis=1)
                # Realign
                out_df.index = self.test_X_index
                full = pd.DataFrame(index=self.test_df_index)
                out_df = full.join(out_df, how="left")
                self.predicted_df = out_df
            else:
                pred_df = pd.DataFrame({"prediction": mapped_preds})
                # Realign
                pred_df.index = self.test_X_index
                full = pd.DataFrame(index=self.test_df_index)
                pred_df = full.join(pred_df, how="left")
                self.predicted_df = pred_df

            if isinstance(self.test_prediction_result, OverridesResultsMixin):
                self.predicted_df[OVERRIDE_INFO_COL] = self.test_prediction_result.compute_and_return_info_column()
        self.ret = remove_all_nan(self.ret)
        self.perf_data = self.ret
        return self.ret

    @staticmethod
    def get_roc_metrics_and_curves(valid_y, probas, target_map, sample_weight=None):
        """
            Compute "1 versus all" AUC scores and "1 versus all" ROC curves points for all classes of the target map

        :param pd.Series valid_y: Pandas Series of size nb_rows containing ground truth class indexes
        :param numpy.ndarray probas: Numpy array of shape (nb_rows, nb_classes) containing model's probas for each classes from target_map
        :param dict target_map: {class_label : class_id for all classes in the guess sample}
        :param numpy.ndarray or pd.Series or list sample_weight: array-like of shape (nb_rows,), default=None. used to weight rows in metrics.

        :return: tuple of dicts, one for the AUC scores per class, one for the curves points per class
        """
        one_vs_all_auc = {}
        one_vs_all_roc_curves = {}
        for class_name, class_id in target_map.items():
            try:
                fp_rates, tp_rates, thresholds = roc_curve(valid_y, probas[:, class_id], pos_label=class_id, sample_weight=sample_weight)

                one_vs_all_roc_curves[class_name] = [{"x": x, "y": y, "p": p}
                                                     for (x, y, p) in trim_curve(zip(fp_rates, tp_rates, thresholds))]
                one_vs_all_auc[class_name] = auc(fp_rates, tp_rates)
            except:
                logger.exception("Failed to compute ROC curve")
        return one_vs_all_auc, one_vs_all_roc_curves

    @staticmethod
    def get_pr_metrics_and_curves(valid_y, probas, target_map, sample_weight=None):
        """
            Compute "1 versus all" Average Precision scores and "1 versus all" Precision-Recall curves points for all classes of the target map

        :param pd.Series valid_y: Pandas Series of size nb_rows containing ground truth class indexes
        :param numpy.ndarray probas: Numpy array of shape (nb_rows, nb_classes) containing model's probas for each classes from target_map
        :param dict target_map: {class_label : class_id for all classes in the guess sample}
        :param numpy.ndarray or pd.Series or list sample_weight: array-like of shape (nb_rows,), default=None. used to weight rows in metrics.

        :return: tuple of dicts, one for the AUC scores per class, one for the curves points per class
        """

        one_vs_all_average_precisions = {}
        one_vs_all_pr_curves = {}
        for class_name, class_id in target_map.items():
            try:
                class_probas = probas[:, class_id]
                one_vs_all_average_precisions[class_name] = average_precision_score(valid_y == class_id, class_probas, sample_weight=sample_weight)
                one_vs_all_pr_curves[class_name] = BinaryClassificationModelScorer.compute_pr_curve_data(valid_y, class_probas, class_id, sample_weight)
            except:
                logger.exception("Failed to compute precision recall curve for class  '%s'" % class_name)
        return one_vs_all_average_precisions, one_vs_all_pr_curves

    @staticmethod
    def get_calibration_metrics_and_curves(valid_y, probas, target_map, sample_weight=None, average="macro"):
        """
            Compute for each class of the target map :
                the calibration losses and curves points of the "1 versus all" binary classifications.
            then aggregate these losses in a mean calibration loss metric used to assess the multiclass performance as
            a whole.

        :param pd.Series valid_y: Pandas Series of size nb_rows containing ground truth class indexes
        :param numpy.ndarray probas: Numpy array of shape (nb_rows, nb_classes) containing model's probas for each classes from target_map
        :param dict target_map: {class_label : class_id for all classes in the guess sample}
        :param numpy.ndarray or pd.Series or list sample_weight: array-like of shape (nb_rows,), default=None. used to weight rows in metrics.
               weights must be strictly positive floats
        :param average: "macro" or "weighted", if set to "macro" the unweighted mean for calibration loss will be returned. If set to "weighted", the weighted mean
        by support (number of instances for each class, weighted by their weight is sample_weight is not None).
        :return: (mean calibration loss over all "1 vs all" losses, one vs all losses (dict), one vs all curves point)
        """
        one_vs_all_curves = {}
        one_vs_all_losses = {}

        for class_name, class_id in target_map.items():
            try:
                y_bin = (valid_y.values == class_id).astype(int)
                freqs, avg_preds, weights = dku_calibration_curve(y_bin, probas[:, class_id], n_bins=10,
                                                                  sample_weight=sample_weight)

                # remove Nan points & unzip back:
                zipped = [(t, p, n) for (t, p, n) in zip(freqs, avg_preds, weights) if not np.isnan(t + p + n)]
                freqs, avg_preds, weights = [x[0] for x in zipped], [x[1] for x in zipped], [x[2] for x in zipped]

                one_vs_all_losses[class_name] = dku_calibration_loss(freqs, avg_preds, weights)

                # List & format curve points + add beginning & end points
                one_vs_all_curves[class_name] = [{"y": 0, "x": 0, "n": 0}]
                one_vs_all_curves[class_name].extend([{"y": t, "x": p, "n": n} for (t, p, n) in zipped])
                one_vs_all_curves[class_name].append({"y": 1, "x": 1, "n": 0})

            except:
                logger.exception("Failed to compute calibration curve")

        unique_values, counts = np.unique(valid_y, return_counts=True)
        inv_target_map = dict((v, k) for k, v in target_map.items())
        calibration_losses = []
        for value in unique_values:
            calibration_losses.append(one_vs_all_losses[inv_target_map[value]])
        if average == "weighted":
            if sample_weight is None or np.sum(sample_weight) == 0:
                return np.average(calibration_losses, weights=counts), one_vs_all_losses, one_vs_all_curves
            else:
                weights_per_classes = []
                for value in unique_values:
                    mask = np.where(value == valid_y, True, False)
                    weights_per_classes.append(np.sum(sample_weight[mask]))
                return np.average(calibration_losses, weights=weights_per_classes), one_vs_all_losses, one_vs_all_curves
        else:
            return np.mean(list(one_vs_all_losses.values())), one_vs_all_losses, one_vs_all_curves


    @staticmethod
    def compute_multiclass_metrics(valid_y, preds, target_map, probas=None, sample_weight=None, unprocessed=None, metric_params=None, treat_failure_as_error=True):
        """
        Compute multiclass metrics

        Note that at this point valid_y, preds contains only classes indexes that belong to the target_map:
            - Preprocessing should have removed any row containing a class not belonging to the target_map.
            - if any class was missing in model, the classes indexes (& corresponding probas columns) were remapped to
            match the target_map order.

        :param pd.Series valid_y: contains ground truth class indexes (Series of size nb_rows)
        :param numpy.ndarray preds: contains model's prediction for each row (classes indexes only). array of shape (nb_rows,)
        :param dict target_map: {class_label : class_id for all classes in the guess sample}
        :param numpy.ndarray probas: Numpy array of shape (nb_rows, nb_classes) containing model's probas for each classes from target_map
                                     default=None. Note: Some metrics won't be computed without the probas
        :param numpy.ndarray or pd.Series or list sample_weight: array-like of shape (nb_rows,), default=None. used to weight rows in metrics.
               weights must be strictly positive floats
        :param pd.Dataframe unprocessed: realigned input dataframe (not preprocessed but rows excluded by preprocessing were still removed)
        :param dict metric_params: metrics parameters from modeling params used for custom metric computations
        :param boolean treat_failure_as_error: if true, raise error on metrics failures else logs the failure & return None for this metric
        :return: metrics for multiclass cases + 1-vs-all binary calibration curves (computed during the mean-calibration loss computation)
        :rtype dict
        """
        results = {"metrics": {}}

        if metric_params is not None:
            if "customMetrics" in metric_params:
                results["metrics"]["customMetricsResults"] = calculate_overall_classification_custom_metrics(metric_params,
                                                                                                             preds,
                                                                                                             probas,
                                                                                                             sample_weight,
                                                                                                             unprocessed,
                                                                                                             valid_y)
                if metric_params["evaluationMetric"] == "CUSTOM":
                    results["metrics"]["customScore"] = get_custom_score_from_custom_metrics_results(
                        results["metrics"]["customMetricsResults"],
                        metric_params["customEvaluationMetricName"]
                    )
            average = get_multiclass_metrics_averaging_method(metric_params)
        else:
            average = "macro"

        logger.info("Computing multiclass metrics with \"{}\" class averaging method".format(average))
        precision, recall, f1score, _ = handle_failure(lambda: precision_recall_fscore_support(valid_y, preds, average=average, pos_label=None, sample_weight=sample_weight),
                                                        treat_failure_as_error)

        results["metrics"]["precision"] = precision
        results["metrics"]["recall"] = recall
        results["metrics"]["f1"] = f1score
        results["metrics"]["accuracy"] = handle_failure(lambda: accuracy_score(valid_y, preds, sample_weight=sample_weight), treat_failure_as_error)
        results["metrics"]["hammingLoss"] = handle_failure(lambda: hamming_loss(valid_y, preds, sample_weight=sample_weight), treat_failure_as_error)

        if probas is not None:
            results["metrics"]["mcalibrationLoss"], results["oneVsAllCalibrationLoss"], results["oneVsAllCalibrationCurves"] = \
                MulticlassModelScorer.get_calibration_metrics_and_curves(valid_y, probas, target_map, sample_weight, average=average)

            results["metrics"]["mrocAUC"] = handle_failure(lambda: mroc_auc_score(valid_y, probas, sample_weight=sample_weight, average=average), treat_failure_as_error)
            results["metrics"]["averagePrecision"] = handle_failure(lambda: m_average_precision_score(valid_y, probas, sample_weight=sample_weight, average=average),
                                                                     treat_failure_as_error)

            results["metrics"]["logLoss"] = handle_failure(lambda: log_loss(valid_y, probas, sample_weight=sample_weight), treat_failure_as_error)
        return results

    @staticmethod
    def get_multiclass_confusion_matrix(valid_y, preds, inv_map, sample_weight):
        """
        Compute the confusion matrix for all the classes belonging to inv_map (inversed target_map) and compare to
        the classes predicted in preds.
        Counters are weighted with sample_weight param if defined.

        Note that at this point valid_y & preds contains only classes indexes that belong to the target_map:
            - Preprocessing should have removed any rows containing a class not belonging to the target_map.
            - if any class was missing in model, the classes indexes were remapped to match the target_map order.

        :param pd.Series valid_y: Pandas Series of size nb_rows containing ground truth class-indexes.
        :param numpy.ndarray preds: Numpy array of shape (nb_rows,) containing model's prediction for each row (classes indexes only)
        :param dict inv_map: {class_id: class_label for all classes in the guess sample}
        :param numpy.ndarray or pd.Series or list sample_weight: array-like of shape (nb_rows,), default=None. used to weight rows in metrics.
               weights must be strictly positive floats

        :return: dict of format:
                {
                    "totalRows": X,                  # nb samples after preprocessing was done
                    "perActual": {
                        "dog": {                     # 1 dict per class in the target map
                            "actualClassCount":  A,  # nb samples from the ground truth belonging to the "dog" class
                            "perPredicted": {        # 1 entry per class in the target map
                                "dog": B             # nb of samples from class dog predicted dog
                                "cat": C             # nb of samples from class dog predicted cat
                                "bird": D            # nb of samples from class dog predicted bird
                            }
                        }
                    }
                }
        """
        assert preds.shape == valid_y.shape
        (nb_rows,) = preds.shape
        counters = defaultdict(Counter)
        count_actuals = Counter()
        if sample_weight is not None:
            for actual, weight in zip(valid_y, sample_weight):
                count_actuals[actual] += weight
            for (actual, predicted, weight) in zip(valid_y, preds, sample_weight):
                counters[actual][predicted] += weight
        else:
            for actual in valid_y:
                count_actuals[actual] += 1
            for (actual, predicted) in zip(valid_y, preds):
                counters[actual][predicted] += 1

        confusion_matrix_dict = {
            "totalRows": nb_rows,
            "perActual": {
                inv_map[actual_class_id]: {
                    "actualClassCount":  float(count_actuals[actual_class_id]),
                    "perPredicted": {
                        inv_map[predicted_class_id]: counters[actual_class_id][predicted_class_id] for predicted_class_id in inv_map
                    }
                } for actual_class_id in inv_map
            }
        }
        logger.info("Calculated confusion matrix")
        return confusion_matrix_dict


class CVMulticlassModelScorer(BaseCVModelScorer):
    def __init__(self, scorers):
        super(CVMulticlassModelScorer, self).__init__(scorers)
        self.r1 = None
        self.use_probas = scorers[0].use_probas

    def score(self):
        super(CVMulticlassModelScorer, self).score()

        self.r1 = self.perfdatas[0]
        self.ret["metrics"] = {}

        metrics = [set(pd["metrics"].keys()) for pd in self.perfdatas]
        metrics = reduce(lambda a, b: a.intersection(b), metrics)

        for metric in metrics:
            if metric in self.DISCARDED_METRICS_FOR_AGG:
                logger.info("Not aggregating metric '%s'" % metric)
                continue
            if metric == "customMetricsResults":
                # todo what logging do we want here - per custom metric ??
                custom_metric_data_per_fold = [x["metrics"]["customMetricsResults"] for x in self.perfdatas]
                self.ret["metrics"]["customMetricsResults"] = aggregate_custom_metrics_for_cross_val_model(custom_metric_data_per_fold)
            else:
                logger.info("Metric is %s" % metric)
                metric_values = [x["metrics"][metric] or 0.0 for x in self.perfdatas]
                logger.info("Metric values : " + "; ".join([str(x) for x in  metric_values ]))
                data = np.array(metric_values)
                logger.info("AVG %s" % data)
                self.ret["metrics"][metric] = np.nanmean(data)
                self.ret["metrics"][metric + "std"] = np.nanstd(data)

        # Don't do much here ...
        self.ret["confusion"] = self.r1["confusion"]
        if "classes" not in self.ret:
            self.ret["classes"] = self.r1["classes"]
        if self.use_probas:
            self.ret["oneVsAllRocCurves"] = self.r1["oneVsAllRocCurves"]
            self.ret["oneVsAllRocAUC"] = self.r1["oneVsAllRocAUC"]
            self.ret["oneVsAllPrCurves"] = self.r1["oneVsAllPrCurves"]
            self.ret["oneVsAllAveragePrecision"] = self.r1["oneVsAllAveragePrecision"]
            self.ret["densityData"] = self.r1["densityData"]
            self.ret["oneVsAllCalibrationCurves"] = self.r1["oneVsAllCalibrationCurves"]
            self.ret["oneVsAllCalibrationLoss"] = self.r1["oneVsAllCalibrationLoss"]

        return self.ret


def save_classification_statistics(predicted_class_series, base_folder_context, probas=None, sample_weight=None, target_map=None, filename=None,
                                   predicted_class_series_per_cut=None, cuts=None):
    pred_per_cut_infos = []
    if predicted_class_series_per_cut is not None:
        for cut in predicted_class_series_per_cut:
            if cut.shape[0] == 0:
                logger.warning("At least one of the predicted set had no value.")
                pred_per_cut_infos = []
                break
            pred_per_cut_infos.append(cut.value_counts().to_dict())

    pred_infos = predicted_class_series.value_counts() if predicted_class_series.shape[0] != 0 else None
    result = {}
    if pred_per_cut_infos and cuts is not None:
        result["predictedClassCountPerCut"] = pred_per_cut_infos
        # We save again cuts here, in case we don't have the performance (for a model evaluation for instance)
        result["cuts"] = cuts
    if pred_infos is not None:
        result["predictedClassCount"] = pred_infos.to_dict()
    if probas is not None and target_map is not None:
        result['probabilityDensities'] = format_all_proba_densities(probas, target_map=target_map, sample_weight=sample_weight)
    base_folder_context.write_json("prediction_statistics.json" if not filename else filename, result)
