import pandas as pd
from lifelines import statistics, CoxPHFitter
from recipes.base_recipe import BaseRecipe

class SurvivalAnalysisStatistics(BaseRecipe):
    def __init__(self):
        pass
    
    def get_multivariate_logrank_test(self, duration_column, event_indicator_column, groupby_columns, groups):
        # concatenate list of grouped dataframes and add group label column
        df_with_labels_list = []
        for groupby_values, group_df in groups:
            group_label = self.get_group_label(groupby_columns, groupby_values)
            group_df[BaseRecipe.GROUP_BY_COLUMN_NAME] = group_label
            df_with_labels_list.append(group_df)

        df_with_labels = pd.concat(df_with_labels_list, axis=0)
        df_with_labels.rename(columns={"index" : duration_column}, inplace=True)
        
        # get multivariate logrank test
        multivariate_logrank_test = statistics.multivariate_logrank_test(df_with_labels[duration_column],
                                                              df_with_labels[BaseRecipe.GROUP_BY_COLUMN_NAME],
                                                              df_with_labels[event_indicator_column])
        return multivariate_logrank_test
    
    def get_pairwise_tests(self, duration_column, event_indicator_column, groupby_columns, groups):
        """
            - log rank tests
            - cox univariate tests
            - cox log likelihood ratio test
        """
        groups_list = []
        for groupby_values, group_df in groups:
            group_label = self.get_group_label(groupby_columns, groupby_values)
            groups_list.append([group_label, group_df])

        cox_univariate_tests = []
        logrank_tests = []
        cox_log_likelihood_ratio_tests = []

        for i in range(len(groups_list)):
            for j in range(i+1, len(groups_list)):
                label1, group1_df = groups_list[i]
                label2, group2_df = groups_list[j]
                dfA = pd.DataFrame({
                        'durations': group1_df[duration_column],
                        'events': group1_df[event_indicator_column],
                        'group': 0
                    })
                dfB = pd.DataFrame({
                        'durations': group2_df[duration_column],
                        'events': group2_df[event_indicator_column],
                        'group': 1
                    })

                logrank_test = statistics.logrank_test(group1_df[duration_column], group2_df[duration_column], group1_df[event_indicator_column], group2_df[event_indicator_column])
                logrank_tests.append([label1, label2, logrank_test.test_statistic, logrank_test.p_value])
                
                dfAB = pd.concat([dfA, dfB])
                cph = CoxPHFitter().fit(dfAB, 'durations', 'events')
                cox_univariate_tests.append([cph.summary.z[0], cph.summary.p[0]])

                cox_log_likelihood_ratio_test = [cph.log_likelihood_ratio_test().test_statistic,
                                                    cph.log_likelihood_ratio_test().p_value]
                cox_log_likelihood_ratio_tests.append(cox_log_likelihood_ratio_test)



        column_names = ["group_A", "group_B", "logrank_test_z_value", "logrank_test_p_value"]
        statistics_df = pd.DataFrame(data=logrank_tests, columns=column_names)

        statistics_df["cox_univariate_test_z_value"] = [x[0] for x in cox_univariate_tests]
        statistics_df["cox_univariate_test_p_value"] = [x[1] for x in cox_univariate_tests]

        statistics_df["cox_univariate_log_likelihood_ratio_test_statistic"] = [x[0] for x in cox_log_likelihood_ratio_tests]
        statistics_df["cox_univariate_log_likelihood_ratio_test_p_value"] = [x[1] for x in cox_log_likelihood_ratio_tests]

        return statistics_df
    
    def get_output_df(self, df, duration_column, event_indicator_column, groupby_columns):
        self.check_duration_column(df, duration_column)
        self.check_event_indicator_column(df, event_indicator_column)
        self.get_processed_data(df, event_indicator_column)
        
        if len(groupby_columns) == 0:
            raise ValueError("Please select groupby columns")
        else:
            groups = df.groupby(groupby_columns, dropna=False)
            
        statistics_df = self.get_pairwise_tests(duration_column, event_indicator_column, groupby_columns, groups)
        
        # get multivariate log rank test 
        multivariate_logrank_test = self.get_multivariate_logrank_test(duration_column, event_indicator_column, groupby_columns, groups)
        
        statistics_df["multivariate_logrank_test_p_value"] = multivariate_logrank_test.p_value
        statistics_df["multivariate_logrank_test_z_value"] = multivariate_logrank_test.test_statistic

        statistics_df = statistics_df.round(3)
        return statistics_df

