import logging
import traceback

import numpy as np
import pandas as pd
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier

from dataiku.doctor.plugins.common_algorithm import PluginPredictionAlgorithm
from dataiku.doctor.prediction.common import CategoricalHyperparameterDimension
from dataiku.doctor.prediction.common import ClassicalPredictionAlgorithm
from dataiku.doctor.prediction.common import FloatHyperparameterDimension
from dataiku.doctor.prediction.common import HyperparametersSpace
from dataiku.doctor.prediction.common import IntegerHyperparameterDimension
from dataiku.doctor.prediction.common import SVMHyperparametersSpace
from dataiku.doctor.prediction.common import TrainableModel
from dataiku.doctor.prediction.common import TreesHyperparametersSpace
from dataiku.doctor.prediction.common import dump_pretrain_info
from dataiku.doctor.prediction.common import get_groups_for_hp_search_cv
from dataiku.doctor.prediction.common import get_initial_intrinsic_perf_data
from dataiku.doctor.prediction.common import get_max_features_dimension
from dataiku.doctor.prediction.common import get_selection_mode
from dataiku.doctor.prediction.common import get_svm_gamma_params_from_clf_params
from dataiku.doctor.prediction.common import prepare_multiframe
from dataiku.doctor.prediction.common import replace_value_by_empty
from dataiku.doctor.prediction.common import safe_del
from dataiku.doctor.prediction.common import safe_positive_int
from dataiku.doctor.prediction.common import scikit_model
from dataiku.doctor.prediction.deep_neural_network_prediction import DeepNeuralNetworkClassification
from dataiku.doctor.prediction.lars import DkuLassoLarsClassifier
from dataiku.doctor.prediction.lightgbm_prediction import LightGBMClassification
from dataiku.doctor.prediction.xgboost_trainable_model import XGBoostTrainableModel
from dataiku.doctor.sparse import prepare_multiframe_as_sparse_if_needed
from dataiku.doctor.utils import doctor_constants
from dataiku.doctor.utils.skcompat import gbt_skcompat_actual_params, dku_logistic_regression, dku_calibrated_classifier_cv
from dataiku.doctor.utils.skcompat import gbt_skcompat_hp_space
from dataiku.doctor.utils.skcompat import get_base_estimator
from dataiku.doctor.utils.skcompat import sgd_skcompat_actual_params
from dataiku.doctor.utils.skcompat import sgd_skcompat_hp_space
from dataiku.doctor.utils.gpu_execution import get_gpu_config_from_core_params, get_single_gpu_id_from_gpu_device, XGBOOSTGpuCapability

logger = logging.getLogger(__name__)

def get_class_weight_dict(train_y):
    # Compute class weight to enforce consistency across splits
    unique_values = np.unique(train_y)
    n_classes = unique_values.size
    class_weight_dict = {
        y: float(len(train_y)) / (n_classes * np.sum(train_y == y))
        for y in unique_values
    }
    return class_weight_dict

CLASSIFICATION_ALGORITHMS = {}


def register_classification_algorithm(algorithm):
    CLASSIFICATION_ALGORITHMS[algorithm.algorithm] = algorithm()


##############################################################
# IMPORTANT
#    If you add any settings here, you MUST add them to
#    classification.tmpl / regression.tmpl for the notebook export
##############################################################

class ScikitClassification(ClassicalPredictionAlgorithm):
    algorithm = "SCIKIT_MODEL"

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        estimator = scikit_model(modeling_params)
        hyperparameters_space = HyperparametersSpace.from_definition(input_hp_space)
        return TrainableModel(estimator, hyperparameters_space=hyperparameters_space)

    def actual_params(self, ret, clf, fit_params):
        amp = {"resolved": ret, "other": {}}
        return amp

    def get_search_settings(self, hyperparameter_search_params, trainable_model):
        search_settings = super(ScikitClassification, self).get_search_settings(
            hyperparameter_search_params, trainable_model
        )
        # Force hyperparameter search size to 1
        search_settings.n_iter = 1
        return search_settings


register_classification_algorithm(ScikitClassification)


class RFClassification(ClassicalPredictionAlgorithm):
    algorithm = "RANDOM_FOREST_CLASSIFICATION"

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        hp_space = TreesHyperparametersSpace.from_definition(
            input_hp_space,
            hp_names_to_dimension_class={
                "min_samples_leaf": IntegerHyperparameterDimension,
                "n_estimators": IntegerHyperparameterDimension
            },
            hp_names_to_dimension={
                "max_features": get_max_features_dimension(input_hp_space),
                "max_depth": IntegerHyperparameterDimension(input_hp_space["max_tree_depth"])
            }
        )

        estimator = RandomForestClassifier(random_state=1337, n_jobs=input_hp_space["n_jobs"])
        return TrainableModel(estimator, hyperparameters_space=hp_space)

    def actual_params(self, ret, clf, fit_params):
        amp = {"resolved": ret, "other": {}}
        safe_del(ret, "rf_classifier_grid")
        params = clf.get_params()
        logger.info("Obtained RF CLF params: %s " % params)

        ret["rf"] = {
            "estimators": len(clf.estimators_),
            "max_tree_depth" : params["max_depth"],
            "min_samples_leaf": params["min_samples_leaf"],
            "selection_mode": get_selection_mode(params["max_features"]),
        }

        if ret["rf"]["selection_mode"] == "number":
            ret["rf"]["max_features"] = params["max_features"]
        if ret["rf"]["selection_mode"] == "prop":
            ret["rf"]["max_feature_prop"] = params["max_features"]

        amp["other"]["rf_min_samples_split"] = params["min_samples_split"]

        return amp


register_classification_algorithm(RFClassification)


class ExtraTreesClassification(ClassicalPredictionAlgorithm):
    algorithm = "EXTRA_TREES"

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        hp_space = TreesHyperparametersSpace.from_definition(
            input_hp_space,
            hp_names_to_dimension_class={
                "min_samples_leaf": IntegerHyperparameterDimension,
                "n_estimators": IntegerHyperparameterDimension
            },
            hp_names_to_dimension={
                "max_features": get_max_features_dimension(input_hp_space),
                "max_depth": IntegerHyperparameterDimension(replace_value_by_empty(input_hp_space["max_tree_depth"], value=0))
            }
        )

        estimator = ExtraTreesClassifier(random_state=1337, n_jobs=input_hp_space["n_jobs"])
        return TrainableModel(estimator, hyperparameters_space=hp_space)

    def actual_params(self, ret, clf, fit_params):
        amp = {"resolved": ret, "other": {}}
        safe_del(ret, "extra_trees_grid")
        params = clf.get_params()
        logger.info("Obtained ET CLF params: %s " % params)

        ret["extra_trees"] = {
            "estimators": len(clf.estimators_),
            "njobs" : params["n_jobs"] if params["n_jobs"] > 0 else -1,
            "max_tree_depth" : params["max_depth"],
            "min_samples_leaf": params["min_samples_leaf"],
            "selection_mode": get_selection_mode(params["max_features"]),
        }
        if ret["extra_trees"]["selection_mode"] == "number":
            ret["extra_trees"]["max_features"] = params["max_features"]
        if ret["extra_trees"]["selection_mode"] == "prop":
            ret["extra_trees"]["max_feature_prop"] = params["max_features"]

        amp["other"]["rf_min_samples_split"] = params["min_samples_split"]
        return amp


register_classification_algorithm(ExtraTreesClassification)


class GBTClassification(ClassicalPredictionAlgorithm):
    algorithm = "GBT_CLASSIFICATION"

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        gbt_skcompat_hp_space(input_hp_space)
        hp_space = HyperparametersSpace.from_definition(
            input_hp_space,
            hp_names_to_dimension_class={
                "min_samples_leaf": IntegerHyperparameterDimension,
                "n_estimators": IntegerHyperparameterDimension,
                "learning_rate": FloatHyperparameterDimension,
                "loss": CategoricalHyperparameterDimension,
                "max_depth": IntegerHyperparameterDimension
            },
            hp_names_to_dimension={
                "max_features": get_max_features_dimension(input_hp_space),
            }
        )

        estimator = GradientBoostingClassifier(random_state=1337, verbose=1)
        return TrainableModel(estimator, hyperparameters_space=hp_space)

    def actual_params(self, ret, clf, fit_params):
        amp = {"resolved": ret, "other": {}}
        safe_del(ret, "gbt_classifier_grid")
        params = clf.get_params()
        logger.info("GBT Params are %s " % params)

        ret["gbt"] = {
            "n_estimators": len(clf.estimators_),
            "max_depth": params["max_depth"],
            "learning_rate" : params["learning_rate"],
            "min_samples_leaf": params["min_samples_leaf"],
            "selection_mode": get_selection_mode(params["max_features"]),
            "loss" : params["loss"]
        }
        if ret["gbt"]["selection_mode"] == "number":
            ret["gbt"]["max_features"] = params["max_features"]
        if ret["gbt"]["selection_mode"] == "prop":
            ret["gbt"]["max_feature_prop"] = params["max_features"]
        gbt_skcompat_actual_params(ret["gbt"])

        return amp


register_classification_algorithm(GBTClassification)


class DecisionTreeClassification(ClassicalPredictionAlgorithm):
    algorithm = "DECISION_TREE_CLASSIFICATION"

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        hp_space = HyperparametersSpace.from_definition(
            input_hp_space,
            hp_names_to_dimension_class={
                "max_depth": IntegerHyperparameterDimension,
                "min_samples_leaf": IntegerHyperparameterDimension,
                "criterion": CategoricalHyperparameterDimension,
                "splitter": CategoricalHyperparameterDimension
            }
        )

        estimator = DecisionTreeClassifier(random_state=1337)
        return TrainableModel(estimator, hyperparameters_space=hp_space)

    def actual_params(self, ret, clf, fit_params):
        amp = {"resolved": ret, "other": {}}
        safe_del(ret, "dtc_classifier_grid")
        params = clf.get_params()
        logger.info("DT params are %s " % params)

        ret["dt"] = {
            "max_depth" : params["max_depth"],
            "criterion" : params["criterion"],
            "min_samples_leaf" : params["min_samples_leaf"],
            "splitter" : params["splitter"]
        }
        return amp


register_classification_algorithm(DecisionTreeClassification)


class LogisticRegClassification(ClassicalPredictionAlgorithm):
    algorithm = "LOGISTIC_REGRESSION"

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        hp_space = HyperparametersSpace.from_definition(
            input_hp_space,
            hp_names_to_dimension_class={
                "C": FloatHyperparameterDimension,
                "penalty": CategoricalHyperparameterDimension
            }
        )

        # In the multinomial case only saga solver can be used with L1 regularization.
        # It seems however to be slower than lbfgs, that will hence be preferred when L1 regularization is not used.
        multi_class = input_hp_space["multi_class"]
        if multi_class == "multinomial":
            l1_enabled = input_hp_space["penalty"]["values"]["l1"]["enabled"]
            solver = "saga" if l1_enabled else "lbfgs"
        else:
            solver = "liblinear"

        estimator = dku_logistic_regression(
            multi_class=multi_class,
            solver=solver,
            random_state=1337
        )

        return TrainableModel(estimator, hyperparameters_space=hp_space)

    def actual_params(self, ret, clf, fit_params):
        amp = {"resolved": ret, "other": {}}
        safe_del(ret, "logit_grid")
        params = clf.get_params()
        logger.info("LR Params are %s " % params)
        ret["logit"] = {
            "penalty":  params["penalty"],
            "multi_class":  params["multi_class"],
            "C": params["C"]
        }
        return amp


register_classification_algorithm(LogisticRegClassification)


class SVCClassification(ClassicalPredictionAlgorithm):
    algorithm = "SVC_CLASSIFICATION"

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        gamma_compatible_kernel_enabled = any(
            input_hp_space["kernel"]["values"][kernel]["enabled"]
            for kernel in ["rbf", "sigmoid", "poly"]
        )

        if not gamma_compatible_kernel_enabled:
            hp_space = HyperparametersSpace.from_definition(
                input_hp_space,
                hp_names_to_dimension_class={
                    "C": FloatHyperparameterDimension,
                    "kernel": CategoricalHyperparameterDimension
                }
            )

        else:
            hp_space = SVMHyperparametersSpace.from_definition(
                input_hp_space,
                hp_names_to_dimension_class={
                    "C": FloatHyperparameterDimension,
                    "gamma": CategoricalHyperparameterDimension,
                    "custom_gamma": FloatHyperparameterDimension,
                    "kernel": CategoricalHyperparameterDimension
                }
            )

        estimator = SVC(
            coef0=input_hp_space['coef0'],
            tol=input_hp_space['tol'],
            probability=True,
            max_iter=input_hp_space['max_iter']
        )

        return TrainableModel(estimator, hyperparameters_space=hp_space)

    def actual_params(self, ret, clf, fit_params):
        amp = {"resolved": ret, "other": {}}
        params = clf.get_params()
        logger.info("Selected SVC Params are %s " % params)
        safe_del(ret, "svc_grid")

        ret["svm"] = {
            "C": params["C"],
            "kernel": params["kernel"],
            "tol": params["tol"],
            "max_iter": params["max_iter"],
            "coef0": params["coef0"]
        }
        ret["svm"].update(get_svm_gamma_params_from_clf_params(params))

        return amp


register_classification_algorithm(SVCClassification)


class SGDClassification(ClassicalPredictionAlgorithm):
    algorithm = "SGD_CLASSIFICATION"

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        input_hp_space = sgd_skcompat_hp_space(input_hp_space)
        hp_space = HyperparametersSpace.from_definition(
            input_hp_space,
            hp_names_to_dimension_class={
                "alpha": FloatHyperparameterDimension,
                "loss": CategoricalHyperparameterDimension,
                "penalty": CategoricalHyperparameterDimension
            }
        )

        estimator = SGDClassifier(
            l1_ratio=input_hp_space["l1_ratio"],
            shuffle=True,
            max_iter=input_hp_space["max_iter"],
            tol=input_hp_space["tol"],
            n_jobs=input_hp_space["n_jobs"],
            random_state=1337
        )

        return TrainableModel(estimator, hyperparameters_space=hp_space)

    def actual_params(self, ret, clf, fit_params):
        amp = {"resolved": ret, "other": {}}
        params = clf.get_params()
        logger.info("Selected SGD Params are %s " % params)
        safe_del(ret, "sgd_grid")
        ret["sgd"] = {
            "loss": params["loss"],
            "penalty": params["penalty"],
            "alpha": params["alpha"],
            "l1_ratio": params["l1_ratio"],
            "n_jobs": params["n_jobs"],
            "n_iter": clf.n_iter_
        }
        sgd_skcompat_actual_params(ret["sgd"])
        return amp


register_classification_algorithm(SGDClassification)


class KNNClassification(ClassicalPredictionAlgorithm):
    algorithm = "KNN"

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        hp_space = HyperparametersSpace.from_definition(
            input_hp_space,
            hp_names_to_dimension={
                "n_neighbors": IntegerHyperparameterDimension(input_hp_space["k"])
            }
        )

        estimator = KNeighborsClassifier(
            weights="distance" if input_hp_space["distance_weighting"] else "uniform",
            algorithm=input_hp_space["algorithm"],
            leaf_size=input_hp_space["leaf_size"],
            p=input_hp_space["p"]
        )

        return TrainableModel(estimator, hyperparameters_space=hp_space, supports_sample_weights=False)

    def actual_params(self, ret, clf, fit_params):
        amp = {"resolved": ret, "other": {}}
        params = clf.get_params()
        logger.info("Selected KNN Params are %s " % params)
        safe_del(ret, "knn_grid")
        ret["knn"] = {
            "k" :  params["n_neighbors"],
            "distance_weighting":  params["weights"] == "distance",
            "algorithm": params["algorithm"],
            "p": params["p"],
            "leaf_size": params["leaf_size"],
        }
        return amp


register_classification_algorithm(KNNClassification)


class LARSClassification(ClassicalPredictionAlgorithm):
    algorithm = "LARS"

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        hp_space = HyperparametersSpace({})
        estimator = DkuLassoLarsClassifier(
            max_var=modeling_params["lars_grid"]["max_features"],
            K=modeling_params["lars_grid"]["K"]
        )
        # LARS Grid is not a real grid
        return TrainableModel(estimator, hyperparameters_space=hp_space, supports_sample_weights=False)

    def actual_params(self, ret, clf, fit_params):
        amp = {"resolved": ret, "other": {}}
        return amp


register_classification_algorithm(LARSClassification)


class NeuralNetworkClassification(ClassicalPredictionAlgorithm):
    algorithm = "NEURAL_NETWORK"

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        hp_space = HyperparametersSpace.from_definition(
            input_hp_space,
            hp_names_to_dimension={
                'hidden_layer_sizes': IntegerHyperparameterDimension(input_hp_space["layer_sizes"])
            }
        )

        estimator = MLPClassifier(
            activation=input_hp_space["activation"],
            solver=input_hp_space["solver"],
            alpha=input_hp_space["alpha"],
            batch_size="auto" if input_hp_space["auto_batch"] else input_hp_space["batch_size"],
            max_iter=input_hp_space["max_iter"],
            random_state=input_hp_space["seed"],
            tol=input_hp_space["tol"],
            early_stopping=input_hp_space["early_stopping"],
            validation_fraction=input_hp_space["validation_fraction"],
            beta_1=input_hp_space["beta_1"],
            beta_2=input_hp_space["beta_2"],
            epsilon=input_hp_space["epsilon"],
            learning_rate=input_hp_space["learning_rate"],
            power_t=input_hp_space["power_t"],
            momentum=input_hp_space["momentum"],
            nesterovs_momentum=input_hp_space["nesterovs_momentum"],
            shuffle=input_hp_space["shuffle"],
            learning_rate_init=input_hp_space["learning_rate_init"]
        )

        return TrainableModel(estimator, hyperparameters_space=hp_space, supports_sample_weights=False)

    def actual_params(self, ret, clf, fit_params):
        amp = {"resolved": ret, "other": {}}
        params = clf.get_params()
        logger.info("Neural Network Params are %s " % params)

        ret["neural_network"] = {
            "layer_sizes": params["hidden_layer_sizes"]
        }
        return amp


register_classification_algorithm(NeuralNetworkClassification)


class XGBClassification(ClassicalPredictionAlgorithm):
    algorithm = "XGBOOST_CLASSIFICATION"

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        try:
            from dataiku.doctor.prediction.dku_xgboost import instantiate_xgb_classifier, expand_tree_method_for_xgboost
        except:
            logger.error("Failed to load xgboost package")
            traceback.print_exc()
            raise Exception("Failed to load XGBoost package")

        n_estimators = input_hp_space['n_estimators']
        if n_estimators <= 0:  # xgboost does not fail gracefully then
            raise Exception("The number of estimators must be a positive number")

        nthread = safe_positive_int(input_hp_space['nthread'])
        missing = input_hp_space['missing'] if input_hp_space['impute_missing'] else np.nan

        hp_space = HyperparametersSpace.from_definition(
            input_hp_space,
            hp_names_to_dimension_class={
                "max_depth": IntegerHyperparameterDimension,
                "learning_rate": FloatHyperparameterDimension,
                "gamma": FloatHyperparameterDimension,
                "min_child_weight": FloatHyperparameterDimension,
                "max_delta_step": FloatHyperparameterDimension,
                "subsample": FloatHyperparameterDimension,
                "colsample_bytree": FloatHyperparameterDimension,
                "colsample_bylevel": FloatHyperparameterDimension,
                "booster": CategoricalHyperparameterDimension
                # no grid for "objective" in classification
            },
            hp_names_to_dimension={
                "reg_alpha": FloatHyperparameterDimension(input_hp_space['alpha']),
                "reg_lambda": FloatHyperparameterDimension(input_hp_space['lambda'])
            }
        )

        gpu_config = get_gpu_config_from_core_params(core_params)

        estimator = instantiate_xgb_classifier(
            n_estimators=n_estimators,
            silent=0,
            n_jobs=nthread,
            scale_pos_weight=input_hp_space['scale_pos_weight'],
            base_score=input_hp_space['base_score'],
            random_state=input_hp_space['seed'],
            missing=missing,
            tree_method=expand_tree_method_for_xgboost(input_hp_space, gpu_config),
            class_weight=None
        )

        device = XGBOOSTGpuCapability.get_device(gpu_config)
        if device != "cpu":
            estimator.set_params(gpu_id=get_single_gpu_id_from_gpu_device(device))

        prediction_type = core_params[doctor_constants.PREDICTION_TYPE]
        is_causal = prediction_type == doctor_constants.CAUSAL_BINARY_CLASSIFICATION

        return XGBoostTrainableModel(
            estimator,
            hyperparameters_space=hp_space,
            is_early_stopping_enabled=(not is_causal and input_hp_space['enable_early_stopping']),
            early_stopping_rounds=input_hp_space['early_stopping_rounds'],
            evaluation_metric_name=modeling_params["metrics"]["evaluationMetric"],
            prediction_type=prediction_type
        )

    def actual_params(self, ret, clf, fit_params):
        amp = {"resolved": ret, "other": {}}
        params = clf.get_params()

        # We serialize np.nan missing parameter as None in actual params
        missing = None if pd.isna(params["missing"]) else params["missing"]

        logger.info("Selected XGBoost Params are %s " % params)
        safe_del(ret, "xgboost")
        ret["xgboost"] = {}
        ret["xgboost"]["max_depth"] = params["max_depth"]
        ret["xgboost"]["learning_rate"] = params["learning_rate"]
        ret["xgboost"]["n_estimators"] = params["n_estimators"]
        ret["xgboost"]["nthread"] = params["n_jobs"] if params["n_jobs"] > 0 else -1
        ret["xgboost"]["gamma"] = params["gamma"]
        ret["xgboost"]["min_child_weight"] = params["min_child_weight"]
        ret["xgboost"]["max_delta_step"] = params["max_delta_step"]
        ret["xgboost"]["subsample"] = params["subsample"]
        ret["xgboost"]["colsample_bytree"] = params["colsample_bytree"]
        ret["xgboost"]["colsample_bylevel"] = params["colsample_bylevel"]
        ret["xgboost"]["alpha"] = params["reg_alpha"]
        ret["xgboost"]["lambda"] = params["reg_lambda"]
        ret["xgboost"]["seed"] = params["random_state"]
        ret["xgboost"]["impute_missing"] = True if missing is not None else False
        ret["xgboost"]["missing"] = missing
        ret["xgboost"]["base_score"] = params["base_score"]
        ret["xgboost"]["scale_pos_weight"] = params["scale_pos_weight"]
        ret["xgboost"]["enable_early_stopping"] = fit_params.get('early_stopping_rounds') is not None
        ret["xgboost"]["early_stopping_rounds"] = fit_params.get('early_stopping_rounds')
        ret["xgboost"]["booster"] = params.get("booster")
        ret["xgboost"]["objective"] = params.get("objective").replace(":", "_")
        return amp


register_classification_algorithm(XGBClassification)

register_classification_algorithm(PluginPredictionAlgorithm)

register_classification_algorithm(LightGBMClassification)

register_classification_algorithm(DeepNeuralNetworkClassification)


def classification_fit_ensemble(modeling_params, core_params, data, target, sample_weight=None):
    """
    Returns (clf, actual_params, prepared_train_X, initial_intrinsic_perf_data)
    Extracts the best estimator for grid search ones
    """
    # To avoid circular imports
    from dataiku.doctor.prediction.ensembles import EnsembleRegressor

    logger.info("Fitting ensemble model")
    clf = EnsembleRegressor(modeling_params["ensemble_params"], core_params)
    clf = clf.fit(data, target.astype(int), sample_weight=sample_weight)

    initial_intrinsic_perf_data = {}
    actual_params = {"resolved": modeling_params}

    return clf, actual_params, data, initial_intrinsic_perf_data


def classification_fit(modeling_params, core_params, transformed_train, transformed_test=None, model_folder_context=None,
                       gridsearch_done_fn=None, target_map=None, with_sample_weight=False, with_class_weight=True,
                       calibration_method=None, calibration_ratio=None, calibrate_on_test=False, monotonic_cst=None):
    """
    Returns (clf, actual_params, prepared_train_X, initial_intrinsic_perf_data)
    Extracts the best estimator for grid search ones
    Note on calibration:
      - calibrate_on_test: calibrate_on_test=True requires a relevant transformed_test dataset. Thus, in the following
        cases, because the training is performed on the full dataset, we enforce calibrate_on_test=False:
          - the training of the final model with k-fold cross-test evaluation
          - the training of the model with the option "TRAIN_FULL_ONLY" in training recipes.
      - calibration_ratio is only relevant when calibrate_on_test=False. When calibrate_on_test=True, the calibration is
        performed on the full transformed_test dataset.
    """
    train_X = transformed_train["TRAIN"]
    column_labels = [s for s in train_X.columns()]
    train_y = transformed_train["target"].astype(int)
    train_X, is_sparse = prepare_multiframe(train_X, modeling_params)

    algorithm = modeling_params['algorithm']
    if algorithm not in CLASSIFICATION_ALGORITHMS.keys():
        raise Exception("Algorithm not available in Python: %s" % algorithm)
    algorithm = CLASSIFICATION_ALGORITHMS[algorithm]

    hyperparameter_search_runner = algorithm.get_search_runner(core_params, modeling_params, column_labels=column_labels,
                                                               model_folder_context=model_folder_context,
                                                               target_map=target_map,
                                                               unprocessed=transformed_train["UNPROCESSED"])

    if with_sample_weight:
        train_w = transformed_train["weight"]
    else:
        train_w = None

    if with_class_weight:
        class_weight_dict = get_class_weight_dict(train_y)
    else:
        class_weight_dict = None

    groups = get_groups_for_hp_search_cv(modeling_params, transformed_train)

    # grid searcher figures out whether or not the algorithm supports sample weights
    hyperparameter_search_runner.initialize_search_context(train_X, train_y,
                                                           groups=groups,
                                                           sample_weight=train_w,
                                                           class_weight=class_weight_dict,
                                                           monotonic_cst=monotonic_cst)
    clf = hyperparameter_search_runner.get_best_estimator()

    if gridsearch_done_fn:
        gridsearch_done_fn()

    # save a copy of train_X as prepared_X for the final output of classification_fit
    prepared_X = train_X[::]

    calibrate_proba = calibration_method is not None and calibration_method.upper() in [doctor_constants.SIGMOID, doctor_constants.ISOTONIC]
    if calibrate_proba:
        logger.info("Performing probabilities calibration")
        if calibrate_on_test:
            logger.info("Calibrating probabilities on the TEST set")
            calib_X = transformed_test["TRAIN"]
            calib_y = transformed_test["target"].astype(int)
            if with_sample_weight:
                calib_w = transformed_test["weight"].astype(float)
            else:
                calib_w = None
            calib_X = prepare_multiframe_as_sparse_if_needed(calib_X, is_sparse)
        else:
            logger.info("Calibrating probabilities on a subset of the TRAIN set")
            # For calibrated models, if the user chooses not to calibrate on the test set,
            # train_X will be a 80% split of the original train_X (the remaining 20%
            # is used to compute the calibration parameters)
            # Note that the 80/20 split can be modified by the user
            if calibration_ratio is None:
                calibration_ratio = doctor_constants.DEFAULT_CALIBRATION_DATA_RATIO
            train_ratio = 1 - calibration_ratio
            if with_sample_weight:
                train_X, calib_X, train_y, calib_y, train_w, calib_w = train_test_split(train_X, train_y, train_w, train_size=train_ratio, random_state=1234)
            else:
                train_X, calib_X, train_y, calib_y = train_test_split(train_X, train_y, train_size=train_ratio, random_state=1234)
                calib_w = None

    dump_pretrain_info(clf, train_X, train_y, train_w, (calibrate_proba and not calibrate_on_test))

    final_fit_parameters = hyperparameter_search_runner.get_final_fit_parameters(sample_weight=train_w)
    clf.fit(train_X, train_y, **final_fit_parameters)

    if calibrate_proba:
        calibrated_clf = dku_calibrated_classifier_cv(clf, cv="prefit", method=calibration_method.lower())
        calibrated_clf.fit(calib_X, calib_y, sample_weight=calib_w)
        clf = calibrated_clf

    initial_intrinsic_perf_data = get_initial_intrinsic_perf_data(train_X, is_sparse)
    if not hyperparameter_search_runner.search_skipped():
        initial_intrinsic_perf_data.update(hyperparameter_search_runner.get_score_info())

    # get_actual_params performs the translation sklearn params (after refit) (e.g. n_estimators)
    # to DSS(raw) params (e.g rf_n_estimators)
    if calibrate_proba:
        actual_params = algorithm.get_actual_params(modeling_params, get_base_estimator(clf), final_fit_parameters)
    else:
        actual_params = algorithm.get_actual_params(modeling_params, clf, final_fit_parameters)
    logger.info("Output params are %s" % actual_params)

    return clf, actual_params, prepared_X, initial_intrinsic_perf_data
