import logging
import math

import numpy as np

from importlib.util import find_spec

from dataiku.doctor.timeseries.models.base_estimator import BaseTimeseriesEstimator

if find_spec("mxnet"):
    # alias to have mxnet support numpy 1.24 and above as mxnet code is frozen (see https://app.shortcut.com/dataiku/story/177201/state-of-numpy-in-dss)
    # this is necessary if mxnet is installed but using a torch lib, due to https://github.com/awslabs/gluonts/blob/7668ce11fc296de18daf71d718f91b2f447e3615/src/gluonts/model/forecast_generator.py#L85
    np.bool = bool

from gluonts.dataset.common import ListDataset
from gluonts.dataset.field_names import FieldName
from pandas.tseries.offsets import Milli, Second
from pandas.tseries.frequencies import to_offset
from sklearn.utils import murmurhash3_32

from dataiku.core import dkujson
from dataiku.doctor.timeseries.preparation.resampling.utils import get_frequency, get_monthly_day_alignment
from dataiku.doctor.timeseries.utils import FULL_TIMESERIES_DF_IDENTIFIER, log_df
from dataiku.doctor.timeseries.utils import get_dataframe_of_timeseries_identifier
from dataiku.doctor.timeseries.utils import timeseries_iterator
from dataiku.doctor.timeseries.utils import future_date_range
from dataiku.doctor.timeseries.utils import ModelForecast
from dataiku.doctor.utils import doctor_constants
from dataiku.doctor.utils.gpu_execution import get_default_gpu_config

logger = logging.getLogger(__name__)


class DkuGluonTSEstimator(BaseTimeseriesEstimator):
    """
        Contains the GluonTS predictor, external features and several model & algo specific parameters. 
        
        /! From release 11.2 to 12.0 (included), only predictor and external_features were serialized and params
        needed for scoring/evaluation were set via the initialize method.
        We still need to use the initialize method for backward compatibility.
    """
    def __init__(
        self,
        frequency,
        prediction_length,
        time_variable,
        target_variable,
        timeseries_identifiers,
        use_timeseries_identifiers_as_features=False,
        monthly_day_alignment=None
    ):
        super(DkuGluonTSEstimator, self).__init__(
            frequency,
            prediction_length,
            time_variable,
            target_variable,
            timeseries_identifiers,
            monthly_day_alignment
        )
        self.use_timeseries_identifiers_as_features = (
            bool(self.timeseries_identifiers) and use_timeseries_identifiers_as_features
        )

        self.predictor = None

    def _create_single_timeseries_dict(
        self, timeseries_identifier, past_df_of_timeseries_identifier, future_df_of_timeseries_identifier=None
    ):
        """
        Return a dictionary for a single time series (with identifier timeseries_identifier), to be used for the
        creation of a gluon ListDataset
        """

        single_timeseries_dict = {
            FieldName.START: past_df_of_timeseries_identifier[self.time_variable].iloc[0],
            FieldName.TARGET: past_df_of_timeseries_identifier[self.target_variable],
            FieldName.ITEM_ID: timeseries_identifier,
        }

        if self.external_features:
            # For GluonTS models, preprocessing is always done on the full dataframe so we use the key
            # FULL_TIMESERIES_DF_IDENTIFIER (guaranteed to always be present) to get preprocessed external features
            # columns. Incidentally, they are the same for every time series.
            external_features_for_identifier = self.external_features[FULL_TIMESERIES_DF_IDENTIFIER]
            if external_features_for_identifier:
                past_external_features = past_df_of_timeseries_identifier[external_features_for_identifier].values.T
                if future_df_of_timeseries_identifier is None:
                    single_timeseries_dict[FieldName.FEAT_DYNAMIC_REAL] = past_external_features
                else:
                    future_external_features = future_df_of_timeseries_identifier[external_features_for_identifier].values.T
                    single_timeseries_dict[FieldName.FEAT_DYNAMIC_REAL] = np.append(
                        past_external_features, future_external_features, axis=1
                    )

        if self.use_timeseries_identifiers_as_features:
            timeseries_identifier_dict = dkujson.loads(timeseries_identifier)
            single_timeseries_dict[FieldName.FEAT_STATIC_CAT] = [
                murmurhash3_32(timeseries_identifier_dict[col], seed=1337, positive=True)
                for col in self.timeseries_identifiers
            ]

        return single_timeseries_dict

    def _align_external_features_single_ts(self, d):
        new_d = d.copy()
        if FieldName.FEAT_DYNAMIC_REAL in d:
            size = len(d[FieldName.TARGET])
            if d[FieldName.FEAT_DYNAMIC_REAL].shape[1] < size + self.prediction_length:
                # Not enough external data, some padding is performed
                new_shape = new_d[FieldName.FEAT_DYNAMIC_REAL].shape[0], size + self.prediction_length
                new_d[FieldName.FEAT_DYNAMIC_REAL] = np.zeros(new_shape)
                # Fill the empty block with all the available values from external features
                new_d[FieldName.FEAT_DYNAMIC_REAL][:,:d[FieldName.FEAT_DYNAMIC_REAL].shape[1]] = d[FieldName.FEAT_DYNAMIC_REAL]
            else:
                # Trim the extra values of external features beyond self.prediction_length
                new_d[FieldName.FEAT_DYNAMIC_REAL] = d[FieldName.FEAT_DYNAMIC_REAL][:,:(size + self.prediction_length)]

        return new_d

    def _align_external_features(self, timeseries_dicts):
        """"
        :param timeseries_dicts: dict as follows:
             - key: any FieldName
             - value: dict as follows:
                - key: target, or external feature name
                - value: series of values
        :return: dict with the same structure. The series of values for each external feature is exactly
                 self.prediction_length longer than the series of values for the target. If there is
                 not enough data in external features, padding will be performed.
        """
        aligned_timeseries_dicts = {}
        for k, d in timeseries_dicts.items():
            aligned_timeseries_dicts[k] = self._align_external_features_single_ts(d)
        return aligned_timeseries_dicts

    def _create_gluon_list_dataset(self, timeseries_dicts):
        """Create a gluonts ListDataset from list of time series dicts

        Args:
            timeseries_dicts (list): List of time series dicts.
        """
        return ListDataset(timeseries_dicts, freq=self.frequency)

    def _build_single_forecasts_dict(self, sample_forecasts, last_past_date, quantiles, prediction_length_override=None):
        quantiles_forecasts = [sample_forecasts.quantile(quantile) for quantile in quantiles]

        forecast_dates = future_date_range(
            last_past_date,
            self.prediction_length,
            self.frequency,
            self.monthly_day_alignment,
        )

        ret = {
            ModelForecast.TIMESTAMPS: forecast_dates,
            ModelForecast.FORECAST_VALUES: sample_forecasts.mean,
            ModelForecast.QUANTILES_FORECASTS: np.array(quantiles_forecasts),
        }

        if prediction_length_override:
            ret[ModelForecast.TIMESTAMPS] = ret[ModelForecast.TIMESTAMPS][:prediction_length_override]
            ret[ModelForecast.FORECAST_VALUES] = ret[ModelForecast.FORECAST_VALUES][:prediction_length_override]
            ret[ModelForecast.QUANTILES_FORECASTS] = ret[ModelForecast.QUANTILES_FORECASTS][:,:prediction_length_override]

        return ret

    def initialize(self, core_params, modeling_params):
        """ Sets the params that are needed for scoring (required for models that were not serialized with all parameters from release 11.2 to 12.0)

        Args:
            core_params (dict): Core params of the model
            modeling_params (dict): Resolved modeling params
        """
        self.frequency = get_frequency(core_params)
        self.prediction_length = core_params[doctor_constants.PREDICTION_LENGTH]
        self.time_variable = core_params[doctor_constants.TIME_VARIABLE]
        self.target_variable = core_params[doctor_constants.TARGET_VARIABLE]
        self.timeseries_identifiers = core_params[doctor_constants.TIMESERIES_IDENTIFIER_COLUMNS]
        self.use_timeseries_identifiers_as_features = False
        self.monthly_day_alignment = get_monthly_day_alignment(core_params)

    def fit(self, train_df, external_features=None, shift_map=None):
        raise NotImplementedError()

    def _prepare_predict(self):
        """
        Perform basic checks before running the predict method, and set random seeds if any.
        """
        if self.predictor is None:
            raise ValueError("Trying to predict an estimator that has not been trained")

    def predict_single(
        self,
        past_df_of_timeseries_identifier,
        future_df_of_timeseries_identifier,
        quantiles,
        timeseries_identifier,
        fit_before_predict=False,
        prediction_length_override=None
    ):
        """
        Produce the forecast values for a single time series, with identifier timeseries_identifier
        """
        self._prepare_predict()

        timeseries_dict = self._create_single_timeseries_dict(
            timeseries_identifier, past_df_of_timeseries_identifier, future_df_of_timeseries_identifier
        )

        if prediction_length_override:
            aligned_timeseries_dict = self._align_external_features_single_ts(timeseries_dict)
            predict_data = self._create_gluon_list_dataset([aligned_timeseries_dict])
        else:
            predict_data = self._create_gluon_list_dataset([timeseries_dict])

        # here predictor.predict outputs a generator of a single gluonts.model.forecast.SampleForecast object because
        # predict_data contains a single timeseries
        sample_forecasts = next(self.predictor.predict(predict_data))

        last_past_date = past_df_of_timeseries_identifier[self.time_variable].iloc[-1]

        return self._build_single_forecasts_dict(sample_forecasts, last_past_date, quantiles, prediction_length_override)

    def predict(self, past_df, future_df, quantiles, fit_before_predict=False, prediction_length_override=None):
        """
        Produce the forecast values for all time series

        Return:
            Dictionary where keys are time series identifiers and values are the forecast values for the time series.
            Each forecast contains the time stamps, the mean forecast values, and the quantile forecasts (2D-array)
        """
        self._prepare_predict()

        last_past_dates = {}
        timeseries_dicts = {}
        for timeseries_identifier, past_df_of_timeseries_identifier in timeseries_iterator(
                past_df, self.timeseries_identifiers
        ):
            logger.info("Predicting model for time series %s" % timeseries_identifier)
            log_df(logger, past_df_of_timeseries_identifier, self.time_variable, None, "\t - Past")
            future_df_of_timeseries_identifier = None
            if self.external_features:
                future_df_of_timeseries_identifier = get_dataframe_of_timeseries_identifier(
                    future_df, timeseries_identifier
                )
                log_df(logger, future_df_of_timeseries_identifier, self.time_variable, None, "\t - External features future")

            last_past_dates[timeseries_identifier] = past_df_of_timeseries_identifier[self.time_variable].iloc[-1]

            timeseries_dicts[timeseries_identifier] = self._create_single_timeseries_dict(
                timeseries_identifier, past_df_of_timeseries_identifier, future_df_of_timeseries_identifier
            )

        forecasts_by_timeseries = {}

        if prediction_length_override is not None:
            prediction_length = prediction_length_override
        else:
            prediction_length = self.prediction_length
        n_horizons = int(np.ceil(prediction_length / self.prediction_length))
        # GluonTS only allows to forecast self.prediction_length values, with strong assumptions on the format of inputs
        # We iterate n_horizons times the following steps:
        # - format the input (predict_data) as expected by the GluonTS API
        # - forecast of a batch of size self.prediction_length
        # - update of forecasts_by_timeseries (result)
        # - update of timeseries_dicts and last_past_dates (inputs for the next iteration)
        for _ in range(n_horizons):
            aligned_timeseries_dicts = self._align_external_features(timeseries_dicts)
            predict_data = self._create_gluon_list_dataset(list(aligned_timeseries_dicts.values()))

            forecasts = self.predictor.predict(predict_data)

            for sample_forecasts in forecasts:
                timeseries_identifier = sample_forecasts.item_id
                forecasts_dict = self._build_single_forecasts_dict(
                    sample_forecasts, last_past_dates[timeseries_identifier], quantiles
                )
                if timeseries_identifier not in forecasts_by_timeseries:
                    # First horizon
                    forecasts_by_timeseries[timeseries_identifier] = forecasts_dict
                else:
                    # Additional horizons: add the new timestamps and forecast_values + remove the inaccurate quantiles
                    forecasts_by_timeseries[timeseries_identifier]["timestamps"] = np.hstack((forecasts_by_timeseries[timeseries_identifier]["timestamps"], forecasts_dict["timestamps"]))
                    forecasts_by_timeseries[timeseries_identifier]["forecast_values"] = np.hstack((forecasts_by_timeseries[timeseries_identifier]["forecast_values"], forecasts_dict["forecast_values"]))
                    forecasts_by_timeseries[timeseries_identifier]["quantiles_forecasts"] = np.hstack((forecasts_by_timeseries[timeseries_identifier]["quantiles_forecasts"], np.full_like(forecasts_dict["quantiles_forecasts"], np.nan)))

                new_values = np.hstack((timeseries_dicts[timeseries_identifier][FieldName.TARGET], forecasts_dict["forecast_values"]))
                timeseries_dicts[timeseries_identifier][FieldName.TARGET] = new_values
                last_past_dates[timeseries_identifier] = forecasts_dict["timestamps"][-1]

        if prediction_length < n_horizons * self.prediction_length:
            # The overriding prediction_length is not an integer multiple of self.prediction_length (used in training)
            n_excess_steps = n_horizons * self.prediction_length - prediction_length
            for forecast in forecasts_by_timeseries.values():
                forecast["timestamps"] = forecast["timestamps"][:-n_excess_steps]
                forecast["forecast_values"] = forecast["forecast_values"][:-n_excess_steps]
                forecast["quantiles_forecasts"] = forecast["quantiles_forecasts"][:,:-n_excess_steps]

        return forecasts_by_timeseries

    def get_fitted_values_and_residuals(self, identifier, df_of_identifier, min_scoring_size):
        """
        For GluonTS models, we compute a prediction for every possible timestep in the historical data.
        This is a simplified version of `predict_single`
        """
        timeseries_dicts = []
        for i in range(len(df_of_identifier) - self.prediction_length - min_scoring_size):
            past_df = df_of_identifier[i: i + min_scoring_size]
            future_df = df_of_identifier[i + min_scoring_size: i + min_scoring_size + self.prediction_length]
            timeseries_dicts.append(self._create_single_timeseries_dict(identifier, past_df, future_df))
        predict_data = self._create_gluon_list_dataset(timeseries_dicts)
        # np.nan here are used to align fitted_values size with the historical data. The rows are then dropped once the initial residuals dataframe has been created. This is to preserve alignment.
        fitted_values = [np.nan] * min_scoring_size + [f.mean[0] for f in self.predictor.predict(predict_data)] + [np.nan] * self.prediction_length
        residuals = df_of_identifier[self.target_variable] - fitted_values

        return fitted_values, residuals


class DkuGluonTSDeepLearningEstimator(DkuGluonTSEstimator):
    def __init__(
            self,
            frequency,
            prediction_length,
            time_variable,
            target_variable,
            timeseries_identifiers,
            use_timeseries_identifiers_as_features=False,
            context_length=1,
            full_context=True,
            auto_num_batches_per_epoch=True,
            num_batches_per_epoch=50,
            epochs=10,
            batch_size=32,
            learning_rate=.001,
            gpu_config=get_default_gpu_config(),
            seed=1337,
            monthly_day_alignment=None,
    ):
        super(DkuGluonTSDeepLearningEstimator, self).__init__(
            frequency=frequency,
            prediction_length=prediction_length,
            time_variable=time_variable,
            target_variable=target_variable,
            timeseries_identifiers=timeseries_identifiers,
            use_timeseries_identifiers_as_features=use_timeseries_identifiers_as_features,
            monthly_day_alignment=monthly_day_alignment,
        )

        self.full_context = full_context
        self.auto_num_batches_per_epoch = auto_num_batches_per_epoch
        self.num_batches_per_epoch = num_batches_per_epoch
        self.epochs = epochs
        self.batch_size = batch_size
        self.seed = seed
        self.gpu_config = gpu_config

        # Searchable parameters
        self.learning_rate = learning_rate
        self.context_length = context_length

    @staticmethod
    def get_distr_output_class(distr_output):
        raise NotImplementedError()

    def set_params(self, **params):
        super(DkuGluonTSDeepLearningEstimator, self).set_params(**params)

        if self.full_context:
            self.context_length = None

        return self

    def _get_estimator(self, train_data, identifier_cardinalities):
        raise NotImplementedError()

    def fit(self, train_df, external_features=None, shift_map=None):
        if external_features is not None:
            self.external_features = external_features

        timeseries_dicts = []
        for timeseries_identifier, train_df_of_timeseries_identifier in timeseries_iterator(train_df, self.timeseries_identifiers):
            timeseries = self._create_single_timeseries_dict(
                timeseries_identifier,
                train_df_of_timeseries_identifier,
            )
            timeseries_dicts.append(timeseries)

        train_data = self._create_gluon_list_dataset(timeseries_dicts)

        np.random.seed(self.seed)

        if self.use_timeseries_identifiers_as_features:
            identifier_cardinalities = train_df[self.timeseries_identifiers].nunique().tolist()
            if identifier_cardinalities:
                logger.info(
                    """Time series identifiers {ids} of cardinality {cardinality} will be encoded to be used as external
                    features of the model""".format(
                        ids=self.timeseries_identifiers, cardinality=identifier_cardinalities
                    )
                )
        else:
            identifier_cardinalities = None
        estimator = self._get_estimator(train_data, identifier_cardinalities=identifier_cardinalities)
        self.predictor = estimator.train(train_data, cache_data=True)
        return self

    def _prepare_predict(self):
        super(DkuGluonTSDeepLearningEstimator, self)._prepare_predict()
        np.random.seed(self.seed)

    def get_trainer(self, train_data):
        raise NotImplementedError()

    def _compute_auto_num_batches_per_epoch(self, train_data):
        """Compute the number of batches per epoch based on the training data size.
        With this formula, each timestep will on average be once in the prediction length part.
        """
        num_samples_total = 0
        for timeseries in train_data:
            timeseries_length = len(timeseries[FieldName.TARGET])
            num_samples = math.ceil(timeseries_length / self.prediction_length)
            num_samples_total += num_samples
        return max(math.ceil(num_samples_total / self.batch_size), 50)

    def get_time_based_parameters(self, parameters_to_set=None):
        """For seconds and miliseconds, we need to set some parameters ourselves
        If not set, these parameters are automatically computed by gluonts using the frequency (this fails with seconds and miliseconds)
        """
        time_parameters = {}
        if to_offset(self.frequency).name in [Second()._prefix, Milli()._prefix]:
            if parameters_to_set:
                if "lags_seq" in parameters_to_set:
                    # this is the default list of gluonts (gluonts.time_feature.get_lags_for_frequency)
                    time_parameters["lags_seq"] = [1, 2, 3, 4, 5, 6, 7]
                if "time_features" in parameters_to_set:
                    time_parameters["time_features"] = []
                if "add_time_feature" in parameters_to_set:
                    time_parameters["add_time_feature"] = False
        return time_parameters



