from __future__ import division

from abc import ABCMeta
from abc import abstractmethod
from warnings import catch_warnings
from warnings import simplefilter

import numpy as np
import pandas as pd
from six import add_metaclass
from sklearn.cluster import KMeans

from dataiku.doctor.exploration.emu.algorithms.utils import init_df_from

# If we generated a sample by perturbing another sample, then both samples will share the same genus.
# Samples that were generated without any original sample will not have a genus until they engender new samples.
GENUS_WHEN_NO_PARENT = -1  # designates the samples that don't have a genus (yet)


def _merge_populations(*args):
    """
    :param args: pairs of concatenable pd.Dataframe, and pairs of 1D np.ndarray
    :return: concatenated values, list containing dataframes and numpy arrays
    """
    for (a, b) in args:
        if isinstance(a, pd.DataFrame) and isinstance(b, pd.DataFrame):
            yield pd.concat((a, b)).reset_index(drop=True)
        elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
            yield np.concatenate((a, b))
        else:
            raise TypeError("tuple must contain only 'pd.DataFrame' or only 'np.ndarray'")


class SampleGenerator(object):
    """
    SampleGenerators are used to generate points using the feature handlers.
    """
    def __init__(self, handlers, std_factor, n_replicates, n_global_uniform_values):
        """
        :param handlers: to sample values for individual features
        :type handlers: dict[str, (BaseNumericalOOFeatureHandler | BaseCategoryOOFeatureHandler | FrozenFeatureHandler)]
        :param float std_factor: to control how far should be the generated standard values from `x`
            from 0 (no perturbation) to 1 (global exploration)
        :param int n_replicates: nb. of new variations of `x` using the standard drawing method
        :param int n_global_uniform_values: size of the global uniform drawing
        """
        self.handlers = handlers
        self.std_factor = std_factor
        self.n_replicates = n_replicates
        self.n_global_uniform_values = n_global_uniform_values

    def _generate_from_replicates(self, pop):
        """
        Generation step uses local perturbations of the population.

        The correspondence between the samples of `pop` and the resulting df is:
            n_replicates = 2
            pop = [a, b, c]
            => new_pop = [a', b', c', a*, b*, c*]

        :param pd.DataFrame pop: initial values around which we will generate the new points
        :return: new population
        :rtype: pd.DataFrame
        """
        new_pop = init_df_from(pop)
        for feature_name, handler in self.handlers.items():
            new_pop[feature_name] = (handler.generate_from_replicates(pop[feature_name].values,
                                                                      self.std_factor,
                                                                      self.n_replicates)
                                            .reshape((pop.shape[0] * self.n_replicates,)))
        return new_pop

    def _generate_global_uniform_values(self, pop):
        """
        Generation step uses global exploration.

        :param pd.DataFrame pop: initial values around which we will generate the new points
        :return: new population
        :rtype: pd.DataFrame
        """
        new_pop = init_df_from(pop)
        for feature_name, handler in self.handlers.items():
            new_pop[feature_name] = handler.generate_global_uniform_values(self.n_global_uniform_values)
        return new_pop

    def generate_from(self, pop):
        """
        Generation step uses local perturbations of the samples and adds global exploration.

        The order of the elements in the resulting df can be inferred from the following example:
            pop = [a, b, c]
            n_replicates = 2
            n_global_uniform_values = 3
            => new_pop = [a', b', c', a*, b*, c*, g1, g2, g3]

        :param pd.DataFrame pop: initial values around which we will generate the new points
        :return: new population
        :rtype: pd.DataFrame
        """
        standard_values_df = self._generate_from_replicates(pop)
        global_uniform_values_df = self._generate_global_uniform_values(pop)
        (new_pop,) = _merge_populations((standard_values_df, global_uniform_values_df))
        return new_pop


@add_metaclass(ABCMeta)
class EvolutionaryOutcomeOptimizer(object):
    def __init__(self, handlers, loss, preprocess, std_factor=0.5, n_replicates=10, n_global_uniform_values=50):
        """
        Evolution strategy to optimize a black-box function

        :param handlers: to sample values for individual features
        :type handlers: dict[str, (BaseNumericalOOFeatureHandler | BaseCategoryOOFeatureHandler | FrozenFeatureHandler)]
        :param function loss: function to minimize. e.g. the predict function of a model
        :param function preprocess: function that preprocesses the data to transform them into numerical numpy arrays.
        """
        self.handlers = handlers
        self.n_replicates = n_replicates
        self.n_global_uniform_values = n_global_uniform_values
        self.sample_generator = SampleGenerator(handlers, std_factor, n_replicates, n_global_uniform_values)
        self.loss = loss
        self.preprocess = preprocess

    def _drop_invalid_samples(self, samples):
        """
        :param samples: pd.DataFrame
        :return: df without the samples that were invalid according the FeatureDomains
        """
        status = np.repeat(True, samples.shape[0])
        for feature_name, handler in self.handlers.items():
            status &= handler.feature_domain.check_validity(samples[feature_name].values)
        return samples.iloc[status]

    def _select_diverse_records(self, population, pop_size):
        if population.shape[0] <= pop_size:
            return population
        processed_pop = self.preprocess(population)
        with catch_warnings():
            simplefilter("ignore")  # we don't care if some clusters are duplicates
            kmeans = KMeans(n_clusters=pop_size, random_state=0).fit(processed_pop)
        clusters = kmeans.predict(processed_pop)

        loss = self.loss(population)
        return (population.groupby(clusters)
                .apply(lambda df: df.iloc[loss[df.index.values].argmin()])
                .astype(population.dtypes))

    @abstractmethod
    def _evolve(self, n_generations, pop_size, population, loss_values):
        pass

    def find_optima(self, population, n_generations=40, pop_size=50):
        """
        Evolution that alternates between expansion and reduction of the population.

        :param pd.DataFrame population: initial pool of data
        :param n_generations: number of generations to run the evolution
        :param pop_size: population size at the end of each generation
        :return: population that gives low values for the loss function
        :rtype: pd.DataFrame
        """
        loss_values = self.loss(population)
        final_population = self._evolve(n_generations, pop_size, population, loss_values)

        # The initial pop may not respect the constraints, and these samples can end up
        # in the final population, so we must filter them out.
        final_population = self._drop_invalid_samples(final_population).reset_index(drop=True)
        final_population = self._select_diverse_records(final_population, pop_size).reset_index(drop=True)
        return final_population.drop_duplicates()


class EfficientEvolutionaryOutcomeOptimizer(EvolutionaryOutcomeOptimizer):
    @staticmethod
    def _select(pop, loss_values, new_pop, new_loss_values, pop_size):
        """
        Selection step that reduces the population to its best elements.

        :param pd.DataFrame pop: current population
        :param np.ndarray loss_values: loss function's result for `pop`
        :param pd.DataFrame new_pop: newly generated samples
        :param np.ndarray new_loss_values: loss function's result for `new_pop`
        :param int pop_size: max population size after reduction
        :return: filtered and sorted population
        :rtype: pd.DataFrame
        """
        pop, loss_values = _merge_populations((pop, new_pop), (loss_values, new_loss_values))

        id_bests_sorted = np.argsort(loss_values)[:pop_size]

        selected_pop = pop.reset_index(drop=True).iloc[id_bests_sorted]
        selected_loss_values = loss_values[id_bests_sorted]

        return selected_pop, selected_loss_values

    def _evolve(self, n_generations, pop_size, population, loss_values):
        for _ in range(n_generations):
            new_population = self.sample_generator.generate_from(population)
            new_loss_values = self.loss(new_population)
            population, loss_values = self._select(population, loss_values, new_population, new_loss_values, pop_size)
        return population


class DiverseEvolutionaryOutcomeOptimizer(EvolutionaryOutcomeOptimizer):
    @staticmethod
    def _select(pop, genera, loss_values, new_pop, new_genera, new_loss_values, pop_size):
        """
        Selection step that reduces the population to its best elements.

        :param pd.DataFrame pop: current population
        :param np.ndarray genera: for each row in `pop`, id of its genus
        :param np.ndarray loss_values: loss function's result for `pop`
        :param pd.DataFrame new_pop: newly generated samples
        :param np.ndarray new_genera: for each row in `new_pop`, id of its genus
        :param np.ndarray new_loss_values: loss function's result for `new_pop`
        :param int pop_size: max population size after reduction
        :return: - filtered and sorted population
                 - corresponding genus ids
                 - pantheon (i.e. Best individual from each extinct genus)
        :rtype: (pd.DataFrame, np.ndarray, np.ndarray, pd.DataFrame)
        """
        # Merge old population and new population
        all_pop, all_genera, all_loss_values = _merge_populations((pop, new_pop),
                                                                  (genera, new_genera),
                                                                  (loss_values, new_loss_values))

        # Compute the quota per genus
        quotas = DiverseEvolutionaryOutcomeOptimizer._get_quota_per_genus(loss_values, genera,
                                                                          new_loss_values, new_genera,
                                                                          pop_size)

        # Genera with 0 quota are dead, so we put their best individual in the pantheon to honor its memory
        best_of_each_genus = (all_pop.groupby(all_genera)
                              .apply(lambda df: df.iloc[all_loss_values[df.index.values].argmin()])
                              .astype(all_pop.dtypes))
        genera_with_zero_quota = quotas.index[quotas == 0]
        pantheon = best_of_each_genus.loc[genera_with_zero_quota].reset_index(drop=True)

        # Fill each genus' quota, by selecting the best individuals per genus
        def get_best_individuals(df):
            sorted_indices = all_loss_values[df.index.values].argsort()
            sorted_df = df.iloc[sorted_indices]
            quota = quotas.loc[df.name]
            if quota > df.shape[0]:
                # The quota exceeds the size of the genus, so we duplicate the
                # genus until it's big enough.
                n_replications = int(np.ceil(quota / df.shape[0]))
                return pd.concat([sorted_df] * n_replications).head(quota)
            return sorted_df.head(quota)

        selected_pop = all_pop.groupby(all_genera).apply(get_best_individuals)
        selected_loss_values = all_loss_values[selected_pop.index.droplevel()]
        selected_pop.reset_index(drop=True, inplace=True)

        # Using the quotas, we can build the new `genera` array
        selected_genera = np.concatenate(pd.DataFrame(quotas.sort_index())
                                           .apply(lambda x: np.repeat(x.name, x[0]), axis=1)
                                           .values)

        # The selected individuals that don't have parents start new genera
        n_individuals_with_no_parent = (selected_genera == GENUS_WHEN_NO_PARENT).sum()
        unused_genus_ids = np.arange(n_individuals_with_no_parent) + np.max(selected_genera) + 1
        selected_genera[selected_genera == GENUS_WHEN_NO_PARENT] = unused_genus_ids

        return selected_pop, selected_genera, selected_loss_values, pantheon

    @staticmethod
    def _get_quota_per_genus(loss_values, genera, new_loss_values, new_genera, pop_size):
        """
        The "quota" of a given genus is the number of individuals from that
        genus that will be allowed to make it to the next generation of the
        population.
        e.g. If the quota of a genus is lower than the size of the genus,
        some individuals from that genus won't be selected and, subsequently,
        will die.

        The sum of the quota for each genus must be approximately equal to
        the `pop_size``

        To choose the quota for each genus, we apply the given rule:
         - ~10% of the population should be made of individuals that don't have
            a genus yet, to favor exploration. (e.g. samples resulting from the
            global uniform drawing)
        - ~70% of the population should be made of individuals that come from
            the most interesting genus. The "most interesting" genus is the
            genus with the minimal loss, among the genera that improved
            their loss during the last sample generation.
        - ~20% of the population should be made of other genera that improved
            their loss during the last sample generation.

        :param np.ndarray loss_values: loss function's result for the current pop
        :param np.ndarray genera: for each row in the current pop, id of its genus
        :param np.ndarray new_loss_values: loss function's result for the newly generated samples
        :param np.ndarray new_genera: for each row in the newly generated samples, id of its genus
        :param int pop_size: max population size after reduction
        :return: Series with: index=genus and values=number_of_individuals_to_keep
        :rtype: pd.Series
        """
        base_quotas = {
            "MOST_INTERESTING_GENUS": .7 * pop_size,
            "GENERA_THAT_IMPROVED": .2 * pop_size,
            "EXPLORATION": .1 * pop_size,
        }

        loss_df = pd.DataFrame({"loss_old_pop": loss_values, "genus": genera})
        new_loss_df = pd.DataFrame({"loss_new_pop": new_loss_values, "genus": new_genera})

        loss_per_genus = loss_df.groupby('genus').min()
        new_loss_per_genus = new_loss_df.groupby('genus').min()

        genus_scores = loss_per_genus.join(new_loss_per_genus, how='outer')
        genus_scores["improvement"] = (genus_scores["loss_old_pop"] - genus_scores["loss_new_pop"]).clip(0)
        genus_scores.at[GENUS_WHEN_NO_PARENT, "improvement"] = 0.
        genus_scores["loss"] = genus_scores[["loss_old_pop", "loss_new_pop"]].min(axis=1)

        # Find most interesting genus
        interesting_genus_scores = genus_scores
        if (genus_scores["improvement"] > 0).any():
            interesting_genus_scores = genus_scores[genus_scores["improvement"] > 0]
        most_interesting_genus = interesting_genus_scores["loss"].idxmin()

        # Compute quota
        quotas = genus_scores["improvement"]
        quotas.at[most_interesting_genus] = 0  # Temporarily set to 0 to permit normalization
        if (quotas > 0).any():
            quotas = (quotas / quotas.sum()) * base_quotas["GENERA_THAT_IMPROVED"]
            quotas.loc[most_interesting_genus] = base_quotas["MOST_INTERESTING_GENUS"]
        else:
            quotas.loc[most_interesting_genus] = base_quotas["MOST_INTERESTING_GENUS"] + base_quotas["GENERA_THAT_IMPROVED"]
        quotas.loc[GENUS_WHEN_NO_PARENT] = base_quotas["EXPLORATION"]

        return np.ceil(quotas).astype(int)

    def _get_new_genera(self, genera):
        """
        The sample generator puts the different generated values in a well-defined order:
            pop = [a, b, c]
            n_replicates = 2
            n_global_uniform_values = 3
            => new_pop = [a', b', c', a*, b*, c*, g1, g2, g3]
        Using this information, we can find the genus of each sample in new_pop.

        :param genera: current genera
        :return: genera for the new population
        :rtype: np.ndarray
        """
        return np.concatenate((
            np.repeat(genera, self.n_replicates),
            np.repeat(GENUS_WHEN_NO_PARENT, self.n_global_uniform_values)
        ))

    def _evolve(self, n_generations, pop_size, population, loss_values):
        # If we generated a sample by perturbing another sample, then both samples will share the same genus
        genera = np.arange(population.shape[0])  # Initially, all samples have a different genus
        pantheon = init_df_from(population)  # Contains the best individual from each extinct genus
        for _ in range(n_generations):
            new_population = self.sample_generator.generate_from(population)
            new_genera = self._get_new_genera(genera)
            new_loss_values = self.loss(new_population)
            population, genera, loss_values, new_notable_deaths = self._select(population, genera, loss_values,
                                                                               new_population, new_genera, new_loss_values,
                                                                               pop_size)
            pantheon = pd.concat([pantheon, new_notable_deaths])
        return pantheon
