import sys
import numpy as np
from statsmodels.tsa.stattools import acf
from dataiku.eda.types import Literal

from dataiku.eda.computations.timeseries.time_series_computation import TimeSeriesComputation
from dataiku.eda.exceptions import InvalidParams, NotEnoughDataError


class ACF(TimeSeriesComputation):
    def __init__(self, series_column, time_column, alpha, adjusted, n_lags):
        super(ACF, self).__init__(series_column, time_column)
        self.alpha = alpha
        self.adjusted = adjusted
        self.n_lags = n_lags

    @staticmethod
    def get_type() -> Literal["acf"]:
        return "acf"

    def describe(self):
        return "{}(series_column={}, time_column={}, alpha={}, adjusted={}, n_lags={})".format(
            self.__class__.__name__,
            self.series_column,
            self.time_column,
            self.alpha,
            self.adjusted,
            self.n_lags,
        )

    @staticmethod
    def build(params):
        return ACF(
            params["seriesColumn"],
            params["timeColumn"],
            params["alpha"],
            params["adjusted"],
            params.get("nLags"),
        )

    def apply(self, idf, ctx):
        series, timestamps = self._get_time_series(idf)

        if len(series) < 2:
            raise NotEnoughDataError("At least 2 values are required in the series (current size: {})".format(len(series)))

        if self.alpha < 0 or self.alpha > 0.5:
            raise InvalidParams("alpha must belong to [0,0.5]")

        n_obs = len(series)
        n_lags = self.n_lags

        if n_lags is None:
            # Compute a reasonably good value for n_lags
            n_lags = min(int(10 * np.log10(n_obs)), n_obs - 1)

        if n_lags <= 0:
            raise InvalidParams("n_lags must be greater than 0")

        if n_lags >= n_obs:
            raise InvalidParams("n_lags must be lower than the series size")

        if sys.version_info.major == 2:
            # statsmodels == 0.10 in Python 2 builtin env
            coefficients, confidence_intervals = acf(
                series, nlags=n_lags, fft=True,
                alpha=self.alpha, unbiased=self.adjusted
            )
        else:
            # statsmodels >= 0.12 in all Python 3 builtin envs since DSS 12
            # See https://github.com/statsmodels/statsmodels/commit/5c8dcca793e8d538300a95924ef958f79412b3e4
            coefficients, confidence_intervals = acf(
                series, nlags=n_lags, fft=True,
                alpha=self.alpha, adjusted=self.adjusted
            )

        return {
            "type": self.get_type(),
            "coefficients": coefficients.tolist(),
            "confidenceIntervals": confidence_intervals.tolist(),
        }
