from enum import Enum
from typing import List, Optional
from dataiku.eda.types import Literal

import numpy as np
import scipy.stats as sps

from dataiku.eda.computations.context import Context
from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame, FloatVector
from dataiku.eda.computations.univariate.abstract_pairwise import AbstractPairwiseUnivariateComputation
from dataiku.eda.exceptions import InvalidParams, DegenerateCaseError, UnknownObjectType
from dataiku.eda.grouping.grouping import Grouping
from dataiku.eda.stats.multitest import multitest_correction
from dataiku.eda.types import LeveneCenter, PairwiseLeveneTestModel, PairwiseLeveneTestResultModel, PValueAdjustmentMethod


class LeveneCenterEnum(str, Enum):
    MEDIAN = "MEDIAN"
    MEAN = "MEAN"
    TRIMMED_MEAN = "TRIMMED_MEAN"


# Pairwise unpaired Levene test
class PairwiseLeveneTest(AbstractPairwiseUnivariateComputation):
    def __init__(self, column: str, grouping: Grouping, one_vs_all: bool, adjustment_method: PValueAdjustmentMethod, center: LeveneCenter, proportion_to_trim: Optional[float]):
        super(PairwiseLeveneTest, self).__init__(column, grouping, one_vs_all, adjustment_method)
        self.center = center
        self.proportion_to_trim = proportion_to_trim

    @staticmethod
    def get_type() -> Literal["pairwise_levene_test"]:
        return "pairwise_levene_test"

    @staticmethod
    def build(params: PairwiseLeveneTestModel) -> 'PairwiseLeveneTest':
        return PairwiseLeveneTest(
            params['column'],
            Grouping.build(params["grouping"]),
            params['oneVsAll'],
            params['adjustmentMethod'],
            params['center'],
            params.get('proportionToTrim'),
        )

    def _check_degenerate_cases(self, samples: List[FloatVector]):
        """Each pair must have at least one non-degenerate sample, raise if that's not the case"""

        def _is_degenerate(sample: FloatVector) -> bool:
            unique_values, value_counts = np.unique(sample, return_counts=True)
            return (len(unique_values) < 2) or (len(unique_values) == 2 and value_counts[0] == value_counts[1])

        def _raise_if_too_many_degenerate_samples(samples: List[FloatVector], allowed_degenerate_samples: int):
            degenerate_samples_found = 0
            for sample in samples:
                if _is_degenerate(sample):
                    degenerate_samples_found += 1

                    if degenerate_samples_found > allowed_degenerate_samples:
                        raise DegenerateCaseError(
                            ("In each pair or populations, at least one sample must have at least three different values of {}, "
                             "or exactly two different values but not in the same proportion").format(self.column)
                        )

        if self.one_vs_all:
            # in 1 vs ALL, either the ref or all of the others need to be not degenerated
            if _is_degenerate(samples[0]):
                _raise_if_too_many_degenerate_samples(samples[1:], allowed_degenerate_samples=0)
        else:
            # in ALL vs ALL, at most 1 sample can be degenerate
            _raise_if_too_many_degenerate_samples(samples, allowed_degenerate_samples=1)

    def apply(self, idf: ImmutableDataFrame, ctx: Context) -> PairwiseLeveneTestResultModel:
        samples, grouped_idfs, _ = self._compute_groups(idf)

        pvalues = []
        statistics = []

        # check degenerate cases: each pair must have at least one non-degenerate sample
        self._check_degenerate_cases(samples)

        if self.center == LeveneCenterEnum.MEDIAN:
            centers = [np.median(s) for s in samples]
        elif self.center == LeveneCenterEnum.MEAN:
            centers = [np.mean(s) for s in samples]
        elif self.center == LeveneCenterEnum.TRIMMED_MEAN:
            if (
                self.proportion_to_trim is None
                or self.proportion_to_trim < 0
                or self.proportion_to_trim >= 0.5
            ):
                raise InvalidParams("Proportion to trim must be set and within [0, 0.5) when the selected center is the trimmed mean")
            centers = [sps.trim_mean(s, proportiontocut=self.proportion_to_trim) for s in samples]
        else:
            raise UnknownObjectType("Center must be one of {}".format(", ".join(va.value for va in LeveneCenterEnum)))

        assert len(centers) == len(samples)

        scaled_samples = [np.abs(s - center) for (s, center) in zip(samples, centers)]

        for series_i, series_j in self._iterate_pairs(scaled_samples, grouped_idfs):
            statistic, pvalue = sps.f_oneway(series_i, series_j)
            pvalues.append(pvalue)
            statistics.append(statistic)

        adjusted_pvalues = multitest_correction(pvalues, self.adjustment_method)

        return {
            "type": self.get_type(),
            "statistics": statistics,
            "pvalues": pvalues,
            "adjustedPvalues": adjusted_pvalues,
            "centers": centers,
        }
