from __future__ import division

from abc import ABCMeta

import numpy as np
from scipy import interpolate
from scipy.stats import truncnorm
from six import add_metaclass
from sklearn.base import TransformerMixin
from sklearn.base import BaseEstimator

from dataiku.doctor.exploration.emu.handlers.base import DistanceName
from dataiku.doctor.exploration.emu.handlers.base import BaseFeatureHandler
from dataiku.doctor.exploration.emu.handlers.base import OOFeatureHandlerMixin
from dataiku.doctor.exploration.emu.handlers.base import CFFeatureHandlerMixin


@add_metaclass(ABCMeta)
class BaseNumericalFeatureHandler(BaseFeatureHandler):
    def __init__(self, feature_domain, distance_name=DistanceName.EUCLIDEAN):
        """
        :param NumericalFeatureDomain feature_domain: constraints
        :param DistanceName distance_name: Euclidean is only for numeric data
        """
        super(BaseNumericalFeatureHandler, self).__init__(feature_domain)
        self.distance_name = distance_name
        self.std_dev = None
        self.bounds = None  # constraints or dataset's limits

    def fit(self, x, y=None):
        min_value = x.min() if self.feature_domain.min_value is None else self.feature_domain.min_value
        max_value = x.max() if self.feature_domain.max_value is None else self.feature_domain.max_value
        self.bounds = (min_value, max_value)
        filtered_x = self._filter_x(x)
        self.std_dev = filtered_x.std() if len(filtered_x) > 0 else 0
        return self

    def _filter_x(self, x):
        is_within_bounds = (x >= self.bounds[0]) & (x <= self.bounds[1])  # throws useless warnings when there are nans
        return x[is_within_bounds]

    def _round(self, arr):
        n_digits_ = self.feature_domain.n_digits
        if n_digits_ == 'auto':
            n_digits_ = self.find_n_digits_by_heuristic()
        if n_digits_ is not None:
            arr = np.round(arr, n_digits_)
        if self.feature_domain.is_integer:
            arr = arr.astype(int)
        return arr

    def distance(self, x1, x2):
        if self.distance_name == DistanceName.EUCLIDEAN:
            return np.abs(x1 - x2)
        elif self.distance_name == DistanceName.GOWER:
            return np.abs(x1 - x2) / (self.bounds[1] - self.bounds[0])
        else:
            raise NotImplementedError('Supported distances are euclidean or gower.')

    def find_n_digits_by_heuristic(self):
        if self.std_dev > 0:
            # n decimals needed to represent 1% of std_dev
            return int(np.ceil(2 - np.log10(self.std_dev)))
        return None


@add_metaclass(ABCMeta)
class BaseNumericalCFFeatureHandler(BaseNumericalFeatureHandler, CFFeatureHandlerMixin):
    pass


@add_metaclass(ABCMeta)
class BaseNumericalOOFeatureHandler(BaseNumericalFeatureHandler, OOFeatureHandlerMixin):
    pass


class InverseSamplingTransformer(TransformerMixin, BaseEstimator):
    def __init__(self, n_bins=200):
        self.n_bins = n_bins
        self.lb = None
        self.ub = None
        self._cumulative_distribution_function = None
        self._probability_point_function = None

    @staticmethod
    def _aggregate_values_for_interpolation(x, y):
        """
        Make the `x` values unique, and aggregate the `y` values accordingly.

        Context: `scipy.interpolate.interp1d` misbehaves when `x != np.unique(x)`.
            The doc explicitly states: "If the values in x are not unique, the resulting
            behavior is undefined."

        :param np.ndarray x: A 1-D array of values representing the independent variable.
        :param np.ndarray y: A 1-D array of values representing the dependent variable.
        :rtype: (np.ndarray, np.ndarray)
        :return: The first array contains the unique values of x in sorted order, and the
            second array contains the averaged values of y corresponding to each unique value of x.
        """
        unique_x, indices = np.unique(x, return_inverse=True)
        averaged_y = np.bincount(indices, weights=y) / np.bincount(indices)
        return unique_x, averaged_y

    def fit(self, x):
        hist, bin_edges = np.histogram(x, bins=self.n_bins, density=True)
        cum_values = np.zeros(bin_edges.shape)
        cum_values[1:] = np.cumsum(hist * np.diff(bin_edges))
        cum_values[-1] = 1.  # avoid numerical errors in toy dataset
        self.lb = x.min()
        self.ub = x.max()

        self._cumulative_distribution_function = interpolate.interp1d(bin_edges, cum_values)

        # Some bins may be empty, so `cum_values` may have duplicates. So, we must aggregate
        # them before fitting the PPF (probability point function).
        # On the other hand, we should not use the aggregated values for the CDF (cumulative
        # distribution function) because it would lower its precision.
        # A more visual explanation can be found in the PR description:
        #   https://github.com/dataiku/dip/pull/21800
        aggregated_cum_values, aggregated_bin_edges = self._aggregate_values_for_interpolation(cum_values, bin_edges)

        self._probability_point_function = interpolate.interp1d(aggregated_cum_values, aggregated_bin_edges)

    def transform(self, x):
        """
        Transform input array x to a new array by applying the cumulative distribution function (CDF)
        of the underlying probability distribution to each element of x. This function maps each element
        of x to its corresponding percentile rank in the distribution. The resulting array has values
        between 0 and 1, inclusive.

        :type x: np.ndarray
        :return: values between 0 and 1, inclusive
        """
        return np.clip(self._cumulative_distribution_function(np.clip(x, self.lb, self.ub)), 0., 1.)

    def inverse_transform(self, x):
        """
        Transform input array x to a new array by applying the inverse probability point function (PFF)
        of the underlying probability distribution to each element of x. This function maps each element
        of x from its corresponding percentile rank in the distribution to its original value.
        The resulting array has values between the lower bound and upper bound of the distribution.

        It's useful to map random values between 0 and 1 to values that follow a certain distribution.

        :type x: np.ndarray
        :return: values between the lower bound and upper bound of the distribution.
        """
        return np.clip(self._probability_point_function(x), self.lb, self.ub)


class NumericalDistributionHandler(BaseNumericalCFFeatureHandler):
    """
    Generates values around the reference. The distribution is bounded by the constraints.
    """
    def __init__(self, feature_domain, distance_name=DistanceName.EUCLIDEAN):
        super(NumericalDistributionHandler, self).__init__(feature_domain, distance_name)
        self.ipt = InverseSamplingTransformer()

    def fit(self, x, y=None):
        super(NumericalDistributionHandler, self).fit(x, y)
        self.ipt.fit(self._filter_x(x))
        return self

    def distance(self, x1, x2):
        x1 = self.ipt.transform(x1)
        x2 = self.ipt.transform(x2)
        if self.distance_name not in {DistanceName.EUCLIDEAN, DistanceName.GOWER}:
            raise NotImplementedError('Supported distances are euclidean or gower.')
        return np.abs(x1 - x2)  # values are normalized, so we return the same thing for GOWER and EUCLIDIAN distances

    def generate_cf_values(self, x_ref, y_ref, radii):
        """
        :param int or float x_ref: value from the reference point
        :param y_ref: prediction for the reference
        :param np.ndarray radii: floats between -1. and +1. Generated values will generally be:
            - close to x_ref if the radii are close to zero
            - different from x_ref if the radii are close to 1 or -1
        :return: one value for each radius that was given in input
        :rtype: np.ndarray
        """
        cdf_ref = self.ipt.transform(x_ref)
        gen_x = self.ipt.inverse_transform(np.clip(cdf_ref + radii, 0, 1))
        return self._round(gen_x)


class NumericalUniformHandler(BaseNumericalCFFeatureHandler):
    """
    Generates values uniformly between the minimum bound and the maximum bound.
    """
    def generate_cf_values(self, x_ref, y_ref, radii):
        """
        :param int or float x_ref: value from the reference point
        :param y_ref: prediction for the reference
        :param np.ndarray radii: floats between -1. and +1. Generated values will generally be:
            - close to x_ref if the radii are close to zero
            - different from x_ref if the radii are close to 1 or -1
        :return: one value for each radius that was given in input
        :rtype: np.ndarray
        """
        min_bound, max_bound = self.bounds
        return self._round(min_bound + (max_bound - min_bound) * np.abs(radii))


class NumericalNormalOOHandler(BaseNumericalOOFeatureHandler):
    """
    Perturbs samples locally with truncated gaussians and add exploration.
    """
    def generate_from_replicates(self, x, std_factor, n_replicates):
        """
        Generate new samples that resemble the `x` parameter. `n_replicates` samples will be
        generated for each sample in `x`, and the lower `std_factor` will be, the more similar
        the samples will be to `x`.

        :param np.ndarray x: initial values around which we will generate the new values
        :param float std_factor: to control how far should be the generated standard values from `x`
        :param int n_replicates: nb. of new variations of `x` using the standard drawing method
        :return: `(x.size, n_replicates)` new values
        :rtype: np.ndarray
        """
        x = np.clip(x, self.bounds[0], self.bounds[1])  # If `x` is out of bounds, truncnorm.rvs can return inf. values
        if self.bounds[0] == self.bounds[1]:
            values = np.tile(self.bounds[0], (n_replicates, x.shape[0]))
        else:
            left_bound = (self.bounds[0] - x).astype(float)
            right_bound = (self.bounds[1] - x).astype(float)
            if self.std_dev > 0:
                left_bound /= self.std_dev
                right_bound /= self.std_dev
            replicates = np.tile(x, (n_replicates, 1))
            values = replicates + std_factor * truncnorm.rvs(left_bound, right_bound, scale=self.std_dev, size=replicates.shape)
        return self._round(values.T)

    def generate_global_uniform_values(self, n):
        """
        :param int n: size of the global uniform drawing
        :return: `(n,)` new values
        :rtype: np.ndarray
        """
        return self._round(np.random.uniform(self.bounds[0], self.bounds[1], size=n))
