from typing import Iterator, List
from dataiku.eda.types import Literal

from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame
from dataiku.eda.grouping.grouping import Grouping
from dataiku.eda.grouping.grouping import GroupingResult
from dataiku.eda.types import UnionGroupingModel, UnionGroupingResultModel


class UnionGrouping(Grouping):
    def __init__(self, groupings: List[Grouping]):
        self.groupings = groupings

    @staticmethod
    def get_type() -> Literal["union"]:
        return "union"

    def describe(self) -> str:
        return "%s" % ','.join(g.describe() for g in self.groupings)

    @staticmethod
    def build(params: UnionGroupingModel) -> 'UnionGrouping':
        return UnionGrouping([Grouping.build(grouping) for grouping in params['groupings']])

    def compute_groups(self, idf: ImmutableDataFrame) -> 'UnionGroupingResult':
        return UnionGroupingResult([grouping.compute_groups(idf) for grouping in self.groupings])


class UnionGroupingResult(GroupingResult):
    def __init__(self, computed_groups: List[GroupingResult]):
        self.computed_groups = computed_groups

    def serialize(self) -> UnionGroupingResultModel:
        return {
            "type": UnionGrouping.get_type(),
            "groupings": [cg.serialize() for cg in self.computed_groups]
        }

    def iter_groups(self) -> Iterator[ImmutableDataFrame]:
        for cg in self.computed_groups:
            for group_idf in cg.iter_groups():
                yield group_idf
