from pandas.api.types import is_numeric_dtype

def filter_workflow(workflow_df, nb_variants, variants_count, variants):
    top_variants = list(variants_count[:nb_variants]['activity'])
    cases = list(variants[variants['activity'].isin(top_variants)]['case'])
    filtered_workflow = workflow_df[workflow_df['case'].isin(cases)]
    return filtered_workflow

def filter_timestamp(workflow, start, end, contained):
    start = start.timestamp()
    end = end.timestamp()
    min_timestamps = workflow.groupby('case')['epoch_timestamp'].min().reset_index()
    max_timestamps = workflow.groupby('case')['epoch_end_timestamp'].max().reset_index()
    if not contained:
        cases_min = set(min_timestamps[min_timestamps['epoch_timestamp']<=end]['case'])
        cases_max = set(max_timestamps[max_timestamps['epoch_end_timestamp']>=start]['case'])
        cases = cases_min.intersection(cases_max)
        filtered_workflow = workflow[workflow['case'].isin(cases)]
    else:
        cases_min = set(min_timestamps[min_timestamps['epoch_timestamp']>=start]['case'])
        cases_max = set(max_timestamps[max_timestamps['epoch_end_timestamp']<=end]['case'])
        cases = cases_min.intersection(cases_max)
        filtered_workflow = workflow[workflow['case'].isin(cases)]
    return filtered_workflow

def filter_start_step(workflow, start_steps):
    workflow_no_start = workflow[workflow['activity']!='START']
    start_activities = workflow_no_start.loc[workflow_no_start.groupby('case')['sorting'].idxmin()][['case', 'activity']]
    cases = set(start_activities[start_activities['activity'].isin(start_steps)]['case'])
    filtered_workflow = workflow[workflow['case'].isin(cases)]
    return filtered_workflow

def filter_end_step(workflow, end_steps):
    workflow_no_end = workflow[workflow['activity']!='END']
    end_activities = workflow_no_end.loc[workflow_no_end.groupby('case')['sorting'].idxmax()][['case', 'activity']]
    cases = set(end_activities[end_activities['activity'].isin(end_steps)]['case'])
    filtered_workflow = workflow[workflow['case'].isin(cases)]
    return filtered_workflow

def filter_case_performance(workflow, min_performance, max_performance, case_performance_scale):
    case_performance = workflow.groupby('case').agg({'epoch_timestamp':'min', 'epoch_end_timestamp':'max'})[['epoch_timestamp','epoch_end_timestamp']].reset_index()
    case_performance['case_performance'] = case_performance['epoch_end_timestamp'] - case_performance['epoch_timestamp']
    cases = set(case_performance[(case_performance['case_performance']>=case_performance_scale[min_performance]) & (case_performance['case_performance']<=case_performance_scale[max_performance])]['case'])

    filtered_workflow = workflow[workflow['case'].isin(cases)]
    return filtered_workflow

def filter_numeric_attributes(workflow, attribute_column, attribute_filter, attribute_scales):
    cases = set(workflow[(workflow[attribute_column + '_attribute']>=attribute_scales[attribute_column][int(attribute_filter[0])]) &
                                  (workflow[attribute_column + '_attribute']<=attribute_scales[attribute_column][int(attribute_filter[1])])]['case'])

    filtered_workflow = workflow[workflow['case'].isin(cases)]
    return filtered_workflow

def filter_categorical_attributes(workflow, attribute_column, attribute_filter):
    cases = set(workflow[workflow[attribute_column + '_attribute'].isin(attribute_filter)]['case'])
    
    filtered_workflow = workflow[workflow['case'].isin(cases)]
    return filtered_workflow

def filter_attributes(workflow, attribute_column, attribute_filter, attribute_scales):
    if is_numeric_dtype(workflow[attribute_column + '_attribute']):
        return filter_numeric_attributes(workflow, attribute_column, attribute_filter, attribute_scales)
    else:
        return filter_categorical_attributes(workflow, attribute_column, attribute_filter)