from typing import List, Optional
from dataiku.eda.types import Literal

import numpy as np
import scipy.stats as sps

from dataiku.eda.computations.computation import UnivariateComputation
from dataiku.eda.computations.context import Context
from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame
from dataiku.eda.exceptions import NoDataError
from dataiku.eda.types import QuantilesModel, QuantilesResultModel


class Quantiles(UnivariateComputation):
    def __init__(self, column: str, confidence: Optional[float], freqs: List[float]):
        super(Quantiles, self).__init__(column)
        self.freqs = freqs
        self.confidence = confidence

    @staticmethod
    def get_type() -> Literal["quantiles"]:
        return "quantiles"

    def describe(self) -> str:
        freqs_desc = ', '.join((str(p) for p in self.freqs))
        return "%s(%s, %s)" % (self.__class__.__name__, self.column, freqs_desc)

    @staticmethod
    def build(params: QuantilesModel) -> 'Quantiles':
        return Quantiles(
            params['column'],
            params.get('confidence'),
            params['freqs']
        )

    def apply(self, idf: ImmutableDataFrame, ctx: Context) -> QuantilesResultModel:
        series = idf.float_col_no_missing(self.column)

        if len(series) == 0:
            raise NoDataError()

        # np.quantile() is much faster than scipy.stats.mquantiles()
        quantiles = np.quantile(series, self.freqs)

        if self.confidence is not None:
            # Formula stolen from scipy.stats.mstats.mquantiles_cimj()
            alpha = min(self.confidence, 1 - self.confidence)
            z = sps.norm.ppf(1 - alpha / 2.)
            smj = sps.mstats.mjci(series, self.freqs)
            lower_bounds = quantiles - z * smj
            upper_bounds = quantiles + z * smj

        quantile_descs = []
        for index, (freq, quantile) in enumerate(zip(self.freqs, quantiles)):
            quantile_desc = {"freq": freq, "quantile": quantile, "lower": None, "upper": None}

            if self.confidence is not None and np.isfinite(lower_bounds[index]) and np.isfinite(upper_bounds[index]):
                quantile_desc["lower"] = lower_bounds[index]
                quantile_desc["upper"] = upper_bounds[index]

            quantile_descs.append(quantile_desc)

        return {"type": Quantiles.get_type(), "quantiles": quantile_descs}
