# Keep on top, to handle missing libraries for MxNet (SC-98195)
from dataiku.doctor.timeseries.models.gluonts.base_estimator import DkuGluonTSEstimator

import gluonts
from dataiku.base.utils import package_is_at_least

from gluonts.model.npts import NPTSPredictor

import numpy as np


class DkuNPTSEstimator(DkuGluonTSEstimator):
    def __init__(
        self,
        frequency,
        prediction_length,
        time_variable,
        target_variable,
        timeseries_identifiers,
        full_context,
        use_seasonal_model,
        use_default_time_features,
        seed,
        context_length=1,
        kernel_type="exponential",
        exp_kernel_weights=1,
        feature_scale=1000,
        monthly_day_alignment=None,
    ):
        super(DkuNPTSEstimator, self).__init__(
            frequency=frequency,
            prediction_length=prediction_length,
            time_variable=time_variable,
            target_variable=target_variable,
            timeseries_identifiers=timeseries_identifiers,
            monthly_day_alignment=monthly_day_alignment,
        )
        self.full_context = full_context
        self.seed = seed
        self.use_seasonal_model = use_seasonal_model
        self.use_default_time_features = use_default_time_features

        # Searchable parameters
        self.context_length = context_length
        self.kernel_type = kernel_type
        self.exp_kernel_weights = exp_kernel_weights
        self.feature_scale = feature_scale

    def initialize(self, core_params, modeling_params):
        super(DkuNPTSEstimator, self).initialize(core_params, modeling_params)
        self.seed = modeling_params["gluonts_npts_timeseries_params"]["seed"]

    def set_params(self, **params):
        super(DkuNPTSEstimator, self).set_params(**params)

        if self.full_context:
            self.context_length = None

        return self

    def fit(self, train_df, external_features=None, shift_map=None):
        if external_features is not None:
            self.external_features = external_features

        constructor_args = {
            "prediction_length": self.prediction_length,
            "context_length": self.context_length,
            "use_seasonal_model": self.use_seasonal_model,
            "use_default_time_features": self.use_default_time_features,
            "kernel_type": self.kernel_type,
            "exp_kernel_weights": self.exp_kernel_weights,
            "feature_scale": self.feature_scale
        }

        if not package_is_at_least(gluonts, "0.14.0"):
            constructor_args["freq"] = self.frequency

        self.predictor = NPTSPredictor(**constructor_args)
        return self

    def _prepare_predict(self):
        super(DkuNPTSEstimator, self)._prepare_predict()
        np.random.seed(self.seed)
