# Keep on top, to handle missing libraries for MxNet (SC-98195)
from dataiku.doctor.timeseries.models.gluonts.mxnet.mxnet_base_estimator import DkuGluonTSMXNetDeepLearningEstimator

import gluonts
from dataiku.base.utils import package_is_at_least

if package_is_at_least(gluonts, "0.12.0"):
    from gluonts.mx.model.seq2seq import MQCNNEstimator
else:
    from gluonts.model.seq2seq import MQCNNEstimator

from dataiku.doctor.utils.gpu_execution import get_default_gpu_config


class DkuMQCNNEstimator(DkuGluonTSMXNetDeepLearningEstimator):
    def __init__(
        self,
        frequency,
        prediction_length,
        time_variable,
        target_variable,
        timeseries_identifiers,
        use_timeseries_identifiers_as_features,
        full_context,
        batch_size,
        epochs,
        auto_num_batches_per_epoch,
        num_batches_per_epoch,
        scaling,
        seed,
        decoder_mlp_dim_seq,
        channels_seq,
        dilation_seq,
        kernel_size_seq,
        quantiles,
        learning_rate=.001,
        context_length=1,
        gpu_config=get_default_gpu_config(),
        monthly_day_alignment=None,
    ):
        super(DkuMQCNNEstimator, self).__init__(
            frequency=frequency,
            prediction_length=prediction_length,
            time_variable=time_variable,
            target_variable=target_variable,
            timeseries_identifiers=timeseries_identifiers,
            use_timeseries_identifiers_as_features=use_timeseries_identifiers_as_features,
            context_length=context_length,
            full_context=full_context,
            batch_size=batch_size,
            epochs=epochs,
            auto_num_batches_per_epoch=auto_num_batches_per_epoch,
            num_batches_per_epoch=num_batches_per_epoch,
            learning_rate=learning_rate,
            gpu_config=gpu_config,
            seed=seed,
            monthly_day_alignment=monthly_day_alignment,
        )

        self.scaling = scaling
        self.decoder_mlp_dim_seq = decoder_mlp_dim_seq
        self.channels_seq = channels_seq
        self.dilation_seq = dilation_seq
        self.kernel_size_seq = kernel_size_seq
        self.quantiles = quantiles

        # Searchable parameters
        # Learning rate & context length are in the parent class

    def initialize(self, core_params, modeling_params):
        super(DkuMQCNNEstimator, self).initialize(core_params, modeling_params)
        algo_params = modeling_params["gluonts_mqcnn_timeseries_params"]
        self.use_timeseries_identifiers_as_features = algo_params.get("use_timeseries_identifiers_as_features", False)
        self.seed = algo_params["seed"]

    def _get_estimator(self, train_data, identifier_cardinalities):
        time_parameters = self.get_time_based_parameters(parameters_to_set=["add_time_feature"])

        trainer = self.get_trainer(train_data)

        return MQCNNEstimator(
            trainer=trainer,
            freq=self.frequency,
            prediction_length=self.prediction_length,
            batch_size=self.batch_size,
            use_feat_dynamic_real=bool(self.external_features),
            use_feat_static_cat=bool(identifier_cardinalities),
            cardinality=identifier_cardinalities,
            context_length=self.context_length,
            scaling=self.scaling,
            decoder_mlp_dim_seq=self.decoder_mlp_dim_seq,
            channels_seq=self.channels_seq,
            dilation_seq=self.dilation_seq,
            kernel_size_seq=self.kernel_size_seq,
            quantiles=self.quantiles,
            add_time_feature=time_parameters.get("add_time_feature", True),
        )
