import numbers

import numpy as np
import pandas as pd
import logging

from prophet import Prophet
from prophet.utilities import regressor_coefficients

from dataiku.doctor.timeseries.utils import future_date_range, timeseries_iterator
from dataiku.doctor.timeseries.utils import ModelForecast
from dataiku.doctor.timeseries.models.statistical.base_estimator import DkuStatisticalEstimator
from dataiku.doctor.timeseries.models.statistical.stats import build_coefficient_dict


logger = logging.getLogger(__name__)


SEASONALITY_ARG = {
    "auto": "auto",
    "always": True,
    "never": False,
}


class DkuProphetEstimator(DkuStatisticalEstimator):
    def __init__(
        self,
        frequency,
        time_variable,
        prediction_length,
        target_variable,
        timeseries_identifier_columns,
        growth,
        floor,
        cap,
        n_changepoints,
        changepoint_range,
        yearly_seasonality,
        weekly_seasonality,
        daily_seasonality,
        seed,
        seasonality_mode="additive",
        seasonality_prior_scale=10.0,
        changepoint_prior_scale=0.05,
        holidays_prior_scale=10.0,
        monthly_day_alignment=None,
    ):
        super(DkuProphetEstimator, self).__init__(
            frequency=frequency,
            time_variable=time_variable,
            prediction_length=prediction_length,
            target_variable=target_variable,
            timeseries_identifier_columns=timeseries_identifier_columns,
            monthly_day_alignment=monthly_day_alignment,
        )

        self.growth = growth
        self.floor = floor
        self.cap = cap
        self.n_changepoints = n_changepoints
        self.changepoint_range = changepoint_range
        self.yearly_seasonality = yearly_seasonality
        self.weekly_seasonality = weekly_seasonality
        self.daily_seasonality = daily_seasonality
        self.seed = seed

        # searchable parameters
        self.seasonality_mode = seasonality_mode
        self.seasonality_prior_scale = seasonality_prior_scale
        self.changepoint_prior_scale = changepoint_prior_scale
        self.holidays_prior_scale = holidays_prior_scale
        

    def initialize(self, core_params, modeling_params):
        """ Params added after 12.1 don't need to be initialized here because they will always be serialized. """
        super(DkuProphetEstimator, self).initialize(core_params, modeling_params)
        algo_params = modeling_params["prophet_timeseries_params"]

        self.growth = algo_params["growth"]
        self.n_changepoints = algo_params["n_changepoints"]
        self.changepoint_range = algo_params["changepoint_range"]
        self.yearly_seasonality = algo_params["yearly_seasonality"]
        self.weekly_seasonality = algo_params["weekly_seasonality"]
        self.daily_seasonality = algo_params["daily_seasonality"]
        self.seed = algo_params["seed"]
        self.seasonality_mode = algo_params["seasonality_mode"]
        self.seasonality_prior_scale = algo_params["seasonality_prior_scale"]
        self.changepoint_prior_scale = algo_params["changepoint_prior_scale"]
        self.holidays_prior_scale = algo_params["holidays_prior_scale"]

    def _fit_single(self, target_values, date_values=None, external_features_values=None):
        """Fit one time series"""
        np.random.seed(self.seed)

        prophet_df = pd.DataFrame({
            "ds": date_values,
            "y": target_values,
        })

        # - prophet uses holidays_prior_scale by default to fit external features if no prior_scale is provided in the add_regressor methods
        # - in our case, holidays_prior_scale is only used with external features because we don't provide country holidays

        model = Prophet(
            growth=self.growth,
            n_changepoints=self.n_changepoints,
            changepoint_range=self.changepoint_range,
            yearly_seasonality=self._parse_seasonality(self.yearly_seasonality),
            weekly_seasonality=self._parse_seasonality(self.weekly_seasonality),
            daily_seasonality=self._parse_seasonality(self.daily_seasonality),
            seasonality_mode=self.seasonality_mode,
            seasonality_prior_scale=self.seasonality_prior_scale,
            changepoint_prior_scale=self.changepoint_prior_scale,
            holidays_prior_scale=self.holidays_prior_scale,
        )

        if self.growth == "logistic":
            prophet_df["floor"] = self.floor
            prophet_df["cap"] = self.cap

        if external_features_values is not None:
            prophet_df = pd.concat([prophet_df, external_features_values], axis=1)
            for external_feature_column in external_features_values:
                model.add_regressor(external_feature_column)

        model.fit(prophet_df)

        return model

    def _forecast_single_timeseries(
        self,
        trained_model,
        past_target_values,
        past_date_values,
        quantiles,
        past_external_features_values,
        future_external_features_values,
        fit_before_predict,
        prediction_length
    ):
        np.random.seed(self.seed)

        if fit_before_predict:
            # instantiate and fit a Prophet model with the same hyperparameters as the trained model used during training
            trained_model = self._fit_single(past_target_values, past_date_values, past_external_features_values)

        last_past_date = past_date_values.iloc[-1]

        forecast_dates = future_date_range(
            last_past_date,
            prediction_length,
            self.frequency,
            self.monthly_day_alignment,
        )
        forecast_df = pd.DataFrame({'ds': forecast_dates})

        if self.growth == "logistic":
            forecast_df["floor"] = self.floor
            forecast_df["cap"] = self.cap

        if future_external_features_values is not None:
            forecast_df = pd.concat([forecast_df, future_external_features_values], axis=1)

        # "yhat" is the column containing the average forecast values
        forecast_values = trained_model.predict(forecast_df)["yhat"].to_numpy()

        samples = trained_model.predictive_samples(forecast_df)["yhat"]
        quantiles_forecasts = np.array([np.percentile(samples, 100 * quantile, axis=1) for quantile in quantiles])

        return {
            ModelForecast.FORECAST_VALUES: forecast_values,
            ModelForecast.QUANTILES_FORECASTS: quantiles_forecasts,
        }

    @staticmethod
    def _parse_seasonality(seasonality_arg):
        if seasonality_arg not in SEASONALITY_ARG:
            raise ValueError("Seasonality parameter must be in {}".format(list(SEASONALITY_ARG)))
        return SEASONALITY_ARG[seasonality_arg]

    def get_coefficients_map_and_names(self):    
        fixed_coefficients = ["k", "m"]
        variable_coefficients = ["delta", "yearly", "weekly", "daily"]

        coefficients_map = {}
        external_features_set = set()
        for timeseries_identifier, trained_model in self.trained_models.items():
            for coeff_name in fixed_coefficients:
                if coeff_name not in coefficients_map:
                    coefficients_map[coeff_name] = build_coefficient_dict()
                coeff = trained_model.params[coeff_name][0]
                if isinstance(coeff, numbers.Number): # Happens when the target is constant
                    coefficients_map[coeff_name]["values"][timeseries_identifier] = coeff
                else: # TODO timeseries : understand whether the length of the array is always 1
                    coefficients_map[coeff_name]["values"][timeseries_identifier] = coeff[0]
            for i, value in enumerate(trained_model.params["delta"][0]):
                coeff_name = "delta_{}".format(i+1)
                if coeff_name not in coefficients_map:
                    coefficients_map[coeff_name] = build_coefficient_dict()
                coefficients_map[coeff_name]["values"][timeseries_identifier] = value

            beta_index = 0
            for seasonality in ["yearly", "weekly", "daily"]:
                if seasonality in trained_model.seasonalities:
                    # each seasonality contains 2 * seasonality_fourier_order fourier coefficients that are ordered in the 'beta' parameter (yearly, then weekly, then daily)  
                    seasonality_coefficients_number = 2 * trained_model.seasonalities[seasonality]["fourier_order"]
                    for i, value in enumerate(trained_model.params["beta"][0][beta_index:beta_index + seasonality_coefficients_number]):
                        coeff_name = "{}_{}".format(seasonality, i+1)
                        if coeff_name not in coefficients_map:
                            coefficients_map[coeff_name] = build_coefficient_dict()
                        coefficients_map[coeff_name]["values"][timeseries_identifier] = value
                    beta_index += seasonality_coefficients_number

            if trained_model.extra_regressors:
                data = regressor_coefficients(trained_model)

                for row in data.itertuples(index=True):
                    external_feature_name = row.regressor
                    coef_value = row.coef

                    if external_feature_name not in coefficients_map:
                        coefficients_map[external_feature_name] = build_coefficient_dict()
                        external_features_set.add(external_feature_name)

                    coefficients_map[external_feature_name]["values"][timeseries_identifier] = coef_value

        external_features_coefficients = sorted(list(external_features_set))

        return coefficients_map, fixed_coefficients, variable_coefficients, external_features_coefficients

    def get_fitted_values_and_residuals(self, identifier, df_of_identifier, min_scoring_size):
        """
        For Prophet models, we compute a prediction for every possible timestep in the historical data.
        This is a simplified version of `_forecast_single_timeseries`
        """
        trained_model = self.trained_models[identifier]
        forecast_df = pd.DataFrame({'ds': df_of_identifier[self.time_variable]})
        if self.growth == "logistic":
            forecast_df["floor"] = self.floor
            forecast_df["cap"] = self.cap
        if self.external_features.get(identifier) is not None:
            forecast_df = pd.concat([forecast_df, df_of_identifier[self.external_features[identifier]]], axis=1)
        fitted_values = trained_model.predict(forecast_df)["yhat"]
        residuals = df_of_identifier[self.target_variable].reset_index(drop=True) - fitted_values

        return fitted_values, residuals