from sklearn.base import BaseEstimator
from lifelines import CoxPHFitter, exceptions
import numpy as np
import pandas as pd
from lab.multivariate_model import MultivariateModel

class ClassCoxPH(BaseEstimator, MultivariateModel):
    def __init__(self,
                 training_dataset,
                 event_indicator_column=None,
                 alpha_penalizer=0.01,
                 l1_ratio=0,
                 have_already_lived=False,
                 conditional_after_column=None,
                 prediction_type="predict_expected_time",
                 survival_quantile=0.75,
                 time_for_proba=30,
                 column_labels=None):
        
        self.training_dataset = training_dataset
        self.event_indicator_column = event_indicator_column
        self.alpha_penalizer = alpha_penalizer
        self.l1_ratio = l1_ratio
        self.prediction_type = prediction_type
        self.have_already_lived = have_already_lived
        self.conditional_after_column = conditional_after_column
        self.coef_ = None
        self.fit_intercept = False
        self.intercept_scaling = 0
        # Intercept term is useless in coxph
        self.intercept_ = 0
        self.prediction_covariates = None
        self.conditional_after_values = None

        MultivariateModel.__init__(self, training_dataset, event_indicator_column, prediction_type, survival_quantile, time_for_proba, column_labels)
        
    def compute_coefs(self):
        """
        adds hazard ratios to the "Regression coefficients" section of the Lab
        """
        list_coefs = []
        hazard_ratios = self.fitted_model.hazard_ratios_.tolist()
        for label in self.column_labels:
            # assign 0 coefficient for event indicator and all "is NA" columns
            if label == self.event_indicator_column or label == self.conditional_after_column or self.is_NA_column(label):
                list_coefs.append(0)
            else:
                list_coefs.append(hazard_ratios[0])
                hazard_ratios.pop(0)
        self.coef_ = np.array(list_coefs)
        
    def fit(self, X, y):
        self.check_event_indicator_values(X)
        self.set_max_duration(y)
        """
        in the current state of the plugin, it is impossible to retrieve the 
        confidence intervals around the hazard ratios
        a model view is required to do so
        """
        confidence_interval_alpha = 0.05 
        model = CoxPHFitter(alpha=confidence_interval_alpha, baseline_estimation_method="breslow", penalizer=self.alpha_penalizer, l1_ratio=self.l1_ratio)
        df = self.get_lifelines_dataframe(X, y)

        try:
            model.fit(df, duration_col=MultivariateModel.DURATION_COLUMN_NAME, event_col=self.event_indicator_column, show_progress=True)
        except exceptions.ConvergenceError:
            raise Exception("Convergence error while fitting model. Please check that all categorical variables in the feature handling are set to 'Drop one dummy'.")
        self.fitted_model = model
        self.compute_coefs()
    
    def get_lifelines_dataframe(self, X, y):
        """
        construct and return dataframe whose columns are
            - all selected features including event_indicator colum and excluding the "is NA" variables
            - excluding conditional after column (if present)
            - target column (time)
        """
        df = pd.DataFrame()
        for label in self.column_labels:
            if not self.is_NA_column(label) and label != self.conditional_after_column:
                column_values, _ = self.get_columns(X, [label])
                df[label] = column_values[:, 0] 
    
        df[MultivariateModel.DURATION_COLUMN_NAME] = y
        return df

    def get_expected_time(self):
        return self.fitted_model.predict_expectation(self.prediction_covariates, self.conditional_after_values)

    def get_times_at_probability(self):
        return self.fitted_model.predict_percentile(self.prediction_covariates, self.survival_quantile, self.conditional_after_values)
    
    def get_probabilities_at_time(self):
        raise NotImplementedError
    
    def process_predictions(self, predictions):
        if isinstance(predictions, np.float64):
            predictions = pd.Series(predictions)

        new_predictions = predictions.replace(np.inf, self.max_duration)
        new_predictions = new_predictions.to_numpy()
        return new_predictions
    
    def is_NA_column(self, label):
        """
        returns True if column is generated by DSS to represent potential NA values
        """
        return label.endswith('N/A')
    
    def set_prediction_object(self, X):
        _, NA_column_indices = self.get_columns(X, [label for label in self.column_labels if self.is_NA_column(label)])
        _, event_indicator_index = self.get_columns(X, [self.event_indicator_column])
    
        delete_column_indices = NA_column_indices
        delete_column_indices.append(event_indicator_index[0])

        if self.have_already_lived:
            conditional_after_values, conditional_after_index = self.get_columns(X, [self.conditional_after_column])
            flat_conditional_after_values = [x for row in conditional_after_values for x in row]
            self.conditional_after_values = flat_conditional_after_values
            delete_column_indices.append(conditional_after_index[0])
        else:
            self.conditional_after_values = None
        
        X = np.delete(X, delete_column_indices, axis=1)

        # create this attribute for the prediction methods
        self.prediction_covariates = X
