import gzip
import logging
import os
import uuid
from math import sqrt

import numpy as np
import pandas as pd
from six.moves import xrange
from sklearn.metrics import accuracy_score
from sklearn.metrics import explained_variance_score
from sklearn.metrics import f1_score
from sklearn.metrics import hamming_loss
from sklearn.metrics import matthews_corrcoef
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error
from sklearn.metrics import precision_score
from sklearn.metrics import r2_score
from sklearn.metrics import recall_score

from dataiku import Dataset
from dataiku.base.folder_context import build_folder_context
from dataiku.base.utils import safe_unicode_str
from dataiku.core import doctor_constants
from dataiku.core import schema_handling
from dataiku.core.dku_pandas_csv import dataframe_to_csv, pandas_date_parser_compat
from dataiku.doctor import step_constants, utils
from dataiku.doctor.diagnostics import default_diagnostics
from dataiku.doctor.diagnostics import diagnostics
from dataiku.doctor.exception import DriftException
from dataiku.doctor.exception import EmptyDatasetException
from dataiku.doctor.prediction.classification_scoring import BinaryClassificationModelScorer, \
    save_classification_statistics, compute_optimized_threshold
from dataiku.doctor.prediction.classification_scoring import MulticlassModelScorer
from dataiku.doctor.prediction.common import make_cost_matrix_score
from dataiku.doctor.prediction.common import make_lift_score
from dataiku.doctor.prediction.custom_evaluation_scoring import calculate_custom_evaluation_metrics
from dataiku.doctor.prediction.custom_scoring import calculate_overall_classification_custom_metrics
from dataiku.doctor.prediction.custom_scoring import calculate_regression_custom_metrics
from dataiku.doctor.prediction.custom_scoring import get_custom_score_from_custom_metrics_results
from dataiku.doctor.prediction.decisions_and_cuts import DecisionsAndCuts
from dataiku.doctor.prediction.regression_scoring import RegressionModelScorer, save_regression_statistics
from dataiku.doctor.prediction.regression_scoring import pearson_correlation
from dataiku.doctor.prediction.scoring_base import PERF_FILENAME
from dataiku.doctor.preprocessing.assertions import cast_assertions_masks_bool
from dataiku.doctor.preprocessing.assertions import MLAssertion
from dataiku.doctor.utils import normalize_dataframe
from dataiku.doctor.utils.api_logs import API_NODE_EVALUATION_DATASET_TYPE, CLOUD_API_NODE_EVALUATION_DATASET_TYPE, CLOUD_API_NODE_LOGS_FEATURE_PREFIX
from dataiku.doctor.utils.api_logs import API_NODE_LOGS_FEATURE_PREFIX
from dataiku.doctor.utils.api_logs import CLASSICAL_EVALUATION_DATASET_TYPE
from dataiku.doctor.utils.api_logs import SAGEMAKER_EVALUATION_DATASET_TYPE
from dataiku.doctor.utils.api_logs import create_filter_on_smvid_and_deployment_step
from dataiku.doctor.utils.api_logs import normalize_api_node_logs_dataset
from dataiku.doctor.utils.listener import DiagOnlyContext
from dataiku.doctor.utils.listener import ProgressListener
from dataiku.doctor.utils.metrics import calibration_loss_binary, handle_failure
from dataiku.doctor.utils.metrics import log_loss
from dataiku.doctor.utils.metrics import mean_absolute_percentage_error
from dataiku.doctor.utils.metrics import mroc_auc_score
from dataiku.doctor.utils.metrics import m_average_precision_score
from dataiku.doctor.utils.metrics import rmsle_score
from dataiku.modelevaluation.data_types import DataDriftParams, cast_as_string
from dataiku.modelevaluation.drift.drift_univariate import DriftUnivariate
from dataiku.modelevaluation.drift.prediction_drift import PredictionDrift, create_prediction_series_from_statistics
from dataikuscoring.utils.prediction_result import ClassificationPredictionResult, PredictionResult

logger = logging.getLogger(__name__)
try:
    from dataiku.modelevaluation.drift.drift_model import DriftModel
    from dataiku.modelevaluation.drift.drift_preparator import DriftPreparator
    from dataiku.modelevaluation.server import DriftProtocol, ModelLikeInfo
except ModuleNotFoundError as e:
    # This can happen for MLflow where there is no guarantee that scipy will be in the code env.
    logger.exception("Could not import DriftModel, drift computation won't be available")

MAX_SAMPLING_ROWS = 50000
MAX_SAMPLING_BYTES = 2 ** 23
SAMPLE_CSV = "sample.csv.gz"
SAMPLE_SCHEMA_JSON = "sample_schema.json"
NO_PROBA_ERR_MESSAGE = "With the 'Skip scoring' option, non probabilistic models are not supported. You need to have a proba_<class_name> column for each class of your classification task."


class EvaluateRecipe(object):

    def __init__(self, model_folder, input_dataset_smartname, output_dataset_smartname, metrics_dataset_smartname,
                 recipe_desc,
                 script, preparation_output_schema, cond_outputs=None, preprocessing_params=None,
                 model_evaluation_store_folder=None,
                 evaluation_dataset_type=None,
                 api_node_logs_config=None,
                 diagnostics_folder=None, fmi=None):
        self.fmi = fmi
        self.dtypes = None
        self.columns = None
        self.target_mapping = None
        self.core_params = None
        self.preprocessing_params = None
        self.feature_preproc = None
        self.input_dataset = None
        self.model_folder_context = build_folder_context(model_folder)
        self.input_dataset_smartname = input_dataset_smartname
        self.output_dataset_smartname = output_dataset_smartname
        self.metrics_dataset_smartname = metrics_dataset_smartname
        self.recipe_desc = recipe_desc
        self.prediction_type = None
        self.has_custom_evaluation_metrics = len(self.recipe_desc.get("customEvaluationMetrics", [])) > 0

        self.dont_compute_performance = not self._has_metrics_dataset() and recipe_desc.get('dontComputePerformance',
                                                                                            False)
        if self.dont_compute_performance:
            logger.info("Will only score and compute statistics")

        self.script = script
        self.preparation_output_schema = preparation_output_schema
        self.cond_outputs = cond_outputs
        self.model_evaluation_store_folder_context = self._build_folder_context_or_none(model_evaluation_store_folder)
        self.evaluation_dataset_type = evaluation_dataset_type
        self.api_node_logs_config = api_node_logs_config
        self.diagnostics_folder_context = self._build_folder_context_or_none(diagnostics_folder)
        self.model_target_column = None

        self.preprocessing_params = preprocessing_params
        self.feature_preproc = preprocessing_params["per_feature"]

        self.smv_to_filter_logs_on = self.api_node_logs_config.get("smvToFilterLogsOn")
        self.deploymentToFilterOn = self.api_node_logs_config.get("deploymentToFilterLogsOn")

        if self.smv_to_filter_logs_on is not None and len(self.smv_to_filter_logs_on) == 0:
            self.smv_to_filter_logs_on = None

        if self.deploymentToFilterOn is not None and len(self.deploymentToFilterOn) == 0:
            self.deploymentToFilterOn = None

        self.infer_output_dataset_schema = False
        self.infer_metrics_dataset_schema = False

        self.target_column_in_dataset = None  # Can differ from model_target_column regarding the evaluation dataset type
        self.prediction_column = "prediction"
        self.proba_columns = []

    def run(self):
        # Fetch input information depending on the backend and model type
        self._fetch_input_dataset_and_model_params()

        self._fix_feature_and_target_dtypes()

        if self.recipe_desc.get('skipScoring', False):
            if self.prediction_type in [doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS]:
                self.proba_columns = ["proba_" + target_class for target_class in self.target_mapping.keys()]
                for proba_column in self.proba_columns:
                    self.dtypes[proba_column] = np.float64
                self.dtypes[self.prediction_column] = object
            else:
                self.dtypes[self.prediction_column] = np.float64

        # If the dataset is API Logs, add a filter to consider only the ones from the evaluated smv
        if self.smv_to_filter_logs_on is not None:
            create_filter_on_smvid_and_deployment_step(self.input_dataset, self.smv_to_filter_logs_on, self.deploymentToFilterOn, self.evaluation_dataset_type)

        context = DiagOnlyContext(self.diagnostics_folder_context)
        listener = ProgressListener(context=context)

        default_diagnostics.register_evaluation_callbacks()

        with listener.push_step(step_constants.ProcessingStep.STEP_LOADING_EVALUATION):

            input_df = self._get_input_df()

            if input_df.empty:
                if self.evaluation_dataset_type != CLASSICAL_EVALUATION_DATASET_TYPE:
                    raise EmptyDatasetException(
                        "The evaluation dataset can not be empty. Check the input dataset or the recipe sampling "
                        "configuration. Note that you marked this dataset as " + self.evaluation_dataset_type)
                else:
                    raise EmptyDatasetException(
                        "The evaluation dataset can not be empty. Check the input dataset or the recipe sampling "
                        "configuration.")

            input_df_copy_unnormalized = input_df.copy()

            if self.evaluation_dataset_type in [API_NODE_EVALUATION_DATASET_TYPE, CLOUD_API_NODE_EVALUATION_DATASET_TYPE]:
                logger.info("Normalizing %s dataset" % API_NODE_EVALUATION_DATASET_TYPE)
                input_df = normalize_api_node_logs_dataset(input_df, self.feature_preproc, self.evaluation_dataset_type)
            elif self.evaluation_dataset_type == CLASSICAL_EVALUATION_DATASET_TYPE:
                logger.info("Normalizing dataset")
                normalize_dataframe(input_df, self.feature_preproc)
            elif self.evaluation_dataset_type == SAGEMAKER_EVALUATION_DATASET_TYPE:
                logger.info("The dataset type is %s, it has been normalized in the logs decoding step" % SAGEMAKER_EVALUATION_DATASET_TYPE)

            if self.recipe_desc.get('skipScoring', False):
                logger.info("The 'Skip scoring' option is enabled. This recipe will use the prediction(s) column(s) of the evaluated dataset to compute metrics.")
                if self.prediction_column not in input_df.columns:
                    raise Exception("The 'prediction' column is mandatory in the evaluation dataset with the 'Skip scoring' option")
                if self.prediction_type == doctor_constants.BINARY_CLASSIFICATION:
                    if len(self.proba_columns) != 2:
                        raise Exception(NO_PROBA_ERR_MESSAGE)
                    probas = input_df[self.proba_columns].values
                    preds = recreate_preds_from_probas_and_threshold(probas, self.target_mapping, self.prediction_type, self._get_used_threshold())
                    # When we skip scoring, we should rebuild the preds from the probas, since they could be invalid if the threshold has been overriden.
                    input_df[self.prediction_column] = preds

            for col in input_df:
                logger.info("NORMALIZED: %s -> %s" % (col, input_df[col].dtype))

            logger.info("Got a dataframe : %s" % str(input_df.shape))

            univariate_drift = None
            sample_ref_df = None
            if self._has_model_evaluation_store():
                # there is a sample of the input that needs scoring. let's score it first, since it's smaller than
                # the full data and will trigger errors earlier if there's any to be
                # triggered
                sample_input_df, sample_output_df = self._get_sample_dfs(input_df)

                logger.info("Loaded sample : %s" % str(sample_input_df.shape))

                # dump data with the predictions
                # Write out the data from python => you lose the "strict" typing, ie bigint becomes double for
                # example, but strings are correct and dates are iso8601 => should be fine to reuse as input in other
                # ML-related operations, where you're going to call ml_dtypes_from_dss_schema() anyway. The proper
                # way would of course be to stream the data back to the JEK and have the jek write with the usual
                # java machinery, but that's a lot of code and calls (at least 1 call to stream, and 1 to verif)
                with self.model_evaluation_store_folder_context.get_file_path_to_write('sample_scored.csv.gz') as sample_path:
                    dataframe_to_csv(sample_output_df, sample_path, gzip.open)
                # don't forget the schema
                sample_output_schema = schema_handling.get_schema_from_df(sample_output_df)
                self.model_evaluation_store_folder_context.write_json('sample_scored_schema.json',
                                                                      {'columns': sample_output_schema})

                drift_columns_settings = {}
                if self.recipe_desc.get('treatDataDriftColumnHandling'):
                    drift_columns_settings = format_data_drift_column_handling(self.recipe_desc.get('dataDriftColumnHandling'),
                                                                               self.evaluation_dataset_type)

                reference_threshold = get_reference_threshold_from_evaluation(evaluation_folder_context=self.model_evaluation_store_folder_context)
                univariate_drift, sample_ref_df = compute_drift(sample_input_df, sample_output_df, self.preprocessing_params,
                                                 self.model_evaluation_store_folder_context,
                                                 self.recipe_desc.get('treatDriftFailureAsError', False),
                                                 self.prediction_type,
                                                 drift_columns_settings,
                                                 self.recipe_desc.get('driftConfidenceLevel', 0.95),
                                                 reference_threshold=reference_threshold,
                                                 should_text_drift=self.recipe_desc.get('hasTextDrift', False),
                                                 text_drift_params=self.recipe_desc.get('textDriftParams',None))

            cast_assertions_masks_bool(input_df)
            diagnostics.on_load_evaluation_dataset_end(df=input_df_copy_unnormalized, univariate_drift=univariate_drift,
                                                       per_feature=self.feature_preproc,
                                                       target_column=self.target_column_in_dataset,
                                                       target_remapping=self.preprocessing_params.get("target_remapping"),
                                                       prediction_type=self.prediction_type)

        with listener.push_step(step_constants.ProcessingStep.STEP_EVAL_PROCESSING):
            output_df, pred_df = self._compute_output_and_pred_df(input_df, input_df_copy_unnormalized)
            if self.has_custom_evaluation_metrics:
                custom_eval_metrics = self._compute_custom_evaluation_metrics_df(output_df, pred_df, input_df_copy_unnormalized, ref_sample_df=sample_ref_df)

        # write scored data
        if self._has_output_dataset():
            output_dataset = Dataset(self.output_dataset_smartname)
            logger.info("writing scored data")
            self._fix_output_dataset_schema(output_dataset, output_df)
            output_dataset.write_from_dataframe(output_df, infer_schema=self.infer_output_dataset_schema)

        # write metrics dataset
        if self._has_metrics_dataset():
            with listener.push_step(step_constants.ProcessingStep.STEP_EVAL_COMPUTING_METRICS_DATASET):
                metrics_df = self._compute_metrics_df(output_df, pred_df)
            if self.has_custom_evaluation_metrics:
                custom_eval_metrics_df = pd.DataFrame.from_dict({
                    m['metric']['name']: [m['value'] if m['didSucceed'] else None] for m in custom_eval_metrics
                })
                metrics_df = pd.concat([metrics_df, custom_eval_metrics_df], axis=1 )
            metrics_dataset = Dataset(self.metrics_dataset_smartname)
            logger.info("writing metrics data")
            metrics_dataset.write_from_dataframe(metrics_df, infer_schema=self.infer_metrics_dataset_schema)

        # write model evaluation store
        if self._has_model_evaluation_store():
            self._perform_other_mes_actions()
            # compute row count on the evaluated data
            clean_all_columns = [c for c in input_df.columns if c not in output_df.columns]
            all_df = pd.concat([input_df[clean_all_columns], output_df], axis=1)
            add_statistics_to_evaluation(all_df, self.model_evaluation_store_folder_context)
            if self.has_custom_evaluation_metrics:
                add_custom_metrics_to_perf(custom_eval_metrics, self.model_evaluation_store_folder_context, self.prediction_type)

    @staticmethod
    def _build_folder_context_or_none(folder_path):
        if folder_path is not None and len(folder_path) > 0:
            return build_folder_context(folder_path)
        else:
            return None

    def _has_output_dataset(self):
        return self.output_dataset_smartname is not None and len(self.output_dataset_smartname) > 0

    def _has_metrics_dataset(self):
        return self.metrics_dataset_smartname is not None and len(self.metrics_dataset_smartname) > 0

    def _has_model_evaluation_store(self):
        return self.model_evaluation_store_folder_context is not None

    def _fetch_input_dataset_and_model_params(self):
        raise NotImplementedError()

    def _get_input_df(self):
        raise NotImplementedError()

    def _get_sample_dfs(self, input_df):
        raise NotImplementedError()

    def _compute_output_and_pred_df(self, input_df, input_df_copy_unnormalized):
        raise NotImplementedError()

    def _compute_metrics_df(self, output_df, pred_df):
        raise NotImplementedError()

    def _compute_custom_evaluation_metrics_df(self, output_df, pred_df, unprocessed_input_df, ref_sample_df):
        raise NotImplementedError()

    def _perform_other_mes_actions(self):
        return

    def _fix_output_dataset_schema(self, output_dataset, output_df):
        return

    def _get_output_from_pred(self, input_df_copy_unnormalized, pred_df):
        if self.recipe_desc.get("filterInputColumns", False):
            clean_kept_columns = [c for c in self.recipe_desc["keptInputColumns"] if c not in pred_df.columns]
        else:
            # also remove  ml assertions mask columns from the output
            clean_kept_columns = [c for c in input_df_copy_unnormalized.columns
                                  if
                                  c not in pred_df.columns and not c.startswith(MLAssertion.ML_ASSERTION_MASK_PREFIX)]
        output_df = pd.concat([input_df_copy_unnormalized[clean_kept_columns], pred_df], axis=1)

        return output_df

    def _fix_feature_and_target_dtypes(self):
        if self.evaluation_dataset_type == API_NODE_EVALUATION_DATASET_TYPE:
            feature_preproc = {API_NODE_LOGS_FEATURE_PREFIX + x: y for x, y in self.feature_preproc.items()}
            self.target_column_in_dataset = API_NODE_LOGS_FEATURE_PREFIX + self.target_column_in_dataset # features columns are prefixed on api node logs datasets

        elif self.evaluation_dataset_type == CLOUD_API_NODE_EVALUATION_DATASET_TYPE:
            feature_preproc = {CLOUD_API_NODE_LOGS_FEATURE_PREFIX + x: y for x, y in self.feature_preproc.items()}
            self.target_column_in_dataset = CLOUD_API_NODE_LOGS_FEATURE_PREFIX + self.target_column_in_dataset # features columns are prefixed on api node logs datasets

        else:
            feature_preproc = self.feature_preproc
        self.dtypes = utils.ml_dtypes_from_dss_schema(self.preparation_output_schema, feature_preproc, prediction_type=self.core_params["prediction_type"])

    def _get_used_threshold(self):
        if self.recipe_desc["overrideModelSpecifiedThreshold"]:
            return self.recipe_desc.get("forcedClassifierThreshold")
        else:
            return self.model_folder_context.read_json("user_meta.json").get("activeClassifierThreshold")

def _read_table(stream, names, dtypes, parse_dates):
    df = pd.read_table(
        stream, names=names, dtype=dtypes, header=None, sep='\t', doublequote=True,
        quotechar='"', parse_dates=parse_dates, float_precision="round_trip",
    )
    df = pandas_date_parser_compat(df, parse_dates, lambda col: pd.to_datetime(col, utc=True))
    return df


def load_input_dataframe(input_dataset, sampling, columns, dtypes, parse_date_columns):
    read_session_id = str(uuid.uuid4())
    with input_dataset._stream(infer_with_pandas=True, sampling=sampling, columns=columns,
                               read_session_id=read_session_id) as stream:
        input_df = _read_table(stream, names=columns, dtypes=dtypes, parse_dates=parse_date_columns)

        # stream seems to have run fine. 'Seems'. Verify that.
        # note to self: this call has to be made after the dataframe creation, because it is streamed
        input_dataset._verify_read(read_session_id)
    return input_df


def add_statistics_to_evaluation(df, evaluation_folder_context):
    evaluation_file = "_evaluation.json"
    evaluation = evaluation_folder_context.read_json(evaluation_file)
    evaluation['nbEvaluationRows'] = df.shape[0]
    evaluation_folder_context.write_json(evaluation_file, evaluation)

def add_custom_metrics_to_perf(custom_evaluation_metrics, evaluation_folder_context, prediction_type):
    if evaluation_folder_context.isfile(PERF_FILENAME):
        perf_json = evaluation_folder_context.read_json(PERF_FILENAME)
    else:
        perf_json = {}
    if prediction_type == doctor_constants.BINARY_CLASSIFICATION:
        metric_results = perf_json.get('tiMetrics', {})
        perf_json['tiMetrics'] = metric_results
    else:
        metric_results = perf_json.get('metrics', {})
        perf_json['metrics'] = metric_results
    if not 'customMetricsResults' in metric_results:
        metric_results['customMetricsResults'] = []
    metric_results['customMetricsResults'].extend(custom_evaluation_metrics)
    evaluation_folder_context.write_json(PERF_FILENAME, perf_json)

def get_reference_threshold_from_evaluation(evaluation_folder_context):
    evaluation_file = "_evaluation.json"
    evaluation = evaluation_folder_context.read_json(evaluation_file)
    return evaluation.get('referenceClassifierThreshold')


# wrap scoring code for model evaluation
# main task of the 'wrapper' is to do a few check before calling the scorers
def run_binary_scoring(modeling_params, decisions_and_cuts, target, target_map, sample_weight,
                       out_folder_context, assertions=None, test_unprocessed=None, test_X=None,
                       treat_metrics_failure_as_error=True):
    if decisions_and_cuts.is_empty():
        logger.error("Missing predictions")
        return

    binary_classif_scorer = BinaryClassificationModelScorer(
        modeling_params,
        out_folder_context,
        decisions_and_cuts,
        target,
        target_map,
        test_unprocessed=test_unprocessed,
        test_X=test_X,
        test_df_index=None,
        test_sample_weight=sample_weight,
        assertions=assertions)

    binary_classif_scorer.score(with_assertions=True, treat_metrics_failure_as_error=treat_metrics_failure_as_error)
    binary_classif_scorer.save(dump_predicted=False)


def run_multiclass_scoring(modeling_params, prediction_result, target, target_map, sample_weight,
                           out_folder_context, assertions=None, test_unprocessed=None, test_X=None,
                           treat_metrics_failure_as_error=True):
    if prediction_result.is_empty():
        return

    # Check that both classes are present, otherwise scoring fails
    n_classes_valid = np.unique(target).shape[0]
    if n_classes_valid < 2:
        return

    n_classes_pred = np.unique(prediction_result.unmapped_preds_not_declined).shape[0]
    if n_classes_pred < n_classes_valid:
        logger.warning("Some classes are not represented in the prediction column.")

    multiclass_classif_scorer = MulticlassModelScorer(
        modeling_params,
        out_folder_context,
        prediction_result,
        target,
        target_map,
        test_unprocessed=test_unprocessed,
        test_X=test_X,
        test_df_index=None,
        test_sample_weight=sample_weight,
        assertions=assertions)
    multiclass_classif_scorer.score(with_assertions=True, treat_metrics_failure_as_error=treat_metrics_failure_as_error)
    multiclass_classif_scorer.save(dump_predicted=False)


def run_regression_scoring(modeling_params, prediction_result, target, sample_weight, out_folder_context,
                           assertions=None, test_unprocessed=None, test_X=None, treat_metrics_failure_as_error=True):
    if prediction_result.is_empty():
        return

    regression_scorer = RegressionModelScorer(modeling_params,
                                              prediction_result,
                                              target,
                                              out_folder_context,
                                              test_unprocessed=test_unprocessed,
                                              test_X=test_X,
                                              test_df_index=None,
                                              test_sample_weight=sample_weight,
                                              assertions=assertions)
    regression_scorer.score(with_assertions=True, treat_metrics_failure_as_error=treat_metrics_failure_as_error)
    regression_scorer.save(dump_predicted=False)


# computing output metrics (the one-line-in-a-dataset version)
def compute_metrics_df(prediction_type, target_map, modeling_params, output_df, metrics, custom_metrics,
                       y, input_df, output_probabilities, sample_weight=None, treat_metrics_failure_as_error=True):
    """
    output_df :
    y :
    input_df : the unprocessed input data, for the custom scoring
    sample_weight :


    returns: a dataframe of a line of metrics
    """
    nonan = output_df[pd.notnull(output_df["prediction"])]
    preds = nonan["prediction"]
    if prediction_type in [doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS]:
        preds.replace(target_map, inplace=True)
        if output_probabilities:
            sorted_classes = sorted(target_map.keys(), key=lambda label: target_map[label])
            probas = nonan[["proba_%s" % label for label in sorted_classes]].values
        else:
            probas = None
    logger.info("Computing metrics")

    # compute metrics
    if prediction_type == doctor_constants.BINARY_CLASSIFICATION:
        computed_metrics = compute_binary_classification_metrics(modeling_params, y, preds, probas, sample_weight, input_df,
                                                                 treat_failure_as_error=treat_metrics_failure_as_error)
    elif prediction_type == doctor_constants.MULTICLASS:
        computed_metrics = MulticlassModelScorer.compute_multiclass_metrics(y.astype(int), preds, target_map, probas, sample_weight,
                                                                            input_df, modeling_params["metrics"],
                                                                            treat_failure_as_error=treat_metrics_failure_as_error)["metrics"]
    elif prediction_type == doctor_constants.REGRESSION:
        computed_metrics = compute_regression_metrics(modeling_params, y, preds, sample_weight, input_df, treat_failure_as_error=treat_metrics_failure_as_error)
    else:
        raise ValueError("Evaluation not supported for %s" % prediction_type)
    # TODO @deepHub evaluation metrics for new prediction types
    logger.info("Metrics computed : ")
    logger.info(computed_metrics)

    dt_now = utils.get_datetime_now_utc()
    if "customMetricsResults" in computed_metrics:
        desired_custom_metric_results_list = [a for a in computed_metrics["customMetricsResults"] if
                                              a["metric"]["name"] in custom_metrics]
        desired_custom_metric_results_dataframe = pd.DataFrame.from_dict({
            "custom_" + a["metric"]["name"]: [a["value"] if a["didSucceed"] else None] for a in
            desired_custom_metric_results_list
        })
        metrics_df = pd.concat([pd.DataFrame.from_dict({'date': [dt_now]}),
                                pd.DataFrame.from_dict(
                                    {a: [computed_metrics.get(a, None)] for a in metrics}),
                                desired_custom_metric_results_dataframe], axis=1)
    else:
        metrics_df = pd.concat([pd.DataFrame.from_dict({'date': [dt_now]}),
                                pd.DataFrame.from_dict(
                                    {a: [computed_metrics.get(a, None)] for a in metrics})], axis=1)
    return metrics_df

def compute_custom_evaluation_metrics_df(
        output_df, # normalised input + preds + probas
        custom_evaluation_metrics,
        prediction_type,
        target_map,
        y_valid,
        unprocessed_input_df, # true input (Eval dataset), no normalisation
        ref_sample_df,
        output_probabilities,
        sample_weight=None,
        treat_metrics_failure_as_error=True):
    is_classification = prediction_type in doctor_constants.CLASSIFICATION_TYPES

    output_copy = output_df.copy()
    input_copy = unprocessed_input_df.copy()
    preds = output_copy["prediction"]
    probas = None
    if is_classification:
        preds.replace(target_map, inplace=True)
        if output_probabilities:
            sorted_classes = sorted(target_map.keys(), key=lambda label: target_map[label])
            probas = output_copy[["proba_%s" % label for label in sorted_classes]].values
    logger.info("Computing metrics")

    return calculate_custom_evaluation_metrics(
        custom_evaluation_metrics,
        y_valid,
        preds,
        probas,
        input_copy,
        output_copy,
        ref_sample_df,
        sample_weight,
        is_classification,
        treat_metrics_failure_as_error
    )

# Note: some of this could (should ?) be factored with the classical model scoring
def compute_regression_metrics(modeling_params, valid_y, preds, sample_weight=None, unprocessed=None, treat_failure_as_error=True):
    metrics = {}

    metrics["evs"] = handle_failure(lambda: explained_variance_score(valid_y, preds, sample_weight=sample_weight), treat_failure_as_error)
    metrics["mape"] = handle_failure(lambda: mean_absolute_percentage_error(valid_y, preds, sample_weight=sample_weight), treat_failure_as_error)
    metrics["mae"] = handle_failure(lambda: mean_absolute_error(valid_y, preds, sample_weight=sample_weight), treat_failure_as_error)
    metrics["mse"] = handle_failure(lambda: mean_squared_error(valid_y, preds, sample_weight=sample_weight), treat_failure_as_error)
    metrics["rmse"] = sqrt(metrics["mse"]) if metrics["mse"] is not None else None
    metrics["rmsle"] = handle_failure(lambda:rmsle_score(valid_y, preds, sample_weight=sample_weight), treat_failure_as_error)
    metrics["r2"] = handle_failure(lambda: r2_score(valid_y, preds, sample_weight=sample_weight), treat_failure_as_error)
    metrics["pearson"] = handle_failure(lambda: pearson_correlation(valid_y, preds, sample_weight=sample_weight), treat_failure_as_error)

    if "customMetrics" in modeling_params["metrics"]:
        metrics["customMetricsResults"] = calculate_regression_custom_metrics(modeling_params["metrics"], unprocessed, valid_y,
                                                                              preds, sample_weight)
        if modeling_params["metrics"]["evaluationMetric"] == "CUSTOM":
            metrics["customScore"] = get_custom_score_from_custom_metrics_results(
                metrics["customMetricsResults"],
                modeling_params["metrics"]["customEvaluationMetricName"]
            )

    return metrics


def compute_binary_classification_metrics(modeling_params, valid_y, preds, probas=None, sample_weight=None, unprocessed=None, treat_failure_as_error=True):
    """

    :param boolean treat_failure_as_error: if true, raise error on metrics computation failures instead of logging a warning & returning None.
    :return dict: binary performance metrics in the form {metric_name: metric_value | None}
    """

    metrics = {}

    if "customMetrics" in modeling_params["metrics"]:
        metrics["customMetricsResults"] = calculate_overall_classification_custom_metrics(modeling_params["metrics"],
                                                                                          preds,
                                                                                          probas,
                                                                                          sample_weight,
                                                                                          unprocessed,
                                                                                          valid_y)
        if modeling_params["metrics"]["evaluationMetric"] == "CUSTOM":
            metrics["customScore"] = get_custom_score_from_custom_metrics_results(
                metrics["customMetricsResults"],
                modeling_params["metrics"]["customEvaluationMetricName"]
            )

    metrics["precision"] = handle_failure(lambda: precision_score(valid_y, preds, sample_weight=sample_weight), treat_failure_as_error)
    metrics["recall"] = handle_failure(lambda: recall_score(valid_y, preds, sample_weight=sample_weight), treat_failure_as_error)
    metrics["f1"] = handle_failure(lambda: f1_score(valid_y, preds, sample_weight=sample_weight), treat_failure_as_error)
    metrics["accuracy"] = handle_failure(lambda: accuracy_score(valid_y, preds, sample_weight=sample_weight), treat_failure_as_error)
    metrics["mcc"] = handle_failure(lambda: matthews_corrcoef(valid_y, preds, sample_weight=sample_weight), treat_failure_as_error)
    metrics["hammingLoss"] = handle_failure(lambda: hamming_loss(valid_y, preds, sample_weight=sample_weight), treat_failure_as_error)
    metrics["costMatrixGain"] = handle_failure(lambda: make_cost_matrix_score(modeling_params["metrics"])(valid_y, preds,
                                                                                                          sample_weight=sample_weight) / valid_y.shape[0],
                                               treat_failure_as_error)
    if probas is not None:
        metrics["auc"] = handle_failure(lambda: mroc_auc_score(valid_y, probas, sample_weight=sample_weight), treat_failure_as_error)
        metrics["logLoss"] = handle_failure(lambda: log_loss(valid_y, probas, sample_weight=sample_weight), treat_failure_as_error)
        metrics["lift"] = handle_failure(lambda: make_lift_score(modeling_params["metrics"])(valid_y, probas, sample_weight=sample_weight), treat_failure_as_error)
        metrics["calibrationLoss"] = handle_failure(lambda: calibration_loss_binary(valid_y, probas, sample_weight=sample_weight), treat_failure_as_error)
        metrics["averagePrecision"] = handle_failure(lambda: m_average_precision_score(valid_y, probas, sample_weight=sample_weight), treat_failure_as_error)
    return metrics


# extra columns added based on prediction+labels
def add_evaluation_columns(prediction_type, pred_df, y, outputs, target_mapping):
    if prediction_type == doctor_constants.REGRESSION:
        logger.info("PRED_DF = %s" % (pred_df.shape,))
        logger.info("Y = %s" % (y.shape,))

        if "error" in outputs:
            pred_df["error"] = pred_df["prediction"] - y
        if "error_decile" in outputs:
            pred_df["error_decile"] = pd.cut(pred_df["prediction"] - y, 10, labels=xrange(0, 10), retbins=True)[0]
        if "abs_error_decile" in outputs:
            pred_df["abs_error_decile"] = \
                pd.cut((pred_df["prediction"] - y).abs(), 10, labels=xrange(0, 10), retbins=True)[0]
        if "relative_error" in outputs:
            pred_df["relative_error"] = (pred_df["prediction"] - y) / y
            pred_df["relative_error"] = pred_df["relative_error"].replace({np.inf: np.nan, -np.inf: np.nan})
    elif prediction_type in [doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS]:
        logger.info("PRED_DF = %s" % (pred_df.shape,))
        logger.info("Y = %s" % (y.shape,))
        if "prediction_correct" in outputs:
            pred_df["prediction_correct"] = pred_df["prediction"].map(target_mapping) == y
    # TODO @deepHub evaluation columns for new prediction types
    return pred_df


# NOTE : The scoring methods used here can be confusing for the reader with regards to the Scoring recipe. Scoring here means giving scores to a model,
# regarding its predictions. (and not the .predict())
def process_input_df_skip_predict(input_df, model_folder_context, pipeline, modeling_params, target_map, prediction_type, prediction_column, proba_columns,
                                  with_sample_weight, dont_compute_performance, recipe_desc, cond_outputs, model_evaluation_store_folder_context):
    # We need the pipeline to drop rows, compute assertions, etc.
    transformed = pipeline.process(input_df)
    unprocessed = transformed['UNPROCESSED']

    # Isolate prediction columns after the drop rows
    pred_df = unprocessed[[prediction_column] + proba_columns]

    # Get target column
    y = None
    sample_weight = None
    if not dont_compute_performance:
        y = transformed['target']
        sample_weight = transformed['weight'] if with_sample_weight else None
    # Add evaluation columns and format prediction
    if y is not None:
        pred_df = add_evaluation_columns(prediction_type, pred_df, y, recipe_desc["outputs"], target_map)

    # ------ Scoring --------
    preds_df = pred_df.iloc[:, 0]
    probas_values = pred_df[proba_columns].values if proba_columns else []
    treat_perf_metrics_failure_as_error = recipe_desc.get("treatPerfMetricsFailureAsError", True)

    if prediction_type == doctor_constants.BINARY_CLASSIFICATION:
        if len(proba_columns) != 2:
            raise Exception(NO_PROBA_ERR_MESSAGE)

        if model_evaluation_store_folder_context is not None and dont_compute_performance:
            save_classification_statistics(preds_df,
                                           model_evaluation_store_folder_context,
                                           probas=probas_values,
                                           sample_weight=None,
                                           target_map=target_map)

        # Probability percentile & Conditional outputs
        handle_percentiles_and_cond_outputs(pred_df, recipe_desc, cond_outputs, model_folder_context, target_map)

        if model_evaluation_store_folder_context is not None and not dont_compute_performance:
            decisions_and_cuts = DecisionsAndCuts.from_probas(probas_values, target_map)
            run_binary_scoring(modeling_params,
                               decisions_and_cuts,
                               transformed["target"].astype(int),
                               target_map,
                               transformed["weight"] if with_sample_weight else None,
                               model_evaluation_store_folder_context,
                               assertions=transformed.get("assertions", None),
                               test_unprocessed=transformed["UNPROCESSED"],
                               test_X=transformed["TRAIN"],
                               treat_metrics_failure_as_error=treat_perf_metrics_failure_as_error)

    if prediction_type == doctor_constants.MULTICLASS:
        if len(proba_columns) == 0:
            raise Exception(NO_PROBA_ERR_MESSAGE)

        prediction_result = ClassificationPredictionResult(target_map,
                                                           probas_values,
                                                           preds_df.values,
                                                           pred_df[prediction_column].map(target_map).values)

        if model_evaluation_store_folder_context is not None and dont_compute_performance:
            save_classification_statistics(preds_df,
                                           model_evaluation_store_folder_context,
                                           probas=probas_values,
                                           sample_weight=None,
                                           target_map=target_map)

        if model_evaluation_store_folder_context is not None and not dont_compute_performance:
            run_multiclass_scoring(modeling_params, prediction_result,
                                   transformed["target"].astype(int),
                                   target_map,
                                   transformed["weight"] if with_sample_weight else None,
                                   model_evaluation_store_folder_context,
                                   assertions=transformed.get("assertions", None),
                                   test_unprocessed=transformed["UNPROCESSED"],
                                   test_X=transformed["TRAIN"],
                                   treat_metrics_failure_as_error=treat_perf_metrics_failure_as_error)

    elif prediction_type == doctor_constants.REGRESSION:
        prediction_result = PredictionResult(pred_df[prediction_column])

        if model_evaluation_store_folder_context is not None and dont_compute_performance:
            save_regression_statistics(pred_df.iloc[:, 0], model_evaluation_store_folder_context)

        if model_evaluation_store_folder_context is not None and not dont_compute_performance:
            run_regression_scoring(modeling_params, prediction_result, transformed["target"],
                                   transformed["weight"] if with_sample_weight else None,
                                   model_evaluation_store_folder_context,
                                   assertions=transformed.get("assertions", None),
                                   test_unprocessed=transformed["UNPROCESSED"],
                                   test_X=transformed["TRAIN"],
                                   treat_metrics_failure_as_error=treat_perf_metrics_failure_as_error)

    return pred_df, y, unprocessed, sample_weight, transformed.get("assertions", None)


def make_drift_mli_from_file(data_filename, mli_schema, preprocessing_params, prediction_type, prediction_column_name):
    mli = ModelLikeInfo()
    mli.preprocessing_params = preprocessing_params
    if prediction_column_name is not None and len(prediction_column_name) > 0:
        for column in mli_schema['columns']:
            if column['name'] == prediction_column_name:
                if prediction_type in {doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS}:
                    logger.info("prediction column (%s) in reference cast as string" % prediction_column_name)
                    # Cast prediction type as string so that make_df does not transform int to float for classification sc-124329
                    column['type'] = 'string'
                elif prediction_type == doctor_constants.REGRESSION:
                    logger.info("prediction column (%s) in reference cast as float" % prediction_column_name)
                    column['type'] = 'float'
    with gzip.open(data_filename, "rt") as data_file:
        mli.sample_df = DriftProtocol.make_df(data_file, mli_schema)
        if mli.sample_df.empty:
            raise EmptyDatasetException(
                "The dataset can not be empty. Check the input dataset or the recipe sampling configuration.")
    return mli


def make_drift_mli_from_df(df, preprocessing_params):
    if df.empty:
        raise EmptyDatasetException(
            "The dataset can not be empty. Check the input dataset or the recipe sampling configuration.")
    mli = ModelLikeInfo()
    mli.sample_df = df
    mli.preprocessing_params = preprocessing_params
    return mli


# Lightweight wrapper on top of drift internals to simplify drift computation within Keras/Regular evaluation recipes
def compute_input_data_drift(reference, current, data_drift_column_handling, treat_drift_failure_as_error, confidence_level, should_text_drift, text_drift_params):
    # Exact values are not really relevant here since we only need drift model accuracy (data drift accuracy):
    # - Confidence level is only used for interpreting the score
    # - Column importances are ignored

    drift_model_result = None
    univariate_drift_result = None
    text_drift_result = None

    data_drift_params = DataDriftParams(data_drift_column_handling, 20, True, confidence_level)
    preparator = DriftPreparator(reference, current, data_drift_params, can_text_drift=True, should_text_drift=should_text_drift, text_drift_params=text_drift_params)
    tabular_ref_df_prepared, tabular_cur_df_prepared, embedded_vector_ref, embedded_vector_cur, per_column_settings, can_compute_drift = preparator.prepare()

    if not can_compute_drift:
        return None, None, None, per_column_settings, len(reference.sample_df), len(current.sample_df)

    for column_settings in per_column_settings:
        if treat_drift_failure_as_error and 'errorMessage' in column_settings:
            raise DriftException(column_settings['errorMessage'])

    if not tabular_cur_df_prepared.empty and not tabular_ref_df_prepared.empty:
        logger.info("Computing drift of the following tabular columns : %s" % tabular_cur_df_prepared.columns)
        drift_model = DriftModel(tabular_ref_df_prepared, tabular_cur_df_prepared, None, data_drift_params.confidence_level, treat_drift_failure_as_error)
        drift_model_result = drift_model.compute_drift()
        univariate_drift_result = DriftUnivariate(
            tabular_ref_df_prepared, tabular_cur_df_prepared,
            data_drift_params.nb_bins,
            data_drift_params.compute_histograms,
            None,
            treat_drift_failure_as_error
        ).compute_drift()
    else:
        logger.info("No tabular columns : global and univariate drift not computed")

    if embedded_vector_cur and embedded_vector_ref:
        from dataiku.modelevaluation.drift.drift_embedding import DriftEmbedding
        logger.info("Computing drift on embeddings of the following columns : %s" % set(embedded_vector_cur.keys()))
        text_drift_result = DriftEmbedding(embedded_vector_ref, embedded_vector_cur).compute_drift()

    return drift_model_result, univariate_drift_result, text_drift_result, per_column_settings, len(reference.sample_df), len(current.sample_df)


def compute_drift(sample_input_df, sample_output_df, preprocessing_params, evaluation_store_folder_context, treat_drift_failure_as_error, prediction_type,
                  data_drift_column_handling, confidence_level, prediction_column_name=None, should_text_drift=False,
                  text_drift_params=None, reference_threshold=None, current_threshold=None):
    reference = None
    drift_reference_cvs_filename = "drift_reference.csv.gz"
    logger.info("Start computing drift metric...")

    try:
        if not evaluation_store_folder_context.isfile(drift_reference_cvs_filename):
            raise DriftException("Drift metric won't be computed because reference data are not available")

        if not preprocessing_params:
            raise DriftException("No preprocessing params were found for the model. Cannot compute drift")

        # Load reference dataset sample
        try:
            drift_reference_schema = evaluation_store_folder_context.read_json("drift_reference.json")
            with evaluation_store_folder_context.get_file_path_to_read(drift_reference_cvs_filename) as drift_ref_path:
                reference = make_drift_mli_from_file(drift_ref_path, drift_reference_schema, preprocessing_params, prediction_type, prediction_column_name)
        except EmptyDatasetException as e:
            raise EmptyDatasetException("%s (empty reference dataset)" % str(e.message))

        # Load current dataset sample
        try:
            current = make_drift_mli_from_df(sample_input_df, preprocessing_params)
        except EmptyDatasetException as e:
            raise EmptyDatasetException("%s (empty current dataset)" % str(e.message))

        # Compute data drift and prediction drift
        drift_model_result, univariate_drift_result, text_drift_result, per_column_settings, \
            ref_sample_size, cur_sample_size = compute_input_data_drift(reference, current, data_drift_column_handling, treat_drift_failure_as_error,
                                                                        confidence_level, should_text_drift, text_drift_params)

        if drift_model_result is not None:
            logger.info("Drift model accuracy (data drift) is %s" % drift_model_result.get("driftModelAccuracy", {"value": None})["value"])

        prediction_drift = compute_prediction_drift(sample_output_df, reference.sample_df, prediction_type, evaluation_store_folder_context,
                                                    treat_drift_failure_as_error, prediction_column_name, reference_threshold=reference_threshold)

        data_metrics = {"driftModelAccuracy": drift_model_result["driftModelAccuracy"] if drift_model_result else None, #  For backward compatibility
                        "driftResult": {
                            "univariateDriftResult": univariate_drift_result,
                            "driftModelResult": drift_model_result,
                            "perColumnSettings": per_column_settings,
                            "referenceSampleSize": ref_sample_size,
                            "currentSampleSize": cur_sample_size,
                            "referenceThreshold": reference_threshold,
                            "currentThreshold": current_threshold,
                            "predictionDriftResult": prediction_drift,
                            "textDriftResult": text_drift_result
                        },
                        }
        evaluation_store_folder_context.write_json('data_metrics.json', data_metrics)
        return univariate_drift_result["columns"] if univariate_drift_result else None, reference.sample_df

    except Exception as e:
        handle_drift_failure(str(e), treat_drift_failure_as_error)
        return None, reference.sample_df if reference else None


def compute_prediction_drift(sample_output_df, sample_reference_df, prediction_type, evaluation_store_folder_context, treat_drift_failure_as_error, prediction_column_name=None, reference_threshold=None):
    if prediction_column_name is None:
        prediction_column_name = "prediction"
    try:
        drift_reference_json_filename = 'drift_reference_prediction_statistics.json'

        if not evaluation_store_folder_context.isfile(drift_reference_json_filename):
            if prediction_column_name not in sample_reference_df:
                logger.info("No {} column in reference. No prediction drift computation".format(prediction_column_name))
                return None
            if prediction_type in (doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS):
                save_classification_statistics(sample_reference_df[prediction_column_name], evaluation_store_folder_context, filename=drift_reference_json_filename)
            elif prediction_type == doctor_constants.REGRESSION:
                save_regression_statistics(sample_reference_df[prediction_column_name].astype(float).dropna(), evaluation_store_folder_context, filename=drift_reference_json_filename)

        if sample_reference_df is not None and prediction_column_name in sample_reference_df:
            ref_prediction = cast_as_string(sample_reference_df[prediction_column_name]) if prediction_type in (doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS) else sample_reference_df[prediction_column_name]
        else:
            ref_prediction_statistics = evaluation_store_folder_context.read_json(drift_reference_json_filename)
            ref_prediction = create_prediction_series_from_statistics(ref_prediction_statistics, prediction_type, reference_threshold)

        cur_prediction = cast_as_string(sample_output_df[prediction_column_name]) if prediction_type in (doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS) else sample_output_df[prediction_column_name]
        return PredictionDrift(ref_prediction, cur_prediction, prediction_column_name).compute_drift()

    except Exception as e:
        handle_drift_failure("Failure during prediction drift computation : {}".format(e), treat_drift_failure_as_error)


def sample_and_store_dataframe(folder_context, input_df, schema, filename=SAMPLE_CSV, schema_filename=SAMPLE_SCHEMA_JSON, output_df=None, limit_sampling=True):
    sample_output_df = None

    with folder_context.get_file_path_to_write(filename) as sample_file:
        if limit_sampling:
            logger.info("Limit sampling is enabled. Will make the sample fit the {} max sampling rows / {} max sampling bytes".format(str(MAX_SAMPLING_ROWS), str(MAX_SAMPLING_BYTES)))
            if len(input_df.index) > MAX_SAMPLING_ROWS:
                sample_input_df = input_df.sample(n=MAX_SAMPLING_ROWS, random_state=1337)
                if output_df is not None:
                    sample_output_df = output_df.sample(n=MAX_SAMPLING_ROWS, random_state=1337)
                actual_rows = MAX_SAMPLING_ROWS
            else:
                sample_input_df = input_df.copy()
                if output_df is not None:
                    sample_output_df = output_df.copy()
                actual_rows = len(input_df.index)

            # We attempt 3 times to make the sample.csz.gz of a regular size. We raise an exception after those 3
            # attempts if it is still

            for i in range(3):
                dataframe_to_csv(sample_input_df, sample_file, gzip.open)
                sample_size = os.stat(sample_file).st_size
                if sample_size > MAX_SAMPLING_BYTES:
                    actual_rows = int(0.98 * actual_rows * (MAX_SAMPLING_BYTES / float(sample_size)))
                    sample_input_df = input_df.sample(n=actual_rows, random_state=1337)
                    if output_df:
                        sample_output_df = output_df.sample(n=actual_rows, random_state=1337)
                else:
                    break

            sample_size = os.stat(sample_file).st_size
            if sample_size > MAX_SAMPLING_BYTES:
                logger.warning("Warning, the sample size {}MB exceeds the soft limit of {}MB".format(round(sample_size * (2 ** -20), 2),
                                round(MAX_SAMPLING_BYTES * (2 ** -20), 2)))

            if len(input_df.index) != actual_rows:
                logger.warning("The sample of the evaluation dataset is too big. It has been truncated to " + str(
                    actual_rows) + " rows for statistic computation.")

        else:
            logger.info("The 'Limit sampling' option is disabled. We will save the entire sample and use it for drift "
                         "computations.")
            sample_input_df = input_df.copy()
            dataframe_to_csv(sample_input_df, sample_file, gzip.open)
            if output_df is not None:
                sample_output_df = output_df.copy()

        if sample_input_df.empty:
            raise EmptyDatasetException(
                "The sample of the evaluation dataset can not be empty. Check the input dataset or the recipe sampling "
                "configuration.")

    folder_context.write_json(schema_filename, schema)

    if sample_output_df is not None:
        return sample_input_df, sample_output_df

    else:
        return sample_input_df


def format_data_drift_column_handling(data_drift_column_handling, evaluation_dataset_type='CLASSIC'):
    if evaluation_dataset_type == API_NODE_EVALUATION_DATASET_TYPE:
        for column in data_drift_column_handling:
            if column.startswith(API_NODE_LOGS_FEATURE_PREFIX):
                data_drift_column_handling[column.split(API_NODE_LOGS_FEATURE_PREFIX)[1]] = data_drift_column_handling.pop(column)

    if evaluation_dataset_type == CLOUD_API_NODE_EVALUATION_DATASET_TYPE:
        for column in data_drift_column_handling:
            if column.startswith(CLOUD_API_NODE_LOGS_FEATURE_PREFIX):
                data_drift_column_handling[column.split(CLOUD_API_NODE_LOGS_FEATURE_PREFIX)[1]] = data_drift_column_handling.pop(column)

    return data_drift_column_handling


def handle_percentiles_and_cond_outputs(pred_df, recipe_desc, cond_outputs, model_folder_context, target_map):
    has_cond_output = recipe_desc["outputProbabilities"] and cond_outputs
    has_percentiles = recipe_desc["outputProbaPercentiles"] or (has_cond_output and
                                                                len([co for co in cond_outputs if
                                                                     co["input"] == "proba_percentile"]))
    if has_percentiles:
        model_perf = model_folder_context.read_json("perf.json")
        if "probaPercentiles" in model_perf and model_perf["probaPercentiles"]:
            percentile = pd.Series(model_perf["probaPercentiles"])
            proba_1 = u"proba_{}".format(
                safe_unicode_str(next(k for k, v in target_map.items()
                                      if v == 1)))
            pred_df["proba_percentile"] = pred_df[proba_1].apply(
                lambda p: percentile.where(percentile <= p).count() + 1)
        else:
            raise Exception("Probability percentiles are missing from model.")

    if has_cond_output:
        for co in cond_outputs:
            inp = pred_df[co["input"]]
            acc = inp.notnull()  # condition accumulator
            for r in co["rules"]:
                if r["operation"] == 'GT':
                    cond = inp > r["operand"]
                elif r["operation"] == 'GE':
                    cond = inp >= r["operand"]
                elif r["operation"] == 'LT':
                    cond = inp < r["operand"]
                elif r["operation"] == 'LE':
                    cond = inp <= r["operand"]
                pred_df.loc[acc & cond, co["name"]] = r["output"]
                acc = acc & (~cond)
            pred_df.loc[acc, co["name"]] = co.get("defaultOutput", "")
    if has_percentiles and not recipe_desc["outputProbaPercentiles"]:  # was only for conditional outputs
        pred_df.drop("proba_percentile", axis=1, inplace=True)


def recreate_preds_from_probas(probas, target_column_name, target, target_mapping, prediction_type, modeling_params):
    preds_classes = list(target_mapping.keys())

    if np.isnan(probas).any():
        raise Exception("Cannot recreate prediction from probabilities : some nan values are found")

    threshold = None
    if prediction_type == doctor_constants.BINARY_CLASSIFICATION:
        if modeling_params['autoOptimizeThreshold']:  # autoOptimizeThreshold = True => dont_compute_performance = False
            logger.info("No forced threshold for this Standalone Evaluation. Computing threshold for metric: %s" % modeling_params['metrics']['thresholdOptimizationMetric'])
            # The threshold calculation method can only use numeric values.
            target_0_1 = target.map({preds_classes[0]: 0, preds_classes[1]: 1})
            target_0_1_not_na_mask = target_0_1.notna()
            if not target_0_1_not_na_mask.any():
                raise Exception("No values in your target column %s %s matches the classes defined in your probability mapping %s" % (target_column_name, str(target.unique()), str(preds_classes)))
            # We need to filter out nan to compute the optimized threshold
            probas_valid_target = probas[target_0_1_not_na_mask]
            target_0_1.dropna(inplace=True)
            decisions_and_cuts = DecisionsAndCuts.from_probas(probas_valid_target, target_mapping)
            threshold = compute_optimized_threshold(target_0_1, decisions_and_cuts, modeling_params['metrics'])
        else:
            logger.info("No threshold optimization, using the forced classifier threshold")
            threshold = modeling_params['forcedClassifierThreshold']
        logger.info("Threshold is set to: %s" % threshold)

    return recreate_preds_from_probas_and_threshold(probas, target_mapping, prediction_type, threshold), threshold


def recreate_preds_from_probas_and_threshold(probas, target_mapping, prediction_type, threshold):
    preds_classes = list(target_mapping.keys())
    computed_pred = []
    for proba in probas:
        if prediction_type == doctor_constants.BINARY_CLASSIFICATION:
            new_pred = preds_classes[1] if proba[1] >= threshold else preds_classes[0]
        else:
            pred_class_index = np.argmax(proba)
            new_pred = None if pred_class_index >= len(preds_classes) else preds_classes[pred_class_index]
        computed_pred.append(new_pred)

    return pd.Series(computed_pred)


def handle_drift_failure(custom_err_message, treat_drift_failure_as_error):
    if treat_drift_failure_as_error:
        logger.exception(custom_err_message)
        raise DriftException(custom_err_message)
    else:
        logger.warning(custom_err_message)
        return None


def is_binary_classification(prediction_type):
    return prediction_type == doctor_constants.BINARY_CLASSIFICATION


def is_classification(prediction_type):
    return prediction_type in doctor_constants.CLASSIFICATION_TYPES


def is_regression(prediction_type):
    return prediction_type == doctor_constants.REGRESSION


def proba_definitions_from_probas(probas):
    # Trims to only the well defined probas
    return list(filter(
        lambda x: 'key' in x and 'value' in x and x['key'] and x['value'],
        probas)
    )
