from typing import Dict, List
from dataiku.eda.types import Literal

import numpy as np
import pandas as pd

from dataiku.eda.computations.computation import UnivariateComputation
from dataiku.eda.computations.context import Context
from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame, RawVector
from dataiku.eda.exceptions import InvalidParams, NotEnoughDataError
from dataiku.eda.types import GuessTimeStepModel, GuessTimeStepResultModel


def delta(n: int, unit: Literal["ms", "s", "m", "h", "d"]) -> int:
    if unit == "ms":
        return n
    elif unit == "s":
        return n * delta(1000, "ms")
    elif unit == "m":
        return n * delta(60, "s")
    elif unit == "h":
        return n * delta(60, "m")
    elif unit == "d":
        return n * delta(24, "h")

    raise ValueError("Unexpected unit {}".format(unit))


class GuessTimeStep(UnivariateComputation):
    def __init__(self, column: str, series_identifiers: List[str]):
        super(GuessTimeStep, self).__init__(column)
        self.series_identifiers = series_identifiers

    @staticmethod
    def get_type() -> Literal["guess_time_step"]:
        return "guess_time_step"

    def describe(self) -> str:
        return "{}(column={}, series_identifiers={})".format(
            self.__class__.__name__,
            self.column,
            self.series_identifiers
        )

    @staticmethod
    def build(params: GuessTimeStepModel) -> 'GuessTimeStep':
        return GuessTimeStep(
            params["column"],
            params["seriesIdentifiers"]
        )

    def _build_series_dataframes(self, idf: ImmutableDataFrame):
        """
        Returns a list of dataframes for each series.
        Only one element if the whole series is used, otherwise as many groups
        as defined by the long format.
        """
        df_data: Dict[str, RawVector] = {
            self.column: idf.date_col(self.column)
        }

        for identifier in self.series_identifiers:
            df_data[identifier] = idf.text_col(identifier)

        dataframe = pd.DataFrame(df_data)

        # filter pd.NaT out of the dataframe: we are not interested in missing timestamps.
        # note: pd.notnull(pd.NaT) == False
        dataframe = dataframe[dataframe[self.column].notnull()]

        if len(self.series_identifiers) == 0:
            return [("single_series", dataframe)]

        return dataframe.groupby(self.series_identifiers)

    def apply(self, idf: ImmutableDataFrame, ctx: Context) -> GuessTimeStepResultModel:
        if self.column in self.series_identifiers:
            raise InvalidParams("The time variable '{}' cannot be used as a series identifier".format(self.column))

        grouped_dfs = self._build_series_dataframes(idf)
        all_steps_ms = []

        for series_id, series_df in grouped_dfs:
            series_timestamps = series_df[self.column]

            if len(series_timestamps) < 2:
                # at most one value in the group => cannot compute any step
                continue

            sorted_timestamps = np.array(series_timestamps.sort_values())
            steps = sorted_timestamps[1:] - sorted_timestamps[:-1]
            # it is easier to manipulate an array of numbers
            steps_ms = steps.astype('timedelta64[ns]') / np.timedelta64(1, "ms")
            all_steps_ms = np.concatenate((all_steps_ms, steps_ms))

        if len(all_steps_ms) < 1:
            error_message = "Not enough data in the series" \
                if len(self.series_identifiers) == 0 \
                else "Not enough data in each individual series. Check the multiple series definition."

            raise NotEnoughDataError(error_message)

        median_step = np.median(all_steps_ms)

        if median_step < delta(990, "ms"):
            time_unit = "MILLISECOND"
            n_time_units = round(median_step)
        elif median_step < delta(1010, "ms"):
            time_unit = "SECOND"
            n_time_units = 1
        elif median_step < delta(59, "s"):
            time_unit = "SECOND"
            n_time_units = round(median_step / delta(1, "s"))
        elif median_step < delta(61, "s"):
            time_unit = "MINUTE"
            n_time_units = 1
        elif median_step < delta(59, "m"):
            time_unit = "MINUTE"
            n_time_units = round(median_step / delta(1, "m"))
        elif median_step < delta(61, "m"):
            time_unit = "HOUR"
            n_time_units = 1
        elif median_step < (delta(1, "d") - delta(10, "m")):
            time_unit = "HOUR"
            n_time_units = round(median_step / delta(1, "h"))
        elif median_step < (delta(1, "d") + delta(10, "m")):
            time_unit = "DAY"
            n_time_units = 1
        elif median_step < (delta(7, "d") - delta(1, "h")):
            time_unit = "DAY"
            n_time_units = round(median_step / delta(1, "d"))
        elif median_step < (delta(7, "d") + delta(1, "h")):
            time_unit = "WEEK"
            n_time_units = 1
        elif median_step < (delta(30, "d") - delta(10, "h")):
            time_unit = "DAY"
            n_time_units = round(median_step / delta(1, "d"))
        elif median_step < (delta(30, "d") + delta(10, "h")):
            time_unit = "MONTH"
            n_time_units = 1
        elif median_step < delta(365 - 2, "d"):
            time_unit = "MONTH"
            n_time_units = round(median_step / delta(30, "d"))
        else:
            time_unit = "YEAR"
            n_time_units = round(median_step / delta(365, "d"))

        return {
            "type": GuessTimeStep.get_type(),
            "timeUnit": time_unit,
            "nTimeUnits": n_time_units,
        }
