from dataiku.eda.types import Literal

from dataiku.eda.computations.timeseries.time_series_computation import TimeSeriesComputation
from dataiku.eda.exceptions import NotEnoughDataError, InvalidParams
from dataiku.vendor.pymannkendall import original_test as mann_kendall_test


class MannKendallTest(TimeSeriesComputation):
    def __init__(self, series_column, time_column, alpha):
        super(MannKendallTest, self).__init__(series_column, time_column)
        self.alpha = alpha

    @staticmethod
    def get_type() -> Literal["mann_kendall"]:
        return "mann_kendall"

    def describe(self):
        return "{}(series_column={}, time_column={}, alpha={})".format(
            self.__class__.__name__,
            self.series_column,
            self.time_column,
            self.alpha,
        )

    @staticmethod
    def build(params):
        return MannKendallTest(
            params["seriesColumn"],
            params["timeColumn"],
            params["alpha"],
        )

    def apply(self, idf, ctx):
        series, timestamps = self._get_time_series(idf)

        if len(series) < 2:
            raise NotEnoughDataError("At least 2 values are required in the series (current size: {})".format(len(series)))

        if self.alpha < 0 or self.alpha > 1:
            raise InvalidParams("alpha must belong to [0,1]")

        res = mann_kendall_test(series, self.alpha)

        trends = {
            "no trend": "NO_TREND",
            "increasing": "INCREASING",
            "decreasing": "DECREASING",
        }

        if res.trend not in trends:
            raise ValueError("Unexpected value for trend: {}".format(res.trend))

        return {
            "type": self.get_type(),
            "statistic": res.z,
            "pValue": res.p,
            "trend": trends[res.trend],
            "tau": res.Tau,
            "score": res.s,
            "variance": res.var_s,
            "slope": res.slope,
            "intercept": res.intercept,
        }
