from typing import Iterator, List
from dataiku.eda.types import Literal

from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame
from dataiku.eda.grouping.grouping import Grouping
from dataiku.eda.grouping.grouping import GroupingResult
from dataiku.eda.types import CrossGroupingModel, CrossGroupingResultModel


class CrossGrouping(Grouping):
    def __init__(self, groupings: List[Grouping]):
        self.groupings = groupings

    @staticmethod
    def get_type() -> Literal["cross"]:
        return "cross"

    @staticmethod
    def build(params: CrossGroupingModel) -> 'CrossGrouping':
        return CrossGrouping([Grouping.build(p) for p in params["groupings"]])

    def describe(self) -> str:
        return ' x '.join([g.describe() for g in self.groupings])

    def compute_groups(self, idf: ImmutableDataFrame) -> 'CrossGroupingResult':
        return CrossGroupingResult([g.compute_groups(idf) for g in self.groupings], idf)


class CrossGroupingResult(GroupingResult):
    def __init__(self, groups: List[GroupingResult], original_idf: ImmutableDataFrame):
        self.groups = groups
        self.original_idf = original_idf

    def serialize(self) -> CrossGroupingResultModel:
        return {
            "type": CrossGrouping.get_type(),
            "groups": [g.serialize() for g in self.groups]
        }

    def _iter_groups_rec(self, idf: ImmutableDataFrame, dim: int) -> Iterator[ImmutableDataFrame]:
        if dim >= len(self.groups):
            yield idf
            return

        for group_idf in self.groups[dim].iter_groups():
            for crossed_idf in self._iter_groups_rec(idf & group_idf, dim + 1):
                yield crossed_idf

    def iter_groups(self) -> Iterator[ImmutableDataFrame]:
        for group_idf in self._iter_groups_rec(self.original_idf, 0):
            yield group_idf
