# Keep on top, to handle missing libraries for MxNet (SC-98195)
from dataiku.doctor.timeseries.models.gluonts.base_estimator import DkuGluonTSEstimator

import gluonts
from gluonts.model.seasonal_naive import SeasonalNaivePredictor

from dataiku.base.utils import package_is_at_least


class DkuSeasonalNaiveSEstimator(DkuGluonTSEstimator):
    def __init__(
            self, frequency, prediction_length, time_variable, target_variable, timeseries_identifiers, season_length=1, monthly_day_alignment=None
    ):
        super(DkuSeasonalNaiveSEstimator, self).__init__(
            frequency=frequency,
            prediction_length=prediction_length,
            time_variable=time_variable,
            target_variable=target_variable,
            timeseries_identifiers=timeseries_identifiers,
            monthly_day_alignment=monthly_day_alignment,
        )

        # searchable parameters
        self.season_length = season_length

    def fit(self, train_df, external_features=None, shift_map=None):

        constructor_args = {
            "prediction_length": self.prediction_length,
            "season_length": self.season_length
        }

        if not package_is_at_least(gluonts, "0.14.0"):
            constructor_args["freq"] = self.frequency

        self.predictor = SeasonalNaivePredictor(**constructor_args)
        return self
