from dataiku.eda.types import Literal

import numpy as np
import scipy.stats as sps

from dataiku.eda.computations.computation import UnivariateComputation
from dataiku.eda.computations.context import Context
from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame
from dataiku.eda.exceptions import GroupsAreNotDisjoint, DegenerateCaseError
from dataiku.eda.exceptions import NotEnoughDataError
from dataiku.eda.grouping.grouping import Grouping
from dataiku.eda.types import KsTest2SampModel, KsTest2SampResultModel


class KsTest2Samp(UnivariateComputation):
    def __init__(self, column: str, grouping: Grouping):
        super(KsTest2Samp, self).__init__(column)
        self.grouping = grouping

    @staticmethod
    def get_type() -> Literal["ks_test_2samp"]:
        return "ks_test_2samp"

    @staticmethod
    def build(params: KsTest2SampModel) -> 'KsTest2Samp':
        return KsTest2Samp(params['column'], Grouping.build(params["grouping"]))

    def apply(self, idf: ImmutableDataFrame, ctx: Context) -> KsTest2SampResultModel:
        grouped_idfs = [gidf[np.isfinite(gidf.float_col(self.column))] for gidf in self.grouping.compute_groups(idf).iter_groups()]

        if len(grouped_idfs) != 2:
            comparative = "more" if len(grouped_idfs) > 2 else "less"
            raise DegenerateCaseError("There are {} than two populations".format(comparative))

        idf1 = grouped_idfs[0]
        idf2 = grouped_idfs[1]

        if len(idf1 & idf2) > 0:
            raise GroupsAreNotDisjoint()

        series1 = idf1.float_col_no_missing(self.column)
        series2 = idf2.float_col_no_missing(self.column)

        if len(series1) == 0 or len(series2) == 0:
            raise NotEnoughDataError("At least one population is empty or does not have any value for {}".format(self.column))

        statistic, pvalue = sps.ks_2samp(series1, series2)

        return {"type": self.get_type(), "statistic": statistic, "pvalue": pvalue}
