import numpy as np
import pandas as pd
from dataiku.eda.types import Literal

from dataiku.eda.computations.timeseries.time_series_computation import TimeSeriesComputation, format_iso8601
from dataiku.eda.exceptions import EdaComputeError
from dataiku.eda.exceptions import InvalidParams


class STLDecomposition(TimeSeriesComputation):
    def __init__(self, series_column, time_column, seasonal, period, trend, low_pass, decomposition_type, robust,
                 seasonal_deg, trend_deg, low_pass_deg, seasonal_jump, trend_jump, low_pass_jump):
        super(STLDecomposition, self).__init__(series_column, time_column)
        self.seasonal = seasonal
        self.period = period
        self.trend = trend
        self.low_pass = low_pass
        self.decomposition_type = decomposition_type
        self.robust = robust
        self.seasonal_deg = seasonal_deg
        self.trend_deg = trend_deg
        self.low_pass_deg = low_pass_deg
        self.seasonal_jump = seasonal_jump
        self.trend_jump = trend_jump
        self.low_pass_jump = low_pass_jump

    @staticmethod
    def get_type() -> Literal["stl_decomposition"]:
        return "stl_decomposition"

    def describe(self):
        return "{}(series_column={}, time_column={}, seasonal={}, period={}, trend={}, low_pass={}, robust={}," \
               " seasonal_deg={}, trend_deg={}, low_pass_deg={}, seasonal_jump={}, trend_jump={}, low_pass_jump={})".format(
            self.__class__.__name__,
            self.series_column,
            self.time_column,
            self.seasonal,
            self.period,
            self.trend,
            self.low_pass,
            self.robust,
            self.seasonal_deg,
            self.trend_deg,
            self.low_pass_deg,
            self.seasonal_jump,
            self.trend_jump,
            self.low_pass_jump
        )

    @staticmethod
    def build(params):
        return STLDecomposition(
            params["seriesColumn"],
            params["timeColumn"],
            params["params"]["seasonal"],
            params["params"].get("period"),
            params["params"].get("trend"),
            params["params"].get("lowPass"),
            params["params"]["decompositionType"],
            params["params"]["robust"],
            params["params"]["seasonalDeg"],
            params["params"]["trendDeg"],
            params["params"]["lowPassDeg"],
            params["params"]["seasonalJump"],
            params["params"]["trendJump"],
            params["params"]["lowPassJump"]
        )

    def _get_degree_param(self, key):
        params = {
            "CONSTANT": 0,
            "CONSTANT_WITH_TREND": 1,
        }

        if self.seasonal_deg not in params:
            raise InvalidParams("Unknown degree mode for: " + key)

        return params[key]

    def apply(self, idf, ctx):
        series, timestamps = self._get_time_series(idf)

        if self.decomposition_type == "ADDITIVE":
            preprocessed_series = series
        elif self.decomposition_type == "MULTIPLICATIVE":
            preprocessed_series = np.log(series)
        else:
            raise InvalidParams("Unknown decomposition type")

        try:
            from statsmodels.tsa.seasonal import STL
        except ImportError:
            raise EdaComputeError("statsmodels.tsa.seasonal.STL is not available")

        try:
            stl = STL(
                pd.Series(preprocessed_series, index=timestamps),
                period=self.period,
                trend=self.trend,
                low_pass=self.low_pass,
                seasonal=self.seasonal,
                robust=self.robust,
                seasonal_deg=self._get_degree_param(self.seasonal_deg),
                trend_deg=self._get_degree_param(self.trend_deg),
                low_pass_deg=self._get_degree_param(self.low_pass_deg),
                seasonal_jump=self.seasonal_jump,
                trend_jump=self.trend_jump,
                low_pass_jump=self.low_pass_jump
            )
        except ValueError as e:
            if ("period must be a positive integer >= 2" in "{}".format(e) or "Unable to determine period from endog" in "{}".format(e)) and self.period is None:
                raise InvalidParams("The period guessing failed with error: \"{}\". Please provide a period in the configuration.".format(e))
            else:
                raise e

        decomposition = stl.fit()
        time_iso = timestamps.to_series().apply(format_iso8601)

        result = {
            "type": self.get_type(),
            "observed": decomposition.observed.tolist(),
            "time": time_iso.tolist(),
            "guessedParams": {
                "period": stl.config["period"],
                "lowPassSmoother": stl.config["low_pass"],
                "trendSmoother": stl.config["trend"],
            },
        }

        if self.decomposition_type == "ADDITIVE":
            result["trend"] = decomposition.trend.tolist()
            result["seasonal"] = decomposition.seasonal.tolist()
            result["resid"] = decomposition.resid.tolist()

        elif self.decomposition_type == "MULTIPLICATIVE":
            result["trend"] = np.exp(decomposition.trend).tolist()
            result["seasonal"] = np.exp(decomposition.seasonal).tolist()
            result["resid"] = np.exp(decomposition.resid).tolist()

        return result
