import numpy as np
import logging
from statsmodels.tsa.exponential_smoothing.ets import ETSModel

from dataiku.doctor.exception import InvalidModelException
from dataiku.doctor.timeseries.models.statistical.base_estimator import DkuStatisticalEstimator, \
    INFORMATION_CRITERION_TO_DISPLAY_NAME
from dataiku.doctor.timeseries.models.statistical.stats import sanitized_stats_value, build_coefficient_dict


logger = logging.getLogger(__name__)


class DkuETSEstimator(DkuStatisticalEstimator):
    def __init__(
            self,
            frequency,
            time_variable,
            prediction_length,
            target_variable,
            timeseries_identifier_columns,
            error="add",
            trend="add",
            damped_trend="true",
            seasonal="add",
            seasonal_periods=7,
            seed=1337,
            include_unstable="false"
    ):
        super(DkuETSEstimator, self).__init__(
            frequency=frequency,
            time_variable=time_variable,
            prediction_length=prediction_length,
            target_variable=target_variable,
            timeseries_identifier_columns=timeseries_identifier_columns,
        )
        self.error = error
        self.trend = trend
        self.damped_trend = damped_trend
        self.seasonal = seasonal
        self.seasonal_periods = seasonal_periods
        self.seed = seed
        self.include_unstable = include_unstable

    def initialize(self, core_params, modeling_params):
        super(DkuETSEstimator, self).initialize(core_params, modeling_params)
        algo_params = modeling_params["ets_params"]
        self.error = algo_params["error"]
        self.trend = algo_params["trend"]
        self.damped_trend = algo_params["damped_trend"]
        self.seasonal = algo_params["seasonal"]
        self.seasonal_periods = algo_params["seasonal_periods"]
        self.seed = algo_params["seed"]
        self.include_unstable = algo_params["include_unstable"]

    @staticmethod
    def string_to_param(value):
        if value == "none":
            return None
        if value == "false":
            return False
        if value == "true":
            return True
        return value

    def _fit_single(self, target_values, date_values=None, external_features_values=None):
        """Fit one time series"""
        np.random.seed(self.seed)
        try:
            trained_model = ETSModel(
                endog=target_values,
                error=self.string_to_param(self.error),
                trend=self.string_to_param(self.trend),
                damped_trend=self.string_to_param(self.damped_trend),
                seasonal=self.string_to_param(self.seasonal),
                seasonal_periods=self.string_to_param(self.seasonal_periods),
                initialization_method="heuristic",
            ).fit()
        except ValueError:
            raise ValueError("Target values must be strictly positive when using multiplicative error, trend or seasonal component.")
        # Depending on the data, the model may not be able to forecast at all, in those cases, the metrics computed during training are nans.
        if np.isnan(trained_model.aic):
            raise InvalidModelException(
                "The ETS model ({}) did not manage to fit, try to disable unstable models (if applicable) or consider using another algorithm.".format(trained_model.short_name)
            )
        return trained_model

    def _forecast_single_timeseries(
            self,
            trained_model,
            past_target_values,
            past_date_values,
            quantiles,
            past_external_features_values,
            future_external_features_values,
            fit_before_predict,
            prediction_length
    ):
        if fit_before_predict:
            trained_model = self._fit_single(past_target_values)
        else:
            if not np.array_equal(trained_model.endog, past_target_values):
                logger.warning(
                    "Using an ETS model with different target than the one used during training"
                )

        prediction_results = trained_model.get_prediction(
            start=trained_model.model.nobs,
            end=trained_model.model.nobs + prediction_length - 1,
        )
        prediction_results.conf_int = prediction_results.pred_int
        return self._build_forecasts_dict(prediction_results, quantiles)

    def get_coefficients_map_and_names(self):
        coefficients_map = {}
        fixed_coefficients = ["initial_level", "initial_trend", "smoothing_level", "smoothing_trend"]
        for timeseries_identifier, trained_model in self.trained_models.items():
            for coeff_name in fixed_coefficients:
                if coeff_name not in coefficients_map:
                    coefficients_map[coeff_name] = build_coefficient_dict()
                if hasattr(trained_model, coeff_name):
                    coefficients_map[coeff_name]["values"][timeseries_identifier] = getattr(trained_model, coeff_name)
                if coeff_name in trained_model.param_names:
                    param_key = trained_model.param_names.index(coeff_name)
                    coefficients_map[coeff_name]["pvalues"][timeseries_identifier] = sanitized_stats_value(trained_model.pvalues[param_key])
                    coefficients_map[coeff_name]["tvalues"][timeseries_identifier] = sanitized_stats_value(trained_model.tvalues[param_key])
                    coefficients_map[coeff_name]["stderrs"][timeseries_identifier] = sanitized_stats_value(trained_model.bse[param_key])
        return coefficients_map, fixed_coefficients, None, None

    def get_information_criteria(self):
        information_criteria = []
        for criterion_name in ["aic", "bic", "hqic", "llf"]:
            criterion = { "values": {}, "displayName": INFORMATION_CRITERION_TO_DISPLAY_NAME[criterion_name] }
            for timeseries_identifier, trained_model in self.trained_models.items():
                item = getattr(trained_model, criterion_name)
                criterion["values"][timeseries_identifier] = self.prepare_information_criteria(item)
            information_criteria.append(criterion)
        return information_criteria

    def get_fitted_values_and_residuals(self, identifier, df_of_identifier, min_scoring_size):
        trained_model = self.trained_models[identifier]
        return trained_model.fittedvalues, trained_model.resid