from dataiku.doctor.timeseries.models.statistical.stats import build_coefficient_dict, sanitized_stats_value
from dataiku.doctor.timeseries.utils.arima_compat import dku_auto_arima
from dataiku.doctor.timeseries.utils.arima_compat import instantiate_arima_model
from dataiku.doctor.timeseries.utils.arima_compat import SUPPORTS_MODEL_COEFFICIENTS
from dataiku.doctor.timeseries.models.statistical.base_estimator import DkuStatisticalEstimator, \
    INFORMATION_CRITERION_TO_DISPLAY_NAME


class DkuAutoArimaEstimator(DkuStatisticalEstimator):
    def __init__(
        self,
        frequency,
        time_variable,
        prediction_length,
        target_variable,
        timeseries_identifier_columns,
        start_p,
        max_p,
        d_,
        max_d,
        start_q,
        max_q,
        start_P,
        max_P,
        D_,
        max_D,
        start_Q,
        max_Q,
        max_order,
        stationary,
        maxiter,
        m=1,
        information_criterion="aic",
        test="kpss",
        seasonal_test="ocsb",
        method="lbfgs",
        monthly_day_alignment=None,
    ):
        super(DkuAutoArimaEstimator, self).__init__(
            frequency=frequency,
            time_variable=time_variable,
            prediction_length=prediction_length,
            target_variable=target_variable,
            timeseries_identifier_columns=timeseries_identifier_columns,
            monthly_day_alignment=monthly_day_alignment,
        )

        # we use the suffix "_" because d_ and D_ are the parameters of the "global" DkuAutoArimaEstimator
        # and not the d and D orders found for each time series that are added to the output of the get_params method
        # this avoids the "global" parameters to be overridden in the get_params method
        self.d_ = d_
        self.D_ = D_

        self.start_p = start_p
        self.max_p = max_p
        
        self.max_d = max_d
        self.start_q = start_q
        self.max_q = max_q
        self.start_P = start_P
        self.max_P = max_P
        self.max_D = max_D
        self.start_Q = start_Q
        self.max_Q = max_Q
        self.max_order = max_order
        self.stationary = stationary
        self.maxiter = maxiter

        # searchable parameters
        self.m = m
        self.information_criterion = information_criterion
        self.test = test
        self.seasonal_test = seasonal_test
        self.method = method

    def set_params(self, **params):
        super(DkuAutoArimaEstimator, self).set_params(**params)

        # remove unused args if stationary is True
        if self.stationary:
            self.d_ = None
            self.D_ = None
            self.max_d = 2
            self.max_D = 1

        return self

    def get_params(self, deep=True):
        params = super(DkuAutoArimaEstimator, self).get_params(deep=deep)

        if self.trained_models is not None:
            for order_name in ["p", "d", "q", "P", "D", "Q"]:
                params[order_name] = {}

            for timeseries_identifier, trained_model in self.trained_models.items():
                p, d, q = trained_model.model.order
                P, D, Q, m = trained_model.model.seasonal_order

                params["p"][timeseries_identifier] = p
                params["d"][timeseries_identifier] = d
                params["q"][timeseries_identifier] = q
                params["P"][timeseries_identifier] = P
                params["D"][timeseries_identifier] = D
                params["Q"][timeseries_identifier] = Q

        return params

    def _fit_single(self, target_values, date_values=None, external_features_values=None):
        """Fit one time series"""

        auto_arima_model = dku_auto_arima(
            target_values,
            external_features_values,
            start_p=self.start_p,
            max_p=self.max_p,
            d=self.d_,
            max_d=self.max_d,
            start_q=self.start_q,
            max_q=self.max_q,
            start_P=self.start_P,
            max_P=self.max_P,
            D=self.D_,
            max_D=self.max_D,
            start_Q=self.start_Q,
            max_Q=self.max_Q,
            max_order=self.max_order,
            stationary=self.stationary,
            maxiter=self.maxiter,
            m=self.m,
            information_criterion=self.information_criterion,
            test=self.test,
            seasonal_test=self.seasonal_test,
            method=self.method,
            seasonal=self.m > 1,
        )
        auto_arima_params = auto_arima_model.get_params()

        # instantiate and fit an ARIMA model with same parameters as auto_arima used
        trained_model = instantiate_arima_model(target_values, external_features_values, auto_arima_params).fit()

        trained_model.model.endog = None
        if self.external_features:
            trained_model.model.exog = None

        return trained_model

    def _forecast_single_timeseries(
        self,
        trained_model,
        past_target_values,
        past_date_values,
        quantiles,
        past_external_features_values,
        future_external_features_values,
        fit_before_predict,
        prediction_length
    ):
        if fit_before_predict:
            auto_arima_params = {
                "order": trained_model.model.order,
                "seasonal_order": trained_model.model.seasonal_order,
                "trend": trained_model.model.trend,
            }

            # instantiate and fit an ARIMA model with the same ARIMA orders and parameters found by pm.auto_arima during training
            trained_model = instantiate_arima_model(past_target_values, past_external_features_values, auto_arima_params).fit()
        else:
            trained_model = trained_model.apply(endog=past_target_values, exog=past_external_features_values, refit=False)

        prediction_results = trained_model.get_prediction(
            start=trained_model.nobs,
            end=trained_model.nobs + prediction_length - 1,
            exog=future_external_features_values,
        )

        return self._build_forecasts_dict(prediction_results, quantiles)

    def get_coefficients_map_and_names(self):
        if not SUPPORTS_MODEL_COEFFICIENTS:
            raise NotImplementedError("pmdarima package version must be at least 1.5.0 in order to extract the model coefficients")

        variable_coefficients = ["p", "q", "P", "Q"]
        orders_methods = ["arparams", "maparams", "seasonalarparams", "seasonalmaparams"]
        stats_key_for_method = {
            "arparams": "ar.L",
            "maparams": "ma.L",
            "seasonalarparams": "ar.S.L",
            "seasonalmaparams": "ma.S.L",
        }
        external_features_set = set()
        coefficients_map = {}
        for timeseries_identifier, trained_model in self.trained_models.items():
            p, _, q = trained_model.model.order
            P, _, Q, seasonal_period = trained_model.model.seasonal_order
            orders = [p, q, P, Q]
            for order_name, order_value, order_method in zip(variable_coefficients, orders, orders_methods):
                if order_value > 0:  # method cannot be called if order is 0
                    for i, value in enumerate(getattr(trained_model, order_method)):
                        coeff_name = "{}_{}".format(order_name, i+1)
                        if coeff_name not in coefficients_map:
                            coefficients_map[coeff_name] = build_coefficient_dict()
                        coefficients_map[coeff_name]["values"][timeseries_identifier] = value
                        try:
                            param_key = stats_key_for_method[order_method] + (str(i+1) if order_name in {'p', 'q'} else str(seasonal_period*(i+1)))
                            coefficients_map[coeff_name]["pvalues"][timeseries_identifier] = sanitized_stats_value(trained_model.pvalues[param_key])
                            coefficients_map[coeff_name]["tvalues"][timeseries_identifier] = sanitized_stats_value(trained_model.tvalues[param_key])
                            coefficients_map[coeff_name]["stderrs"][timeseries_identifier] = sanitized_stats_value(trained_model.bse[param_key])
                        except KeyError:
                            pass

            external_features_names = trained_model.model.exog_names
            if external_features_names:
                coefficients_by_name = trained_model.params.to_dict()
                for external_feature_name in external_features_names:
                    if external_feature_name not in coefficients_map:
                        coefficients_map[external_feature_name] = build_coefficient_dict()
                        external_features_set.add(external_feature_name)
                    coefficients_map[external_feature_name]["values"][timeseries_identifier] = coefficients_by_name[external_feature_name]
                    coefficients_map[external_feature_name]["pvalues"][timeseries_identifier] = sanitized_stats_value(trained_model.pvalues[external_feature_name])
                    coefficients_map[external_feature_name]["tvalues"][timeseries_identifier] = sanitized_stats_value(trained_model.tvalues[external_feature_name])
                    coefficients_map[external_feature_name]["stderrs"][timeseries_identifier] = sanitized_stats_value(trained_model.bse[external_feature_name])

        external_features_coefficients = sorted(list(external_features_set))

        return coefficients_map, None, variable_coefficients, external_features_coefficients

    def get_information_criteria(self):
        information_criteria = []
        for criterion_name in ["aic", "bic", "hqic", "llf"]:
            criterion = { "values": {}, "displayName": INFORMATION_CRITERION_TO_DISPLAY_NAME[criterion_name] }
            for timeseries_identifier, trained_model in self.trained_models.items():
                criterion["values"][timeseries_identifier] = self.prepare_information_criteria(getattr(trained_model, criterion_name))
            information_criteria.append(criterion)
        return information_criteria

    def get_fitted_values_and_residuals(self, identifier, df_of_identifier, min_scoring_size):
        """
        For AutoArima models, fitted_values and residuals are computed/stored at train time.
        """
        trained_model = self.trained_models[identifier]
        return trained_model.fittedvalues, trained_model.resid
