# encoding: utf-8

"""
Execute an evaluation recipe in Keras mode
Must be called in a Flow environment
"""
import logging
import sys

import pandas as pd

import dataiku
from dataiku.base.remoterun import read_dku_env_and_set
from dataiku.base.utils import ErrorMonitoringWrapper
from dataiku.base.utils import safe_unicode_str
from dataiku.core import debugging
from dataiku.core import dkujson
from dataiku.core import doctor_constants
from dataiku.core import schema_handling
from dataiku.doctor.deep_learning.keras_support import scored_dataset_generator
from dataiku.doctor.evaluation.base import EvaluateRecipe, load_input_dataframe, process_input_df_skip_predict, \
    compute_custom_evaluation_metrics_df
from dataiku.doctor.evaluation.base import add_evaluation_columns
from dataiku.doctor.evaluation.base import compute_metrics_df
from dataiku.doctor.evaluation.base import run_binary_scoring
from dataiku.doctor.evaluation.base import run_multiclass_scoring
from dataiku.doctor.evaluation.base import run_regression_scoring
from dataiku.doctor.evaluation.base import sample_and_store_dataframe
from dataiku.doctor.prediction.classification_scoring import save_classification_statistics
from dataiku.doctor.prediction.decisions_and_cuts import DecisionsAndCuts
from dataiku.doctor.prediction.regression_scoring import save_regression_statistics
from dataiku.doctor.preprocessing.assertions import MLAssertion
from dataiku.doctor.preprocessing_handler import PredictionPreprocessingHandler
from dataiku.doctor.utils import normalize_dataframe
from dataiku.doctor.utils.api_logs import API_NODE_EVALUATION_DATASET_TYPE, CLOUD_API_NODE_EVALUATION_DATASET_TYPE
from dataiku.doctor.utils.api_logs import CLASSICAL_EVALUATION_DATASET_TYPE
from dataiku.doctor.utils.api_logs import normalize_api_node_logs_dataset
from dataiku.doctor.utils.gpu_execution import log_nvidia_smi_if_use_gpu
from dataikuscoring.utils.prediction_result import ClassificationPredictionResult
from dataikuscoring.utils.prediction_result import PredictionResult

debugging.install_handler()
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')

logger = logging.getLogger(__name__)


class KerasEvaluateRecipe(EvaluateRecipe):

    def __init__(self, model_folder, input_dataset_smartname, output_dataset_smartname, metrics_dataset_smartname,
                 recipe_desc,
                 script, preparation_output_schema, cond_outputs=None, preprocessing_params=None,
                 evaluation_store_folder=None,
                 evaluation_dataset_type=None,
                 api_node_logs_config=None,
                 diagnostics_folder=None, fmi=None):
        super(KerasEvaluateRecipe, self).__init__(model_folder, input_dataset_smartname, None, output_dataset_smartname,
                                                  metrics_dataset_smartname,
                                                  recipe_desc, script, preparation_output_schema, cond_outputs,
                                                  preprocessing_params,
                                                  evaluation_store_folder, evaluation_dataset_type,
                                                  api_node_logs_config, diagnostics_folder, fmi)
        self.input_df = None
        self.y = None
        self.input_df_copy_unnormalized = None
        self.pred_df = None
        self.y_notnull = None
        self.output_df = None
        self.target_mapping = None
        self.modeling_params = None
        self.collector_data = None
        self.target_mapping = {}
        self.preprocessing_handler = None

    def _fetch_input_dataset_and_model_params(self):
        self.core_params = self.model_folder_context.read_json("core_params.json")
        self.prediction_type = self.core_params["prediction_type"]
        self.modeling_params = self.model_folder_context.read_json("rmodeling_params.json")
        self.collector_data = self.model_folder_context.read_json("collector_data.json")

        self.preprocessing_handler = PredictionPreprocessingHandler.build(self.core_params, self.preprocessing_params,
                                                                          self.model_folder_context)
        self.preprocessing_handler.collector_data = self.collector_data

        if self.prediction_type in [doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS]:
            self.target_mapping = {
                label: int(class_id)
                for label, class_id in self.preprocessing_handler.target_map.items()
            }
        self.target_column_in_dataset = self.core_params.get("target_variable")
        self.model_target_column = self.core_params.get("target_variable")
        self.input_dataset = dataiku.Dataset(self.input_dataset_smartname)

        self.columns, self.dtypes, self.parse_date_columns = dataiku.Dataset.get_dataframe_schema_st(
            self.preparation_output_schema["columns"], parse_dates=True, infer_with_pandas=False)

        self.input_dataset.preparation_requested_output_schema = self.preparation_output_schema

    def _get_input_df(self):
        if self.recipe_desc.get('skipScoring', False):
            input_df = load_input_dataframe(
                input_dataset=self.input_dataset,
                sampling=self.recipe_desc.get('selection', {"samplingMethod": "FULL"}),
                columns=self.columns,
                dtypes=self.dtypes,
                parse_date_columns=self.parse_date_columns,
            )
            return input_df

        else:
            # With Keras, the input and output datasets come together in the same method
            self.generate_input_and_output_df()

            return self.input_df

    def generate_input_and_output_df(self):
        output_generator = scored_dataset_generator(self.model_folder_context, self.input_dataset, self.recipe_desc,
                                                    self.script,
                                                    self.preparation_output_schema, self.cond_outputs,
                                                    output_y=not self.dont_compute_performance,
                                                    output_input_df=True,
                                                    evaluation_dataset_type=self.evaluation_dataset_type,
                                                    filter_input_columns=False)  # Done later

        logger.info("Starting to iterate")
        y_list = []
        pred_df_list = []
        y_notnull_list = []
        output_list = []
        input_df_list = []
        for output_dict in output_generator:
            input_df_list.append(output_dict["input_df"])
            output_list.append(output_dict["scored"])
            pred_df_list.append(output_dict["pred_df"])
            if not self.dont_compute_performance:
                y_list.append(output_dict["y"])
                y_notnull_list.append(output_dict["y_notnull"])
            logger.info("Generator generated a df {}".format(str(output_dict["scored"].shape)))

        if self.dont_compute_performance:
            self.y = None
            self.y_notnull = None
        else:
            self.y = pd.concat(y_list)
            self.y_notnull = pd.concat(y_notnull_list)

        self.input_df = pd.concat(input_df_list)
        self.output_df = pd.concat(output_list)
        self.pred_df = pd.concat(pred_df_list)

    def _get_sample_dfs(self, input_df):
        schema = {"columns": schema_handling.get_schema_from_df(input_df)}

        if self.recipe_desc.get('skipScoring', False):
            input_df_copy = input_df.copy()

            if not self.dont_compute_performance:
                input_df_copy = input_df_copy.dropna(subset=[self.model_target_column])
            sample_input_df = sample_and_store_dataframe(self.model_evaluation_store_folder_context, input_df_copy, schema, limit_sampling=self.recipe_desc.get('limitSampling', True))

            sample_pred_df = sample_input_df[[self.prediction_column] + self.proba_columns]

            # also remove ml assertions mask columns from the output
            clean_kept_columns = [c for c in sample_input_df.columns if c not in sample_pred_df.columns
                                  and not c.startswith(MLAssertion.ML_ASSERTION_MASK_PREFIX)]
            sample_output_df = pd.concat([sample_input_df[clean_kept_columns], sample_pred_df], axis=1)

            return sample_input_df, sample_output_df
        else:
            # In keras, the output dataset has already been generated. We need to sample the same rows as the input
            input_df_copy = input_df.copy()
            output_df_copy = self.output_df.copy()

            if not self.dont_compute_performance:
                input_df_copy = input_df_copy.dropna(subset=[self.model_target_column])
                output_df_copy = output_df_copy.loc[output_df_copy.index.isin(input_df_copy.index)]

            if self.evaluation_dataset_type in [API_NODE_EVALUATION_DATASET_TYPE, CLOUD_API_NODE_EVALUATION_DATASET_TYPE]:
                output_df_copy = normalize_api_node_logs_dataset(output_df_copy, self.feature_preproc, self.evaluation_dataset_type)
            elif self.evaluation_dataset_type == CLASSICAL_EVALUATION_DATASET_TYPE:
                normalize_dataframe(output_df_copy, self.feature_preproc)
            else:
                raise ValueError("Evaluation dataset type %s is not handled for keras models." % self.evaluation_dataset_type)

            return sample_and_store_dataframe(self.model_evaluation_store_folder_context, input_df_copy, schema,
                                             output_df=output_df_copy, limit_sampling=self.recipe_desc.get('limitSampling', True))

    def _compute_output_and_pred_df(self, input_df, input_df_copy_unnormalized):
        if self.recipe_desc.get('skipScoring', False):
            pipeline = self.preprocessing_handler.build_preprocessing_pipeline(
                with_target=not self.dont_compute_performance)

            self.pred_df, self.y_notnull, self.unprocessed, _, _ = process_input_df_skip_predict(input_df,
                                                                                                 self.model_folder_context,
                                                                                                 pipeline,
                                                                                                 self.modeling_params,
                                                                                                 self.target_mapping,
                                                                                                 self.prediction_type,
                                                                                                 self.prediction_column,
                                                                                                 self.proba_columns,
                                                                                                 None,
                                                                                                 self.dont_compute_performance,
                                                                                                 self.recipe_desc,
                                                                                                 None,
                                                                                                 self.model_evaluation_store_folder_context)
            return self._get_output_from_pred(input_df_copy_unnormalized, self.pred_df), self.pred_df

        else:
            if self.y is not None:
                self.output_df = add_evaluation_columns(self.prediction_type, self.output_df, self.y,
                                                        self.recipe_desc["outputs"], self.target_mapping)

            if self.recipe_desc.get("filterInputColumns", False):
                clean_kept_columns = [c for c in self.recipe_desc["keptInputColumns"] if c not in self.pred_df.columns]
            else:
                clean_kept_columns = [c for c in input_df_copy_unnormalized.columns if c not in self.pred_df.columns]

            return pd.concat([self.output_df[clean_kept_columns],
                              self.output_df[self.pred_df.columns],
                              self.output_df[self.recipe_desc["outputs"] if self.y is not None else []]],
                             axis=1), self.pred_df

    def _compute_metrics_df(self, output_df, pred_df):
        return compute_metrics_df(self.prediction_type, self.target_mapping, self.modeling_params, output_df,
                                  self.recipe_desc.get("metrics"),
                                  self.recipe_desc.get("customMetrics"),
                                  self.y_notnull,
                                  self.input_df if not self.recipe_desc.get('skipScoring') else self.unprocessed,
                                  self.recipe_desc.get("outputProbabilities"), None,
                                  treat_metrics_failure_as_error=self.recipe_desc.get("treatPerfMetricsFailureAsError", True))

    def _compute_custom_evaluation_metrics_df(self, output_df, pred_df, unprocessed_input_df, ref_sample_df):
        metrics_df = compute_custom_evaluation_metrics_df(
            output_df,
            self.recipe_desc.get("customEvaluationMetrics"),
            self.prediction_type,
            self.target_mapping,
            self.y_notnull,
            unprocessed_input_df,
            ref_sample_df,
            self.recipe_desc.get("outputProbabilities"),
            None,
            treat_metrics_failure_as_error=self.recipe_desc.get("treatPerfMetricsFailureAsError", True))

        return metrics_df

    def _perform_other_mes_actions(self):
        if self.recipe_desc.get('skipScoring', False):
            return  # Already done
        treat_perf_metrics_failure_as_error = self.recipe_desc.get("treatPerfMetricsFailureAsError", True)
        if self.prediction_type == doctor_constants.BINARY_CLASSIFICATION:
            sorted_classes = sorted(self.target_mapping.keys(), key=lambda label: self.target_mapping[label])
            proba_cols = [u"proba_{}".format(safe_unicode_str(c)) for c in sorted_classes]
            decisions_and_cuts = DecisionsAndCuts.from_probas(self.pred_df[proba_cols].values, self.target_mapping)
            # The model's optimized threshold or the user forced threshold is stored under the same key in the recipe.
            threshold = self.recipe_desc["forcedClassifierThreshold"]
            if not self.dont_compute_performance:
                self.modeling_params["autoOptimizeThreshold"] = False
                self.modeling_params["forcedClassifierThreshold"] = threshold
                run_binary_scoring(self.modeling_params, decisions_and_cuts, self.y_notnull,
                                   self.target_mapping, None, self.model_evaluation_store_folder_context,
                                   treat_metrics_failure_as_error=treat_perf_metrics_failure_as_error)
            else:
                prediction_result = decisions_and_cuts.get_prediction_result_for_nearest_cut(threshold)
                mapped_pred_col_series = pd.Series(prediction_result.preds)
                save_classification_statistics(mapped_pred_col_series,
                                               self.model_evaluation_store_folder_context,
                                               probas=self.pred_df[proba_cols].values,
                                               sample_weight=None,
                                               target_map=self.target_mapping)

        elif self.prediction_type == doctor_constants.MULTICLASS:
            sorted_classes = sorted(self.target_mapping.keys(), key=lambda label: self.target_mapping[label])
            proba_cols = [u"proba_{}".format(safe_unicode_str(c)) for c in sorted_classes]
            probas = self.pred_df[proba_cols].values
            if not self.dont_compute_performance:
                prediction_result = ClassificationPredictionResult(self.target_mapping, probas=probas,
                                                                   preds=self.pred_df["prediction"].values)
                run_multiclass_scoring(self.modeling_params, prediction_result,
                                       self.y_notnull.astype(int), self.target_mapping, None,
                                       self.model_evaluation_store_folder_context,
                                       treat_metrics_failure_as_error=treat_perf_metrics_failure_as_error)
            else:
                save_classification_statistics(self.pred_df["prediction"],
                                               self.model_evaluation_store_folder_context,
                                               probas=probas,
                                               sample_weight=None,
                                               target_map=self.target_mapping)

        elif self.prediction_type == doctor_constants.REGRESSION:
            if not self.dont_compute_performance:
                run_regression_scoring(self.modeling_params, PredictionResult(self.pred_df["prediction"]),
                                       self.y_notnull, None, self.model_evaluation_store_folder_context,
                                       treat_metrics_failure_as_error=treat_perf_metrics_failure_as_error)
            else:
                save_regression_statistics(self.pred_df["prediction"], self.model_evaluation_store_folder_context)


if __name__ == "__main__":
    debugging.install_handler()
    logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
    read_dku_env_and_set()

    with ErrorMonitoringWrapper():
        runner = KerasEvaluateRecipe(sys.argv[1], sys.argv[2],
                                     # there is no 'managedFolderSmartId' (sys.argv[3]) for Keras for legacy reasons
                                     sys.argv[4], sys.argv[5],
                                     dkujson.load_from_filepath(sys.argv[6]),
                                     dkujson.load_from_filepath(sys.argv[7]),
                                     dkujson.load_from_filepath(sys.argv[8]),
                                     dkujson.load_from_filepath(sys.argv[9]),
                                     dkujson.load_from_filepath(sys.argv[10]),
                                     sys.argv[11],
                                     evaluation_dataset_type=sys.argv[12],
                                     api_node_logs_config=dkujson.loads(sys.argv[13]),
                                     diagnostics_folder=sys.argv[14],
                                     fmi=sys.argv[15])
        log_nvidia_smi_if_use_gpu(recipe_desc=runner.recipe_desc)
        runner.run()
