from dataiku.doctor.timeseries.models.statistical.stats import build_coefficient_dict, sanitized_stats_value
from dataiku.doctor.timeseries.models.statistical.base_estimator import DkuStatisticalEstimator
from statsmodels.tsa.arima.model import ARIMA


class DkuArimaEstimator(DkuStatisticalEstimator):
    def __init__(
            self,
            frequency,
            time_variable,
            prediction_length,
            target_variable,
            timeseries_identifier_columns,
            p=0,
            d=0,
            q=0,
            P=0,
            D=0,
            Q=0,
            s=0,
            trend=None,
            enforce_stationarity=True,
            enforce_invertibility=True,
            concentrate_scale=False,
            trend_offset=1

    ):
        super(DkuArimaEstimator, self).__init__(
            frequency=frequency,
            time_variable=time_variable,
            prediction_length=prediction_length,
            target_variable=target_variable,
            timeseries_identifier_columns=timeseries_identifier_columns,
        )
        self.p = p
        self.d = d
        self.q = q

        self.P = P
        self.D = D
        self.Q = Q
        self.s = s

        self.trend = trend
        self.trend_offset = trend_offset

        self.enforce_invertibility = enforce_invertibility
        self.enforce_stationarity = enforce_stationarity
        self.concentrate_scale = concentrate_scale

    def _fit_single(self, target_values, date_values=None, external_features_values=None):
        """Fit one time series"""
        params = self.get_params()
        arima_params = {
            "order": (params.get("p", 0), params.get("d", 0), params.get("q", 0)),
            "seasonal_order": (params.get("P", 0), params.get("D", 0), params.get("Q", 0), params.get("s", 7)),
            "trend": params.get("trend", "n"),
            "trend_offset": params.get("trend_offset", 1),
            "concentrate_scale": params.get("concentrate_scale", False),
            "enforce_invertibility": params.get("enforce_invertibility", True),
            "enforce_stationarity": params.get("enforce_stationarity", True)
        }
        model = ARIMA(
            endog=target_values,
            exog=external_features_values,
            order=arima_params["order"],
            seasonal_order=arima_params["seasonal_order"],
            trend=arima_params["trend"],
            trend_offset=arima_params["trend_offset"],
            concentrate_scale=arima_params["concentrate_scale"],
            enforce_invertibility=arima_params["enforce_invertibility"],
            enforce_stationarity=arima_params["enforce_stationarity"]
        )
        result = model.fit()
        return result

    def _forecast_single_timeseries(
            self,
            trained_model,
            past_target_values,
            past_date_values,
            quantiles,
            past_external_features_values,
            future_external_features_values,
            fit_before_predict,
            prediction_length
    ):
        if fit_before_predict:
            trained_model = self._fit_single(past_target_values, past_date_values, past_external_features_values)
        else:
            trained_model = trained_model.apply(endog=past_target_values, exog=past_external_features_values, refit=False)

        prediction_results = trained_model.get_prediction(
            exog=future_external_features_values,
            start=trained_model.nobs,
            end=trained_model.nobs + prediction_length -1)
        return self._build_forecasts_dict(prediction_results, quantiles)

    def get_coefficients_map_and_names(self):
        variable_coefficients = ["p", "q", "P", "Q"]
        orders_methods = ["arparams", "maparams", "seasonalarparams", "seasonalmaparams"]
        stats_key_for_method = {
            "arparams": "ar.L",
            "maparams": "ma.L",
            "seasonalarparams": "ar.S.L",
            "seasonalmaparams": "ma.S.L",
        }
        external_features_set = set()
        coefficients_map = {}
        for timeseries_identifier, trained_model in self.trained_models.items():
            p, _, q = trained_model.model.order
            P, _, Q, seasonal_period = trained_model.model.seasonal_order
            orders = [p, q, P, Q]
            for order_name, order_value, order_method in zip(variable_coefficients, orders, orders_methods):
                if order_value > 0:  # method cannot be called if order is 0
                    for i, value in enumerate(getattr(trained_model, order_method)):
                        coeff_name = "{}_{}".format(order_name, i+1)
                        if coeff_name not in coefficients_map:
                            coefficients_map[coeff_name] = build_coefficient_dict()
                        coefficients_map[coeff_name]["values"][timeseries_identifier] = value
                        try:
                            param_key = stats_key_for_method[order_method] + (str(i+1) if order_name in {'p', 'q'} else str(seasonal_period*(i+1)))
                            coefficients_map[coeff_name]["pvalues"][timeseries_identifier] = sanitized_stats_value(trained_model.pvalues[param_key])
                            coefficients_map[coeff_name]["tvalues"][timeseries_identifier] = sanitized_stats_value(trained_model.tvalues[param_key])
                            coefficients_map[coeff_name]["stderrs"][timeseries_identifier] = sanitized_stats_value(trained_model.bse[param_key])
                        except KeyError:
                            pass

            external_features_names = trained_model.model.exog_names
            if external_features_names:
                coefficients_by_name = trained_model.params.to_dict()
                for external_feature_name in external_features_names:
                    if external_feature_name not in coefficients_map:
                        coefficients_map[external_feature_name] = build_coefficient_dict()
                        external_features_set.add(external_feature_name)
                    coefficients_map[external_feature_name]["values"][timeseries_identifier] = coefficients_by_name[external_feature_name]
                    coefficients_map[external_feature_name]["pvalues"][timeseries_identifier] = sanitized_stats_value(trained_model.pvalues[external_feature_name])
                    coefficients_map[external_feature_name]["tvalues"][timeseries_identifier] = sanitized_stats_value(trained_model.tvalues[external_feature_name])
                    coefficients_map[external_feature_name]["stderrs"][timeseries_identifier] = sanitized_stats_value(trained_model.bse[external_feature_name])

        external_features_coefficients = sorted(list(external_features_set))

        return coefficients_map, None, variable_coefficients, external_features_coefficients

    def get_fitted_values_and_residuals(self, identifier, df_of_identifier, min_scoring_size):
        """
        For ARIMA models, fitted_values and residuals are computed/stored at train time.
        """
        trained_model = self.trained_models[identifier]
        return trained_model.fittedvalues, trained_model.resid