import datetime as dt

import logging
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score, log_loss

from dataiku.core import doctor_constants
from dataiku.doctor import utils, step_constants
from dataiku.doctor.causal.perf.model_perf import _compute_causal_metrics_multi_treatment
from dataiku.doctor.causal.utils.metrics import compute_auuc_score, compute_qini_score, compute_net_uplift_score
from dataiku.doctor.causal.utils.misc import TreatmentMap
from dataiku.doctor.causal.utils.models import get_predictions_from_causal_model_single_treatment, get_predictions_from_causal_model_multi_treatment
from dataiku.doctor.prediction.classification_scoring import MulticlassModelScorer
from dataiku.doctor.preprocessing_handler import PredictionPreprocessingHandler
from dataiku.doctor.utils.listener import ProgressListener
from dataiku.doctor.utils.metrics import calibration_loss_binary
from dataiku.doctor.utils.metrics import mroc_auc_score

logger = logging.getLogger(__name__)


class CausalPredictionEvaluationHandler(object):
    def __init__(self, core_params, preprocessing_params, modeling_params, dku_causal_model, propensity_model, collector_data, model_folder_context):
        self.core_params = core_params
        self.preprocessing_params = preprocessing_params
        self.modeling_params = modeling_params
        self.dku_causal_model = dku_causal_model
        self.propensity_model = propensity_model
        self.collector_data = collector_data
        self.model_folder_context = model_folder_context
        self.listener = ProgressListener()

    def evaluate(self, input_df, output_metrics, compute_metrics=False):
        with self.listener.push_step(step_constants.ProcessingStep.STEP_COLLECTING_PREPROCESSING_DATA):
            preprocessing_handler = PredictionPreprocessingHandler.build(self.core_params, self.preprocessing_params, self.model_folder_context)
            preprocessing_handler.collector_data = self.collector_data
            pipeline = preprocessing_handler.build_preprocessing_pipeline(with_target=True, with_treatment=True)

        with self.listener.push_step(step_constants.ProcessingStep.STEP_PREPROCESS_FULL):
            logger.info("Processing it")
            input_df_copy = input_df.copy()
            transformed = pipeline.process(input_df_copy)

        with self.listener.push_step(step_constants.ProcessingStep.STEP_SCORING):
            logger.info("Predicting it")
            multi_treatment = self.core_params.get("enable_multi_treatment", False) and len(self.core_params.get("treatment_values", [])) > 2
            treatment_map = TreatmentMap(self.core_params["control_value"], self.core_params["treatment_values"], self.preprocessing_params["drop_missing_treatment_values"])
            if multi_treatment:
                all_cates = get_predictions_from_causal_model_multi_treatment(self.dku_causal_model, transformed)
                pred_df = pd.DataFrame({"predicted_effect_" + t: cate for t, cate in all_cates.items()})
                pred_df["predicted_best_treatment"] = pred_df.idxmax(axis=1).map(lambda x: x[len("predicted_effect_"):])
            else:
                cate, _, _, _ = get_predictions_from_causal_model_single_treatment(self.dku_causal_model, transformed)
                pred_df = pd.DataFrame({"predicted_effect": cate.ravel()}, index=input_df_copy.index)
            if self.propensity_model is not None:
                proba_t = self.propensity_model.predict_proba(transformed["TRAIN"].as_np_array())
                if multi_treatment:
                    for k, v in treatment_map.items():
                        pred_df["propensity_" + k] = proba_t[:, v]
                else:
                    pred_df["propensity"] = proba_t[:,1]
            else:
                proba_t = None

            if not compute_metrics:
                return pred_df, None
            else:
                y = transformed["target"].to_numpy()
                if multi_treatment:
                    treatment = transformed["treatment"].to_numpy().astype(int)
                    causal_metrics = _compute_causal_metrics_multi_treatment(y, treatment, all_cates, proba_t, treatment_map,
                                                                             self.modeling_params["metrics"])
                    # Currently sticking to the same metrics dataset schema for binary and multi-valued treatments
                    # We keep only the "*All" metrics, i.e. computed on {treatment=control or treatment=t} subsets
                    # for each possible treatment t
                    computed_metrics = {
                        "auuc": causal_metrics["causalPerfMultiAll"]["normalized"]["auuc"] if "auuc" in output_metrics else None,
                        "qini": causal_metrics["causalPerfMultiAll"]["normalized"]["qini"] if "qini" in output_metrics else None
                    }
                else:
                    cate = pred_df["predicted_effect"].to_numpy()
                    treatment = transformed["treatment"].to_numpy()

                    if "netUplift" in output_metrics:
                        if "netUpliftPoint" in self.modeling_params["metrics"]:
                            net_uplift_point = self.modeling_params["metrics"]["netUpliftPoint"]
                        else:
                            net_uplift_point = 0.5
                            logger.warning("Undefined net uplift point, using {} as default value".format(net_uplift_point))
                    if self.modeling_params["metrics"].get("causalWeighting") == doctor_constants.INVERSE_PROPENSITY:
                        sample_weights = np.ones(y.shape[0])
                        treatment_0_mask = treatment == 0
                        treatment_1_mask = treatment == 1
                        sample_weights[treatment_0_mask] = proba_t[treatment_0_mask, 0]
                        sample_weights[treatment_1_mask] = proba_t[treatment_1_mask, 1]
                    else:
                        sample_weights = None
                    computed_metrics = {
                        "auuc": compute_auuc_score(y, treatment, cate, sample_weights=sample_weights) if "auuc" in output_metrics else None,
                        "qini": compute_qini_score(y, treatment, cate, sample_weights=sample_weights) if "qini" in output_metrics else None,
                        "netUplift": compute_net_uplift_score(y, treatment, cate, net_uplift_point, sample_weights=sample_weights) if "netUplift" in output_metrics else None
                    }

                if self.propensity_model is not None:
                    if multi_treatment:
                        selection = ["propensity_" + k for k, v in sorted(treatment_map.items(), key=lambda x: x[1])]
                        all_propensity_arr = pred_df[selection].to_numpy()
                        computed_metrics["propensityAuc"] = mroc_auc_score(treatment, all_propensity_arr)
                        computed_metrics["propensityLogLoss"] = log_loss(treatment, all_propensity_arr)
                        computed_metrics["propensityCalibrationLoss"], _, _ = \
                            MulticlassModelScorer.get_calibration_metrics_and_curves(pd.Series(treatment), all_propensity_arr, treatment_map.mapping, None)
                    else:
                        computed_metrics["propensityAuc"] = roc_auc_score(treatment, pred_df["propensity"])
                        computed_metrics["propensityLogLoss"] = log_loss(treatment, pred_df["propensity"])
                        computed_metrics["propensityCalibrationLoss"] = calibration_loss_binary(transformed["treatment"], proba_t)

                logger.info("Metrics computed : ")
                logger.info(computed_metrics)
                metrics_df = pd.concat([pd.DataFrame.from_dict({'date': [utils.get_datetime_now_utc()]}),
                                        pd.DataFrame.from_dict(
                                            {a: [computed_metrics.get(a, None)] for a in output_metrics})], axis=1)
                return pred_df, metrics_df
