from dataiku.eda.types import Literal

import scipy.stats as sps

from dataiku.eda.computations.univariate.abstract_multi_sample import AbstractMultiSampleUnivariateComputation
from dataiku.eda.computations.context import Context
from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame
from dataiku.eda.grouping.grouping import Grouping
from dataiku.eda.types import MoodTestNSampModel, MoodTestNSampResultModel


class MoodTestNSamp(AbstractMultiSampleUnivariateComputation):
    def __init__(self, column: str, grouping: Grouping):
        super(MoodTestNSamp, self).__init__(column, grouping)

    @staticmethod
    def get_type() -> Literal["mood_test_nsamp"]:
        return "mood_test_nsamp"

    @staticmethod
    def build(params: MoodTestNSampModel) -> 'MoodTestNSamp':
        return MoodTestNSamp(params['column'], Grouping.build(params["grouping"]))

    def apply(self, idf: ImmutableDataFrame, ctx: Context) -> MoodTestNSampResultModel:
        samples, grouped_idfs, _ = self._compute_groups(idf)
        self._check_disjoint_groups_and_not_degenerate_case(grouped_idfs)

        statistic, pvalue, _, _ = sps.median_test(*samples, ties='ignore')
        return {"type": self.get_type(), "statistic": statistic, "pvalue": pvalue}
