import logging

from dataiku.doctor import step_constants
from dataiku.doctor.causal.perf.model_perf import CausalModelIntrinsicScorer
from dataiku.doctor.causal.train.training_handler import CausalTrainingHandler
from dataiku.doctor.causal.utils.misc import check_causal_prediction_type, TreatmentMap
from dataiku.doctor.commands import build_preprocessing_handler
from dataiku.doctor.diagnostics import default_diagnostics
from dataiku.doctor.diagnostics.causal import check_imbalanced_treatment, check_imbalanced_treatment_outcome
from dataiku.doctor.prediction.common import needs_hyperparameter_search
from dataiku.doctor.prediction.common import regridify_optimized_params
from dataiku.doctor.preprocessing_collector import PredictionPreprocessingDataCollector
from dataiku.doctor.utils.gpu_execution import log_nvidia_smi_if_use_gpu
from dataiku.doctor.utils.listener import ModelStatusContext
from dataiku.doctor.utils.listener import ProgressListener
from dataiku.doctor.utils.split import load_train_set, load_test_set
from dataiku.doctor.utils import unix_time_millis
from dataiku.doctor.utils import write_done_traininfo

logger = logging.getLogger(__name__)


def launch_training(core_params, modeling_sets, preprocessing_params, preprocessing_folder_context, split_folder_context, split_desc, operation_mode=None):
    start = unix_time_millis()
    check_causal_prediction_type(core_params["prediction_type"])
    log_nvidia_smi_if_use_gpu(core_params=core_params)
    diagnostics_params = core_params.get("diagnosticsSettings", {})

    logger.info("PPS is %s" % preprocessing_params)
    preprocessing_listener = ProgressListener()
    # Fill all the listeners ASAP to have correct progress data
    preprocessing_listener.add_future_steps(step_constants.PRED_REGULAR_PREPROCESSING_STEPS)
    for modeling_set in modeling_sets:
        listener = preprocessing_listener.new_child(ModelStatusContext(modeling_set["model_folder_context"], start))
        if needs_hyperparameter_search(modeling_set.get('modelingParams', {})):
            listener.add_future_step(step_constants.ProcessingStep.STEP_HYPERPARAMETER_SEARCHING)
        listener.add_future_steps(step_constants.PRED_REGULAR_TRAIN_STEPS)
        modeling_set["listener"] = listener

    if core_params.get("enable_multi_treatment", False):
        treatment_map = TreatmentMap(core_params["control_value"], core_params["treatment_values"], preprocessing_params["drop_missing_treatment_values"])
    else:
        treatment_map = None
    assert not core_params.get("time", {}).get("enabled", False), "Time ordering must be disabled for causal prediction"

    default_diagnostics.register_causal_predictions_callbacks(core_params)

    if operation_mode == "TRAIN_SPLITTED_AND_FULL":
        # Do hyperparameter search on full dataset + save the optimal model
        with preprocessing_listener.push_step(step_constants.ProcessingStep.STEP_LOADING_TRAIN):
            full_df = load_train_set(core_params, preprocessing_params, split_desc, "full", split_folder_context)

            collector = PredictionPreprocessingDataCollector(full_df, preprocessing_params)
            collector_data = collector.build()

            preproc_handler = build_preprocessing_handler(collector_data, core_params, preprocessing_folder_context,
                                                          preprocessing_params)
            pipeline = preproc_handler.build_preprocessing_pipeline(with_target=True, with_treatment=True, allow_empty_mf=False)

        with preprocessing_listener.push_step(step_constants.ProcessingStep.STEP_PREPROCESS_FULL):
            transformed_full = pipeline.fit_and_process(full_df)
            check_imbalanced_treatment(diagnostics_params, transformed_full["treatment"], treatment_map)
            check_imbalanced_treatment_outcome(diagnostics_params, transformed_full["treatment"], transformed_full["target"], core_params["prediction_type"]=="CAUSAL_REGRESSION", treatment_map=treatment_map)
            preproc_handler.save_data()
            preproc_handler.report(pipeline)

        for modeling_set in modeling_sets:

            modeling_params = modeling_set["modelingParams"]
            model_folder_context = modeling_set["model_folder_context"]
            listener = modeling_set["listener"]
            training_full_handler = CausalTrainingHandler(core_params, modeling_params, model_folder_context, listener, target_map=preproc_handler.target_map, treatment_map=treatment_map)
            dku_causal_model, actual_params, iipd = training_full_handler.train(transformed_full)
            if modeling_params.get("propensityModeling", {}).get("enabled", False):
                logger.info("Training propensity model")
                propensity_model = training_full_handler.train_propensity(transformed_full)
            else:
                propensity_model = None
            training_full_handler.save_model(actual_params, dku_causal_model, propensity_model)
            # Replace original hyperparameter space by best point from the search (no search on single point)
            modeling_set["modelingParams"] = regridify_optimized_params(actual_params["resolved"], modeling_params)

            with listener.push_step(step_constants.ProcessingStep.STEP_SCORING):
                CausalModelIntrinsicScorer(dku_causal_model, transformed_full["TRAIN"], model_folder_context, iipd).score_and_save()

    with preprocessing_listener.push_step(step_constants.ProcessingStep.STEP_LOADING_TRAIN):
        train_df = load_train_set(core_params, preprocessing_params, split_desc, "train", split_folder_context)
        for col in train_df:
            logger.info("Train col : %s (%s)" % (col, train_df[col].dtype))

    with preprocessing_listener.push_step(step_constants.ProcessingStep.STEP_COLLECTING):
        collector = PredictionPreprocessingDataCollector(train_df, preprocessing_params)
        collector_data = collector.build()

        preproc_handler = build_preprocessing_handler(collector_data, core_params, preprocessing_folder_context,
                                                      preprocessing_params)
        pipeline = preproc_handler.build_preprocessing_pipeline(with_target=True, with_treatment=True, allow_empty_mf=False)

    with preprocessing_listener.push_step(step_constants.ProcessingStep.STEP_PREPROCESS_TRAIN):
        transformed_train = pipeline.fit_and_process(train_df)
        check_imbalanced_treatment(diagnostics_params, transformed_train["treatment"], treatment_map)
        check_imbalanced_treatment_outcome(diagnostics_params, transformed_train["treatment"], transformed_train["target"], core_params["prediction_type"]=="CAUSAL_REGRESSION", treatment_map=treatment_map)
        preproc_handler.save_data()
        preproc_handler.report(pipeline)

    with preprocessing_listener.push_step(step_constants.ProcessingStep.STEP_LOADING_TEST):
        test_df = load_test_set(core_params, preprocessing_params, split_desc, split_folder_context)

    with preprocessing_listener.push_step(step_constants.ProcessingStep.STEP_PREPROCESS_TEST):
        test_df_index = test_df.index.copy()
        transformed_test = pipeline.process(test_df)

    preprocessing_listener.save_status()
    preprocessing_end = unix_time_millis()

    for modeling_set in modeling_sets:
        model_start = unix_time_millis()

        modeling_params = modeling_set["modelingParams"]
        model_folder_context = modeling_set["model_folder_context"]
        listener = modeling_set["listener"]
        training_handler = CausalTrainingHandler(core_params, modeling_params, model_folder_context, listener, target_map=preproc_handler.target_map, treatment_map=treatment_map)
        if modeling_params.get("propensityModeling", {}).get("enabled", False):
            logger.info("Training propensity model")
            propensity_model = training_handler.train_propensity(transformed_train)
        else:
            propensity_model = None
        dku_causal_model, actual_params, iipd = training_handler.train(transformed_train)
        training_handler.causal_score(dku_causal_model, test_df_index, transformed_test, propensity_model=propensity_model)
        if operation_mode != "TRAIN_SPLITTED_AND_FULL":
            # intrinsic scoring + model saving already done on model trained on full dataset for TRAIN_SPLITTED_AND_FULL
            CausalModelIntrinsicScorer(dku_causal_model, transformed_train["TRAIN"], model_folder_context, iipd).score_and_save()
            training_handler.save_model(actual_params, dku_causal_model, propensity_model)
        end = unix_time_millis()

        listeners_json = preprocessing_listener.merge(modeling_set["listener"])
        write_done_traininfo(model_folder_context, start, model_start, end, listeners_json,
                             end_preprocessing_time=preprocessing_end)
    return "ok"
