from enum import Enum
from typing import List, Optional
from dataiku.eda.types import Literal

import numpy as np
import scipy.stats as sps

from dataiku.eda.computations.univariate.abstract_multi_sample import AbstractMultiSampleUnivariateComputation
from dataiku.eda.computations.context import Context
from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame, FloatVector
from dataiku.eda.exceptions import InvalidParams, DegenerateCaseError, UnknownObjectType
from dataiku.eda.grouping.grouping import Grouping
from dataiku.eda.types import LeveneCenter, LeveneTestNSampModel, LeveneTestNSampResultModel


class LeveneCenterEnum(str, Enum):
    MEDIAN = "MEDIAN"
    MEAN = "MEAN"
    TRIMMED_MEAN = "TRIMMED_MEAN"


class LeveneTestNSamp(AbstractMultiSampleUnivariateComputation):
    def __init__(self, column: str, grouping: Grouping, center: LeveneCenter, proportion_to_trim: Optional[float]):
        super(LeveneTestNSamp, self).__init__(column, grouping)
        self.center = center
        self.proportion_to_trim = proportion_to_trim

    @staticmethod
    def get_type() -> Literal["levene_test_nsamp"]:
        return "levene_test_nsamp"

    @staticmethod
    def build(params: LeveneTestNSampModel) -> 'LeveneTestNSamp':
        return LeveneTestNSamp(
            params['column'],
            Grouping.build(params["grouping"]),
            params['center'],
            params.get('proportionToTrim')
        )

    def _check_degenerate_cases(self, samples: List[FloatVector]) -> None:
        for sample in samples:
            unique_values, value_counts = np.unique(sample, return_counts=True)
            if (len(unique_values) > 2) or (len(unique_values) == 2 and value_counts[0] != value_counts[1]):
                return

        raise DegenerateCaseError(
            ("At least one sample must have at least three different values of {}, "
             "or exactly two different values but not in the same proportion").format(self.column)
        )

    def apply(self, idf: ImmutableDataFrame, ctx: Context) -> LeveneTestNSampResultModel:
        samples, grouped_idfs, computed_groups = self._compute_groups(idf)
        self._check_disjoint_groups_and_not_degenerate_case(grouped_idfs)
        self._check_degenerate_cases(samples)

        if self.center == LeveneCenterEnum.MEDIAN:
            centers = [np.median(s) for s in samples]
        elif self.center == LeveneCenterEnum.MEAN:
            centers = [np.mean(s) for s in samples]
        elif self.center == LeveneCenterEnum.TRIMMED_MEAN:
            if (
                self.proportion_to_trim is None
                or self.proportion_to_trim < 0
                or self.proportion_to_trim >= 0.5
            ):
                raise InvalidParams("Proportion to trim must be set and within [0, 0.5) when the selected center is the trimmed mean")
            centers = [sps.trim_mean(s, proportiontocut=self.proportion_to_trim) for s in samples]
        else:
            raise UnknownObjectType("Center must be one of {}".format(", ".join(va.value for va in LeveneCenterEnum)))

        assert len(centers) == len(samples)

        scaled_samples = [np.abs(s - center) for (s, center) in zip(samples, centers)]
        statistic, pvalue = sps.f_oneway(*scaled_samples)

        return {
            "type": self.get_type(),
            "statistic": statistic,
            "pvalue": pvalue,
            "groups": computed_groups.serialize(),
            "centers": centers
        }
