from typing import Iterator, List
from dataiku.eda.types import Literal

import pandas as pd

from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame
from dataiku.eda.exceptions import InvalidParams
from dataiku.eda.grouping.grouping import Grouping
from dataiku.eda.grouping.grouping import GroupingResult
from dataiku.eda.types import MultiAnumGroupingModel, MultiAnumGroupingResultModel


class MultiAnumGrouping(Grouping):
    def __init__(self, columns: List[str]):
        self.columns = columns

    @staticmethod
    def get_type() -> Literal["multi_anum"]:
        return "multi_anum"

    def describe(self) -> str:
        return "{}(columns={})".format(self.__class__.__name__, self.columns)

    @staticmethod
    def build(params: MultiAnumGroupingModel) -> 'MultiAnumGrouping':
        return MultiAnumGrouping(params["columns"])

    def compute_groups(self, idf: ImmutableDataFrame) -> 'MultiAnumGroupingResult':
        n_columns = len(self.columns)
        if n_columns == 0:
            raise InvalidParams("There must be at least one column to group over")

        if len(idf) == 0:
            return MultiAnumGroupingResult(self.columns, [], [])

        df_data = {}
        for column in self.columns:
            df_data[column] = idf.text_col(column)

        df = pd.DataFrame(df_data, copy=False)
        # There is no need to sort the groups at this point because we retrieve
        # the group keys from a dict later - and the order of keys in a dict is
        # not ensured. This allows for a significant performance gain when
        # computing the groups.
        grouped_by = df.groupby(self.columns, sort=True, as_index=False)
        sorted_group_keys = grouped_by.grouper.result_index.tolist()

        all_groups = []
        all_idfs = []

        for group_key in sorted_group_keys:
            group_indices = grouped_by.indices.get(group_key)

            if group_indices is None or len(group_indices) == 0:
                # when grouping over a single column, pandas.groupby produces
                # an empty group with an empty label, because our immutable data
                # frame considers missing values as empty strings.
                # let's filter this special group
                continue

            group_idf = idf[group_indices]
            all_idfs.append(group_idf)

            formatted_key = [group_key] if n_columns == 1 else list(group_key)
            all_groups.append(formatted_key)

        return MultiAnumGroupingResult(self.columns, all_idfs, all_groups)


class MultiAnumGroupingResult(GroupingResult):

    def __init__(self, columns: List[str], idfs: List[ImmutableDataFrame], group_values: List[List[str]]):
        self.columns = columns
        self.idfs = idfs
        self.group_values = group_values

    def serialize(self) -> MultiAnumGroupingResultModel:
        return {
            "type": MultiAnumGrouping.get_type(),
            "columns": self.columns,
            "groupValues": self.group_values,
        }

    def iter_groups(self) -> Iterator[ImmutableDataFrame]:
        for group_idf in self.idfs:
            yield group_idf
