import gzip
import logging
from time import time

import pandas as pd
from sklearn.base import BaseEstimator
import numpy as np
from scipy import stats
from statsmodels.tsa.stattools import acf
from statsmodels.stats.diagnostic import acorr_ljungbox

from dataiku.core.dku_pandas_csv import dataframe_to_csv
from dataiku.doctor.timeseries.utils import timeseries_iterator, _groupby_compat, get_random_id
from dataiku.doctor.utils.stats import jarque_bera

logger = logging.getLogger(__name__)


class BaseTimeseriesEstimator(BaseEstimator):
    def __init__(
            self,
            frequency,
            prediction_length,
            time_variable,
            target_variable,
            timeseries_identifiers,
            monthly_day_alignment=None
    ):
        self.frequency = frequency
        self.prediction_length = prediction_length
        self.time_variable = time_variable
        self.target_variable = target_variable
        self.timeseries_identifiers = timeseries_identifiers
        self.monthly_day_alignment = monthly_day_alignment

        self.external_features = None


    def compute_residuals(self, train_full_df, min_scoring_size, residuals_folder_context, progress_handler = None):
        start_residuals = time()
        nb_identifiers = len(train_full_df.groupby(_groupby_compat(self.timeseries_identifiers))) if len(self.timeseries_identifiers) else 1
        identifiers_mapping = {}
        for idx, (timeseries_identifier, df_of_timeseries_identifier) in enumerate(timeseries_iterator(train_full_df, self.timeseries_identifiers), 1):
            logger.info("Processing residuals of identifier %s on full timeseries df of shape %s" % (timeseries_identifier, df_of_timeseries_identifier.shape))
            fitted_values, residuals = self.get_fitted_values_and_residuals(timeseries_identifier, df_of_timeseries_identifier, min_scoring_size)
            identifier_residuals = {
                "target": df_of_timeseries_identifier[self.target_variable].to_list(),
                "dates": df_of_timeseries_identifier[self.time_variable].apply(lambda d: d.strftime('%Y-%m-%dT%H:%M:%S.%fZ')).to_list(),
                "fittedValues": fitted_values if isinstance(fitted_values, list) else fitted_values.to_list(),
                "residuals": residuals if isinstance(residuals, list) else residuals.to_list()
            }
            
            df_residuals = pd.DataFrame.from_dict(identifier_residuals)
            df_residuals = df_residuals.dropna()
            residuals_are_zeros = False
            if df_residuals["residuals"].std() == 0:
                df_residuals["stdResiduals"] = 0
                residuals_are_zeros = True
            else:
                df_residuals["stdResiduals"] = (df_residuals["residuals"] - df_residuals["residuals"].mean()) / df_residuals["residuals"].std()
            df_residuals["theoreticalQuantiles"] = stats.norm.ppf(np.arange(1.0, len(df_residuals["fittedValues"]) + 1) / (len(df_residuals["fittedValues"]) + 1))
            for identifier_col in sorted(self.timeseries_identifiers):
                df_residuals[identifier_col] = df_of_timeseries_identifier.reset_index().loc[df_residuals.index][identifier_col].values

            csv_key = get_random_id()
            identifiers_mapping[timeseries_identifier] = csv_key

            with residuals_folder_context.get_file_path_to_write("residuals-{}.csv.gz".format(csv_key)) as dataset_path:
                dataframe_to_csv(df_residuals[sorted(self.timeseries_identifiers) + ['dates', 'target', 'fittedValues', 'residuals', 'stdResiduals', "theoreticalQuantiles"]], 
                                 dataset_path, 
                                 gzip.open)

            residuals_stats = { "timeseriesIdentifier": timeseries_identifier }

            if not residuals_are_zeros:
                # Those stats should not be defined when residuals are 0 since the computation of auto-correlation
                # ultimately leads to NaN values. (division by 0 std)
                acf_x, confint = acf(df_residuals["residuals"], nlags=10, fft=False, adjusted=False, alpha=0.05)[:2]
                lb = acorr_ljungbox(df_residuals["residuals"], min(df_residuals.shape[0] - 1, 10), return_df=True)
                residuals_stats["acf"] = {
                    "x": list(acf_x),
                    "confint": list(confint),
                }
                residuals_stats.update({
                    "ljungBox": lb["lb_stat"].to_list()[-1],
                    "ljungBoxPValue": lb["lb_pvalue"].to_list()[-1]
                })

            jb, jb_pvalue, skew, kurtosis = jarque_bera(df_residuals["residuals"])
            residuals_stats.update({
                "jarqueBera": jb,
                "jarqueBeraPValue": jb_pvalue,
                "skew": skew,
                "kurtosis": kurtosis,
            })

            residuals_folder_context.write_json("stats-{}.json".format(csv_key), residuals_stats)
            logger.info("{}/{} residuals processed.".format(idx, nb_identifiers))
            if progress_handler:
                progress_handler.set_percentage(int(idx / nb_identifiers * 100))
        residuals_folder_context.write_json("identifiers_mapping.json", identifiers_mapping)
        end_residuals = time()
        logger.info("Computed residuals in {} seconds".format(end_residuals-start_residuals))

    def get_params(self, deep=True):
        # To fix post-training computation for models pickled before the introduction of `monthly_day_alignment`
        if not hasattr(self, "monthly_day_alignment"):
            self.monthly_day_alignment = None
        # To fix post-training computation for models pickled before the homogeneization of `timeseries_identifiers` attribute name.
        if not hasattr(self, "timeseries_identifiers"):
            self.timeseries_identifiers = self.timeseries_identifier_columns
        return super(BaseTimeseriesEstimator, self).get_params(deep=deep)

    def get_fitted_values_and_residuals(self, identifier, df_of_identifier, min_scoring_size):
        """
        Computes residuals for a given timeseries identifier. This is algorithm dependent since
        every algorithm has different requirements regarding minimum sample scoring size and prediction length.
        :param identifier: identifier of the timeseries
        :param df_of_identifier: DataFrame of the identifier
        :param min_scoring_size: Minimum timeseries size required for the model to score
        :return: Tuple of fitted values and residuals
        """
        raise NotImplementedError

    def fit(self, train_df, external_features=None, shift_map=None):
        raise NotImplementedError

    def predict(self, past_df, future_df, quantiles, fit_before_predict=False, prediction_length_override=None):
        raise NotImplementedError

    def predict_single(self, past_df, future_df, quantiles, timeseries_identifier, fit_before_predict=False, prediction_length_override=None):
        raise NotImplementedError
