from dataiku.doctor.prediction.common import CategoricalHyperparameterDimension
from dataiku.doctor.prediction.common import FloatHyperparameterDimension
from dataiku.doctor.prediction.common import HyperparametersSpace
from dataiku.doctor.prediction.common import IntegerHyperparameterDimension
from dataiku.doctor.prediction.common import ClassicalPredictionAlgorithm
from dataiku.doctor.prediction.common import safe_positive_int
from dataiku.doctor.prediction.lightgbm_trainable_model import LightGBMTrainableModel
from dataiku.doctor.utils import doctor_constants


class LightGBMHyperparametersSpace(HyperparametersSpace):

    def __init__(self, space_definition, use_bagging, subsample, subsample_freq):
        super(LightGBMHyperparametersSpace, self).__init__(space_definition)

        self.use_bagging = use_bagging
        self.subsample = subsample
        self.subsample_freq = subsample_freq

    def enrich_hyperparam_point(self, point):
        # Gradient-based One-Side sampling (goss) is not compatible with
        # bagging. We only include those parameters when the boosting type
        # is Gradient Boosting Decision Trees (gbdt).
        if self.use_bagging and point["boosting_type"] == "gbdt":
            point["subsample"] = self.subsample
            point["subsample_freq"] = self.subsample_freq

        return point


class _LightGBMPredictionAlgorithm(ClassicalPredictionAlgorithm):

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        hp_space = self._create_hp_space(input_hp_space)
        estimator = self._create_estimator(input_hp_space)

        prediction_type = core_params[doctor_constants.PREDICTION_TYPE]
        can_have_early_stopping = prediction_type not in doctor_constants.CAUSAL_PREDICTION_TYPES

        return LightGBMTrainableModel(
            estimator=estimator,
            hyperparameters_space=hp_space,
            is_early_stopping_enabled=can_have_early_stopping and input_hp_space["early_stopping"],
            early_stopping_rounds=input_hp_space["early_stopping_rounds"],
            evaluation_metric_name=modeling_params["metrics"]["evaluationMetric"],
            prediction_type=core_params[doctor_constants.PREDICTION_TYPE]
        )

    def _create_estimator(self, input_hp_space):
        raise NotImplementedError("must be implemented in subclasses")

    @staticmethod
    def _create_hp_space(input_hp_space):
        return LightGBMHyperparametersSpace.from_definition(
            input_hp_space,
            hp_names_to_dimension_class={
                "boosting_type": CategoricalHyperparameterDimension,
                "num_leaves": IntegerHyperparameterDimension,
                "learning_rate": FloatHyperparameterDimension,
                "n_estimators": IntegerHyperparameterDimension,
                "min_split_gain": FloatHyperparameterDimension,
                "min_child_weight": FloatHyperparameterDimension,
                "min_child_samples": IntegerHyperparameterDimension,
                "colsample_bytree": FloatHyperparameterDimension,
                "reg_alpha": FloatHyperparameterDimension,
                "reg_lambda": FloatHyperparameterDimension
            },
            constructor_args={
                "use_bagging": input_hp_space["use_bagging"],
                "subsample": input_hp_space["subsample"],
                "subsample_freq": input_hp_space["subsample_freq"]
            }
        )

    def actual_params(self, output, estimator, fit_parameters):
        model_params = estimator.get_params()

        output["lightgbm"] = {
            "boosting_type": model_params["boosting_type"],
            "num_leaves": model_params["num_leaves"],
            "max_depth": model_params["max_depth"],
            "learning_rate": model_params["learning_rate"],
            "n_estimators": model_params["n_estimators"],
            "subsample_for_bin": model_params["subsample_for_bin"],
            "objective": model_params["objective"],
            "min_split_gain": model_params["min_split_gain"],
            "min_child_weight": model_params["min_child_weight"],
            "min_child_samples": model_params["min_child_samples"],
            "subsample": model_params["subsample"],
            "subsample_freq": model_params["subsample_freq"],
            "colsample_bytree": model_params["colsample_bytree"],
            "reg_alpha": model_params["reg_alpha"],
            "reg_lambda": model_params["reg_lambda"],
            "random_state": model_params["random_state"],
            "n_jobs": model_params["n_jobs"],
            "importance_type": model_params["importance_type"],
            "early_stopping": "early_stopping_rounds" in fit_parameters,
            "early_stopping_rounds": fit_parameters.get("early_stopping_rounds"),
        }

        return {
            "resolved": output,
            "other": {},
        }


class LightGBMClassification(_LightGBMPredictionAlgorithm):
    algorithm = "LIGHTGBM_CLASSIFICATION"

    def _create_estimator(self, input_hp_space):
        from lightgbm import LGBMClassifier

        n_jobs = safe_positive_int(input_hp_space["n_jobs"])
        random_state = input_hp_space["random_state"]
        max_depth = input_hp_space["max_depth"]

        # first_metric_only is relevant only in the context of early stopping.
        # Setting this parameter to True means that LightGBM will only use the
        # metric we provide for evaluation and not the default metric for the
        # prediction task.
        first_metric_only = True

        return LGBMClassifier(
            n_jobs=n_jobs,
            random_state=random_state,
            max_depth=max_depth,
            first_metric_only=first_metric_only,
            importance_type="gain"
        )


class LightGBMRegression(_LightGBMPredictionAlgorithm):
    algorithm = "LIGHTGBM_REGRESSION"

    def _create_estimator(self, input_hp_space):
        from lightgbm import LGBMRegressor

        n_jobs = safe_positive_int(input_hp_space["n_jobs"])
        random_state = input_hp_space["random_state"]
        max_depth = input_hp_space["max_depth"]

        # first_metric_only is relevant only in the context of early stopping.
        # Setting this parameter to True means that LightGBM will only use the
        # metric we provide for evaluation and not the default metric for the
        # prediction task.
        first_metric_only = True

        return LGBMRegressor(
            n_jobs=n_jobs,
            random_state=random_state,
            max_depth=max_depth,
            first_metric_only=first_metric_only,
            importance_type="gain"
        )
