from typing import Iterator
from dataiku.eda.types import Literal

import numpy as np

from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame
from dataiku.eda.grouping.grouping import Grouping, GroupingResult
from dataiku.eda.types import TopNTimeGroupingModel, TopNTimeGroupingResultModel


class TopNTimeGrouping(Grouping):
    def __init__(self, column: str, n: int):
        self.column = column
        self.n = n

    @staticmethod
    def get_type() -> Literal["topn_time"]:
        return "topn_time"

    def describe(self) -> str:
        return "{}(column={}, n={})".format(
            self.__class__.__name__,
            self.column,
            self.n,
        )

    @staticmethod
    def build(params: TopNTimeGroupingModel) -> 'TopNTimeGrouping':
        return TopNTimeGrouping(
            params["column"],
            params["n"],
        )

    def compute_groups(self, idf: ImmutableDataFrame) -> 'TopNTimeGroupingResult':
        # No need to do anything
        if self.n == 0:
            return TopNTimeGroupingResult(idf[[]])

        # Get timestamps in a numpy array
        timestamps = idf.date_col(self.column).values

        # Get indexes of valid dates
        non_nat_indexes, = np.nonzero(~np.isnat(timestamps))

        # No need to sort if there are fewer values than we need
        if len(non_nat_indexes) <= self.n:
            return TopNTimeGroupingResult(idf[non_nat_indexes])

        # Perform partial sort
        limited_non_nat_indexes_indexes = np.argpartition(timestamps[non_nat_indexes], kth=self.n - 1)[:self.n]

        # Order of indices used to slice the dataframe at the end *should* not really matter, but it is simple enough
        # here to preserve initial ordering of rows by transforming the list of indices into a boolean mask
        mask = np.zeros(len(non_nat_indexes), dtype=np.bool_)
        mask[limited_non_nat_indexes_indexes] = True
        limited_non_nat_indexes = non_nat_indexes[mask]

        return TopNTimeGroupingResult(idf[limited_non_nat_indexes])


class TopNTimeGroupingResult(GroupingResult):

    def __init__(self, idf: ImmutableDataFrame):
        self.idf = idf

    def serialize(self) -> TopNTimeGroupingResultModel:
        return {
            "type": TopNTimeGrouping.get_type(),
        }

    def iter_groups(self) -> Iterator[ImmutableDataFrame]:
        yield self.idf
