import numpy as np
from statsmodels.tsa.stattools import pacf
from dataiku.eda.types import Literal

from dataiku.eda.computations.timeseries.time_series_computation import TimeSeriesComputation
from dataiku.eda.exceptions import InvalidParams, NotEnoughDataError


class PACF(TimeSeriesComputation):
    _DKU_TO_SM_METHODS = {
        "YULE_WALKER": "yw_mle",
        "OLS": "ols",
        "OLS_UNBIASED": "ols-adjusted",
        "LEVINSON_DURBIN": "ld_biased",
    }

    def __init__(self, series_column, time_column, alpha, method, n_lags):
        super(PACF, self).__init__(series_column, time_column)
        self.alpha = alpha
        self.method = method
        self.n_lags = n_lags

    @staticmethod
    def get_type() -> Literal["pacf"]:
        return "pacf"

    def describe(self):
        return "{}(series_column={}, time_column={}, alpha={}, method={}, n_lags={})".format(
            self.__class__.__name__,
            self.series_column,
            self.time_column,
            self.alpha,
            self.method,
            self.n_lags,
        )

    @staticmethod
    def build(params):
        return PACF(
            params["seriesColumn"],
            params["timeColumn"],
            params["alpha"],
            params["method"],
            params.get("nLags"),
        )

    def apply(self, idf, ctx):
        series, timestamps = self._get_time_series(idf)

        if len(series) < 2:
            raise NotEnoughDataError("At least 2 values are required in the series (current size: {})".format(len(series)))

        if self.alpha < 0 or self.alpha > 0.5:
            raise InvalidParams("alpha must belong to [0,0.5]")

        n_obs = len(series)
        n_lags = self.n_lags

        if n_lags is None:
            # Compute a reasonably good value for n_lags
            n_lags = min(int(10 * np.log10(n_obs)), n_obs - 1)

        if n_lags <= 0:
            raise InvalidParams("n_lags must be greater than 0")

        if n_lags >= n_obs:
            raise InvalidParams("n_lags must be lower than the series size")

        if self.method not in PACF._DKU_TO_SM_METHODS:
            raise Exception("Unknown method ({})".format(self.method))

        method = PACF._DKU_TO_SM_METHODS[self.method]

        coefficients, confidence_intervals = pacf(series, nlags=n_lags, alpha=self.alpha, method=method)

        return {
            "type": self.get_type(),
            "coefficients": coefficients.tolist(),
            "confidenceIntervals": confidence_intervals.tolist(),
        }
