import logging

from dataiku.core import doctor_constants
from dataiku.doctor import step_constants
from dataiku.doctor.prediction.common import needs_hyperparameter_search
from dataiku.doctor.utils.model_io import to_pkl
from dataiku.doctor.utils import get_hyperparams_search_time_traininfo
from dataiku.doctor.utils import write_hyperparam_search_time_traininfo
from dataiku.doctor.causal.perf.model_perf import causal_prediction_scorer_with_valid
from dataiku.doctor.causal.train.fit import causal_classification_fit, causal_regression_fit
from dataiku.doctor.causal.utils.misc import check_causal_prediction_type
from dataiku.doctor.causal.utils.models import train_propensity_model
from dataiku.doctor.diagnostics.causal import check_treatment_randomness, check_treatment_positivity, check_propensity_model_calibration

logger = logging.getLogger(__name__)


class CausalTrainingHandler(object):

    def __init__(self, core_params, modeling_params, run_folder_context, listener, target_map=None, treatment_map=None):
        self.core_params = core_params
        self.modeling_params = modeling_params
        self.run_folder_context = run_folder_context
        self.listener = listener
        self.target_map = target_map
        self.treatment_map = treatment_map
        self.prediction_type = self.core_params["prediction_type"]
        self.diagnostics_params = core_params.get("diagnosticsSettings", {})
        check_causal_prediction_type(self.prediction_type)
        if self.prediction_type == doctor_constants.CAUSAL_BINARY_CLASSIFICATION:
            assert(self.target_map is not None), "Missing outcome remapping."
            assert(len(self.target_map) == 2), \
                "The outcome remapping is expected to hold exactly 2 classes for binary classifications, found {}.".format(len(self.target_map))

    def train_propensity(self, transformed_mf):
        t = transformed_mf["treatment"].values
        X = transformed_mf["TRAIN"].as_np_array()
        calibrate_proba = self.modeling_params["propensityModeling"]["calibrateProbabilities"]
        calibration_data_ratio = self.modeling_params["propensityModeling"]["calibrationDataRatio"]
        return train_propensity_model(X, t, calibrate_proba, calibration_data_ratio)

    def train(self, transformed_mf):
        if needs_hyperparameter_search(self.modeling_params):
            previous_search_time = get_hyperparams_search_time_traininfo(self.run_folder_context)
            initial_state = step_constants.ProcessingStep.STEP_HYPERPARAMETER_SEARCHING

            def gridsearch_done_fn():
                step = self.listener.pop_step()
                write_hyperparam_search_time_traininfo(self.run_folder_context, step["time"])
                self.listener.push_step(step_constants.ProcessingStep.STEP_FITTING)
                self.listener.save_status()

        else:
            initial_state = step_constants.ProcessingStep.STEP_FITTING
            previous_search_time = None

            def gridsearch_done_fn():
                pass

        with self.listener.push_step(initial_state, previous_duration=previous_search_time):
            if self.prediction_type == doctor_constants.CAUSAL_BINARY_CLASSIFICATION:
                (dku_causal_model, actual_params, prepared_X, prepared_t, iipd) = \
                    causal_classification_fit(self.core_params,
                                              self.modeling_params,
                                              transformed_mf,
                                              model_folder_context=self.run_folder_context,
                                              gridsearch_done_fn=gridsearch_done_fn,
                                              treatment_map=self.treatment_map)

            elif self.prediction_type == doctor_constants.CAUSAL_REGRESSION:
                (dku_causal_model, actual_params, prepared_X, prepared_t, iipd) = causal_regression_fit(self.core_params,
                                                                                                        self.modeling_params,
                                                                                                        transformed_mf,
                                                                                                        model_folder_context=self.run_folder_context,
                                                                                                        gridsearch_done_fn=gridsearch_done_fn,
                                                                                                        treatment_map=self.treatment_map)

            else:
                raise ValueError("Invalid prediction type: {}".format(self.prediction_type))

        return dku_causal_model, actual_params, iipd

    def save_model(self, actual_params, dku_causal_model, propensity_model=None):
        with self.listener.push_step(step_constants.ProcessingStep.STEP_SAVING):
            logger.info("PICKLING %s" % dku_causal_model)
            to_pkl(dku_causal_model, self.run_folder_context, "causal_model.pkl")
            self.run_folder_context.write_json("actual_params.json", actual_params)
            if propensity_model is not None:
                logger.info("PICKLING %s" % propensity_model)
                to_pkl(propensity_model, self.run_folder_context, "propensity_model.pkl")

    def causal_score(self, dku_causal_model, test_df_index, transformed_test, propensity_model=None):
        with self.listener.push_step(step_constants.ProcessingStep.STEP_SCORING):
            scorer = causal_prediction_scorer_with_valid(self.modeling_params, dku_causal_model, transformed_test,
                                                         self.run_folder_context, test_df_index, self.prediction_type == doctor_constants.CAUSAL_REGRESSION,
                                                         propensity_model=propensity_model, treatment_map=self.treatment_map)
            scorer.score()
            if propensity_model is not None:
                check_treatment_randomness(self.diagnostics_params, scorer.perf_data, self.treatment_map)
                check_treatment_positivity(self.diagnostics_params, scorer.perf_data, self.treatment_map)
                check_propensity_model_calibration(self.diagnostics_params, scorer.perf_data, self.treatment_map)
            scorer.save()
