from typing import Iterator, List, Optional
from dataiku.eda.types import Literal, Final

import numpy as np
import pandas as pd

from dataiku.core.binning_utils import nice_bin_edges
from dataiku.eda.computations.immutable_data_frame import ImmutableDataFrame, FloatVector
from dataiku.eda.exceptions import InvalidParams
from dataiku.eda.grouping.grouping import Grouping
from dataiku.eda.grouping.grouping import GroupingResult
from dataiku.eda.types import BinnedGroupingModel, BinnedGroupingResultModel, BinningMode


class BinnedGrouping(Grouping):
    # With automatic binning (mode==AUTO), the nb. of bins is determined using a heuristic.
    # It is important to remember that the estimated ideal nb. of bins is unbounded.
    # - The nb. of bins can be limited with parameter 'nbBins' (but this parameter is optional)
    # - Constant AUTO_NB_BINS_MAX is an additional (hardcoded) limit which prevents the heuristic from going crazy
    AUTO_NB_BINS_MAX: Final[int] = 100

    def __init__(self, column: str, mode: BinningMode, nb_bins: Optional[int], custom_bounds: Optional[List[float]], keep_na: bool):
        self.column = column
        self.mode = mode
        self.nb_bins = nb_bins
        self.custom_bounds = custom_bounds
        self.keep_na = keep_na

    @staticmethod
    def get_type() -> Literal["binned"]:
        return "binned"

    def describe(self) -> str:
        return "Binned(%s)" % self.column

    @staticmethod
    def build(params: BinnedGroupingModel) -> 'BinnedGrouping':
        return BinnedGrouping(
            params['column'],
            params['mode'],
            params.get('nbBins'),
            params.get('customBounds'),
            params.get('keepNA')
        )

    def compute_auto_bin_edges(self, series: FloatVector) -> FloatVector:
        # Do not use 'auto' or 'fd' mode (https://github.com/numpy/numpy/issues/11879)
        estimated_nb_bins = len(np.histogram_bin_edges(series, bins='doane'))
        nb_bins = min(estimated_nb_bins, BinnedGrouping.AUTO_NB_BINS_MAX)
        if self.nb_bins is not None:
            nb_bins = min(self.nb_bins, nb_bins)

        return nice_bin_edges(series, nb_bins)

    def compute_fixed_bin_edges(self, series: FloatVector) -> FloatVector:
        nb_bins = self.nb_bins
        if nb_bins is None or nb_bins < 1:
            raise InvalidParams("Expected the number of bins to be greater than or equal to 1")

        return nice_bin_edges(series, nb_bins)

    def compute_custom_bin_edges(self) -> FloatVector:
        if self.custom_bounds is None:
            raise InvalidParams("Expected custom bounds, found none")

        float64_max = np.finfo(np.float64).max  # aka infinity
        bin_edges = np.sort(self.custom_bounds)
        return np.concatenate(([-float64_max], bin_edges, [float64_max]))

    def compute_groups(self, idf: ImmutableDataFrame) -> 'BinnedGroupingResult':
        series = idf.float_col(self.column)
        idf_no_missing = idf[np.isfinite(series)]
        series_no_missing = idf_no_missing.float_col(self.column)

        if self.mode == 'AUTO':
            bin_edges = self.compute_auto_bin_edges(series_no_missing)
        elif self.mode == 'FIXED_NB':
            bin_edges = self.compute_fixed_bin_edges(series_no_missing)
        elif self.mode == 'CUSTOM':
            bin_edges = self.compute_custom_bin_edges()
        else:
            raise NotImplementedError("Not implemented binning mode: %s" % self.mode)

        bin_map = np.digitize(series_no_missing, bin_edges) - 1
        indices = pd.Series(bin_map, copy=False).groupby(bin_map).indices

        idfs = []
        empty_idf = idf_no_missing[[]]
        for i in range(len(bin_edges) - 1):
            if i in indices:
                idfs.append(idf_no_missing[indices.get(i)])
            else:
                idfs.append(empty_idf)

        idf_missing = None
        if self.keep_na:
            nan_mask = np.isnan(series)
            if np.any(nan_mask):
                idf_missing = idf[nan_mask]

        return BinnedGroupingResult(self.column, bin_edges, idfs, idf_missing)


class BinnedGroupingResult(GroupingResult):
    def __init__(self, column: str, bin_edges: FloatVector, idfs: List[ImmutableDataFrame], idf_missing: Optional[ImmutableDataFrame]):
        self.column = column
        self.bin_edges = bin_edges
        self.idfs = idfs
        self.idf_missing = idf_missing

    def serialize(self) -> BinnedGroupingResultModel:
        return {
            "type": BinnedGrouping.get_type(),
            "edges": list(self.bin_edges),
            "column": self.column,
            "hasNA": self.idf_missing is not None
        }

    def iter_groups(self) -> Iterator[ImmutableDataFrame]:
        for group_idf in self.idfs:
            yield group_idf
        if self.idf_missing is not None:
            yield self.idf_missing
