from __future__ import division

from abc import ABCMeta

import pandas as pd
import numpy as np
from six import add_metaclass

from dataiku.doctor.exploration.emu.handlers.base import CFFeatureHandlerMixin
from dataiku.doctor.exploration.emu.handlers.base import OOFeatureHandlerMixin
from dataiku.doctor.exploration.emu.handlers.base import BaseFeatureHandler


@add_metaclass(ABCMeta)
class BaseCategoryFeatureHandler(BaseFeatureHandler):
    def __init__(self, feature_domain):
        """
        :param CategoricalFeatureDomain feature_domain: constraints
        """
        super(BaseCategoryFeatureHandler, self).__init__(feature_domain)
        self.categories = None
        self.ratios = None

    def fit(self, x, y=None):
        if self.feature_domain.categories is None:
            filtered_x = x
        else:
            filtered_x = x[self.feature_domain.check_validity(x)]
        self.categories, counts = np.unique(filtered_x, return_counts=True)
        self.ratios = counts / np.sum(counts)
        return self

    def distance(self, x1, x2):
        return int(x1 != x2)


@add_metaclass(ABCMeta)
class BaseCategoryCFFeatureHandler(BaseCategoryFeatureHandler, CFFeatureHandlerMixin):
    def __init__(self, feature_domain, target=None):
        """
        :param CategoricalFeatureDomain feature_domain: constraints
        :param None or int target: desired class for the counterfactuals
        """
        super(BaseCategoryCFFeatureHandler, self).__init__(feature_domain)
        self.target = target


@add_metaclass(ABCMeta)
class BaseCategoryOOFeatureHandler(BaseCategoryFeatureHandler, OOFeatureHandlerMixin):
    pass


class CategoryDistributionHandler(BaseCategoryCFFeatureHandler):
    """
    Picks values from `x` following the distribution of the values of `x`.
    The proportion of `x_ref` in the drawing depends on the `radii` that were provided.
    """
    def generate_cf_values(self, x_ref, y_ref, radii):
        """
        :param x_ref: categorical value from the reference point
        :param y_ref: prediction for the reference
        :param np.ndarray radii: floats between -1. and +1. Generated values will generally be:
            - close to x_ref if the radii are close to zero
            - different from x_ref if the radii are close to 1 or -1
        :return: one value for each radius that was given in input
        :rtype: np.ndarray
        """
        if x_ref in self.categories:
            should_keep_x_ref = np.abs(radii) < np.random.uniform(size=radii.shape[0])
        else:
            should_keep_x_ref = np.zeros_like(radii).astype(bool)
        x_gen = np.random.choice(self.categories, size=radii.shape[0], p=self.ratios)
        x_gen[should_keep_x_ref] = x_ref
        return x_gen


class CategoryTargetEncodingHandler(BaseCategoryCFFeatureHandler):
    """
    Picks values from `x` following the distribution of the values of `x` for which the model's
    prediction would be `target` (or anything but `y_ref` if `target` is None).
    The proportion of `x_ref` in the drawing depends on the `radii` that were provided.
    """
    def __init__(self, feature_domain, target=None):
        super(CategoryTargetEncodingHandler, self).__init__(feature_domain, target)
        self.classes = None
        self.category_distribution_per_class = None
        self.category_distribution_all_classes_but_one = None

    def _get_base_distribution(self, y_ref):
        """
        :param str y_ref: prediction for the reference
        :return: array that contains the weight of each category in the distribution
        of points for which the prediction is:
          - Anything but y_ref if `fit` was called with target == None
          - target if `fit` was called with target != None
        :rtype: np.ndarray
        """
        if self.target is None:
            return self.category_distribution_all_classes_but_one[y_ref].values
        return self.category_distribution_per_class[self.target].values

    def fit(self, x, y=None):
        super(CategoryTargetEncodingHandler, self).fit(x, y)
        self.classes = np.unique(y)

        df = pd.DataFrame({"column": x, "target": y})

        # Only keep relevant categories
        df = df[df["column"].isin(self.categories)]

        categories_counts = df.groupby(["column", "target"]).size().unstack()

        # Add potential missing (class, category) in the count
        categories_counts = (categories_counts.reindex(columns=self.classes)
                             .reindex(index=self.categories)
                             .fillna(0))

        class_counts = pd.Series(y).value_counts()

        # category_distribution_per_class_df[y_0][x_0] is the weight of the
        # category `x_0` in the distribution of `x` when considering only class `y_0`.
        # Example:
        #    x = [A  A  B  B  B  B ]
        #    y = [t1 t1 t1 t2 t2 t2]
        #    category_distribution_per_class_df =     t1  t2
        #                                         A  2/3   0
        #                                         B  1/3   1
        self.category_distribution_per_class = categories_counts / class_counts  # cols: classes, rows: categories

        if self.target is not None:
            return self

        # When the target is None, we need to aggregate the distributions
        # of all classes but y_ref. Since we can't know y_ref in advance, we
        # compute these distributions for each possible y_ref.

        # Example:
        #    x = [A  A  B  B  B  B ]
        #    y = [t1 t1 t1 t2 t2 t2]
        #    sum_category_distribution = A  2/3 + 0
        #                                B  1/3 + 1
        sum_category_distribution = self.category_distribution_per_class.sum(axis=1)

        # We build a distribution of the categories, for each possible
        # class, such that, for a class `j`, the distribution of the
        # category `i` is the weight of category `i` in the distribution
        # of the data that consists of all the samples except for the ones
        # that have class `j`.
        #    x = [A  A  B  B  B  B ]
        #    y = [t1 t1 t1 t2 t2 t2]
        #    not_normalized =           t1        t2
        #                     A  2/3 - 2/3   2/3 - 0
        #                     B  4/3 - 1/3   4/3 - 1
        not_normalized = (-self.category_distribution_per_class).add(sum_category_distribution, axis=0)
        self.category_distribution_all_classes_but_one = not_normalized / not_normalized.sum(axis=0)
        return self

    def generate_cf_values(self, x_ref, y_ref, radii):
        """
        :param x_ref: categorical value from the reference point
        :param y_ref: prediction for the reference
        :param np.ndarray radii: floats between -1. and +1. Generated values will generally be:
            - close to x_ref if the radii are close to zero
            - different from x_ref if the radii are close to 1 or -1
        :return: one value for each radius that was given in input
        :rtype: np.ndarray
        """
        categories = self.categories
        batch_size = radii.shape[0]

        # Example:
        #    abs_radii = [.2, .4]
        abs_radii = np.abs(radii)

        if len(categories) == 1:
            return np.tile(categories[0], (batch_size,))

        # The goal is to generate `batch_size` points, each following slightly different distributions depending on
        # abs_radii. Because calling np.random.choice `batch_size` times is costly, we use a different approach.

        # We want to apply the following method:
        # - compute the proba of keeping x_ref
        # - see if we keep x_ref or not
        # - if we keep it, end
        # - if we do not, then draw among the remaining categories

        # First, we draw values following the distribution we would have had if keeping x_ref wasn't allowed.
        base_distribution = self._get_base_distribution(y_ref)

        without_xref_mask = categories != x_ref
        distribution_without_xref = base_distribution[without_xref_mask]
        sum_distribution_without_xref = distribution_without_xref.sum()
        if np.isclose(sum_distribution_without_xref, 0):  # array is full of zeros, draw uniformly
            distribution_without_xref = np.ones_like(distribution_without_xref) / len(distribution_without_xref)
        else:
            distribution_without_xref = distribution_without_xref / sum_distribution_without_xref
        categories_without_xref = categories[without_xref_mask]
        draw_without_xref = np.random.choice(categories_without_xref, size=batch_size, p=distribution_without_xref)

        # Now, we compute the proba of keeping x_ref, which depends on both the radii and the base distribution.
        if x_ref not in categories:
            return draw_without_xref

        # The formula to find p(x_ref) is a little but complex, so to illustrate it:
        # Say we have 4 classes a, b, c, d. The ref value is b. So we have:
        #   base_distribution = bd
        #   assert(sum(base_distribution) == 1)
        # The first step is to compute p from the base distribution bd:
        #   p(b) = bd(b) * (1 - radius)
        # But this is not correct because we need to normalize. The whole process is therefore:
        #   p(b) = bd(b) * (1 - radius) / [bd(b) * (1 - radius) + (bd(a) + bd(c) + bd(d)) * radius]
        # The normalisation term is cumbersome. To make the computation easier, we decompose its
        # first part using:
        #   bd(b) * (1 - radius) = bd(b) * (1 - radius - radius) + bd(b) * radius
        # We obtain:
        #   p(b) = bd(b) * (1 - radius) / [bd(b) * (1 - radius) + (bd(a) + bd(c) + bd(d)) * radius]
        #   p(b) = bd(b) * (1 - radius) / [bd(b) * (1 - radius - radius) + bd(b) * radius + (bd(a) + bd(c) + bd(d)) * radius]
        #   p(b) = bd(b) * (1 - radius) / [bd(b) * (1 - 2 * radius) + (bd(a) + bd(b) + bd(c) + bd(d)) * radius]
        # And since the sum of all base distributions is 1:
        #   p(b) = bd(b) * (1 - radius) / [bd(b) * (1 - 2 * radius) + radius]
        bd_x_ref = base_distribution[categories == x_ref]
        proba_keeping_x_ref = ((1 - abs_radii) * bd_x_ref) / ((1 - 2 * abs_radii) * bd_x_ref + abs_radii)

        # Finally, we replace some values we drew by x_ref to get the final drawing
        return np.where(np.random.random(size=batch_size) <= proba_keeping_x_ref, x_ref, draw_without_xref)


class CategoryDistributionOOHandler(BaseCategoryOOFeatureHandler):
    """
    Perturb samples locally with uniform sampling of categories and add exploration.
    """
    def generate_from_replicates(self, x, std_factor, n_replicates):
        """
        Generate new samples that resemble the `x` parameter. `n_replicates` samples will be
        generated for each sample in `x`, and the lower `std_factor` will be, the more similar
        the samples will be to `x`.

        :param np.ndarray x: initial values around which we will generate the new values
        :param float std_factor: to control how far should be the generated standard values from `x`
        :param int n_replicates: nb. of new variations of `x` using the standard drawing method
        :return: `(x.size, n_replicates)` new values
        :rtype: np.ndarray
        """
        # here std_factor plays the role of max probability of not sampling the current category this_x
        radii = np.random.uniform(0., std_factor, (x.shape[0], n_replicates))

        # we sample values for each feature following their distribution
        values = np.random.choice(self.categories, size=radii.shape, p=self.ratios)

        # To make sure that the new samples are similar to `x`:
        # for each new sample and for each feature, we either take the new sampled value, or `x`'s corresponding value.
        replicates = np.tile(x.reshape((x.shape[0], 1)), n_replicates)
        should_keep_x_ref = np.abs(radii) < np.random.uniform(size=radii.shape)
        should_keep_x_ref = np.where(np.isin(replicates, self.categories), should_keep_x_ref, False)
        return np.where(should_keep_x_ref, replicates, values)

    def generate_global_uniform_values(self, n):
        """
        :param int n: size of the global uniform drawing
        :return: `(n,)` new values
        :rtype: np.ndarray
        """
        return np.random.choice(self.categories, size=n)
