import logging
from typing import Tuple

import pandas as pd
import numpy as np

from dataiku.core.doctor_constants import PREDICTION_LENGTH
from dataiku.core.doctor_constants import TIMESERIES_IDENTIFIER_COLUMNS
from dataiku.core.doctor_constants import TIME_VARIABLE
from dataiku.core.doctor_constants import TEST_SIZE
from dataiku.core.doctor_constants import EVALUATION_PARAMS
from dataiku.core.doctor_constants import CUSTOM_TRAIN_TEST_INTERVALS
from dataiku.doctor.diagnostics.diagnostics import DiagnosticType
from dataiku.doctor.timeseries.utils import add_ignored_timeseries_diagnostics_and_logs, _groupby_compat
from dataiku.doctor.timeseries.utils import encode_timeseries_identifier
from dataiku.doctor.timeseries.utils import pretty_timeseries_identifiers
from dataiku.doctor.timeseries.utils import SINGLE_TIMESERIES_IDENTIFIER


logger = logging.getLogger(__name__)

class AbstractTimeseriesSplitHandler(object):
    """ Set of class creating splits of a dataframe. A split is a tuple with a train and a test dataset.
    Attributes:
        timeseries_identifier_columns (list): List of time series identifier column names
        time_column (str): Name of the time variable column
        prediction_length (int): Length of the forecast horizon
        test_size (int): Length of the test set
    """
    def __init__(self, core_params):
        self.timeseries_identifier_columns = core_params[TIMESERIES_IDENTIFIER_COLUMNS]
        self.time_column = core_params[TIME_VARIABLE]
        self.prediction_length = core_params[PREDICTION_LENGTH]
        self.test_size = core_params[EVALUATION_PARAMS][TEST_SIZE]

        self._time_steps_ranks = None

    def prepare_split_dataframe(self, df, required_timeseries_training_size, skip_too_short_timeseries=False):
        """
        Prepare dataframe for splitting (checking time series sizes and potentially skipping
        those that are too short)

        :param df: Input dataframe to be split. /!\ HAS TO BE SORTED BY TIME /!\
        :type df: pd.DataFrame
        :param required_timeseries_training_size: Minimum required size for the train set of the smallest split
               for any time series.
        :type required_timeseries_training_size: int
        :param skip_too_short_timeseries: Whether to skip time series that are too short for training or not
        :type skip_too_short_timeseries: boolean
        :return: The prepared dataframe to use in `AbstractSplitHandler.split`
        :rtype: pd.DataFrame
        """
        raise NotImplementedError

    def _skip_too_short_timeseries(self, df, min_required_length_for_split, too_short_timeseries_identifiers_with_reasons, kept_timeseries_identifier_values):
        """
        Filter the full dataframe based on valid and invalid (too short) identifiers while adding relevant diagnostics and logging.
        :param df: the full timeseries dataframe
        :type df: pd.DataFrame
        :param min_required_length_for_split: the minimum required length
        :type min_required_length_for_split: int
        :param too_short_timeseries_identifiers_with_reasons: the list of identifiers to drop
        :type too_short_timeseries_identifiers_with_reasons: list[str]
        :param kept_timeseries_identifier_values: the list of identifiers to keep
        :type kept_timeseries_identifier_values: list[str]
        :return: the filtered timeseries dataframe
        :rtype: pd.DataFrame
        """
        add_ignored_timeseries_diagnostics_and_logs(
            timeseries_identifier_columns=self.timeseries_identifier_columns,
            unseen_timeseries_identifiers=[],
            too_short_timeseries_identifiers=[ts[0] for ts in too_short_timeseries_identifiers_with_reasons],
            all_timeseries_ignored=len(kept_timeseries_identifier_values) == 0,
            min_required_length=min_required_length_for_split,
            recipe_type="training",
            diagnostic_type=DiagnosticType.ML_DIAGNOSTICS_DATASET_SANITY_CHECKS
        )

        if self.timeseries_identifier_columns and too_short_timeseries_identifiers_with_reasons:
            # We filter the dataframe keeping only time series which are long enough
            return df[df.set_index(self.timeseries_identifier_columns).index.isin(kept_timeseries_identifier_values)]

        return df

    def _retrieve_short_timeseries_identifiers_with_reasons(self, df, min_required_size):
        """
        Iterate over the dataframe timeseries and gather the valid and valid identifiers.
        :param df: the full timeseries dataframe
        :type df: pd.DataFrame
        :param min_required_size: optional minimum required size.
        :rtype min_required_size: Optional[int]
        :return: A tuple containing the list of valid and invalid (too short) identifiers. For invalid indentifiers it contains a tuple with the identifier name, the type of the invalid ds (train/test) and the required size.
        :rtype: Tuple[list[str], list[Tuple[str, str, int]]
        """
        raise NotImplementedError

    def _generate_error_message_cta(self):
        """
        Adds a call-to-action message to the error message depending on the timeseries split configuration.
        :return: str: the call-to-action error message.
        """
        raise NotImplementedError

    def filter_timeseries(self, df, required_timeseries_size = None, skip_too_short_timeseries=False):
        """
        Filters the full input dataframe with potential too short timeseries if `skip_too_short_timeseries` == `True`. Otherwise, it raises an exception.
        :param df: the full input dataframe
        :type df: pd.DataFrame
        :param required_timeseries_size: the minimum required timeseries size
        :type required_timeseries_size: Optional[int]
        :param skip_too_short_timeseries: Whether to raise an exception or filter when there are too short timeseries.
        :type skip_too_short_timeseries: bool
        :return: the filtered dataframe
        :rtype: pd.DataFrame
        """
        kept_timeseries_identifier_values, too_short_timeseries_identifiers_with_reasons = self._retrieve_short_timeseries_identifiers_with_reasons(df, required_timeseries_size)
        if skip_too_short_timeseries:
            df = self._skip_too_short_timeseries(df, required_timeseries_size, too_short_timeseries_identifiers_with_reasons, kept_timeseries_identifier_values)
        elif len(too_short_timeseries_identifiers_with_reasons):
            error_message = self.format_too_short_timeseries_error_message(too_short_timeseries_identifiers_with_reasons, required_timeseries_size)
            error_message += self._generate_error_message_cta()
            raise ValueError(error_message)
        return df

    def format_too_short_timeseries_error_message(self, too_short_timeseries_identifiers_with_reasons, required_training_size):
        """
        Creates an error message based on the list of too_short_timeseries and the required_training_size
        :param too_short_timeseries_identifiers_with_reasons: invalid identifiers with the specified problematic part train/test/full and the size of the part
        :type too_short_timeseries_identifiers_with_reasons: list[Tuple[str, str, int]]
        :param required_training_size:
        :rtype int
        :return: A formatted error message with specific information for every identifier.
        :rtype: str
        """
        error_message = "The following time series are too short for training: "
        for (identifier, dataset, length) in too_short_timeseries_identifiers_with_reasons:
            error_message += "{}: {} length: {} < Min length: {}; ".format(pretty_timeseries_identifiers(identifier) if self.timeseries_identifier_columns else identifier, dataset, length, self.prediction_length if dataset == "test dataset" else required_training_size)
        return error_message

    def split(self, df, split_id=None):
        """
        Returns an iterator of splits of a prepared dataframe
        :param df: The prepared dataframe
        :type df: pd.DataFrame
        :param split_id: a split id
        :type split_id: Optional[int]
        :return: an iterator of splits
        :rtype: Iterator[Tuple[pd.DataFrame, pd.DataFrame]]
        """
        raise NotImplementedError

    def _single_split(self, split_id, df):
        """
        Create a single split
        :param split_id: A split id: fold id or interval id
        :type split_id: int
        :param df: the prepared dataframe
        :type df: pd.DataFrame
        :return: A tuple of a train, test and historical dataframes
        :rtype: Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]
        """
        raise NotImplementedError

    def compute_time_steps_ranks(self, df):
        """Compute time-steps ranks
        - time_steps_ranks is a pd.Series of the rank of each time steps relatively to its time series.
        - time steps are ranked by reverse chronological order (from most recent within each time series (0) to oldest).
        - time_steps_ranks itself is zero-indexed, with each 'value' corresponding to index of the df.
        :param df: a dataframe
        :type pd.Dataframe
        :return: a series of rank
        :rtype: pd.Series
        """
        #
        if self.timeseries_identifier_columns:
            df_grouped_by_timeseries = df.groupby(_groupby_compat(self.timeseries_identifier_columns))
            self._time_steps_ranks = df_grouped_by_timeseries[self.time_column].rank(method="first", ascending=False) - 1
        else:
            self._time_steps_ranks = df[self.time_column].rank(method="first", ascending=False) - 1

class CustomTrainTestTimeseriesSplitHandler(AbstractTimeseriesSplitHandler):
    """Creates splits based on a list of date intervals.
    Attributes:

    """
    def __init__(self, core_params):
        super(CustomTrainTestTimeseriesSplitHandler, self).__init__(core_params)
        intervals = core_params.get(CUSTOM_TRAIN_TEST_INTERVALS, [])
        self.timestep_params = core_params["timestepParams"]
        if len(intervals) > 1:
            raise ValueError("Several intervals were specified. Interval based splitting only support one interval")
        # Intervals should be a list[Dict[str, Tuple[str, str]]) for train/test intervals
        self.intervals = [{ "train": [np.datetime64(interval["train"][0]), np.datetime64(interval["train"][1])], "test": [np.datetime64(interval["test"][0]), np.datetime64(interval["test"][1])] } for interval in intervals]
        self._round_interval_boundaries()

    def _round_interval_boundaries(self):
        """
        Expands user-defined intervals to align with the time series' frequency,
        ensuring splits don't cut across time periods (e.g., a month).

        The method expands the interval to cover the full duration of the periods
        containing the user-provided start and end dates. For example, for monthly
        data, '2020-03-15' to '2020-05-10' becomes '2020-03-01' to '2020-05-31'.

        As a result, the final interval can be substantially larger and differ from
        the one specified by the user.
        """
        timeunit = self.timestep_params.get("timeunit")
        if not timeunit:
            return
        for interval in self.intervals:
            for part in ["train", "test"]:
                start_date = pd.Timestamp(interval[part][0])
                end_date = pd.Timestamp(interval[part][1])

                new_start, new_end = start_date, end_date
                if timeunit == "WEEK":
                    java_to_python_days = {
                        1: 6,
                        2: 0,
                        3: 1,
                        4: 2,
                        5: 3,
                        6: 4,
                        7: 5
                    }
                    end_of_week_day_java = self.timestep_params.get("endOfWeekDay", 1)
                    end_of_week_day = java_to_python_days[end_of_week_day_java]
                    new_start = pd.tseries.offsets.Week(weekday=end_of_week_day).rollback(start_date).normalize()
                    new_end = pd.tseries.offsets.Week(weekday=end_of_week_day).rollforward(end_date).normalize() + pd.Timedelta(days=1) - pd.Timedelta(milliseconds=1)
                elif timeunit in ["MONTH", "QUARTER", "HALF_YEAR"]:
                    new_start = pd.tseries.offsets.MonthBegin().rollback(start_date).normalize()
                    new_end = pd.tseries.offsets.MonthBegin().rollback(end_date).normalize() + pd.DateOffset(months=1) - pd.Timedelta(milliseconds=1)
                elif timeunit == "YEAR":
                    new_start = pd.tseries.offsets.YearBegin().rollback(start_date).normalize()
                    new_end = pd.tseries.offsets.YearBegin().rollback(end_date).normalize() + pd.DateOffset(years=1) - pd.Timedelta(milliseconds=1)
                else:
                    # For smaller time units, we floor the start date and ceil the end date
                    freq_map = {"SECOND": "s", "MINUTE": "min", "HOUR": "h", "DAY": "D", "BUSINESS_DAY": "D"}
                    timeunit_to_offset = {
                        "BUSINESS_DAY": pd.DateOffset(days=1),
                        "DAY": pd.DateOffset(days=1),
                        "HOUR": pd.DateOffset(hours=1),
                        "MINUTE": pd.DateOffset(minutes=1),
                        "SECOND": pd.DateOffset(seconds=1),
                    }

                    if timeunit in freq_map:
                        freq = freq_map[timeunit]
                        new_start = start_date.floor(freq)
                        new_end = end_date.floor(freq) + timeunit_to_offset[timeunit] - pd.Timedelta(milliseconds=1)
                interval[part] = [np.datetime64(new_start), np.datetime64(new_end)]

    def _single_split(self, interval_id, df):
        interval = self.intervals[interval_id]
        fold_train_start_date, fold_train_end_date = interval["train"][0], interval["train"][1]
        fold_test_start_date, fold_test_end_date = interval["test"][0], interval["test"][1]
        train_indices = (df[self.time_column] >= fold_train_start_date) & (df[self.time_column] <= fold_train_end_date)
        test_indices = (df[self.time_column] >= fold_test_start_date) & (df[self.time_column] <= fold_test_end_date)
        historical_indices = (df[self.time_column] >= fold_train_start_date) & (df[self.time_column] < fold_test_start_date)

        train_df = df.loc[train_indices]
        test_df = df.loc[test_indices]
        historical_df = df.loc[historical_indices]
        return train_df, test_df, historical_df

    def split(self, df, split_id=None):
        if split_id is None:
            for interval_id in range(len(self.intervals)):
                split = self._single_split(interval_id, df)
                self.log_split_info(split, interval_id)
                yield split
        else:
            yield self._single_split(split_id, df)

    def prepare_split_dataframe(self, df, required_timeseries_training_size, skip_too_short_timeseries=False):
        self.compute_time_steps_ranks(df)
        for interval in self.intervals:
            self._assert_valid_interval(interval)
        return self.filter_timeseries(df,
                                      required_timeseries_size=required_timeseries_training_size,
                                      skip_too_short_timeseries=skip_too_short_timeseries)

    def _retrieve_short_timeseries_identifiers_with_reasons(self, df, min_required_training_size):
        too_short_timeseries_identifiers_with_reasons = []
        kept_timeseries_identifier_values = []
        if self.timeseries_identifier_columns:
            df_grouped_by_timeseries = df.groupby(_groupby_compat(self.timeseries_identifier_columns))
            for identifier_values, timeseries_df in df_grouped_by_timeseries:
                train_df, test_df, _ = self._single_split(0, timeseries_df)
                if len(train_df) < min_required_training_size: # We need to be able to train and to score at least once
                    too_short_timeseries_identifiers_with_reasons.append((encode_timeseries_identifier(identifier_values, self.timeseries_identifier_columns), "train dataset", len(train_df)))
                elif len(test_df) < self.prediction_length:
                    too_short_timeseries_identifiers_with_reasons.append((encode_timeseries_identifier(identifier_values, self.timeseries_identifier_columns), "test dataset", len(test_df)))
                else:
                    kept_timeseries_identifier_values.append(identifier_values)
        else:
            train_df, test_df, _ = self._single_split(0, df)
            if len(train_df) < min_required_training_size:
                too_short_timeseries_identifiers_with_reasons.append((SINGLE_TIMESERIES_IDENTIFIER, "train dataset", len(train_df)))
            elif len(test_df) < self.prediction_length:
                too_short_timeseries_identifiers_with_reasons.append((SINGLE_TIMESERIES_IDENTIFIER, "test dataset", len(test_df)))
            else:
                kept_timeseries_identifier_values.append(SINGLE_TIMESERIES_IDENTIFIER)

        return kept_timeseries_identifier_values, too_short_timeseries_identifiers_with_reasons

    @staticmethod
    def _assert_valid_interval(interval):
        """
        Ensure that the interval is valid for the given dataframe. Raises ValueError otherwise.
        :param interval: an interval
        :return:
        """
        train_start_date, train_end_date = interval["train"]
        test_start_date, test_end_date = interval["test"]
        if train_start_date > train_end_date:
            raise ValueError("Invalid interval: train start date is after train end date")
        if test_start_date > test_end_date:
            raise ValueError("Invalid interval: test start date is after test end date")
        if train_end_date >= test_start_date:
            raise ValueError("Invalid interval: train end date is after test start date")

    def _generate_error_message_cta(self):
        return "Ensure that you have enough training data in every defined fold and/or select 'Skip too short time series'."

    def log_split_info(self, split, interval_id):
        if split is None or len(split) != 3: return # Safety net, we don't want to crash logging.
        train_df, test_df, historical_df = split
        if not len(train_df) or not len(test_df) or not len(historical_df): return # Exception or skipped timeseries
        interval = self.intervals[interval_id]
        fold_train_start_date, fold_train_end_date = interval["train"][0], interval["train"][1]
        fold_test_start_date, fold_test_end_date = interval["test"][0], interval["test"][1]
        logger.info(
            "Interval split description: train: (%s, %s) , test: (%s, %s). Dataframes: train (%s, %s), test (%s, %s), historical (%s, %s)" %
            (fold_train_start_date,
             fold_train_end_date,
             fold_test_start_date,
             fold_test_end_date,
             train_df[self.time_column].iloc[0],
             train_df[self.time_column].iloc[-1],
             test_df[self.time_column].iloc[0],
             test_df[self.time_column].iloc[-1],
             historical_df[self.time_column].iloc[0],
             historical_df[self.time_column].iloc[-1])
        )

class KFoldTimeseriesSplitHandler(AbstractTimeseriesSplitHandler):
    """ Similar to sklearn.model_selection.TimeSeriesSplit but with fixed test set size.
    Class to compute rolling windows of train and test sets.
    Attributes:
        n_splits (int): Number of splits.
        fold_offset (bool, optional): Add an offset of size prediction_length between splits. Defaults to False.
        equal_duration_folds (bool, optional): Whether each fold should be of equal duration. Defaults to False.
    """
    def __init__(self, n_splits, core_params, fold_offset=False, equal_duration_folds=False):
        super(KFoldTimeseriesSplitHandler, self).__init__(core_params)
        self.n_splits = n_splits
        self.fold_offset = fold_offset
        self.equal_duration_folds = equal_duration_folds

    def split(self, df, split_id=None):
        if split_id is None:
            for split_idx in range(self.n_splits):
                train_df, test_df, historical_df = self._single_split(split_idx, df)
                if train_df.empty or test_df.empty:
                    continue
                yield train_df, test_df, historical_df
        else:
            yield self._single_split(split_id, df)

    def _single_split(self, split_id, df):
        """
        Iteratively split (by time) the dataframe into couples of (train, test, historical) dataframes
        In the context of KFoldTimerseriesSplit, the historical dataframe is equal to the train one
        df MUST be 0 indexed, otherwise the logic here is incorrect

        Example:
            - Timeseries [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] with test_size=2 and n_splits=3 yields the splits:
                1. train: [0, 1, 2, 3, 4], test: [5, 6], historical: [0, 1, 2, 3, 4]
                2. train: [0, 1, 2, 3, 4, 5, 6], test: [7, 8], historical: [0, 1, 2, 3, 4, 5, 6]
                3. train: [0, 1, 2, 3, 4, 5, 6, 7, 8], test: [9, 10], historical: [0, 1, 2, 3, 4, 5, 6, 7, 8]

            - Timeseries [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] with test_size=2, n_splits=2 and fold_offset=True yields:
                1. train: [0, 1, 2, 3, 4], test: [5, 6], historical: [0, 1, 2, 3, 4]
                2. train: [0, 1, 2, 3, 4, 5, 6, 7, 8], test: [9, 10], historical: [0, 1, 2, 3, 4, 5, 6, 7, 8]

            - Timeseries [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] with test_size=2, n_splits=3 and equal_duration_folds=True yields:
                1. train: [0, 1, 2, 3, 4], test: [5, 6], historical: [0, 1, 2, 3, 4]
                2. train: [2, 3, 4, 5, 6], test: [7, 8], historical: [2, 3, 4, 5, 6]
                3. train: [4, 5, 6, 7, 8], test: [9, 10], historical: [4, 5, 6, 7, 8]

        :param df: Input dataframe to be split. /!\ HAS TO BE SORTED BY TIME /!\
        :param split_id: Id of the split to generate (0 is the first split with the smallest train set).
               If not set, all splits are generated.
        """
        if self._time_steps_ranks is None:
            logger.warning("Time steps ranks not available. Make sure `prepare_split_dataframe` is called before `split`.")
            self.compute_time_steps_ranks(df)
        # time step rank where the test time series ends (inclusive)
        most_recent_test_timestep = self.test_size * (self.n_splits - split_id - 1) * (2 if self.fold_offset else 1)
        # time step rank where the train time series ends (inclusive).
        # It is greater than most_recent_test_timestep as the time steps are ranked by reverse chronological order
        most_recent_train_timestep = most_recent_test_timestep + self.test_size
        # A higher time step rank means an older time step
        train_indices = self._time_steps_ranks >= most_recent_train_timestep

        # If every split is of equal duration, the train start is offset from the beginning
        if self.equal_duration_folds:
            beginning_offset = self.test_size * split_id * (2 if self.fold_offset else 1)
            oldest_train_timestamp = self._time_steps_ranks.iloc[0] - beginning_offset
            train_indices &= self._time_steps_ranks <= oldest_train_timestamp
        test_indices = (self._time_steps_ranks >= most_recent_test_timestep) & (self._time_steps_ranks < most_recent_train_timestep)
        train_df = df.loc[train_indices]
        test_df = df.loc[test_indices]
        return train_df, test_df, train_df

    def prepare_split_dataframe(self, df, required_timeseries_training_size, skip_too_short_timeseries=False):
        # train set of first window (length = full_length - n_splits * test_size)
        # (if fold_offset: length = full_length - (2 * n_splits - 1) * test_size)
        # must be bigger than the required time series size for training a given algorithm, i.e.:
        self.compute_time_steps_ranks(df)
        min_required_length_for_split = self.n_splits * self.test_size + required_timeseries_training_size
        if self.fold_offset:
            min_required_length_for_split += (self.n_splits - 1) * self.test_size
        return self.filter_timeseries(df, min_required_length_for_split, skip_too_short_timeseries)

    def _retrieve_short_timeseries_identifiers_with_reasons(self, df, min_required_length_for_split):
        if self.timeseries_identifier_columns:
            df_grouped_by_timeseries = df.groupby(_groupby_compat(self.timeseries_identifier_columns))
            timeseries_sizes = df_grouped_by_timeseries.size()
            too_short_timeseries_identifiers_with_reasons = [
                (encode_timeseries_identifier(identifier_values, self.timeseries_identifier_columns), "full dataset", timeseries_length)
                for (identifier_values, timeseries_length) in timeseries_sizes[timeseries_sizes < min_required_length_for_split].items()
            ]
            kept_timeseries_identifier_values = set(timeseries_sizes[timeseries_sizes >= min_required_length_for_split].index)
        else:
            if len(df.index) < min_required_length_for_split:
                too_short_timeseries_identifiers_with_reasons = [(SINGLE_TIMESERIES_IDENTIFIER, "full dataset", len(df.index))]
                kept_timeseries_identifier_values = set()
            else:
                too_short_timeseries_identifiers_with_reasons = []
                kept_timeseries_identifier_values = set(SINGLE_TIMESERIES_IDENTIFIER)

        return kept_timeseries_identifier_values, too_short_timeseries_identifiers_with_reasons

    def _generate_error_message_cta(self):
        helpers = []
        if self.test_size > 1:
            helpers.append("the evaluation set size")
        if self.n_splits > 1:
            helpers.append("the number of folds")
        helpers.append("the season length (if applicable)")
        return " Try to decrease {}. Or select 'Skip too short time series'.".format(" and/or ".join(helpers))

class TimeseriesInteractiveScoringSplitHandler(KFoldTimeseriesSplitHandler):
    """
    Extend TimeseriesSplitHandler to return only the last horizon as the test set for interactive scoring scenario creation.
    """
    def __init__(self, core_params):
        super().__init__(1, core_params)
        self.test_size = self.prediction_length
