from typing import Dict
from dataiku.eda.types import Literal

import pandas as pd

from dataiku.doctor.timeseries.preparation.resampling import Resampler, supports_monthly_alignment
from dataiku.eda.computations.computation import Computation
from dataiku.eda.computations.context import Context
from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame, RawVector
from dataiku.eda.exceptions import InvalidParams
from dataiku.eda.types import ResampledComputationModel, ResampledComputationResultModel


class ResampledComputation(Computation):
    def __init__(self, computation, series_columns, series_identifiers,
                 time_column, n_units, time_unit,
                 interpolation_method, interpolation_constant_value,
                 extrapolation_method, extrapolation_constant_value,
                 start_date_mode, custom_start_date,
                 end_date_mode, custom_end_date,
                 time_unit_end_of_week,
                 time_unit_end_of_quarter,
                 time_unit_end_of_half_year,
                 time_unit_end_of_year,
                 time_unit_monthly_alignment,
                 duplicate_timestamps_handling_method):
        self.computation = computation
        self.series_columns = series_columns
        self.series_identifiers = series_identifiers
        self.time_column = time_column
        self.n_units = n_units
        self.time_unit = time_unit
        self.interpolation_method = interpolation_method
        self.interpolation_constant_value = interpolation_constant_value
        self.extrapolation_method = extrapolation_method
        self.extrapolation_constant_value = extrapolation_constant_value
        self.start_date_mode = start_date_mode
        self.custom_start_date = custom_start_date
        self.end_date_mode = end_date_mode
        self.custom_end_date = custom_end_date
        self.time_unit_end_of_week = time_unit_end_of_week
        self.time_unit_end_of_quarter = time_unit_end_of_quarter
        self.time_unit_end_of_half_year = time_unit_end_of_half_year
        self.time_unit_end_of_year = time_unit_end_of_year
        self.time_unit_monthly_alignment = time_unit_monthly_alignment
        self.duplicate_timestamps_handling_method = duplicate_timestamps_handling_method

    @staticmethod
    def get_type() -> Literal["resampled"]:
        return "resampled"

    def describe(self) -> str:
        parameters = [
            ("series_columns", self.series_columns),
            ("time_column", self.time_column),
            ("n_units", self.n_units),
            ("time_unit", self.time_unit),
            ("interpolation_method", self.interpolation_method),
            ("extrapolation_method", self.extrapolation_method),
            ("duplicate_timestamps_handling_method", self.duplicate_timestamps_handling_method)
        ]

        parameters_desc = ", ".join("{}={}".format(k, v) for (k, v) in parameters)
        return "{}({})".format(self.__class__.__name__, parameters_desc)

    @staticmethod
    def build(params: ResampledComputationModel) -> 'ResampledComputation':
        spec = params["spec"]
        settings = spec["settings"]

        return ResampledComputation(
            Computation.build(params["computation"]),
            spec["seriesColumns"],
            spec["seriesIdentifiers"],
            spec["timeColumn"],
            settings["nUnits"],
            settings["timeUnit"],
            settings["interpolationMethod"],
            settings["interpolationConstantValue"],
            settings["extrapolationMethod"],
            settings["extrapolationConstantValue"],
            settings["startDateMode"],
            settings["customStartDate"],
            settings["endDateMode"],
            settings["customEndDate"],
            settings["timeUnitEndOfWeek"],
            settings["timeUnitEndOfQuarter"],
            settings["timeUnitEndOfHalfYear"],
            settings["timeUnitEndOfYear"],
            settings["timeUnitMonthlyAlignment"],
            settings["duplicateTimestampsHandlingMethod"]
        )

    @staticmethod
    def _require_result_checking() -> bool:
        return False

    def _get_end_of_week(self):
        enum_mapping = {
            "MONDAY": "MON",
            "TUESDAY": "TUE",
            "WEDNESDAY": "WED",
            "THURSDAY": "THU",
            "FRIDAY": "FRI",
            "SATURDAY": "SAT",
            "SUNDAY": "SUN",
        }

        if self.time_unit_end_of_week not in enum_mapping:
            raise InvalidParams("Unknown end of week: {}".format(self.time_unit_end_of_week))

        return enum_mapping[self.time_unit_end_of_week]

    def _get_monthly_alignment(self):
        if not supports_monthly_alignment(self.time_unit):
            return None

        if self.time_unit_monthly_alignment < 0 or self.time_unit_monthly_alignment > 31:
            raise InvalidParams("Invalid day of the month: {}".format(self.time_unit_monthly_alignment))

        return self.time_unit_monthly_alignment

    def _get_unit_alignment(self):
        if self.time_unit == "QUARTER":
            enums = [
                "JAN_APR_JUL_OCT",
                "FEB_MAY_AUG_NOV",
                "MAR_JUN_SEP_DEC"
            ]

            if self.time_unit_end_of_quarter not in enums:
                raise InvalidParams("Unknown end of quarter: {}".format(self.time_unit_end_of_quarter))

            return enums.index(self.time_unit_end_of_quarter) + 1

        if self.time_unit == "HALF_YEAR":
            enums = [
                "JAN_JUL",
                "FEB_AUG",
                "MAR_SEP",
                "APR_OCT",
                "MAY_NOV",
                "JUN_DEC"
            ]

            if self.time_unit_end_of_half_year not in enums:
                raise InvalidParams("Unknown end of half year: {}".format(self.time_unit_end_of_half_year))

            return enums.index(self.time_unit_end_of_half_year) + 1

        if self.time_unit == "YEAR":
            enums = [
                "JANUARY", "FEBRUARY", "MARCH", "APRIL", "MAY", "JUNE",
                "JULY", "AUGUST", "SEPTEMBER", "OCTOBER", "NOVEMBER", "DECEMBER"
            ]

            if self.time_unit_end_of_year not in enums:
                raise InvalidParams("Unknown end of year: {}".format(self.time_unit_end_of_year))

            return enums.index(self.time_unit_end_of_year) + 1

        return None

    def _build_resampler(self) -> Resampler:
        return Resampler(
            time_step=self.n_units,
            time_unit=self.time_unit,
            interpolation_method=self.interpolation_method,
            interpolation_constant_value=self.interpolation_constant_value,
            extrapolation_method=self.extrapolation_method,
            extrapolation_constant_value=self.extrapolation_constant_value,
            start_date_mode=self.start_date_mode,
            custom_start_date=self.custom_start_date,
            end_date_mode=self.end_date_mode,
            custom_end_date=self.custom_end_date,
            time_unit_end_of_week=self._get_end_of_week(),
            duplicate_timestamps_handling_method=self.duplicate_timestamps_handling_method,
            unit_alignment=self._get_unit_alignment(),
            monthly_alignment=self._get_monthly_alignment(),
        )

    def _build_dataframe(self, idf: ImmutableDataFrame):
        df_data: Dict[str, RawVector] = {
            self.time_column: idf.date_col(self.time_column)
        }

        for identifier in self.series_identifiers:
            df_data[identifier] = idf.text_col(identifier)

        for series_column in self.series_columns:
            df_data[series_column] = idf.float_col(series_column)

        return pd.DataFrame(df_data)

    def apply(self, idf: ImmutableDataFrame, ctx: Context) -> ResampledComputationResultModel:
        if self.time_column in self.series_identifiers:
            raise InvalidParams("The time variable '{}' cannot be used as a series identifier".format(self.time_column))

        variables_identifiers_intersection = [value for value in self.series_columns if value in self.series_identifiers]
        if len(variables_identifiers_intersection) > 0:
            raise InvalidParams(
                "The series variable '{}' cannot be used as a series identifier".format(
                    variables_identifiers_intersection[0]
                )
            )

        with ctx.sub("Resample"):
            resampler = self._build_resampler()
            dataframe = self._build_dataframe(idf)

            resampled_dataframe = resampler.transform(
                dataframe,
                self.time_column,
                timeseries_identifier_columns=self.series_identifiers,
                numerical_columns=self.series_columns
            )

            resampled_idf = ImmutableDataFrame.from_df(resampled_dataframe)

        with ctx.sub("Compute"):
            computation_result = self.computation.apply_safe(resampled_idf, ctx)

        return {
            "type": self.get_type(),
            "result": computation_result,
        }
