from typing import List, Optional
from dataiku.eda.types import Literal

import pandas as pd
from pandas.tseries.offsets import CustomBusinessDay, MonthEnd

from dataiku.eda.computations.computation import UnivariateComputation
from dataiku.eda.computations.context import Context
from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame
from dataiku.eda.exceptions import NotEnoughDataError, InvalidParams
from dataiku.eda.types import MatchTimeStepModel, MatchTimeStepResultModel, Step


def _find_possible_offsets(timestamps):
    step = timestamps[1] - timestamps[0]
    candidates = []

    if step.days >= 365:
        n_years = int(round(step.days / 365.))
        candidates.append(pd.DateOffset(years=n_years))

    if step.days >= 28:
        n_months = int(round(step.days / 30.))
        candidates.append(pd.DateOffset(months=n_months))
        candidates.append(MonthEnd(n=n_months))

    if 0 < step.days < 7:
        candidates.append(CustomBusinessDay(weekmask="Mon Tue Wed Thu Fri"))
        candidates.append(CustomBusinessDay(weekmask="Sun Mon Tue Wed Thu"))
        candidates.append(CustomBusinessDay(weekmask="Sat Sun"))
        candidates.append(CustomBusinessDay(weekmask="Fri Sat"))

    return candidates


def _has_regular_offset(timestamps, offset):
    start = timestamps[0]
    for i, ts in enumerate(timestamps[1:], start=1):
        candidate_ts = start + i * offset
        if candidate_ts != ts:
            return False

    # all timestamps matched the offset
    return True


def _detect_offset(timestamps):
    offsets = _find_possible_offsets(timestamps)
    for offset in offsets:
        if _has_regular_offset(timestamps, offset):
            return offset

    return None


def _early_detect_time_step(timestamps) -> Optional[Step]:
    offset = _detect_offset(timestamps)

    if isinstance(offset, CustomBusinessDay):
        return {
            "type": "business_days_step",
            "weekMask": offset.weekmask,
        }

    if isinstance(offset, pd.DateOffset) and hasattr(offset, "years"):
        iso_year_duration = "P{}Y".format(int(offset.years))
        return {
            "type": "iso_duration_step",
            "duration": iso_year_duration,
        }

    if isinstance(offset, pd.DateOffset) and hasattr(offset, "months"):
        iso_month_duration = "P{}M".format(int(offset.months))
        return {
            "type": "iso_duration_step",
            "duration": iso_month_duration,
        }
    
    if isinstance(offset, MonthEnd):
        iso_month_duration = "P{}M".format(int(offset.n))
        return {
            "type": "iso_duration_step",
            "duration": iso_month_duration,
        }

    return None


class MatchTimeStep(UnivariateComputation):
    def __init__(self, column: str, max_steps: int):
        super(MatchTimeStep, self).__init__(column)
        self.max_steps = max_steps

    @staticmethod
    def get_type() -> Literal["match_time_step"]:
        return "match_time_step"

    def describe(self) -> str:
        return "{}(column={}, max_steps={})".format(
            self.__class__.__name__,
            self.column,
            self.max_steps,
        )

    @staticmethod
    def build(params: MatchTimeStepModel) -> 'MatchTimeStep':
        return MatchTimeStep(
            params["column"],
            params["maxSteps"],
        )

    def apply(self, idf: ImmutableDataFrame, ctx: Context) -> MatchTimeStepResultModel:
        if self.max_steps <= 0:
            raise InvalidParams("max_steps must be > 0")

        warnings = []
        if pd.isnull(idf.date_col(self.column)).any():
            warnings.append(
                "The time variable '{}' has missing values. "
                "It is strongly advised to select another time variable or select an interpolation/extrapolation method in the resampling settings.".format(self.column)
            )

        dates = idf.date_col_no_missing(self.column)

        if len(dates) < 2:
            raise NotEnoughDataError("At least two values are required to compute the step")

        # dates is of type pd.DateTimeIndex
        sorted_dates = dates.sort_values()

        # try to early detect if the period is fairly simple
        detected_step = _early_detect_time_step(sorted_dates)

        if detected_step is not None:
            return {
                "type": MatchTimeStep.get_type(),
                "steps": [detected_step],
                "nSteps": 1,
                "warnings": warnings,
            }

        # no easy detection, let's do it the hard way
        # all_steps is of type pd.TimedeltaIndex
        all_steps = sorted_dates[1:] - sorted_dates[:-1]
        steps = all_steps.drop_duplicates()

        # format the time steps according to ISO8601
        iso_durations = [step.isoformat() for step in steps]

        if "P0DT0H0M0S" in iso_durations:
            # the zero-length time step indicates something fishy
            warnings.append(
                "The time variable '{}' contains conflicting duplicate timestamps. "
                "It is strongly advised to select another time variable or select a duplicate timestamps handling method in the resampling settings.".format(self.column)
            )
        elif len(steps) > 1:
            warnings.append(
                "The time variable '{}' does not have a regular time step. "
                "It is strongly advised to resample the series first or select another time variable.".format(self.column)
            )

        # limit the payload to the 3 first detected steps
        # to avoid transferring too much information
        detected_steps: List[Step] = [{
            "type": "iso_duration_step",
            "duration": iso_duration
        } for iso_duration in iso_durations[:self.max_steps]]

        return {
            "type": MatchTimeStep.get_type(),
            "steps": detected_steps,
            "nSteps": len(steps),
            "warnings": warnings,
        }
