from abc import abstractmethod
import dataiku

modeling_metrics = {'auc': 'ROC_AUC', 
                    'f1': 'F1', 
                    'accuracy': 'ACCURACY', 
                    'precision': 'PRECISION', 
                    'recall': 'RECALL', 
                    'costMatrixGain': 'COST_MATRIX', 
                    'logLoss': 'LOG_LOSS', 
                    'lift': 'CUMULATIVE_LIFT', 
                    'customMetricsResults': 'CUSTOM'}

class FeatureSelection():
    
    def __init__(self, **kwargs):
        project_key = dataiku.get_custom_variables()["projectKey"]
        client = dataiku.api_client()
        project = client.get_project(project_key)
        self.saved_model_id = kwargs.get('saved_model_id')
        self.ml_task = project.get_ml_task(kwargs.get('analysis_id'), kwargs.get('ml_task_id'))
        self.ml_task.guess()
        self.ml_task.wait_guess_complete()
        self.settings = self.ml_task.get_settings()
        self.settings_raw = self.settings.get_raw()
        self.preprocessing = self.settings_raw['preprocessing']
        self.algorithm_name = 'CustomPyPredAlgo_generalized-linear-models_generalized-linear-models_binary-classification'
        self.selected_features = None
        self.settings_raw['splitParams']['ssdSelection']['samplingMethod'] = 'RANDOM_FIXED_NB'
        self.settings_raw['splitParams']['ssdSelection']['maxRecords'] = 100000
        for model in self.settings_raw['modeling'].keys():
            if type(self.settings_raw['modeling'][model]) is dict:
                if 'enabled' in self.settings_raw['modeling'][model].keys():
                    self.settings_raw['modeling'][model]['enabled'] = False
        for model in self.settings_raw['modeling']['plugin_python'].keys():
            if type(self.settings_raw['modeling']['plugin_python']) is dict:
                if 'enabled' in self.settings_raw['modeling']['plugin_python'][model].keys():
                    self.settings_raw['modeling']['plugin_python'][model]['enabled'] = False
        self.settings_raw['modeling']['plugin_python'][self.algorithm_name]['enabled'] = True
        self.features = list(self.preprocessing['per_feature'].keys())
        for feature in self.features:
            if feature == 'id':
                self.preprocessing['per_feature'][feature]['role'] = 'REJECT'
            elif feature == 'credit_event':
                self.preprocessing['per_feature'][feature]['role'] = 'TARGET'
            else:
                self.preprocessing['per_feature'][feature]['role'] = 'INPUT'
                self.preprocessing['per_feature'][feature]['rescaling'] = 'NONE'
        algo_settings = self.settings.get_algorithm_settings(self.algorithm_name)
        algo_settings['params']['penalty'] = [0]
        self.settings.save()
        self.nb_features = kwargs.get('nb_features')
        self.clean_all_training_sessions()

    def clean_all_training_sessions(self):
        model_ids = self.ml_task.get_trained_models_ids()
        for model in model_ids:
            self.ml_task.delete_trained_model(model)

class StepwiseSelection(FeatureSelection):
    
    def __init__(self, **kwargs):
        super(StepwiseSelection, self).__init__(**kwargs)
        self.metric = kwargs.get('metric')
        self.settings_raw['modeling']['metrics']['evaluationMetric'] = modeling_metrics[self.metric]
        self.settings.save()
        self.features = [f for f in self.features if self.preprocessing['per_feature'][f]['role'] == 'INPUT']
        self.forward_backward = None
    
    @abstractmethod
    def set_features(self, tested_features, keep_feature):
        pass
                
    def launch_selection(self):
        for step in range(self.nb_features):
            print(step)
            self.compute_step(step)
            self.update_selected_features(step)
            self.clean_training_sessions(step)
        self.deploy_model()
    
    def compute_step(self, step):
        tested_features = [f for f in self.features if f not in self.selected_features]
        for feature in self.selected_features:
            self.preprocessing['per_feature'][feature]['role'] = 'INPUT'
        for keep_feature in tested_features:
            print(keep_feature)
            self.set_features(tested_features, keep_feature)
            self.settings.save()
            self.ml_task.start_train(session_name=self.forward_backward + ' Step ' + str(step + 1) + ' - ' + keep_feature)
            self.ml_task.wait_train_complete()
    
    def compute_selected_feature(self, step):
        model_ids = self.ml_task.get_trained_models_ids()
        model_perf = dict()

        for model in model_ids:
            snippet = self.ml_task.get_trained_model_snippet(model)
            if 'sessionName' in snippet['userMeta']:
                if ('Step ' + str(step + 1)) in snippet['userMeta']['sessionName']:
                    model_perf[model] = {'variable': snippet['userMeta']['sessionName'].replace(self.forward_backward + ' Step ' + str(step + 1) + ' - ', ''),
                                         'perf': snippet[self.metric] if self.metric in snippet else 0}

        max_perf = max([model_perf[m]['perf'] for m in model_perf])
        self.best_model = [k for k, v in model_perf.items() if v['perf'] == max_perf][0]

        return model_perf[self.best_model]['variable']
    
    def update_selected_features(self, step):
        selected_feature = self.compute_selected_feature(step)
        self.selected_features.append(selected_feature)
    
    def clean_training_sessions(self, step):
        model_ids = self.ml_task.get_trained_models_ids()
        
        for model in model_ids:
            snippet = self.ml_task.get_trained_model_snippet(model)
            if (self.forward_backward + ' Step ' + str(step + 1)) in snippet['userMeta']['sessionName']:
                if model != self.best_model:
                    self.ml_task.delete_trained_model(model)
    
    def deploy_model(self):
        self.ml_task.redeploy_to_flow(model_id=self.best_model, saved_model_id=self.saved_model_id)

class ForwardStepwiseSelection(StepwiseSelection):
    
    def __init__(self, **kwargs):
        print('forward')
        super(ForwardStepwiseSelection, self).__init__(**kwargs)
        self.selected_features = []
        self.forward_backward = 'Forward'
        
    def set_features(self, tested_features, keep_feature):
        for feature in tested_features:
            if feature == keep_feature:
                self.preprocessing['per_feature'][feature]['role'] = 'INPUT'
            else:
                self.preprocessing['per_feature'][feature]['role'] = 'REJECT'
                
    def update_selected_features(self, step):
        selected_feature = self.compute_selected_feature(step)
        self.selected_features.append(selected_feature)

class LassoSelection(FeatureSelection):
    
    def __init__(self, **kwargs):
        super(LassoSelection, self).__init__(**kwargs)
        self.selected_features = None
        self.features = list(self.preprocessing['per_feature'].keys())
        self.features = [f for f in self.features if self.preprocessing['per_feature'][f]['role'] == 'INPUT']
        self.penalty_variables = dict()
    
    def launch_selection(self):
        penalty = 0.001
        self.run_lasso_training(penalty)
        self.update_selection()
        nb_variables_coefs = {len(variables): penalty for penalty, variables in self.penalty_variables.items()}
        nb_variables = nb_variables_coefs.keys()
        while self.nb_features not in nb_variables:
            if len(nb_variables) > 0:
                before = [nb for nb in nb_variables if nb < self.nb_features]
                if len(before) > 0:
                    closest_before = max(before)
                after = [nb for nb in nb_variables if nb > self.nb_features]
                if len(after) > 0:
                    closest_after = min(after)
                if len(before) == 0 and len(after) > 0:
                    penalty = nb_variables_coefs[closest_after] * 10
                elif len(before) > 0 and len(after) == 0:
                    penalty = nb_variables_coefs[closest_before] / 10
                elif len(before) > 0 and len(after) > 0:
                    penalty = (nb_variables_coefs[closest_before] + nb_variables_coefs[closest_after]) / 2
            self.run_lasso_training(penalty)
            self.update_selection()
            self.clean_models()
            nb_variables_coefs = {len(variables): penalty for penalty, variables in self.penalty_variables.items()}
            nb_variables = nb_variables_coefs.keys()
        self.deploy_model()
    
    def run_lasso_training(self, penalty):
        algo_settings = self.settings.get_algorithm_settings(self.algorithm_name)
        algo_settings['params']['penalty'] = [penalty]
        self.settings.save()
        self.ml_task.start_train(session_name='Lasso Step ' + str(penalty))
        self.ml_task.wait_train_complete()

    def get_model_and_variables(self, model_id):
        model_details = self.ml_task.get_trained_model_details(model_id)
        variables = model_details.get_raw().get('iperf').get('lmCoefficients').get('variables')
        parsed_variables = list(set([variable.split(':')[1] if ':' in variable else variable for variable in variables]))
        penalty = model_details.get_raw().get('modeling').get('plugin_python_grid').get('params').get('penalty')[0]
        return penalty, parsed_variables
        
    def update_selection(self):
        model_ids = self.ml_task.get_trained_models_ids()
        models = [self.get_model_and_variables(model_id) for model_id in model_ids]
        self.penalty_variables = {model[0]: model[1] for model in models}
        
    def clean_models(self):
        variable_penalties = dict()
        for penalty, variables in self.penalty_variables.items():
            try:
                variable_penalties[len(variables)].append(penalty)
            except KeyError:
                variable_penalties[len(variables)] = [penalty]
        for nb_variables in variable_penalties.keys():
            if nb_variables > self.nb_features and len(variable_penalties[nb_variables]) > 1:
                self.delete_models([penalty for penalty in variable_penalties[nb_variables] if penalty != max(variable_penalties[nb_variables])])
            else:
                self.delete_models([penalty for penalty in variable_penalties[nb_variables] if penalty != min(variable_penalties[nb_variables])])
    
    def delete_models(self, penalties):
        model_ids = self.ml_task.get_trained_models_ids()
        
        for model in model_ids:
            snippet = self.ml_task.get_trained_model_snippet(model)
            for penalty in penalties:
                if snippet['userMeta']['sessionName'].endswith(str(penalty)):
                    self.ml_task.delete_trained_model(model)
    
    def deploy_model(self):
        model_ids = self.ml_task.get_trained_models_ids()
        models = [self.get_model_and_variables(model_id) for model_id in model_ids]
        model_nb_variables = [len(model[1]) for model in models]
        index_best_model = model_nb_variables.index(self.nb_features)
        best_model_id = model_ids[index_best_model]
        self.ml_task.redeploy_to_flow(model_id=best_model_id, saved_model_id=self.saved_model_id)

class TreeSelection(FeatureSelection):
    
    def __init__(self, **kwargs):
        super(TreeSelection, self).__init__(**kwargs)
        for model in self.settings_raw['modeling'].keys():
            if type(self.settings_raw['modeling'][model]) is dict:
                if 'enabled' in self.settings_raw['modeling'][model].keys():
                    self.settings_raw['modeling'][model]['enabled'] = False
        for model in self.settings_raw['modeling']['plugin_python'].keys():
            if type(self.settings_raw['modeling']['plugin_python']) is dict:
                if 'enabled' in self.settings_raw['modeling']['plugin_python'][model].keys():
                    self.settings_raw['modeling']['plugin_python'][model]['enabled'] = False
        self.settings_raw['modeling']['random_forest_classification']['enabled'] = True
        self.settings.save()
    
    def launch_selection(self):
        self.ml_task.start_train(session_name='Tree-Based Selection')
        self.ml_task.wait_train_complete()
        self.deploy_model()
    
    def deploy_model(self):
        model_ids = self.ml_task.get_trained_models_ids()
        self.ml_task.redeploy_to_flow(model_id=model_ids[0], saved_model_id=self.saved_model_id)

