# Keep on top, to handle missing libraries for MxNet (SC-98195)

import gluonts
from dataiku.base.utils import package_is_at_least
from dataiku.doctor.timeseries.models.gluonts.mxnet.mxnet_base_estimator import DkuGluonTSMXNetDeepLearningEstimator

if package_is_at_least(gluonts, "0.12.0"):
    from gluonts.mx.model.deepar import DeepAREstimator
else:
    from gluonts.model.deepar import DeepAREstimator

from dataiku.doctor.utils.gpu_execution import get_default_gpu_config


class DkuDeepAREstimator(DkuGluonTSMXNetDeepLearningEstimator):
    def __init__(
        self,
        frequency,
        prediction_length,
        time_variable,
        target_variable,
        timeseries_identifiers,
        use_timeseries_identifiers_as_features,
        full_context,
        batch_size,
        epochs,
        auto_num_batches_per_epoch,
        num_batches_per_epoch,
        scaling,
        num_parallel_samples,
        minimum_scale,
        seed,
        learning_rate=.001,
        context_length=1,
        num_layers=2,
        num_cells=40,
        cell_type="lstm",
        dropoutcell_type="ZoneoutCell",
        dropout_rate=0.1,
        alpha=0,
        beta=0,
        distr_output="StudentTOutput",
        gpu_config=get_default_gpu_config(),
        monthly_day_alignment=None,
    ):
        super(DkuDeepAREstimator, self).__init__(
            frequency=frequency,
            prediction_length=prediction_length,
            time_variable=time_variable,
            target_variable=target_variable,
            timeseries_identifiers=timeseries_identifiers,
            use_timeseries_identifiers_as_features=use_timeseries_identifiers_as_features,
            context_length=context_length,
            full_context=full_context,
            batch_size=batch_size,
            epochs=epochs,
            auto_num_batches_per_epoch=auto_num_batches_per_epoch,
            num_batches_per_epoch=num_batches_per_epoch,
            learning_rate=learning_rate,
            gpu_config=gpu_config,
            seed=seed,
            monthly_day_alignment=monthly_day_alignment,
        )

        self.scaling = scaling
        self.num_parallel_samples = num_parallel_samples
        self.minimum_scale = minimum_scale

        # Searchable parameters
        # Learning rate & context length are in the parent class
        self.num_layers = num_layers
        self.num_cells = num_cells
        self.cell_type = cell_type
        self.dropoutcell_type = dropoutcell_type
        self.dropout_rate = dropout_rate
        self.alpha = alpha
        self.beta = beta
        self.distr_output = distr_output

    def initialize(self, core_params, modeling_params):
        super(DkuDeepAREstimator, self).initialize(core_params, modeling_params)
        algo_params = modeling_params["gluonts_deepar_timeseries_params"]
        self.use_timeseries_identifiers_as_features = algo_params.get("use_timeseries_identifiers_as_features", False)
        self.seed = algo_params["seed"]

    def _get_estimator(self, train_data, identifier_cardinalities):
        time_parameters = self.get_time_based_parameters(parameters_to_set=["lags_seq", "time_features"])

        trainer = self.get_trainer(train_data)

        return DeepAREstimator(
            trainer=trainer,
            freq=self.frequency,
            prediction_length=self.prediction_length,
            batch_size=self.batch_size,
            use_feat_dynamic_real=bool(self.external_features),
            use_feat_static_cat=bool(identifier_cardinalities),
            cardinality=identifier_cardinalities,
            context_length=self.context_length,
            scaling=self.scaling,
            num_parallel_samples=self.num_parallel_samples,
            minimum_scale=self.minimum_scale,
            num_layers=self.num_layers,
            num_cells=self.num_cells,
            cell_type=self.cell_type,
            dropoutcell_type=self.dropoutcell_type,
            dropout_rate=self.dropout_rate,
            alpha=self.alpha,
            beta=self.beta,
            distr_output=DkuGluonTSMXNetDeepLearningEstimator.get_distr_output_class(self.distr_output),
            lags_seq=time_parameters.get("lags_seq"),
            time_features=time_parameters.get("time_features"),
        )
