from typing import Optional
from dataiku.eda.types import Literal

from statsmodels.stats.weightstats import DescrStatsW

from dataiku.doctor.utils import dku_nonaninf
from dataiku.eda.computations.computation import UnivariateComputation
from dataiku.eda.computations.context import Context
from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame
from dataiku.eda.exceptions import NoDataError
from dataiku.eda.types import MeanModel, MeanResultModel


class Mean(UnivariateComputation):
    def __init__(self, column: str, confidence: Optional[float]):
        super(Mean, self).__init__(column)
        self.confidence = confidence

    @staticmethod
    def get_type() -> Literal["mean"]:
        return "mean"

    @staticmethod
    def build(params: MeanModel) -> 'Mean':
        return Mean(params['column'], params.get('confidence'))

    def apply(self, idf: ImmutableDataFrame, ctx: Context) -> MeanResultModel:
        series = idf.float_col_no_missing(self.column)

        if len(series) == 0:
            raise NoDataError()

        mean = series.mean()
        output: MeanResultModel = {"type": self.get_type(), "value": mean}
        if self.confidence is not None:
            if len(series) > 1:
                lower, upper = DescrStatsW(series).tconfint_mean(alpha=1 - self.confidence)
            else:
                lower, upper = None, None
            output["lower"] = dku_nonaninf(lower)
            output["upper"] = dku_nonaninf(upper)

        return output
