import random
import numpy as np
import pandas as pd
from deap import base, creator, tools
from sklearn.linear_model import LogisticRegression, Lasso
import fitness_function as ff

def get_df_with_interactions(df_X):
    features = df_X.columns
    (N, n_features) = df_X.shape
    df_out = df_X.copy()
    for i1 in range(n_features):
        f1 = features[i1]
        s1 = df_X[f1]
        for i2 in range(i1, n_features):
            f2  = features[i2]
            s2 = df_X[f2]
            if f1 == f2:
                f_square = "dku_gen:{}_(^2)".format(f1)
                df_out[f_square] = s1 * s1
            else:
                f_prod = "dku_gen:{}_(*)_{}".format(f1, f2)
                df_out[f_prod] = s1 * s2
                f_sum = "dku_gen:{}_(+)_{}".format(f1, f2)
                df_out[f_sum] = s1 + s2
                f_diff = "dku_gen:{}_(-)_{}".format(f1, f2)
                df_out[f_diff] = s1 - s2
    return df_out


def transform_df(df_X, columns):
    df_res = pd.DataFrame()
    for column in columns:
        if column.startswith("dku_gen:"):
            if "_(^2)" in column:
                orig_col = column.split("dku_gen:")[1].split("_(^2)")[0]
                s = df_X[orig_col] * df_X[orig_col]
                df_res[column] = s
            elif "_(*)_" in column:
                [orig_col1, orig_col2] = column.split("dku_gen:")[1].split("_(*)_")
                s = df_X[orig_col1] * df_X[orig_col2]
                df_res[column] = s
            elif "_(+)_" in column:
                [orig_col1, orig_col2] = column.split("dku_gen:")[1].split("_(+)_")
                s = df_X[orig_col1] + df_X[orig_col2]
                df_res[column] = s
            elif "_(-)_" in column:
                [orig_col1, orig_col2] = column.split("dku_gen:")[1].split("_(-)_")
                s = df_X[orig_col1] - df_X[orig_col2]
                df_res[column] = s
            else:
                raise Exception("Column not recognized !")
        else:
            df_res[column] = df_X[column].values
    return df_res
                

class FeatureSelectionGA:
    """
    FeaturesSelectionGA
    This class uses Genetic Algorithm to find out the best features for an input model
    using Distributed Evolutionary Algorithms in Python (DEAP) package. Default toolbox is
    used for GA but it can be changed accordingly.
    """
    def __init__(self, df_X, y, model=None, gen_ratio=1., cv_split=5, verbose=0):
        """
        Parameters
        -----------
        x: {array-like}, shape = [n_samples, n_features]
           Training vectors, where n_samples is the number of samples 
           and n_features is the number of features.
        y: {array-like}, shape = [n_samples]
           Target Values
        model: scikit-learn supported model, 
        cv_split: int
                  Number of splits for cross_validation to compute fitness.
        verbose: 0 or 1
        """
        self.toolbox = None
        self.creator = self._create()
        self.cv_split = cv_split
        self.n_features_in = df_X.shape[1]
        df_X = get_df_with_interactions(df_X)
        self.df_X = df_X
        self.X = self.df_X.values
        self.n_features_out = df_X.shape[1]
        print(self.df_X.shape)
        print(self.X.shape)
        self.y = y
        self.gen_ratio = gen_ratio
        
        if len(np.unique(self.y)) > 10:
            self.task = "regression"
        else:
            self.task = "classification"
        self.model = model
        if self.model is None:
            if self.task == "regression":
                self.model = Lasso()
            else:
                self.model = LogisticRegression()
        self.verbose = verbose
        if self.verbose==1:
            print("Model {} will select best features among {} features using cv_split :{}.".format(model, df_X.shape[1], cv_split))
            print("Shape of train_X: {} and target: {}".format(df_X.shape, y.shape))
        self.final_fitness = []
        self.fitness_in_generation = {}
        self.best_ind = None

    def evaluate(self, individual):
        fit_obj = ff.FitnessFunction(n_splits=self.cv_split, task=self.task)
        np_ind = np.asarray(individual)
        if np.sum(np_ind) == 0:
            fitness = 0.0
        else:
            feature_idx = np.where(np_ind == 1)[0]
            fitness = fit_obj.compute_fitness(self.model, self.X[:, feature_idx], self.y)
        
        if self.verbose == 1:
            print("Individual: {}  Fitness_score: {} ".format(individual, fitness))
            
        return fitness,
    
    def _create(self):
        creator.create("FeatureSelect", base.Fitness, weights=(1.0,))
        creator.create("Individual", list, fitness=creator.FeatureSelect)
        return creator
    
    def _init_individual(self):
        threshold = 0.2
        base = [1 for _ in range(self.n_features_in)]
        combined = [int(np.random.rand() < threshold) for _ in range(self.n_features_out - self.n_features_in)]
        return base + combined
    
    def _init_toolbox(self):
        toolbox = base.Toolbox()
        toolbox.register("attr_bool", random.randint, 0, 1)
        # Structure initializers
        toolbox.register("individual", lambda x, y, z: creator.Individual(self._init_individual()), creator.Individual, toolbox.attr_bool, self.n_features_out)
        toolbox.register("population", tools.initRepeat, list, toolbox.individual)
        return toolbox
        
    def _default_toolbox(self):
        """ 
        Register custom created toolbox. Evalute function will be registerd
        in this method.
        Parameters
        ----------
            Registered toolbox with crossover, mutate, select tools except evaluate
        Returns
        -------
            self
        """
        toolbox = self._init_toolbox()
        toolbox.register("mate", tools.cxTwoPoint)
        toolbox.register("mutate", tools.mutFlipBit, indpb=0.1)
        toolbox.register("select", tools.selTournament, tournsize=3)
        toolbox.register("evaluate", self.evaluate)
        return toolbox
    
    def get_final_scores(self,pop,fits):
        self.final_fitness = list(zip(pop, fits))
        
    def generate(self, n_pop, cxpb=0.5, mutxpb=0.2, n_gen=5, set_toolbox=False):
        """ 
        Generate evolved population
        Parameters
        ----------
        n_pop : {int}
                population size
        cxpb  : {float}
                crossover probablity
        mutxpb: {float}
                mutation probablity
        n_gen : {int}
                number of generations
        set_toolbox : {boolean}
                      If True then you have to create custom toolbox before calling 
                      method. If False use default toolbox.
        Returns
        -------
            Fittest population
        """
        if self.verbose==1:
            print("Population: {}, crossover_probablity: {}, mutation_probablity: {}, total generations: {}".format(n_pop, cxpb, mutxpb, n_gen))
        
        if not set_toolbox:
            self.toolbox = self._default_toolbox()
        else:
            raise Exception("Please create a toolbox. Use create_toolbox to create and register_toolbox to register. Else set set_toolbox = False to use default toolbox")
        pop = self.toolbox.population(n_pop)
        CXPB, MUTPB, NGEN = cxpb, mutxpb, n_gen

        # Evaluate the entire population
        print("EVOLVING.......")
        fitnesses = list(map(self.toolbox.evaluate, pop))
        
        for ind, fit in zip(pop, fitnesses):
            ind.fitness.values = fit

        for g in range(NGEN + 1):
            print("-- GENERATION {} --".format(g))
            offspring = self.toolbox.select(pop, int(self.gen_ratio*len(pop)))
            self.fitness_in_generation[str(g)] = max([ind.fitness.values[0] for ind in pop])
            # Clone the selected individuals
            offspring = list(map(self.toolbox.clone, offspring))

            # Apply crossover and mutation on the offspring
            for child1, child2 in zip(offspring[::2], offspring[1::2]):
                if random.random() < CXPB:
                    self.toolbox.mate(child1, child2)
                    del child1.fitness.values
                    del child2.fitness.values

            for mutant in offspring:
                if random.random() < MUTPB:
                    self.toolbox.mutate(mutant)
                    del mutant.fitness.values

            # Evaluate the individuals with an invalid fitness
            weak_ind = [ind for ind in offspring if not ind.fitness.valid]
            fitnesses = list(map(self.toolbox.evaluate, weak_ind))
            for ind, fit in zip(weak_ind, fitnesses):
                ind.fitness.values = fit
            print("Evaluated %i individuals" % len(weak_ind))

            # The population is entirely replaced by the offspring
            pop[:] = offspring
            
        # Gather all the fitnesses in one list and print the stats
        fits = [ind.fitness.values[0] for ind in pop]
        
        length = len(pop)
        mean = sum(fits) / length
        sum2 = sum(x*x for x in fits)
        std = abs(sum2 / length - mean**2)**0.5
        if self.verbose==1:
            print("  Min %s" % min(fits))
            print("  Max %s" % max(fits))
            print("  Avg %s" % mean)
            print("  Std %s" % std)
    
        print("-- Only the fittest survives --")

        self.best_ind = tools.selBest(pop, 1)[0]
        print("Best individual is %s, %s" % (self.best_ind, self.best_ind.fitness.values))
        self.get_final_scores(pop,fits)
        
        return pop