from typing import Iterator, List, Tuple

from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame, FloatVector
from dataiku.eda.computations.univariate.abstract_multi_sample import AbstractMultiSampleUnivariateComputation
from dataiku.eda.exceptions import GroupsAreNotDisjoint
from dataiku.eda.grouping.grouping import Grouping
from dataiku.eda.types import PValueAdjustmentMethod


class AbstractPairwiseUnivariateComputation(AbstractMultiSampleUnivariateComputation):

    def __init__(self, column: str, grouping: Grouping, one_vs_all: bool, adjustment_method: PValueAdjustmentMethod):
        super(AbstractPairwiseUnivariateComputation, self).__init__(column, grouping)
        self.one_vs_all = one_vs_all
        self.adjustment_method = adjustment_method

    def _iterate_pairs(self, samples: List[FloatVector], grouped_idfs: List[ImmutableDataFrame]) -> Iterator[Tuple[FloatVector, FloatVector]]:
        for i, idf_i in enumerate(grouped_idfs):
            # if 1 vs ALL, only compute comparisons with the reference group (which is the first group)
            if self.one_vs_all and i > 0:
                break

            for j, idf_j in enumerate(grouped_idfs):
                if i >= j:
                    continue

                if len(idf_i & idf_j) > 0:
                    # We should never end up here, this is likely a programming mistake from the caller of EDA compute
                    raise GroupsAreNotDisjoint()

                series_i = samples[i]
                series_j = samples[j]

                yield (series_i, series_j)
