import logging
import numpy as np

from dataiku.core import doctor_constants
from dataiku.doctor.causal.utils.models import CausalLearning
from dataiku.doctor.prediction.common import dump_pretrain_info
from dataiku.doctor.prediction.common import get_groups_for_hp_search_cv
from dataiku.doctor.prediction.common import get_initial_intrinsic_perf_data

logger = logging.getLogger(__name__)


def causal_classification_fit(core_params,
                              modeling_params,
                              transformed_train,
                              model_folder_context=None,
                              gridsearch_done_fn=None,
                              treatment_map=None):
    """
    Returns (clf, actual_params, prepared_train_X, initial_intrinsic_perf_data)
    Extracts the best estimator for grid search ones
    """
    train_X = transformed_train["TRAIN"].as_np_array()
    column_labels = [s for s in transformed_train["TRAIN"].columns()]
    train_y = transformed_train["target"].to_numpy()
    train_t = transformed_train["treatment"].to_numpy()

    causal_learning = CausalLearning(modeling_params)
    causal_algorithm = causal_learning.get_causal_algorithm(modeling_params['algorithm'], True)

    causal_hyperparameter_search_runner = causal_algorithm.get_search_runner(
        core_params, modeling_params, column_labels=column_labels, model_folder_context=model_folder_context, treatment_map=treatment_map
    )
    
    groups = get_groups_for_hp_search_cv(modeling_params, transformed_train)
    compute_propensity = modeling_params["metrics"].get("causalWeighting") == doctor_constants.INVERSE_PROPENSITY
    causal_hyperparameter_search_runner.initialize_search_context(train_X, train_y, train_t, groups=groups, compute_propensity=compute_propensity)
    best_causal_model = causal_hyperparameter_search_runner.get_best_estimator()

    if gridsearch_done_fn is not None:
        gridsearch_done_fn()

    # save a copy of train_{X,t} as prepared_{X,t} for the final output of causal_classification_fit
    prepared_X = train_X[::]
    prepared_t = train_t[::]

    dump_pretrain_info(best_causal_model, train_X, train_y)

    final_fit_parameters = causal_hyperparameter_search_runner.get_final_fit_parameters()

    dku_causal_model = causal_learning.get_dku_causal_model(best_causal_model, True, treatment_map=treatment_map)
    dku_causal_model.fit(train_X, train_y, train_t, **final_fit_parameters)

    initial_intrinsic_perf_data = get_initial_intrinsic_perf_data(train_X, False)
    if not causal_hyperparameter_search_runner.search_skipped():
        initial_intrinsic_perf_data.update(causal_hyperparameter_search_runner.get_score_info())

    modeling_params_copy = dict(modeling_params)
    actual_params = causal_algorithm.actual_params(modeling_params_copy, dku_causal_model, final_fit_parameters)
    actual_params["skipExpensiveReports"] = modeling_params["skipExpensiveReports"]
    logger.info("Output params are %s" % actual_params)

    return dku_causal_model, actual_params, prepared_X, prepared_t, initial_intrinsic_perf_data


def causal_regression_fit(core_params,
                          modeling_params,
                          transformed_train,
                          model_folder_context=None,
                          gridsearch_done_fn=None,
                          treatment_map=None):
    """
    Returns (clf, actual_params, prepared_train_X, initial_intrinsic_perf_data)
    Extracts the best estimator for grid search ones
    """
    train_X = transformed_train["TRAIN"].as_np_array()
    column_labels = [s for s in transformed_train["TRAIN"].columns()]
    train_y = transformed_train["target"].to_numpy()
    train_t = transformed_train["treatment"].to_numpy()

    causal_learning = CausalLearning(modeling_params)
    causal_algorithm = causal_learning.get_causal_algorithm(modeling_params['algorithm'], False)

    causal_hyperparameter_search_runner = causal_algorithm.get_search_runner(
        core_params, modeling_params, column_labels=column_labels, model_folder_context=model_folder_context, treatment_map=treatment_map
    )

    assert train_y.dtype == float

    groups = get_groups_for_hp_search_cv(modeling_params, transformed_train)
    compute_propensity = modeling_params["metrics"].get("causalWeighting") == doctor_constants.INVERSE_PROPENSITY
    causal_hyperparameter_search_runner.initialize_search_context(train_X, train_y, train_t, groups=groups, compute_propensity=compute_propensity)
    clf = causal_hyperparameter_search_runner.get_best_estimator()

    if gridsearch_done_fn is not None:
        gridsearch_done_fn()

    prepared_X = train_X[::]
    prepared_t = train_t[::]

    dump_pretrain_info(clf, train_X, train_y)

    final_fit_parameters = causal_hyperparameter_search_runner.get_final_fit_parameters()

    dku_causal_model = causal_learning.get_dku_causal_model(clf, False, treatment_map=treatment_map)

    dku_causal_model.fit(train_X, train_y, train_t, **final_fit_parameters)

    initial_intrinsic_perf_data = get_initial_intrinsic_perf_data(train_X, False)
    if not causal_hyperparameter_search_runner.search_skipped():
        initial_intrinsic_perf_data.update(causal_hyperparameter_search_runner.get_score_info())

    modeling_params_copy = dict(modeling_params)
    actual_params = causal_algorithm.actual_params(modeling_params_copy, dku_causal_model, final_fit_parameters)
    actual_params["skipExpensiveReports"] = modeling_params["skipExpensiveReports"]
    logger.info("Output params are %s" % actual_params)

    return dku_causal_model, actual_params, prepared_X, prepared_t, initial_intrinsic_perf_data
