from typing import Iterator
from dataiku.eda.types import Literal

from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame
from dataiku.eda.grouping.grouping import Grouping
from dataiku.eda.grouping.grouping import GroupingResult
from dataiku.eda.types import MergeGroupingModel, MergeGroupingResultModel


class MergeGrouping(Grouping):
    def __init__(self, inner_grouping: Grouping):
        self.inner_grouping = inner_grouping

    @staticmethod
    def get_type() -> Literal["merge"]:
        return "merge"

    def describe(self) -> str:
        return "Merge(%s)" % self.inner_grouping.describe()

    @staticmethod
    def build(params: MergeGroupingModel) -> 'MergeGrouping':
        return MergeGrouping(Grouping.build(params['innerGrouping']))

    def compute_groups(self, idf: ImmutableDataFrame) -> 'MergeGroupingResult':
        inner_grouping_result = self.inner_grouping.compute_groups(idf)

        merged_groups_idf = idf[[]]
        for idf in inner_grouping_result.iter_groups():
            merged_groups_idf |= idf

        return MergeGroupingResult(inner_grouping_result, merged_groups_idf)


class MergeGroupingResult(GroupingResult):
    def __init__(self, inner_grouping_result: GroupingResult, idf: ImmutableDataFrame):
        self.inner_grouping_result = inner_grouping_result
        self.idf = idf

    def serialize(self) -> MergeGroupingResultModel:
        return {
            "type": MergeGrouping.get_type(),
            "innerGroupingResult": self.inner_grouping_result.serialize()
        }

    def iter_groups(self) -> Iterator[ImmutableDataFrame]:
        yield self.idf
