import logging
import numpy as np
import pandas as pd

from dataiku.doctor import step_constants
from dataiku.doctor.causal.utils.misc import TreatmentMap
from dataiku.doctor.causal.utils.models import get_predictions_from_causal_model_single_treatment, get_predictions_from_causal_model_multi_treatment
from dataiku.doctor.preprocessing_handler import PredictionPreprocessingHandler
from dataiku.doctor.utils.listener import ProgressListener

logger = logging.getLogger(__name__)


class CausalPredictionScoringHandler(object):

    def __init__(self, core_params, preprocessing_params, modeling_params, dku_causal_model, propensity_model, collector_data, preprocessing_folder_context):
        self.core_params = core_params
        self.preprocessing_params = preprocessing_params
        self.modeling_params = modeling_params
        self.dku_causal_model = dku_causal_model
        self.propensity_model = propensity_model
        self.propensity_scoring_enabled = propensity_model is not None

        self.collector_data = collector_data
        self.preprocessing_folder_context = preprocessing_folder_context
        self.listener = ProgressListener()

    def score(self, input_df, assign_treatment=False, treatment_ratio=0.):
        logger.info("Got a dataframe : %s" % str(input_df.shape))

        with self.listener.push_step(step_constants.ProcessingStep.STEP_COLLECTING_PREPROCESSING_DATA):
            preprocessing_handler = PredictionPreprocessingHandler.build(self.core_params, self.preprocessing_params, self.preprocessing_folder_context)
            preprocessing_handler.collector_data = self.collector_data
            preprocessing_pipeline = preprocessing_handler.build_preprocessing_pipeline(with_target=False, with_treatment=False)

        with self.listener.push_step(step_constants.ProcessingStep.STEP_PREPROCESS_FULL):
            logger.info("Processing it")
            input_df_copy = input_df.copy()
            transformed = preprocessing_pipeline.process(input_df_copy)

        with self.listener.push_step(step_constants.ProcessingStep.STEP_SCORING):
            logger.info("Predicting it")
            multi_treatment = self.core_params.get("enable_multi_treatment", False) and len(self.core_params.get("treatment_values", [])) > 2
            if multi_treatment:
                all_cates = get_predictions_from_causal_model_multi_treatment(self.dku_causal_model, transformed)
                pred_df = pd.DataFrame({"predicted_effect_" + t: cate for t, cate in all_cates.items()})
                pred_df["predicted_best_treatment"] = pred_df.idxmax(axis=1).map(lambda x: x[len("predicted_effect_"):])
            else:
                cate, _, _, _ = get_predictions_from_causal_model_single_treatment(self.dku_causal_model, transformed)
                pred_df = pd.DataFrame({"predicted_effect": cate}, index=input_df_copy.index)
            logger.info("Done predicting it")
            if self.propensity_scoring_enabled:
                proba_t = self.propensity_model.predict_proba(transformed["TRAIN"].as_np_array())
                if multi_treatment:
                    treatment_map = TreatmentMap(self.core_params["control_value"], self.core_params["treatment_values"], self.preprocessing_params["drop_missing_treatment_values"])
                    for k, v in treatment_map.items():
                        pred_df["propensity_" + k] = proba_t[:, v]
                else:
                    pred_df["propensity"] = proba_t[:,1]

            if assign_treatment and not multi_treatment:
                logger.info("Assigning treatment (sorting)")
                threshold = np.quantile(pred_df["predicted_effect"].values, (1-treatment_ratio))
                assignment = (pred_df["predicted_effect"] > threshold)
                pred_df["treatment_recommended"] = assignment
                logger.info("Done assigning treatment")

        return pred_df
