import logging

from dataiku.core import doctor_constants
from dataiku.doctor.crossval.search_runner import TimeseriesForecastingSearchRunner
from dataiku.doctor.prediction.common import get_input_hyperparameter_space, ETSHyperparametersSpace
from dataiku.doctor.prediction.common import TrainableModel
from dataiku.doctor.prediction.common import TabularPredictionAlgorithm
from dataiku.doctor.prediction.common import HyperparametersSpace
from dataiku.doctor.prediction.common import GluonTSTransformerHyperparametersSpace
from dataiku.doctor.prediction.common import SeasonalTrendLoessHyperparametersSpace
from dataiku.doctor.prediction.common import IntegerHyperparameterDimension
from dataiku.doctor.prediction.common import OddIntegerHyperparameterDimension
from dataiku.doctor.prediction.common import FloatHyperparameterDimension
from dataiku.doctor.prediction.common import CategoricalHyperparameterDimension
from dataiku.doctor.prediction.custom_scoring import get_custom_evaluation_metric
from dataiku.doctor.prediction.metric import TIMESERIES_METRICS_NAME_TO_FIELD_NAME
from dataiku.doctor.prediction.metric import CUSTOM
from dataiku.doctor.timeseries.preparation.resampling.utils import get_frequency, get_monthly_day_alignment
from dataiku.doctor.timeseries.utils import prefix_custom_metric_name
from dataiku.core.doctor_constants import TIMESERIES_FORECAST
from dataiku.doctor.utils.gpu_execution import get_gpu_config_from_core_params, GluonTSMXNetGPUCapability, get_default_gpu_params, GluonTSTorchGPUCapability

logger = logging.getLogger(__name__)


class TimeseriesForecastingAlgorithm(TabularPredictionAlgorithm):
    class ModelRefit:
        MANDATORY = "MANDATORY"
        SUPPORTED = "SUPPORTED"
        NOT_SUPPORTED = "NOT_SUPPORTED"

    REFIT_FOR_SCORING = ModelRefit.NOT_SUPPORTED
    USE_GLUON_TS = False
    SUPPORTS_EXTERNAL_FEATURES = False
    SUPPORTS_MODEL_COEFFICIENTS = False
    SUPPORTS_INFORMATION_CRITERIA = False
    # SUPPORTS_EXTERNAL_FEATURES to be synced with:
    # platypus -> static/dataiku/js/analysis/prediction/timeseries/settings.js L.1678
    # java -> com/dataiku/dip/analysis/model/prediction/TimeseriesForecastingModelDetails.java L.43

    @staticmethod
    def build(algorithm):
        return TIMESERIES_FORECASTING_ALGORITHMS_MAP[algorithm]

    def get_min_size_for_scoring(self, modeling_params, prediction_length):
        raise NotImplementedError

    def get_min_size_for_training(self, modeling_params, prediction_length):
        raise NotImplementedError

    def get_max_used_timesteps_for_scoring(self, modeling_params, prediction_length, frequency):
        return self.get_min_size_for_scoring(modeling_params, prediction_length)

    def get_search_runner(self, split_handler, core_params, modeling_params, model_folder_context=None):
        logger.info(
            "Create CLF from params: {} for algorithm {}".format(modeling_params, modeling_params["algorithm"])
        )

        input_hp_space = get_input_hyperparameter_space(modeling_params, modeling_params["algorithm"])
        trainable_model = self.model_from_params(input_hp_space, modeling_params, core_params)

        hyperparameter_search_params = modeling_params.get("grid_search_params", {})
        trainable_model.hyperparameters_space.set_random_state(hyperparameter_search_params.get("seed", 0))
        
        search_settings = self.get_search_settings(hyperparameter_search_params, trainable_model)

        metrics_params = modeling_params["metrics"]
        evaluation_metric = metrics_params["evaluationMetric"]
        metric_sign = -1  # timeseries default metrics are only errors (i.e. lower is better)
        if evaluation_metric == CUSTOM.name:
            evaluation_metric_name = prefix_custom_metric_name(metrics_params["customEvaluationMetricName"])
            metric_sign = 1 if get_custom_evaluation_metric(metrics_params)["greaterIsBetter"] else -1
        else:
            evaluation_metric_name = TIMESERIES_METRICS_NAME_TO_FIELD_NAME[evaluation_metric].get_field(prediction_type=TIMESERIES_FORECAST)

        return TimeseriesForecastingSearchRunner(
            trainable_model=trainable_model,
            search_settings=search_settings,
            split_handler=split_handler,
            fit_before_predict=self.should_fit_before_predict(),
            min_timeseries_size_for_training=self.get_min_size_for_training(modeling_params, core_params[doctor_constants.PREDICTION_LENGTH]),
            model_folder_context=model_folder_context,
            evaluation_metric=evaluation_metric_name,
            metric_sign=metric_sign
        )

    def should_fit_before_predict(self, force=False):
        if self.REFIT_FOR_SCORING == TimeseriesForecastingAlgorithm.ModelRefit.MANDATORY:
            logger.info("Fit before predict mandatory for algorithm: %s", self.__class__.__name__)
            return True
        elif self.REFIT_FOR_SCORING == TimeseriesForecastingAlgorithm.ModelRefit.SUPPORTED and force:
            logger.info("Fit before predict requested for algorithm: %s", self.__class__.__name__)
            return True
        return False


class GluonTSTrivialIdentity(TimeseriesForecastingAlgorithm):
    USE_GLUON_TS = True

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        from dataiku.doctor.timeseries.models.gluonts.trivial_identity import DkuTrivialIdentityEstimator
        estimator = DkuTrivialIdentityEstimator(
            frequency=get_frequency(core_params),
            prediction_length=core_params[doctor_constants.PREDICTION_LENGTH],
            time_variable=core_params[doctor_constants.TIME_VARIABLE],
            target_variable=core_params[doctor_constants.TARGET_VARIABLE],
            timeseries_identifiers=core_params[doctor_constants.TIMESERIES_IDENTIFIER_COLUMNS],
            monthly_day_alignment=get_monthly_day_alignment(core_params),
        )

        hyperparameters_space = HyperparametersSpace.from_definition(input_hp_space)

        return TrainableModel(estimator, hyperparameters_space=hyperparameters_space)

    def actual_params(self, ret, model, fit_params):
        ret["trivial_identity_timeseries_params"] = {}
        amp = {"resolved": ret}
        return amp

    def get_min_size_for_scoring(self, modeling_params, prediction_length):
        return prediction_length

    def get_min_size_for_training(self, modeling_params, prediction_length):
        return prediction_length


class GluonTSSeasonalNaive(TimeseriesForecastingAlgorithm):
    USE_GLUON_TS = True

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        from dataiku.doctor.timeseries.models.gluonts.seasonal_naive import DkuSeasonalNaiveSEstimator
        estimator = DkuSeasonalNaiveSEstimator(
            frequency=get_frequency(core_params),
            prediction_length=core_params[doctor_constants.PREDICTION_LENGTH],
            time_variable=core_params[doctor_constants.TIME_VARIABLE],
            target_variable=core_params[doctor_constants.TARGET_VARIABLE],
            timeseries_identifiers=core_params[doctor_constants.TIMESERIES_IDENTIFIER_COLUMNS],
            monthly_day_alignment=get_monthly_day_alignment(core_params),
        )

        hyperparameters_space = HyperparametersSpace.from_definition(
            input_hp_space,
            hp_names_to_dimension_class={
                "season_length": IntegerHyperparameterDimension,
            },
        )

        return TrainableModel(estimator, hyperparameters_space=hyperparameters_space)

    def actual_params(self, ret, model, fit_params):
        params = model.get_params()
        ret["seasonal_naive_timeseries_params"] = {"season_length": params["season_length"]}
        amp = {"resolved": ret}
        return amp

    def get_min_size_for_scoring(self, modeling_params, prediction_length):
        return max(prediction_length, modeling_params["seasonal_naive_timeseries_params"]["season_length"])

    def get_min_size_for_training(self, modeling_params, prediction_length):
        max_season_length = get_parameter_max_value(modeling_params, "season_length")
        return max(prediction_length, max_season_length)


class GluonTSNPTSForecaster(TimeseriesForecastingAlgorithm):
    USE_GLUON_TS = True
    SUPPORTS_EXTERNAL_FEATURES = True

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        from dataiku.doctor.timeseries.models.gluonts.npts import DkuNPTSEstimator
        estimator = DkuNPTSEstimator(
            frequency=get_frequency(core_params),
            prediction_length=core_params[doctor_constants.PREDICTION_LENGTH],
            time_variable=core_params[doctor_constants.TIME_VARIABLE],
            target_variable=core_params[doctor_constants.TARGET_VARIABLE],
            timeseries_identifiers=core_params[doctor_constants.TIMESERIES_IDENTIFIER_COLUMNS],
            full_context=input_hp_space["full_context"],
            use_seasonal_model=input_hp_space["use_seasonal_model"],
            use_default_time_features=input_hp_space["use_default_time_features"],
            seed=input_hp_space["seed"],
            monthly_day_alignment=get_monthly_day_alignment(core_params),
        )

        hp_names_to_dimension_class = {
            "kernel_type": CategoricalHyperparameterDimension,
            "exp_kernel_weights": FloatHyperparameterDimension,
            "feature_scale": FloatHyperparameterDimension,
        }

        if not input_hp_space["full_context"]:
            hp_names_to_dimension_class["context_length"] = IntegerHyperparameterDimension

        hyperparameters_space = HyperparametersSpace.from_definition(
            input_hp_space,
            hp_names_to_dimension_class=hp_names_to_dimension_class,
        )

        return TrainableModel(estimator, hyperparameters_space=hyperparameters_space)

    def actual_params(self, ret, model, fit_params):
        params = model.get_params()
        ret["gluonts_npts_timeseries_params"] = {
            "context_length": params["context_length"],
            "kernel_type": params["kernel_type"],
            "exp_kernel_weights": params["exp_kernel_weights"],
            "use_seasonal_model": params["use_seasonal_model"],
            "use_default_time_features": params["use_default_time_features"],
            "feature_scale": params["feature_scale"],
            "seed": params["seed"],
        }
        amp = {"resolved": ret}
        return amp

    def get_min_size_for_scoring(self, modeling_params, prediction_length):
        return max(prediction_length, 10)  # 10 is a "reasonable" arbitrary value

    def get_min_size_for_training(self, modeling_params, prediction_length):
        return max(prediction_length, 10)  # 10 is a "reasonable" arbitrary value

    def get_max_used_timesteps_for_scoring(self, modeling_params, prediction_length, frequency):
        return get_used_context_length(modeling_params, "gluonts_npts_timeseries_params", 1100)


class GluonTSTorchSimpleFeedForward(TimeseriesForecastingAlgorithm):
    USE_GLUON_TS = True

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        gpu_config = get_gpu_config_from_core_params(core_params)
        # visible devices need to be set before importing, as importing certain models from gluonts leads to a cuda process being initialised
        if not GluonTSTorchGPUCapability.should_use_gpu(gpu_config):
            # as of v0.15, gluonts torch models will automatically use gpu for the predictor, even if the trainer is built on cpu
            # https://github.com/awslabs/gluonts/issues/3208
            GluonTSTorchGPUCapability.disable_all_cuda_devices()
            # When the torch version is incorrect, the `is_gpu_available` check fails the first time, but passes on subsequent times
            # therefor, disable gpu here to prevent rechecking
            gpu_config["params"]["useGpu"] = False

        from dataiku.doctor.timeseries.models.gluonts.torch.simple_feed_forward import DkuSimpleFeedForwardEstimator

        estimator = DkuSimpleFeedForwardEstimator(
            frequency=get_frequency(core_params),
            prediction_length=core_params[doctor_constants.PREDICTION_LENGTH],
            time_variable=core_params[doctor_constants.TIME_VARIABLE],
            target_variable=core_params[doctor_constants.TARGET_VARIABLE],
            timeseries_identifiers=core_params[doctor_constants.TIMESERIES_IDENTIFIER_COLUMNS],
            full_context=input_hp_space["full_context"],
            num_hidden_dimensions=input_hp_space["num_hidden_dimensions"],
            batch_size=input_hp_space["batch_size"],
            epochs=input_hp_space["epochs"],
            auto_num_batches_per_epoch=input_hp_space["auto_num_batches_per_epoch"],
            num_batches_per_epoch=input_hp_space["num_batches_per_epoch"],
            gpu_config=gpu_config,
            seed=input_hp_space["seed"],
            weight_decay=input_hp_space["weight_decay"],
            monthly_day_alignment=get_monthly_day_alignment(core_params),
        )

        hp_names_to_dimension_class = {
            "distr_output": CategoricalHyperparameterDimension,
            "batch_normalization": CategoricalHyperparameterDimension,
            "learning_rate": FloatHyperparameterDimension,
            "weight_decay": FloatHyperparameterDimension
        }

        if not input_hp_space["full_context"]:
            hp_names_to_dimension_class["context_length"] = IntegerHyperparameterDimension

        hyperparameters_space = HyperparametersSpace.from_definition(
            input_hp_space,
            hp_names_to_dimension_class=hp_names_to_dimension_class,
        )

        return TrainableModel(estimator, hyperparameters_space=hyperparameters_space)

    def actual_params(self, ret, model, fit_params):
        params = model.get_params()
        ret["gluonts_torch_simple_feedforward_timeseries_params"] = {
            "context_length": params["context_length"],
            "distr_output": params["distr_output"],
            "batch_normalization": params["batch_normalization"],
            "num_hidden_dimensions": params["num_hidden_dimensions"],
            "learning_rate": params["learning_rate"],
            "batch_size": params["batch_size"],
            "epochs": params["epochs"],
            "num_batches_per_epoch": params["num_batches_per_epoch"],
            "seed": params["seed"],
            "weight_decay": params["weight_decay"]
        }
        amp = {"resolved": ret}
        return amp

    def get_min_size_for_scoring(self, modeling_params, prediction_length):
        return get_used_context_length(modeling_params, "gluonts_torch_simple_feedforward_timeseries_params", prediction_length)

    def get_min_size_for_training(self, modeling_params, prediction_length):
        max_context_length = get_max_context_length(modeling_params, prediction_length)
        return max_context_length + prediction_length

class GluonTSTorchDeepAR(TimeseriesForecastingAlgorithm):
    USE_GLUON_TS = True

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        gpu_config = get_gpu_config_from_core_params(core_params)
        # visible devices need to be set before importing, as importing certain models from gluonts leads to a cuda process being initialised
        if not GluonTSTorchGPUCapability.should_use_gpu(gpu_config):
            # as of v0.15, gluonts torch models will automatically use gpu for the predictor, even if the trainer is built on cpu
            # https://github.com/awslabs/gluonts/issues/3208
            GluonTSTorchGPUCapability.disable_all_cuda_devices()
            # When the torch version is incorrect, the `is_gpu_available` check fails the first time, but passes on subsequent times
            # therefor, disable gpu here to prevent rechecking
            gpu_config["params"]["useGpu"] = False

        from dataiku.doctor.timeseries.models.gluonts.torch.deepar import DkuDeepAREstimator
        estimator = DkuDeepAREstimator(
            frequency=get_frequency(core_params),
            prediction_length=core_params[doctor_constants.PREDICTION_LENGTH],
            time_variable=core_params[doctor_constants.TIME_VARIABLE],
            target_variable=core_params[doctor_constants.TARGET_VARIABLE],
            timeseries_identifiers=core_params[doctor_constants.TIMESERIES_IDENTIFIER_COLUMNS],
            use_timeseries_identifiers_as_features=input_hp_space.get("use_timeseries_identifiers_as_features", False),
            full_context=input_hp_space["full_context"],
            scaling=input_hp_space["scaling"],
            num_parallel_samples=input_hp_space["num_parallel_samples"],
            batch_size=input_hp_space["batch_size"],
            epochs=input_hp_space["epochs"],
            auto_num_batches_per_epoch=input_hp_space["auto_num_batches_per_epoch"],
            num_batches_per_epoch=input_hp_space["num_batches_per_epoch"],
            gpu_config=gpu_config,
            seed=input_hp_space["seed"],
            weight_decay=input_hp_space["weight_decay"],
            patience=input_hp_space["patience"],
            monthly_day_alignment=get_monthly_day_alignment(core_params),
        )

        hp_names_to_dimension_class = {
            "learning_rate": FloatHyperparameterDimension,
            "num_layers": IntegerHyperparameterDimension,
            "num_cells": IntegerHyperparameterDimension,
            "dropout_rate": FloatHyperparameterDimension,
            "distr_output": CategoricalHyperparameterDimension,
            "weight_decay": FloatHyperparameterDimension
        }

        if not input_hp_space["full_context"]:
            hp_names_to_dimension_class["context_length"] = IntegerHyperparameterDimension

        hyperparameters_space = HyperparametersSpace.from_definition(
            input_hp_space,
            hp_names_to_dimension_class=hp_names_to_dimension_class,
        )

        return TrainableModel(estimator, hyperparameters_space=hyperparameters_space)

    def actual_params(self, ret, model, fit_params):
        params = model.get_params()

        ret["gluonts_torch_deepar_timeseries_params"] = {
            "context_length": params["context_length"],
            "distr_output": params["distr_output"],
            "use_timeseries_identifiers_as_features": params.get("use_timeseries_identifiers_as_features", False),
            "num_layers": params["num_layers"],
            "num_cells": params["num_cells"],
            "dropout_rate": params["dropout_rate"],
            "scaling": params["scaling"],
            "num_parallel_samples": params["num_parallel_samples"],
            "learning_rate": params["learning_rate"],
            "batch_size": params["batch_size"],
            "epochs": params["epochs"],
            "num_batches_per_epoch": params["num_batches_per_epoch"],
            "seed": params["seed"],
            "weight_decay": params["weight_decay"],
            "patience": params["patience"]
        }
        amp = {"resolved": ret}
        return amp

    def get_min_size_for_scoring(self, modeling_params, prediction_length):
        return get_used_context_length(modeling_params, "gluonts_torch_deepar_timeseries_params", prediction_length)

    def get_min_size_for_training(self, modeling_params, prediction_length):
        max_context_length = get_max_context_length(modeling_params, prediction_length)
        return max_context_length + prediction_length

    def get_max_used_timesteps_for_scoring(self, modeling_params, prediction_length, frequency):
        # importing here so that forecasting ML task won't always fail if gluonts is not installed
        from gluonts.time_feature import get_lags_for_frequency
        return self.get_min_size_for_scoring(modeling_params, prediction_length) + max(get_lags_for_frequency(freq_str=frequency))


class GluonTSSimpleFeedForward(TimeseriesForecastingAlgorithm):
    USE_GLUON_TS = True

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        gpu_config = get_gpu_config_from_core_params(core_params)
        gpu_list = gpu_config.get("params", get_default_gpu_params()).get("gpuList")
        GluonTSMXNetGPUCapability.init_cuda_visible_devices(gpu_list)
        # visible devices need to be set before importing, as importing certain models from gluonts leads to a cuda process being initialised
        from dataiku.doctor.timeseries.models.gluonts.mxnet.simple_feed_forward import DkuSimpleFeedForwardEstimator

        estimator = DkuSimpleFeedForwardEstimator(
            frequency=get_frequency(core_params),
            prediction_length=core_params[doctor_constants.PREDICTION_LENGTH],
            time_variable=core_params[doctor_constants.TIME_VARIABLE],
            target_variable=core_params[doctor_constants.TARGET_VARIABLE],
            timeseries_identifiers=core_params[doctor_constants.TIMESERIES_IDENTIFIER_COLUMNS],
            full_context=input_hp_space["full_context"],
            num_hidden_dimensions=input_hp_space["num_hidden_dimensions"],
            num_parallel_samples=input_hp_space["num_parallel_samples"],
            batch_size=input_hp_space["batch_size"],
            epochs=input_hp_space["epochs"],
            auto_num_batches_per_epoch=input_hp_space["auto_num_batches_per_epoch"],
            num_batches_per_epoch=input_hp_space["num_batches_per_epoch"],
            gpu_config=gpu_config,
            seed=input_hp_space["seed"],
            monthly_day_alignment=get_monthly_day_alignment(core_params),
        )

        hp_names_to_dimension_class = {
            "distr_output": CategoricalHyperparameterDimension,
            "batch_normalization": CategoricalHyperparameterDimension,
            "mean_scaling": CategoricalHyperparameterDimension,
            "learning_rate": FloatHyperparameterDimension,
        }

        if not input_hp_space["full_context"]:
            hp_names_to_dimension_class["context_length"] = IntegerHyperparameterDimension

        hyperparameters_space = HyperparametersSpace.from_definition(
            input_hp_space,
            hp_names_to_dimension_class=hp_names_to_dimension_class,
        )

        return TrainableModel(estimator, hyperparameters_space=hyperparameters_space)

    def actual_params(self, ret, model, fit_params):
        params = model.get_params()
        ret["gluonts_simple_feedforward_timeseries_params"] = {
            "context_length": params["context_length"],
            "distr_output": params["distr_output"],
            "batch_normalization": params["batch_normalization"],
            "mean_scaling": params["mean_scaling"],
            "num_hidden_dimensions": params["num_hidden_dimensions"],
            "num_parallel_samples": params["num_parallel_samples"],
            "learning_rate": params["learning_rate"],
            "batch_size": params["batch_size"],
            "epochs": params["epochs"],
            "num_batches_per_epoch": params["num_batches_per_epoch"],
            "seed": params["seed"],
        }
        amp = {"resolved": ret}
        return amp

    def get_min_size_for_scoring(self, modeling_params, prediction_length):
        return get_used_context_length(modeling_params, "gluonts_simple_feedforward_timeseries_params", prediction_length)

    def get_min_size_for_training(self, modeling_params, prediction_length):
        max_context_length = get_max_context_length(modeling_params, prediction_length)
        return max_context_length + prediction_length


class GluonTSDeepAR(TimeseriesForecastingAlgorithm):
    USE_GLUON_TS = True
    SUPPORTS_EXTERNAL_FEATURES = True

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        gpu_config = get_gpu_config_from_core_params(core_params)
        gpu_list = gpu_config.get("params", get_default_gpu_params()).get("gpuList")
        GluonTSMXNetGPUCapability.init_cuda_visible_devices(gpu_list)
        # visible devices need to be set before importing, as importing DeepAREstimator from gluonts leads to a cuda process being initialised
        from dataiku.doctor.timeseries.models.gluonts.mxnet.deepar import DkuDeepAREstimator

        estimator = DkuDeepAREstimator(
            frequency=get_frequency(core_params),
            prediction_length=core_params[doctor_constants.PREDICTION_LENGTH],
            time_variable=core_params[doctor_constants.TIME_VARIABLE],
            target_variable=core_params[doctor_constants.TARGET_VARIABLE],
            timeseries_identifiers=core_params[doctor_constants.TIMESERIES_IDENTIFIER_COLUMNS],
            use_timeseries_identifiers_as_features=input_hp_space.get("use_timeseries_identifiers_as_features", False),
            full_context=input_hp_space["full_context"],
            scaling=input_hp_space["scaling"],
            num_parallel_samples=input_hp_space["num_parallel_samples"],
            minimum_scale=input_hp_space["minimum_scale"],
            batch_size=input_hp_space["batch_size"],
            epochs=input_hp_space["epochs"],
            auto_num_batches_per_epoch=input_hp_space["auto_num_batches_per_epoch"],
            num_batches_per_epoch=input_hp_space["num_batches_per_epoch"],
            gpu_config=gpu_config,
            seed=input_hp_space["seed"],
            monthly_day_alignment=get_monthly_day_alignment(core_params),
        )

        hp_names_to_dimension_class = {
            "learning_rate": FloatHyperparameterDimension,
            "num_layers": IntegerHyperparameterDimension,
            "num_cells": IntegerHyperparameterDimension,
            "cell_type": CategoricalHyperparameterDimension,
            "dropoutcell_type": CategoricalHyperparameterDimension,
            "dropout_rate": FloatHyperparameterDimension,
            "alpha": FloatHyperparameterDimension,
            "beta": FloatHyperparameterDimension,
            "distr_output": CategoricalHyperparameterDimension,
        }

        if not input_hp_space["full_context"]:
            hp_names_to_dimension_class["context_length"] = IntegerHyperparameterDimension

        hyperparameters_space = HyperparametersSpace.from_definition(
            input_hp_space,
            hp_names_to_dimension_class=hp_names_to_dimension_class,
        )

        return TrainableModel(estimator, hyperparameters_space=hyperparameters_space)

    def actual_params(self, ret, model, fit_params):
        params = model.get_params()

        ret["gluonts_deepar_timeseries_params"] = {
            "context_length": params["context_length"],
            "distr_output": params["distr_output"],
            "use_timeseries_identifiers_as_features": params.get("use_timeseries_identifiers_as_features", False),
            "num_layers": params["num_layers"],
            "num_cells": params["num_cells"],
            "cell_type": params["cell_type"],
            "dropoutcell_type": params["dropoutcell_type"],
            "dropout_rate": params["dropout_rate"],
            "alpha": params["alpha"],
            "beta": params["beta"],
            "scaling": params["scaling"],
            "num_parallel_samples": params["num_parallel_samples"],
            "minimum_scale": params["minimum_scale"],
            "learning_rate": params["learning_rate"],
            "batch_size": params["batch_size"],
            "epochs": params["epochs"],
            "num_batches_per_epoch": params["num_batches_per_epoch"],
            "seed": params["seed"],
        }
        amp = {"resolved": ret}
        return amp

    def get_min_size_for_scoring(self, modeling_params, prediction_length):
        return get_used_context_length(modeling_params, "gluonts_deepar_timeseries_params", prediction_length)

    def get_min_size_for_training(self, modeling_params, prediction_length):
        max_context_length = get_max_context_length(modeling_params, prediction_length)
        return max_context_length + prediction_length

    def get_max_used_timesteps_for_scoring(self, modeling_params, prediction_length, frequency):
        # importing here so that forecasting ML task won't always fail if gluonts is not installed
        from gluonts.time_feature import get_lags_for_frequency
        return self.get_min_size_for_scoring(modeling_params, prediction_length) + max(get_lags_for_frequency(freq_str=frequency))
        

class GluonTSTransformer(TimeseriesForecastingAlgorithm):
    USE_GLUON_TS = True
    SUPPORTS_EXTERNAL_FEATURES = True

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        gpu_config = get_gpu_config_from_core_params(core_params)
        gpu_list = gpu_config.get("params", get_default_gpu_params()).get("gpuList")
        GluonTSMXNetGPUCapability.init_cuda_visible_devices(gpu_list)

        # visible devices need to be set before importing, as importing TransformerEstimator from gluonts leads to a cuda process being initialised
        from dataiku.doctor.timeseries.models.gluonts.mxnet.transformer import DkuTransformerEstimator

        estimator = DkuTransformerEstimator(
            frequency=get_frequency(core_params),
            prediction_length=core_params[doctor_constants.PREDICTION_LENGTH],
            time_variable=core_params[doctor_constants.TIME_VARIABLE],
            target_variable=core_params[doctor_constants.TARGET_VARIABLE],
            timeseries_identifiers=core_params[doctor_constants.TIMESERIES_IDENTIFIER_COLUMNS],
            use_timeseries_identifiers_as_features=input_hp_space.get("use_timeseries_identifiers_as_features", False),
            full_context=input_hp_space["full_context"],
            num_parallel_samples=input_hp_space["num_parallel_samples"],
            batch_size=input_hp_space["batch_size"],
            epochs=input_hp_space["epochs"],
            auto_num_batches_per_epoch=input_hp_space["auto_num_batches_per_epoch"],
            num_batches_per_epoch=input_hp_space["num_batches_per_epoch"],
            gpu_config=gpu_config,
            seed=input_hp_space["seed"],
            monthly_day_alignment=get_monthly_day_alignment(core_params),
        )

        hp_names_to_dimension_class = {
            "learning_rate": FloatHyperparameterDimension,
            "model_dim": IntegerHyperparameterDimension,
            "inner_ff_dim_scale": IntegerHyperparameterDimension,
            "num_heads": IntegerHyperparameterDimension,
            "dropout_rate": FloatHyperparameterDimension,
            "distr_output": CategoricalHyperparameterDimension,
        }

        if not input_hp_space["full_context"]:
            hp_names_to_dimension_class["context_length"] = IntegerHyperparameterDimension

        hyperparameters_space = GluonTSTransformerHyperparametersSpace.from_definition(
            input_hp_space,
            hp_names_to_dimension_class=hp_names_to_dimension_class,
        )

        return TrainableModel(estimator, hyperparameters_space=hyperparameters_space)

    def actual_params(self, ret, model, fit_params):
        params = model.get_params()

        ret["gluonts_transformer_timeseries_params"] = {
            "context_length": params["context_length"],
            "distr_output": params["distr_output"],
            "use_timeseries_identifiers_as_features": params.get("use_timeseries_identifiers_as_features", False),
            "model_dim": params["model_dim"],
            "inner_ff_dim_scale": params["inner_ff_dim_scale"],
            "num_heads": params["num_heads"],
            "dropout_rate": params["dropout_rate"],
            "num_parallel_samples": params["num_parallel_samples"],
            "learning_rate": params["learning_rate"],
            "batch_size": params["batch_size"],
            "epochs": params["epochs"],
            "num_batches_per_epoch": params["num_batches_per_epoch"],
            "seed": params["seed"],
        }
        amp = {"resolved": ret}
        return amp

    def get_min_size_for_scoring(self, modeling_params, prediction_length):
        return get_used_context_length(modeling_params, "gluonts_transformer_timeseries_params", prediction_length)

    def get_min_size_for_training(self, modeling_params, prediction_length):
        max_context_length = get_max_context_length(modeling_params, prediction_length)
        return max_context_length + prediction_length

    def get_max_used_timesteps_for_scoring(self, modeling_params, prediction_length, frequency):
        # importing here so that forecasting ML task won't always fail if gluonts is not installed
        from gluonts.time_feature import get_lags_for_frequency
        return self.get_min_size_for_scoring(modeling_params, prediction_length) + max(get_lags_for_frequency(freq_str=frequency))


class GluonTSMQCNN(TimeseriesForecastingAlgorithm):
    USE_GLUON_TS = True
    SUPPORTS_EXTERNAL_FEATURES = True

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        gpu_config = get_gpu_config_from_core_params(core_params)
        gpu_list = gpu_config.get("params", get_default_gpu_params()).get("gpuList")
        GluonTSMXNetGPUCapability.init_cuda_visible_devices(gpu_list)

        # visible devices need to be set before importing, as importing MQCNNEstimator from gluonts leads to a cuda process being initialised
        from dataiku.doctor.timeseries.models.gluonts.mxnet.mq_cnn import DkuMQCNNEstimator

        estimator = DkuMQCNNEstimator(
            frequency=get_frequency(core_params),
            prediction_length=core_params[doctor_constants.PREDICTION_LENGTH],
            time_variable=core_params[doctor_constants.TIME_VARIABLE],
            target_variable=core_params[doctor_constants.TARGET_VARIABLE],
            timeseries_identifiers=core_params[doctor_constants.TIMESERIES_IDENTIFIER_COLUMNS],
            use_timeseries_identifiers_as_features=input_hp_space.get("use_timeseries_identifiers_as_features", False),
            quantiles=core_params[doctor_constants.QUANTILES],
            full_context=input_hp_space["full_context"],
            scaling=input_hp_space["scaling"],
            seed=input_hp_space["seed"],
            decoder_mlp_dim_seq=input_hp_space["decoder_mlp_dim_seq"],
            channels_seq=input_hp_space["channels_seq"],
            dilation_seq=input_hp_space["dilation_seq"],
            kernel_size_seq=input_hp_space["kernel_size_seq"],
            batch_size=input_hp_space["batch_size"],
            epochs=input_hp_space["epochs"],
            auto_num_batches_per_epoch=input_hp_space["auto_num_batches_per_epoch"],
            num_batches_per_epoch=input_hp_space["num_batches_per_epoch"],
            gpu_config=gpu_config,
            monthly_day_alignment=get_monthly_day_alignment(core_params),
        )

        hp_names_to_dimension_class = {
            "learning_rate": FloatHyperparameterDimension
        }

        if not input_hp_space["full_context"]:
            hp_names_to_dimension_class["context_length"] = IntegerHyperparameterDimension

        hyperparameters_space = HyperparametersSpace.from_definition(
            input_hp_space,
            hp_names_to_dimension_class=hp_names_to_dimension_class,
        )

        return TrainableModel(estimator, hyperparameters_space=hyperparameters_space)

    def actual_params(self, ret, model, fit_params):
        params = model.get_params()

        ret["gluonts_mqcnn_timeseries_params"] = {
            "context_length": params["context_length"],
            "use_timeseries_identifiers_as_features": params.get("use_timeseries_identifiers_as_features", False),
            "scaling": params["scaling"],
            "seed": params["seed"],
            "decoder_mlp_dim_seq": params["decoder_mlp_dim_seq"],
            "channels_seq": params["channels_seq"],
            "dilation_seq": params["dilation_seq"],
            "kernel_size_seq": params["kernel_size_seq"],
            "learning_rate": params["learning_rate"],
            "batch_size": params["batch_size"],
            "epochs": params["epochs"],
            "num_batches_per_epoch": params["num_batches_per_epoch"],
        }
        amp = {"resolved": ret}
        return amp

    def get_min_size_for_scoring(self, modeling_params, prediction_length):
        return get_used_context_length(modeling_params, "gluonts_mqcnn_timeseries_params", 4 * prediction_length)

    def get_min_size_for_training(self, modeling_params, prediction_length):
        max_context_length = get_max_context_length(modeling_params, 4 * prediction_length)
        return max_context_length + prediction_length


class AutoArima(TimeseriesForecastingAlgorithm):
    REFIT_FOR_SCORING = TimeseriesForecastingAlgorithm.ModelRefit.SUPPORTED
    SUPPORTS_EXTERNAL_FEATURES = True
    SUPPORTS_MODEL_COEFFICIENTS = True
    SUPPORTS_INFORMATION_CRITERIA = True

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        from dataiku.doctor.timeseries.models.statistical.auto_arima import DkuAutoArimaEstimator

        estimator = DkuAutoArimaEstimator(
            frequency=get_frequency(core_params),
            time_variable=core_params[doctor_constants.TIME_VARIABLE],
            prediction_length=core_params[doctor_constants.PREDICTION_LENGTH],
            target_variable=core_params[doctor_constants.TARGET_VARIABLE],
            timeseries_identifier_columns=core_params[doctor_constants.TIMESERIES_IDENTIFIER_COLUMNS],
            start_p=input_hp_space["start_p"],
            max_p=input_hp_space["max_p"],
            d_=input_hp_space.get("d"),
            max_d=input_hp_space.get("max_d"),
            start_q=input_hp_space["start_q"],
            max_q=input_hp_space["max_q"],
            start_P=input_hp_space["start_P"],
            max_P=input_hp_space["max_P"],
            D_=input_hp_space.get("D"),
            max_D=input_hp_space.get("max_D"),
            start_Q=input_hp_space["start_Q"],
            max_Q=input_hp_space["max_Q"],
            max_order=input_hp_space["max_order"],
            stationary=input_hp_space["stationary"],
            maxiter=input_hp_space["maxiter"],
            monthly_day_alignment=get_monthly_day_alignment(core_params),
        )

        hp_names_to_dimension_class = {
            "m": IntegerHyperparameterDimension,
            "information_criterion": CategoricalHyperparameterDimension,
            "test": CategoricalHyperparameterDimension,
            "method": CategoricalHyperparameterDimension,
        }

        if get_parameter_max_value(modeling_params, "m") > 1:
            # seasonal_test should not be a dimension when m is only 1
            hp_names_to_dimension_class["seasonal_test"] = CategoricalHyperparameterDimension

        hyperparameters_space = HyperparametersSpace.from_definition(
            input_hp_space,
            hp_names_to_dimension_class=hp_names_to_dimension_class,
        )

        return TrainableModel(estimator, hyperparameters_space=hyperparameters_space)

    def actual_params(self, ret, model, fit_params):
        params = model.get_params()

        ret["auto_arima_timeseries_params"] = {
            "information_criterion": params["information_criterion"],
            "test": params["test"],
            "seasonal_test": params["seasonal_test"],
            "method": params["method"],
            "stationary": params["stationary"],
            # orders (p, q, d, P, Q, D) can only be retrieved if the estimator was trained but actual_params can be called without training if no hp search
            "p": params.get("p"),
            "d": params.get("d"),
            "q": params.get("q"),
            "P": params.get("P"),
            "D": params.get("D"),
            "Q": params.get("Q"),
            "m": params["m"]
        }
        amp = {"resolved": ret}
        return amp

    def get_min_size_for_scoring(self, modeling_params, prediction_length):
        return prediction_length + max(prediction_length, modeling_params["auto_arima_timeseries_params"]["m"])

    def get_min_size_for_training(self, modeling_params, prediction_length):
        max_season_length = get_parameter_max_value(modeling_params, "m")
        return 2 * prediction_length + 2 * max_season_length


class Arima(TimeseriesForecastingAlgorithm):
    REFIT_FOR_SCORING = TimeseriesForecastingAlgorithm.ModelRefit.SUPPORTED
    SUPPORTS_EXTERNAL_FEATURES = True
    SUPPORTS_MODEL_COEFFICIENTS = True

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        from dataiku.doctor.timeseries.models.statistical.arima import DkuArimaEstimator
        estimator = DkuArimaEstimator(
            frequency=get_frequency(core_params),
            time_variable=core_params[doctor_constants.TIME_VARIABLE],
            prediction_length=core_params[doctor_constants.PREDICTION_LENGTH],
            target_variable=core_params[doctor_constants.TARGET_VARIABLE],
            timeseries_identifier_columns=core_params[doctor_constants.TIMESERIES_IDENTIFIER_COLUMNS],
            p=input_hp_space["p"],
            q=input_hp_space["q"],
            d=input_hp_space["d"],
            P=input_hp_space["P"],
            Q=input_hp_space["Q"],
            s=input_hp_space["s"],
            D=input_hp_space["D"],
            trend=input_hp_space["trend"],
            trend_offset=input_hp_space["trend_offset"],
            enforce_stationarity=input_hp_space["enforce_stationarity"],
            enforce_invertibility=input_hp_space["enforce_invertibility"],
            concentrate_scale=input_hp_space["concentrate_scale"]
        )
        hyperparameters_space = HyperparametersSpace.from_definition(input_hp_space)
        return TrainableModel(estimator, hyperparameters_space=hyperparameters_space)

    def actual_params(self, ret, model, fit_params):
        params = model.get_params()

        ret["arima_timeseries_params"] = {
            "p": params.get("p"),
            "d": params.get("d"),
            "q": params.get("q"),
            "P": params.get("P"),
            "D": params.get("D"),
            "Q": params.get("Q"),
            "s": params.get("s"),
            "concentrate_scale": params.get("concentrate_scale"),
            "enforce_invertibility": params.get("enforce_invertibility"),
            "enforce_stationarity": params.get("enforce_stationarity"),
            "trend": params.get("trend"),
            "trend_offset": params.get("trend_offset")
        }
        amp = {"resolved": ret}
        return amp


    def get_min_size_for_scoring(self, modeling_params, prediction_length):
        return prediction_length + max(prediction_length, modeling_params["arima_timeseries_params"]["s"])

    def get_min_size_for_training(self, modeling_params, prediction_length):
        s = modeling_params['arima_grid']['s']
        return 2 * prediction_length + 2 * s


class SeasonalTrendLoess(TimeseriesForecastingAlgorithm):
    REFIT_FOR_SCORING = TimeseriesForecastingAlgorithm.ModelRefit.MANDATORY
    SUPPORTS_MODEL_COEFFICIENTS = True
    SUPPORTS_INFORMATION_CRITERIA = True

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        from dataiku.doctor.timeseries.models.statistical.seasonal_trend_loess import DkuSeasonalTrendLoessEstimator
        estimator = DkuSeasonalTrendLoessEstimator(
            frequency=get_frequency(core_params),
            time_variable=core_params[doctor_constants.TIME_VARIABLE],
            prediction_length=core_params[doctor_constants.PREDICTION_LENGTH],
            target_variable=core_params[doctor_constants.TARGET_VARIABLE],
            timeseries_identifier_columns=core_params[doctor_constants.TIMESERIES_IDENTIFIER_COLUMNS],
            auto_trend=input_hp_space["auto_trend"],
            auto_low_pass=input_hp_space["auto_low_pass"],
            monthly_day_alignment=get_monthly_day_alignment(core_params),
        )

        hp_names_to_dimension_class = {
            "period": IntegerHyperparameterDimension,
            "seasonal": OddIntegerHyperparameterDimension,
            "seasonal_deg": CategoricalHyperparameterDimension,
            "trend_deg": CategoricalHyperparameterDimension,
            "low_pass_deg": CategoricalHyperparameterDimension,
            "seasonal_jump": IntegerHyperparameterDimension,
            "trend_jump": IntegerHyperparameterDimension,
            "low_pass_jump": IntegerHyperparameterDimension,
        }

        if not input_hp_space["auto_trend"]:
            hp_names_to_dimension_class["trend"] = OddIntegerHyperparameterDimension

        if not input_hp_space["auto_low_pass"]:
            hp_names_to_dimension_class["low_pass"] = OddIntegerHyperparameterDimension

        hyperparameters_space = SeasonalTrendLoessHyperparametersSpace.from_definition(
            input_hp_space,
            hp_names_to_dimension_class=hp_names_to_dimension_class,
        )

        return TrainableModel(estimator, hyperparameters_space=hyperparameters_space)

    def actual_params(self, ret, model, fit_params):
        params = model.get_params()
        ret["seasonal_loess_timeseries_params"] = {
            "period": params["period"],
            "seasonal": params["seasonal"],
            "trend": params["trend"],
            "low_pass": params["low_pass"],
            "seasonal_deg": params["seasonal_deg"],
            "trend_deg": params["trend_deg"],
            "low_pass_deg": params["low_pass_deg"],
            "seasonal_jump": params["seasonal_jump"],
            "trend_jump": params["trend_jump"],
            "low_pass_jump": params["low_pass_jump"]
        }
        amp = {"resolved": ret}
        return amp

    def get_min_size_for_scoring(self, modeling_params, prediction_length):
        return max(prediction_length, modeling_params["seasonal_loess_timeseries_params"]["period"], 10)  # STL requires at least 10 time steps for training

    def get_min_size_for_training(self, modeling_params, prediction_length):
        max_season_length = get_parameter_max_value(modeling_params, "period")
        return max(prediction_length, max_season_length, 10)  # STL requires at least 10 time steps for training


class ETS(TimeseriesForecastingAlgorithm):
    REFIT_FOR_SCORING = TimeseriesForecastingAlgorithm.ModelRefit.SUPPORTED
    SUPPORTS_MODEL_COEFFICIENTS = True
    SUPPORTS_INFORMATION_CRITERIA = True

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        from dataiku.doctor.timeseries.models.statistical.ets import DkuETSEstimator
        estimator = DkuETSEstimator(
            frequency=get_frequency(core_params),
            time_variable=core_params[doctor_constants.TIME_VARIABLE],
            prediction_length=core_params[doctor_constants.PREDICTION_LENGTH],
            target_variable=core_params[doctor_constants.TARGET_VARIABLE],
            timeseries_identifier_columns=core_params[doctor_constants.TIMESERIES_IDENTIFIER_COLUMNS],
            seasonal_periods=input_hp_space["seasonal_periods"],
            seed=input_hp_space["seed"],
            include_unstable=input_hp_space["include_unstable"]
        )

        hp_names_to_dimension_class = {
            "trend": CategoricalHyperparameterDimension,
            "damped_trend": CategoricalHyperparameterDimension,
            "seasonal": CategoricalHyperparameterDimension,
            "error": CategoricalHyperparameterDimension
        }

        hyperparameters_space = ETSHyperparametersSpace.from_definition(
            input_hp_space,
            hp_names_to_dimension_class=hp_names_to_dimension_class,
            include_unstable=modeling_params["ets_timeseries_grid"]["include_unstable"]
        )

        return TrainableModel(estimator, hyperparameters_space=hyperparameters_space)

    def actual_params(self, ret, model, fit_params):
        params = model.get_params()
        ret["ets_params"] = {
            "trend": params["trend"],
            "damped_trend": params["damped_trend"],
            "seasonal": params["seasonal"],
            "error": params["error"],
            "seasonal_periods": params["seasonal_periods"],
            "seed": params["seed"],
            "include_unstable": params["include_unstable"]
        }
        amp = {"resolved": ret}
        return amp

    def get_min_size_for_scoring(self, modeling_params, prediction_length):
        return max(prediction_length, modeling_params["ets_params"]["seasonal_periods"])

    def get_min_size_for_training(self, modeling_params, prediction_length):
        # Heuristic method needs at least 10 observations
        return max(prediction_length, modeling_params["ets_timeseries_grid"]["seasonal_periods"], 10)


class Prophet(TimeseriesForecastingAlgorithm):
    REFIT_FOR_SCORING = TimeseriesForecastingAlgorithm.ModelRefit.SUPPORTED
    SUPPORTS_EXTERNAL_FEATURES = True
    SUPPORTS_MODEL_COEFFICIENTS = True

    def model_from_params(self, input_hp_space, modeling_params, core_params):
        from dataiku.doctor.timeseries.models.statistical.prophet import DkuProphetEstimator
        estimator = DkuProphetEstimator(
            frequency=get_frequency(core_params),
            time_variable=core_params[doctor_constants.TIME_VARIABLE],
            prediction_length=core_params[doctor_constants.PREDICTION_LENGTH],
            target_variable=core_params[doctor_constants.TARGET_VARIABLE],
            timeseries_identifier_columns=core_params[doctor_constants.TIMESERIES_IDENTIFIER_COLUMNS],
            growth=input_hp_space["growth"],
            floor=input_hp_space["floor"],
            cap=input_hp_space.get("cap"),  # cap can be undefined if growth is not logistic
            n_changepoints=input_hp_space["n_changepoints"],
            changepoint_range=input_hp_space["changepoint_range"],
            yearly_seasonality=input_hp_space["yearly_seasonality"],
            weekly_seasonality=input_hp_space["weekly_seasonality"],
            daily_seasonality=input_hp_space["daily_seasonality"],
            seed=input_hp_space["seed"],
            monthly_day_alignment=get_monthly_day_alignment(core_params),
        )

        hp_names_to_dimension_class = {
            "seasonality_mode": CategoricalHyperparameterDimension,
            "seasonality_prior_scale": FloatHyperparameterDimension,
            "changepoint_prior_scale": FloatHyperparameterDimension,
        }

        if input_hp_space["_use_external_features"]:
            hp_names_to_dimension_class["holidays_prior_scale"] = FloatHyperparameterDimension

        hyperparameters_space = HyperparametersSpace.from_definition(
            input_hp_space,
            hp_names_to_dimension_class=hp_names_to_dimension_class,
        )

        return TrainableModel(estimator, hyperparameters_space=hyperparameters_space)

    def actual_params(self, ret, model, fit_params):
        params = model.get_params()
        ret["prophet_timeseries_params"] = {
            "growth": params["growth"],
            "floor": params["floor"],
            "cap": params["cap"],
            "n_changepoints": params["n_changepoints"],
            "changepoint_range": params["changepoint_range"],
            "yearly_seasonality": params["yearly_seasonality"],
            "weekly_seasonality": params["weekly_seasonality"],
            "daily_seasonality": params["daily_seasonality"],
            "seed": params["seed"],
            "seasonality_mode": params["seasonality_mode"],
            "seasonality_prior_scale": params["seasonality_prior_scale"],
            "changepoint_prior_scale": params["changepoint_prior_scale"],
            "holidays_prior_scale": params["holidays_prior_scale"],
        }
        amp = {"resolved": ret}
        return amp

    def get_min_size_for_scoring(self, modeling_params, prediction_length):
        # even though Prophet.predict works with only 1 past value, we use 2 in case we re-fit in the scoring/evaluation recipe
        return 2

    def get_min_size_for_training(self, modeling_params, prediction_length):
        # 2 is the minimum number of past values required in the Prophet.fit method
        return 2


def get_parameter_max_value(modeling_params, parameter_name):
    input_hp_space = get_input_hyperparameter_space(modeling_params, modeling_params["algorithm"])
    parameter = input_hp_space[parameter_name]

    if modeling_params["grid_search_params"]["strategy"] == "GRID":
        is_parameter_explicit = parameter["gridMode"] == "EXPLICIT"
    else:
        is_parameter_explicit = parameter["randomMode"] == "EXPLICIT"

    if is_parameter_explicit:
        return max(parameter["values"])
    else:
        return parameter["range"]["max"]


def get_max_context_length(modeling_params, full_context_length):
    input_hp_space = get_input_hyperparameter_space(modeling_params, modeling_params["algorithm"])
    if input_hp_space["full_context"]:
        return full_context_length
    else:
        return get_parameter_max_value(modeling_params, "context_length")


def get_used_context_length(modeling_params, model_params_key, full_context_length):
    context_length = modeling_params[model_params_key].get("context_length")
    if context_length is not None:
        return context_length
    return full_context_length


TIMESERIES_FORECASTING_ALGORITHMS_MAP = {
    "TRIVIAL_IDENTITY_TIMESERIES": GluonTSTrivialIdentity(),
    "SEASONAL_NAIVE": GluonTSSeasonalNaive(),
    "GLUONTS_NPTS_FORECASTER": GluonTSNPTSForecaster(),
    "GLUONTS_SIMPLE_FEEDFORWARD": GluonTSSimpleFeedForward(),
    "GLUONTS_DEEPAR": GluonTSDeepAR(),
    "GLUONTS_TORCH_SIMPLE_FEEDFORWARD": GluonTSTorchSimpleFeedForward(),
    "GLUONTS_TORCH_DEEPAR": GluonTSTorchDeepAR(),
    "GLUONTS_TRANSFORMER": GluonTSTransformer(),
    "GLUONTS_MQCNN": GluonTSMQCNN(),
    "AUTO_ARIMA": AutoArima(),
    "ARIMA": Arima(),
    "ETS": ETS(),
    "SEASONAL_LOESS": SeasonalTrendLoess(),
    "PROPHET": Prophet(),
}
