import dataiku
import dash
from dash import dcc, html, Input, Output, State, ctx, dash_table
import dash_bootstrap_components as dbc
import base64
import pandas as pd
from datetime import datetime
from pandas.api.types import is_datetime64_any_dtype as is_datetime
import sklearn
from sklearn.cluster import KMeans
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.impute import SimpleImputer
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import io
import ast
import pickle
from flask import request
from layout_components_style import  get_dropdown_style, get_subheader_style, get_div_message_style, get_button_style,get_table_styles

# Style options
subheader_style = get_subheader_style()
dropdown_style = get_dropdown_style()
div_message_style = get_div_message_style()
table_styles = get_table_styles()
button_style=get_button_style()
# Dataiku setup
client = dataiku.api_client()
project = client.get_default_project()

# Variables
# MODIFY these veriables to your own folder and dataset IDs (check on your dataiku flow)
images_folder_var = "ZwIBEZhA"
output_folder_var = "YIM9iOet"
versions_folder_var = "SyyDkxtP"
compare_seg_folder_var = "LvSKpnPJ"
output_model_folder_var = "TdqwWquP"
new_seg_designer_webapp_id = "JbkofrY"
seg_manager_webapp_id = "JmCL7n0"
multi_seg_explorer_webapp_id = "KGDl9Ji"
metadata_dataset_var = "metadata_dataset"
rule_based_dataset_var = "rule_based_specs_dataset"

# Dataiku Files and Folders
metadata_dataset = dataiku.Dataset(metadata_dataset_var)
metadata_df = metadata_dataset.get_dataframe()


images_folder = dataiku.Folder(images_folder_var)
output_folder = dataiku.Folder(output_folder_var)
versions_folder = dataiku.Folder(versions_folder_var)
output_model_folder = dataiku.Folder(output_model_folder_var)
compare_seg_folder = dataiku.Folder(compare_seg_folder_var)

# Load the images from Dataiku folder
image_names = ["MLS.png", "RBS.png"]
kmeans_image, rule_based_image,  = [images_folder.get_download_stream(image).read() for image in image_names]

# --- Silent webapp restart helper (no UI output) ---
DEBUG_RESTART = True  # set False to silence prints

def restart_app_silently(webapp_id: str):
    """Restart another DSS webapp; only prints to logs on success/failure."""
    try:
        project.get_webapp(webapp_id).start_or_restart_backend()
        if DEBUG_RESTART:
            print(f"Webapp restart SUCCESS (webapp_id={webapp_id})")
    except Exception as e:
        if DEBUG_RESTART:
            print(f"Webapp restart FAILED (webapp_id={webapp_id}): {e}")
            
def create_empty_figure():
    return {
            'data': [],
            'layout': {
                'xaxis': {'visible': False},
                'yaxis': {'visible': False},
                'annotations': [{'text': " ", 'xref': "paper", 'yref': "paper", 'showarrow': False, 'font': {'size': 12}}]
            }
        }

def encode_image(image):
    return base64.b64encode(image).decode()

kmeans_image_encoded = encode_image(kmeans_image)
rule_based_image_encoded = encode_image(rule_based_image)

# Function to get datasets from the Dataiku project
def get_datasets():
    datasets = project.list_datasets()
    dataset_options = [{"label": ds["name"], "value": ds["name"]} for ds in datasets]
    filtered_dataset_options = [item for item in dataset_options if item['label'] not in ['metadata_dataset', 'rule_based_specs_dataset']]
    return filtered_dataset_options

def display_file_content(session_name):
    file_name = f"/{session_name}.csv"
    
    # Check if the file exists in the managed folder
    existing_files = output_folder.list_paths_in_partition()
    if file_name not in existing_files:
        print(f"Warning: {file_name} not found in output folder.")
        return None,html.P(f"No {session_name}.csv exists in the managed session folder.", style=div_message_style
                          )

    # Read the file content
    with output_folder.get_download_stream(file_name) as stream:
        file_content = stream.read().decode('utf-8')

    file_df = pd.read_csv(io.StringIO(file_content))

    # Move 'cluster' column (if necessary)
    file_df = move_columns(file_df, 1, 'cluster')

    columns = [{"name": i, "id": i} for i in file_df.columns]
    data = file_df.to_dict('records')

    return columns, data

def read_and_merge_files(folder, file_paths, sessions_list):
    merged_df = pd.DataFrame()
    for session_file, session_column in zip(file_paths, sessions_list):
        try:
            with folder.get_download_stream(session_file) as stream:
                df = pd.read_csv(stream, usecols=['account_id', 'cluster'])
                df.rename(columns={'cluster': session_column}, inplace=True)
                if merged_df.empty:
                    merged_df = df
                else:
                    merged_df = merged_df.merge(df, on='account_id', how='inner')
                    merged_df = merged_df.drop_duplicates()
        except Exception as e:
            print(f"Error reading or merging file {session_file}: {e}")
    return merged_df

# Construct file paths
def construct_file_paths(sessions_list):
    prefix = "/"
    suffix = ".csv"
    return [prefix + x + suffix for x in sessions_list if isinstance(x, str)]

def format_large_number(num):
    if num >= 1_000_000:
        return f'{num // 1_000_000}M'
    elif num >= 1_000:
        return f'{num // 1_000}K'
    else:
        return str(num)
    
def generate_filter_input(dataset, column):
    df = dataset
    col = column
    if pd.api.types.is_numeric_dtype(df[col]) and not pd.api.types.is_bool_dtype(df[col]):
        min_val = round(df[col].min())
        max_val = round(df[col].max())
        mean_val = round(df[col].mean()) 
        q25 = round(df[col].quantile(0.25),2)
        q75 = round(df[col].quantile(0.75),2)

        # Adjust step size dynamically
        range_diff = max_val - min_val
        step = range_diff / 1000 if range_diff > 1000 else range_diff / 100

        # Marks only for min, mean, and max
        marks = {
            int(min_val): format_large_number(min_val),
            int(mean_val): format_large_number(mean_val),
            int(max_val): format_large_number(max_val),
            #int(q25): format_large_number(q25),
            #int(q75): format_large_number(q75)
        }

        return html.Div([
            html.Label(f"Filter {col}", style=dropdown_style),
            dcc.RangeSlider(
                id={"type": "filter-slider", "index": col},
                min=min_val,
                max=max_val,
                step=step,
                marks=marks, 
                value=[min_val, max_val], 
                tooltip={"placement": "bottom", "always_visible": True},  
                updatemode='drag', 
                allowCross=False  
            )
        ])
    
    elif pd.api.types.is_categorical_dtype(df[col]) or pd.api.types.is_object_dtype(df[col]):
        unique_values = df[col].unique()
        return html.Div([
            html.Label(f"Filter {col}", style=dropdown_style),
            dcc.Dropdown(
                id={"type": "filter-dropdown", "index": col},
                options=[{"label": val, "value": val} for val in unique_values],
                multi=True,
                placeholder=f"Select values for {col}" ,
                style=dropdown_style
            )
        ])
    
    elif pd.api.types.is_bool_dtype(df[col]):
        return html.Div([
            html.Label(f"Filter {col}", style=dropdown_style),
            dcc.Dropdown(
                id={"type": "filter-dropdown", "index": col},
                options=[{"label": str(val), "value": val} for val in [True, False]],
                multi=True,
                placeholder=f"Select values for {col}",
                style=dropdown_style
            )
        ])
    
    elif pd.api.types.is_datetime64_any_dtype(df[col]):
        return html.Div([
            html.Label(f"Filter {col}", style=dropdown_style),
            dcc.DatePickerRange(
                id={"type": "filter-date-picker", "index": col},
                start_date=df[col].min().date(),
                end_date=df[col].max().date(),
                start_date_placeholder_text='Start Date',
                end_date_placeholder_text='End Date'
            )
        ])
    
    return None

def filter_dataframe(df, selected_columns, filter_values, filter_ranges, start_dates, end_dates):
    selected_columns_num_init = [col for col in selected_columns if pd.api.types.is_numeric_dtype(df[col])]
    selected_columns_bool = [col for col in selected_columns if pd.api.types.is_bool_dtype(df[col])]
    selected_columns_date = [col for col in selected_columns if is_datetime(df[col])]
    selected_columns_num = list(filter(lambda x: x not in set(selected_columns_bool), selected_columns_num_init))
    selected_columns_cat = list(filter(lambda x: x not in set(selected_columns_num), selected_columns))
    selected_columns_str = list(filter(lambda x: x not in set(selected_columns_date), selected_columns_cat))

    # Initialize the filtered DataFrame as the original DataFrame
    filtered_df = df.copy()

    # Apply numeric filters cumulatively
    for i, col in enumerate(selected_columns_num):
        print(i, col)
        range_vals = filter_ranges[i]
        filtered_df = filtered_df[(filtered_df[col] >= range_vals[0]) & (filtered_df[col] <= range_vals[1])]
        print(filtered_df.head())

    # Apply categorical filters cumulatively
    for i, col in enumerate(selected_columns_str):
        values = filter_values[i]
        filtered_df = filtered_df[filtered_df[col].isin(values)]

    # Apply date filters cumulatively
    for i, col in enumerate(selected_columns_date):
        if start_dates[i] is not None and end_dates[i] is not None:
            filtered_df = filtered_df[(filtered_df[col] >= start_dates[i]) & (filtered_df[col] <= end_dates[i])]

    return filtered_df, selected_columns_date, selected_columns_str, selected_columns_num


# Function to apply filters
def apply_filters(original_df, filter_dict):
    df = original_df.copy()
    selected_columns = filter_dict.keys()
    filter_columns_num_init = [col for col in selected_columns if pd.api.types.is_numeric_dtype(df[col])]
    filter_columns_bool = [col for col in selected_columns if pd.api.types.is_bool_dtype(df[col])]
    filter_columns_date = [col for col in selected_columns if pd.api.types.is_datetime64_any_dtype(df[col])]
    filter_columns_num = list(filter(lambda x: x not in set(filter_columns_bool), filter_columns_num_init))
    filter_columns_cat = list(filter(lambda x: x not in set(filter_columns_num), selected_columns))
    filter_columns_str = list(filter(lambda x: x not in set(filter_columns_date), filter_columns_cat))
    for col in filter_columns_str:
        df = df[df[col].isin(filter_dict[col])]
    for col in filter_columns_num:
        df = df[(df[col] >= filter_dict[col][0]) & (df[col] <= filter_dict[col][1])]
    for col in filter_columns_date:
        df = df[(df[col] >= filter_dict[col][0]) & (df[col] <= filter_dict[col][1])]
    return df

def build_filter_dictionary(selected_columns_date, selected_columns_str, selected_columns_num, start_dates, end_dates, filter_values, filter_ranges):
    date_range = [start_dates + end_dates]
    res_date = dict(map(lambda i,j : (i,j) , selected_columns_date, date_range))
    res_cat = dict(map(lambda i,j : (i,j) , selected_columns_str, filter_values))
    res_num = dict(map(lambda i,j : (i,j) , selected_columns_num, filter_ranges))
    filter_dictionary = {**res_cat, **res_num, **res_date}
    return filter_dictionary

def outer_join(dataset_1, dataset_2, key_column, cluster_name):
    return(pd.merge(df1, 
                  df2[[key_column,cluster_name]], 
                  on=key_column, 
                  how='outer', 
                  suffixes=('', '_old')))

def move_columns(dataset, order, col_name):
    cols = dataset.columns.tolist()  
    cols.insert(order, cols.pop(cols.index(col_name)))  
    dataset = dataset[cols]  
    return dataset

# CSS for selected and unselected states
selected_style = {
    'cursor': 'pointer',
    'width': '140px',
    'height': '140px',
    'border': '5px solid CornflowerBlue',  # Highlight border for selected method
    'display': 'block',
    'margin-left': 'auto',
    'margin-right': 'auto'
}

unselected_style = {
    'cursor': 'pointer',
    'width': '140px',
    'height': '140px',
    'border': '1px solid #ccc',  # Default border for unselected method
    'display': 'block',
    'margin-left': 'auto',
    'margin-right': 'auto'
}

def rule_based(dataset, weights, column_features, cluster_name, num_bins):
    Performance_data = dataset.copy()
    Rank_df = Performance_data[column_features].apply(lambda x: x.rank(pct=True))
    
    # Check for uniform features
    uniform_features = [col for col in column_features if dataset[col].nunique() == 1]
    if uniform_features:
        return None, None, None, f"Warning: The following features are uniform and cannot be used for segmentation: {', '.join(uniform_features)}"
    
    for feature in weights:
        if feature in Rank_df.columns:
            Rank_df[feature] = Rank_df[feature] * weights[feature]
    Rank_df['Average'] = Rank_df.mean(axis=1)
    
    # Handle potential binning issues
    try:
        # Use the number of bins specified by the user
        engagement_qcut, bins = pd.qcut(Rank_df['Average'], num_bins, 
                                        labels=[f'Segment_{i+1}' for i in range(num_bins)], 
                                        retbins=True, 
                                        duplicates='drop')
        Performance_data[cluster_name] = engagement_qcut
        return Performance_data, bins, weights, ""
    except ValueError as e:
        if "Segment labels must be one fewer than the number of segment edges" in str(e):
            return None, None, None, f"Warning: Unable to create the requested number of segments due to insufficient variability in the data. Please reduce the number of segments or select more diverse features."
        else:
            raise e
            
def update_clusters(df, features):
    # Create a condition to check if all pairs of original and updated features are the same
    conditions = [
        df[f"{feature}_original"] == df[feature] for feature in features
    ]
    
    # Combine conditions with logical AND
    combined_condition = conditions[0]
    for condition in conditions[1:]:
        combined_condition &= condition
    
    # Assign 'cluster' if all feature pairs match, otherwise assign 'new_cluster'
    df['updated_cluster'] = df.apply(lambda row: row['cluster'] if combined_condition[row.name] else row['new_cluster'], axis=1)
    
    return df
  
def apply_existing_bounds(filename, column_features, output_file, dataset, remap_dict):

    # Load pre-existing segmentation specs
    rule_based = dataiku.Dataset("rule_based_specs_dataset")
    rule_based_df = rule_based.get_dataframe()
    rule_based_info = rule_based_df[rule_based_df['session_name'] == filename]

    if rule_based_info.empty:
        return None, None, f"Error: No saved rule-based parameters found for session: {filename}"

    # Parse weights and bounds safely
    try:
        weights = ast.literal_eval(rule_based_info['weights'].iloc[0])
        bounds = ast.literal_eval(rule_based_info['bounds'].iloc[0])
    except Exception as e:
        return None, None, f"Error parsing weights or bounds: {str(e)}"

    df_results = dataset.copy()
    Rank_df = df_results[column_features].apply(lambda x: x.rank(pct=True))

    # Handle uniform features
    uniform_features = [col for col in column_features if df_results[col].nunique() == 1]
    if uniform_features:
        column_features = [col for col in column_features if col not in uniform_features]
        if not column_features:
            return None, None, "Warning: All selected features are uniform and cannot be used for segmentation."

    # Apply weights
    missing_features = [feature for feature in weights if feature not in Rank_df.columns]
    if missing_features:
        return None, None, f"Error: The following features are missing from the dataset: {', '.join(missing_features)}"

    for feature in weights:
        Rank_df[feature] *= weights.get(feature, 1)

    Rank_df['Average'] = Rank_df.mean(axis=1)

    # Adjust bounds to ensure full coverage
    avg_min = Rank_df['Average'].min()
    avg_max = Rank_df['Average'].max()

    if bounds[0] > avg_min:
        print(f"Adjusting lower bound: {bounds[0]} -> {avg_min}")
        bounds[0] = avg_min - 1e-5

    if bounds[-1] < avg_max:
        print(f"Adjusting upper bound: {bounds[-1]} -> {avg_max}")
        bounds[-1] = avg_max + 1e-5

    try:
        labels = [f'Segment_{i+1}' for i in range(len(bounds) - 1)]

        # Cut using adjusted bounds
        engagement_cut = pd.cut(Rank_df['Average'], bins=bounds,
                                labels=labels, right=True, include_lowest=True)

        df_results['new_cluster'] = engagement_cut.astype(str).map(remap_dict)

        # Ensure columns exist in both datasets
        merged_columns = ['account_id', 'cluster'] + list(set(column_features))
        merged_columns = list(dict.fromkeys(merged_columns))  # Remove duplicates while preserving order

        df_merged = pd.merge(
            output_file[merged_columns],
            df_results[['account_id', 'new_cluster'] + column_features],
            on='account_id',
            how='right',
            suffixes=('_original', '')
        )

        # Recompute updated clusters
        df_updated_results = update_clusters(df_merged, column_features)
        df_updated_results.reset_index(drop=True, inplace=True)
        df_results.reset_index(drop=True, inplace=True)

        # Finalize cluster assignment
        df_results['cluster'] = df_updated_results['updated_cluster']
        df_results.drop(columns=['new_cluster'], inplace=True)

        return df_results, weights, ""
    
    except ValueError as e:
        if "Segment labels must be one fewer than the number of segment edges" in str(e):
            return None, None, "Warning: Unable to apply pre-existing segments due to insufficient variability in the data."
        else:
            raise e
    
def kmeans_clustering(dataset, selected_features, num_clusters, cluster_name):
    X = dataset[selected_features]
    # Check if X is empty
    if X.empty:
        return None, None, None, "Error: The input dataset has no samples after filtering/processing."

    # Check for uniform features and remove them
    uniform_features = [feature for feature in selected_features if dataset[feature].nunique() == 1]
    if uniform_features:
        X = X.drop(columns=uniform_features)
        print(f"Warning: The following uniform features were removed: {uniform_features}")
    # Recheck if X is empty after removing uniform features
    if X.empty:
        return None, None, None, "Error: The input dataset has no usable features after removing uniform features."

    try:
        numerical_features = X.select_dtypes(include=['int64', 'float64']).columns.tolist()
        categorical_features = X.select_dtypes(include=['object', 'category']).columns.tolist()

        # Initialize an empty list for transformers
        transformers = []

        if numerical_features:
            numerical_transformer = Pipeline(steps=[
                ('imputer', SimpleImputer(strategy='mean')),
                ('scaler', StandardScaler())
            ])
            transformers.append(('num', numerical_transformer, numerical_features))

        if categorical_features:
            categorical_transformer = Pipeline(steps=[
                ('imputer', SimpleImputer(strategy='most_frequent')),
                ('encoder', OneHotEncoder(handle_unknown='ignore'))
            ])
            transformers.append(('cat', categorical_transformer, categorical_features))
        # If both transformers are empty, raise an error
        if not transformers:
            return None, None, None, "Error: No valid features to process for clustering."

        preprocessor = ColumnTransformer(transformers=transformers)

        pipeline = Pipeline(steps=[
            ('preprocessor', preprocessor),
            ('kmeans', KMeans(n_clusters=num_clusters, random_state=0))
        ])

        pipeline.fit(X)
        dataset[cluster_name] = pipeline.named_steps['kmeans'].labels_    
        print(dataset)
        return dataset, pipeline, preprocessor, None

    except ValueError as e:
        if "n_samples=" in str(e):
            return None, None, None, f"Error: The number of clusters ({num_clusters}) exceeds the number of samples ({X.shape[0]}). Please reduce the number of clusters."
        elif "Number of distinct clusters (X) found smaller than n_clusters" in str(e):
            return None, None, None, "Error: The data does not have enough distinct clusters to match the requested number of clusters. Try reducing the number of clusters."
        else:
            return None, None, None, f"Unexpected error during KMeans clustering: {str(e)}"

    except MemoryError as e:
        return None, None, None, "Error: The dataset is too large to fit into memory for KMeans clustering. Consider using a smaller dataset or applying more aggressive filtering."

    except Exception as e:
        return None, None, None, f"An unexpected error occurred during KMeans clustering: {str(e)}"

# Function to apply existing model
def apply_existing_model(cluster_column, df, model_folder_id, session, remap_dict):
    df_results = df.copy()
    pkl_model = read_pickle_from_managed_folder(model_folder_id, session + '.pkl')
    pipeline = pkl_model
    df_results[cluster_column] = pkl_model.predict(df_results)
    
    df_results[cluster_column] = 'Segment_' + (df_results[cluster_column]+1).astype(str)
    print(df_results[cluster_column])
    df_results[cluster_column] =  df_results[cluster_column].astype(str)
    df_results[cluster_column]
    df_results[cluster_column] = df_results[cluster_column].astype(str).map(remap_dict)
    df_results[cluster_column]
    print("KMeans model successfully applied!")
    return df_results, pipeline.named_steps['preprocessor']

# Function to read pickle file from Dataiku folder
def read_pickle_from_managed_folder(managed_folder_id, pickle_name):
    managed_folder = dataiku.Folder(managed_folder_id)
    with managed_folder.get_download_stream(pickle_name) as f:
        data = f.read()
        pickle_object = pickle.loads(data)
    print(f"'{pickle_name}' successfully read!")
    return pickle_object

def process_table_data_with_remap(table_data, remap_data, cluster_column_name='cluster'):
    if not table_data:
        return None  # Return None if there's no data

    # Convert the table data to a DataFrame
    df = pd.DataFrame(table_data)

    # Ensure the cluster column exists
    if cluster_column_name not in df.columns:
        return None  # Return None if the cluster column is not found

    # Apply remap if remap data is provided and valid
    if remap_data and 'mapping_dict' in remap_data:
        mapping_dict = remap_data['mapping_dict']
        # Map cluster names and fall back to original cluster names if no remap is available
        df[cluster_column_name] = df[cluster_column_name].astype(str).map(mapping_dict).fillna(df[cluster_column_name])

    return df

def feature_importance_rfclassifier(dataset, features, preprocessor, target_column):
    # Identify numerical and categorical features
    df = dataset
    X = df[features]
    y = df[target_column]
    
    # Check for uniform features and remove them
    uniform_features = [feature for feature in features if df[feature].nunique() == 1]
    if uniform_features:
        X = X.drop(columns=uniform_features)
        print(f"Warning: The following uniform features were removed: {uniform_features}")
    
    # Recheck if X is empty after removing uniform features
    if X.empty:
        return None, None, None, "Error: The input dataset has no usable features after removing uniform features."

    numerical_features = X.select_dtypes(include=['number']).columns.tolist()
    categorical_features = X.select_dtypes(include=['object', 'category']).columns.tolist()
    
    # Create the pipeline with preprocessing and classifier
    pipeline = Pipeline(steps=[
        ('preprocessor', preprocessor),
        ('classifier', RandomForestClassifier(n_estimators=100, random_state=42))
    ])
    
    # Train-test split
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
    
    # Train the model
    pipeline.fit(X_train, y_train)
    
    # Get feature importances from the trained model
    rf_model = pipeline.named_steps['classifier']
    importance_values = rf_model.feature_importances_
    
    # Initialize the feature names list
    all_feature_names = []
    
    # Add numerical feature names if they exist
    if numerical_features:
        all_feature_names += numerical_features
    
    # Add encoded categorical feature names if they exist
    if categorical_features:
        # Access the ColumnTransformer inside the preprocessor
        column_transformer = pipeline.named_steps['preprocessor']
        
        # Find the transformer that handles categorical features
        for name, transformer, cols in column_transformer.transformers_:
            if any(col in categorical_features for col in cols):
                # Assuming the transformer is OneHotEncoder
                if hasattr(transformer, 'get_feature_names_out'):
                    encoded_feature_names = transformer.named_steps['encoder'].get_feature_names_out(cols)
                    all_feature_names += list(encoded_feature_names)
                else:
                    all_feature_names += cols
    
    # Create a DataFrame to store feature importance
    importance_df = pd.DataFrame({
        'Feature': all_feature_names,
        'Importance': importance_values
    }).sort_values(by='Importance', ascending=False)
    
    return importance_df

def delete_confirmation_modal():
    modal_overlay_style = {
        'zIndex': '1001',
        'position': 'fixed',
        'top': 0,
        'left': 0,
        'width': '100%',
        'height': '100%',
        'backgroundColor': 'rgba(0, 0, 0, 0.4)',
        'display': 'flex',
        'alignItems': 'center',
        'justifyContent': 'center'
    }

    modal_box_style = {
        'backgroundColor': 'white',
        'padding': '20px',
        'borderRadius': '6px',
        'boxShadow': '0 4px 10px rgba(0,0,0,0.15)',
        'width': '320px',
        'fontFamily': 'Helvetica Neue, sans-serif',
        'fontSize': '13px'
    }

    title_style = {
        'fontWeight': '600',
        'fontSize': '16px',
        'marginBottom': '10px'
    }

    body_text_style = {
        'marginBottom': '20px',
        'lineHeight': '1.4',
        'fontSize': '13px'
    }

    footer_style = {
        'display': 'flex',
        'justifyContent': 'flex-end',
        'gap': '10px'
    }

    return dbc.Modal([
        html.Div([
            html.Div("Confirm Delete", style=title_style),
            html.Div("This will permanently delete this session version. This action cannot be undone.", style=body_text_style),
            html.Div([
                html.Button("Cancel", id="cancel-delete", n_clicks=0,
                            style={**get_button_style(), 'background': 'white', 'color': '#2D86FB'}),
                html.Button("Delete", id="confirm-delete", n_clicks=0,
                            style={**get_button_style(), 'background': '#2D86FB', 'color': 'white'})
            ], style=footer_style)
        ], style=modal_box_style)
    ],
    id="delete-confirmation-modal",
    is_open=False,
    style=modal_overlay_style,
    backdrop=True,
    centered=True
)

def create_empty_figure():
    return {
            'data': [],
            'layout': {
                'xaxis': {'visible': False},
                'yaxis': {'visible': False},
                'annotations': [{'text': "No data yet.", 'xref': "paper", 'yref': "paper", 'showarrow': False, 'font': {'size': 12}}]
            }
        }


def load_segmentation_parameters(metadata_sess, index):
    selected_feat = ast.literal_eval(metadata_sess['selected_features'].iloc[index])
    segmentation_meth = metadata_sess['method'].iloc[index]
    specs = ast.literal_eval(metadata_sess['specifications'].iloc[index])
    cluster_remap = ast.literal_eval(metadata_sess['cluster_name_remap'].iloc[index])
    orig_data_name = metadata_sess['original_dataset'].iloc[index]
    return selected_feat, segmentation_meth, specs, cluster_remap, orig_data_name

def generate_feature_options(df_results, cluster_id_columns):
    numerical_columns = df_results.select_dtypes(include=['number']).columns.tolist()
    categorical_columns = df_results.select_dtypes(include=['object', 'category']).columns.tolist()
    columns_list = [col for col in numerical_columns + categorical_columns if col not in cluster_id_columns]
    return [{"label": feature, "value": feature} for feature in columns_list]
 
        
# Define the color palette (consistent across all visualizations)
color_palette = {
            0: 'burlywood',
            1: 'CornflowerBlue',
            2: 'lightcoral',
            3: 'DarkSeaGreen',
            4: 'orange',
            5: 'mediumpurple',
            6: 'CadetBlue',
            7: 'lightgoldenrodyellow',
            8: 'DarkSalmon',
            9: 'plum'
        }

def update_mixed_graph(selected_feature, table_data):
    if not selected_feature or not table_data:
        return []

    df_result = pd.DataFrame(table_data)
    cluster_unique = df_result['cluster'].unique()
    cluster_colors = [color_palette[i % 10] for i in range(len(cluster_unique))]

    if pd.api.types.is_numeric_dtype(df_result[selected_feature]):
        # Create a boxplot with side-by-side boxplots for each cluster
        fig = px.box(df_result, x="cluster", y=selected_feature, points="all", color="cluster")
        
        fig.update_traces(boxmean=True)

        # Updated layout dimensions and margins
        fig.update_layout(
            font=dict(family="Helvetica Neue", size=12),
            margin=dict(l=10, r=10, t=30, b=10),
            paper_bgcolor='white',
            plot_bgcolor='white',
            xaxis_title='Segment',
            yaxis_title=selected_feature,
            showlegend=False
        )

        # Manually update the colors for each trace
        for i, cluster in enumerate(cluster_unique):
            fig.update_traces(
                selector=dict(name=f'{cluster}'), 
                marker_color=cluster_colors[i]
            )
        return html.Div([dcc.Graph(figure=fig)])

    else:
        # For categorical features, create a grouped bar chart
        value_counts = df_result.groupby(['cluster', selected_feature]).size().reset_index(name='count')
        total_counts = df_result.groupby('cluster').size().reset_index(name='total')
        merged_df = pd.merge(value_counts, total_counts, on='cluster')
        merged_df['percentage'] = (merged_df['count'] / merged_df['total']) * 100

        fig = px.bar(
            merged_df,
            x='cluster',
            y='percentage',
            color=selected_feature,
            labels={'percentage': 'Percentage Distribution', 'cluster': 'Segment'},
            hover_data={'percentage': ':.2f', 'count': ':d'},
        )

        # Updated layout dimensions and margins
        fig.update_layout(
            font=dict(family="Helvetica Neue", size=12),
            barmode='stack',
            #height=400,
            #width=400,
            margin=dict(l=10, r=10, t=30, b=10),
            paper_bgcolor='white',
            plot_bgcolor='white',
            xaxis=dict(
                showline=True, linewidth=2, linecolor='black', showgrid=True,
                title_standoff=40  # Add extra space between x-axis title and the graph
            ),
            yaxis=dict(showline=True, linewidth=2, linecolor='black', showgrid=True),
            showlegend=True,
            legend=dict(
                orientation="h",  # Horizontal orientation for the legend
                yanchor="bottom",
                y=-0.35,  # Increase the space between the x-axis and the legend
                xanchor="center",
                x=0.5  # Center the legend horizontally
            )
        )

        return html.Div([dcc.Graph(figure=fig)])


def histogram_figure(dataset, x_column, y_column, labels):           
    fig = px.bar(
        dataset,
        x=x_column,
        y=y_column,
        orientation='h',
        labels=labels
    )
    
    # Update layout with white background
    fig.update_layout(
        xaxis_title=x_column,
        yaxis_title=y_column,
        yaxis=dict(autorange="reversed"),
        plot_bgcolor='white',    # Set the plot background to white
        paper_bgcolor='white'    # Set the paper (overall) background to white
    )
    
    return fig


def update_cluster_pie_chart(table_data):
    if table_data:
        df_result = pd.DataFrame(table_data)
        cluster_counts = df_result['cluster'].value_counts().reset_index()
        cluster_counts.columns = ['cluster', 'count']
        cluster_unique = df_result['cluster'].unique()

        # Create a mapping from cluster names to their corresponding colors
        cluster_color_map = {cluster: color_palette[i % 10] for i, cluster in enumerate(cluster_unique)}

        # Use the color map to create a list of colors for the pie chart
        cluster_colors = [cluster_color_map[cluster] for cluster in cluster_counts['cluster']]
  
        fig = go.Figure(data=[go.Pie(
            labels=cluster_counts['cluster'], 
            values=cluster_counts['count'], 
            marker=dict(colors=cluster_colors),
            textinfo='percent',
            insidetextorientation='radial'
        )])
        
        # Update the layout to position the legend below the pie chart
        fig.update_layout(
            uniformtext_minsize=8, 
            uniformtext_mode='hide',
            showlegend=True,
            legend=dict(
                orientation="h",
                yanchor="top",
                y=-0.1,  # Adjust this value to move the legend closer or further from the chart
                xanchor="center",
                x=0.5
            )
        )
        
        return fig
    return {}


def update_numerical_feature_multiselect(table_data):
    if table_data:
        df = pd.DataFrame(table_data)
        numerical_features = [feature for feature in df.select_dtypes(include='number').columns if pd.api.types.is_numeric_dtype(df[feature])]
        feature_options = [{'label': feature, 'value': feature} for feature in numerical_features]

        # Add "Select All" option
        feature_options.insert(0, {"label": "Select All", "value": "select_all"})
        
        return feature_options
    return []

def create_average_heatmap(df, selected_features, cluster_column_name):
    """
    This function normalizes each numerical feature individually to a [0,1] scale,
    calculates the average value by cluster, and returns a heatmap figure with hover text showing actual values.
    """
    # Select only numerical features from selected features
    numerical_features = [feature for feature in selected_features if pd.api.types.is_numeric_dtype(df[feature])]

    # Normalize each numerical feature to a common scale [0,1] individually
    normalized_df = df[numerical_features].apply(lambda x: (x - x.min()) / (x.max() - x.min()))
    normalized_df[cluster_column_name] = df[cluster_column_name]

    # Calculate the average for each normalized feature by cluster
    average_df = normalized_df.groupby(cluster_column_name)[numerical_features].mean()

    # Calculate the actual average for hover text
    actual_averages = df.groupby(cluster_column_name)[numerical_features].mean()

    # Create heatmap using the normalized average values
    fig = px.imshow(
        average_df.T,
        aspect="auto",
        color_continuous_scale=px.colors.sequential.Viridis[::-1],
        labels=dict(x="Segment", y="Feature", color=" ")
    )

    # Add hover template with actual values
    fig.update_traces(
        hovertemplate="Segment: %{x}<br>Feature: %{y}<br>Average Value: %{customdata:.2f}",

        #hovertemplate="Segment: %{x}<br>Feature: %{y}<br>Normalized Value: %{z:.2f}<br>Actual Value: %{customdata:.2f}",
        customdata=actual_averages.T.values
    )

    # Update x-axis to show tick labels
    fig.update_xaxes(showticklabels=True)

    # Update color bar to show "Low" and "High" instead of numbers
    fig.update_coloraxes(colorbar=dict(
        tickvals=[0, 1],
        ticktext=["Low", "High"]
    ))

    return html.Div([dcc.Graph(figure=fig)])

def create_sankey(dataset, column_1, column_2):
    # Check if the dataset is empty
    if dataset.empty:
        return create_empty_figure()

    df = dataset.copy()

    # Count transitions from column_1 (old_version) to column_2 (new_segmentation)
    transition_counts = df.groupby([column_1, column_2]).size().reset_index(name='count')

    # Sort clusters for consistent visual order
    old_clusters = sorted(df[column_1].dropna().unique())
    new_clusters = sorted(df[column_2].dropna().unique())

    # Create list of nodes
    nodes = [{'label': cluster} for cluster in old_clusters] + [{'label': cluster} for cluster in new_clusters]

    # Create index mapping for nodes
    cluster_indices_old = {cluster: i for i, cluster in enumerate(old_clusters)}
    cluster_indices_new = {cluster: i + len(old_clusters) for i, cluster in enumerate(new_clusters)}

    # Assign node colors using color_palette
    node_colors_old = [color_palette[i % len(color_palette)] for i in range(len(old_clusters))]
    node_colors_new = [color_palette[cluster_indices_old.get(cluster, 0) % len(color_palette)] for cluster in new_clusters]
    node_colors = node_colors_old + node_colors_new

    # Build Sankey links with labels and values
    links = {
        'source': [cluster_indices_old[src] for src in transition_counts[column_1]],
        'target': [cluster_indices_new[tgt] for tgt in transition_counts[column_2]],
        'value': transition_counts['count'].tolist(),
        'customdata': [f"{src} → {tgt}" for src, tgt in zip(transition_counts[column_1], transition_counts[column_2])],
        'label': [f"{src} → {tgt} ({count})"
                  for src, tgt, count in zip(
                      transition_counts[column_1],
                      transition_counts[column_2],
                      transition_counts['count']
                  )]
    }

    # Color links based on source node
    link_colors = [node_colors[source] for source in links['source']]

    # Safely truncate node labels
    def safe_label_truncate(label, max_len=20):
        if pd.isna(label):
            return 'Unknown'
        return label[:max_len] + '...' if len(label) > max_len else label

    # Create the Sankey figure
    fig = go.Figure(data=[go.Sankey(
        node=dict(
            pad=30,
            thickness=20,
            line=dict(color="black", width=0.5),
            label=[safe_label_truncate(n['label']) for n in nodes],
            color=node_colors
        ),
        link=dict(
            source=links['source'],
            target=links['target'],
            value=links['value'],
            customdata=links['customdata'],
            label=links['label'],  # Optional visual labels for links (might not render on all plotly versions)
            hovertemplate='Transition: %{customdata}<br>Count: %{value}<extra></extra>',
            color=link_colors
        )
    )])

    # Format column names for annotation
    source_label = column_1.replace('_', ' ').title()
    target_label = column_2.replace('_', ' ').title()

    # Update layout with annotations and spacing
    fig.update_layout(
        title_text="",
        font=dict(family="Helvetica Neue", size=12, color="black"),
        margin=dict(l=50, r=50, t=50, b=50),
        annotations=[
            dict(
                x=-0.05, y=1.05, xref="paper", yref="paper",
                text=source_label, showarrow=False,
                font=dict(size=14, color="black"),
                align="left"
            ),
            dict(
                x=1.05, y=1.05, xref="paper", yref="paper",
                text=target_label, showarrow=False,
                font=dict(size=14, color="black"),
                align="right"
            )
        ]
    )

    return fig


def create_metadata_record(session_name, segmentation_version, status, original_dataset, segmentation_method, description, specifications, selected_features, cluster_remap_dict):
    # Get the headers from the current request context
    headers = dict(request.headers)
    # Use Dataiku API to get the auth info based on these headers
    auth_info = client.get_auth_info_from_browser_headers(headers)
    return {
        "session_name": session_name,
        "segmentation_version": segmentation_version,
        "status": status,
        "original_dataset": original_dataset,
        "username": auth_info["authIdentifier"],  
        "datetime": datetime.now().strftime("%d/%m/%Y %H:%M:%S"),
        "method": segmentation_method,
        "description": description,
        "specifications": specifications,
        "selected_features": selected_features,
        "cluster_name_remap": cluster_remap_dict 
    }

def save_data_to_dataiku(session_name, remapped_data, metadata_record, method, kmeans_pipeline_serialized=None, rb_bounds=None, rb_weights=None):
    try:
        # Convert remapped_data to DataFrame with error handling
        try:
            df = pd.DataFrame(remapped_data)
            if df.empty:
                return html.P("Error: Cannot save empty results dataset.", style=div_message_style)
        except Exception as df_error:
            return html.P(f"Error creating DataFrame from segmentation results: {str(df_error)}", style=div_message_style)
        
        # Save main segmentation file
        try:
            session_filename = f"{session_name}.csv"
            csv_data = df.to_csv(index=False).encode("utf-8")
            output_folder.upload_data(session_filename, csv_data)
        except Exception as csv_error:
            return html.P(f"Error saving main results file: {str(csv_error)}", style=div_message_style)
        
        # Save version file 
        try:
            version_filename = f"{session_name}_0.csv"
            versions_folder.upload_data(version_filename, csv_data)  # Reuse already encoded data
        except Exception as version_error:
            # Continue even if version save fails
            print(f"Warning: Could not save version file: {str(version_error)}")
        
        # Update metadata
        try:
            metadata_dataset = dataiku.Dataset(metadata_dataset_var)
            metadata_df = metadata_dataset.get_dataframe()
            new_metadata_df = pd.DataFrame([metadata_record])
            updated_metadata_df = pd.concat([metadata_df, new_metadata_df], ignore_index=True)
            metadata_dataset.write_with_schema(updated_metadata_df)
        except Exception as metadata_error:
            return html.P(f"Error updating metadata: {str(metadata_error)}", style=div_message_style)
        
        # Save method-specific model information
        if method == "kmeans" and kmeans_pipeline_serialized:
            try:
                pickle_filename = f"{session_name}.pkl"
                kmeans_pipeline = pickle.loads(base64.b64decode(kmeans_pipeline_serialized.encode('utf-8')))
                pickle_bytes = io.BytesIO()
                pickle.dump(kmeans_pipeline, pickle_bytes)
                with output_model_folder.get_writer(pickle_filename) as w:
                    w.write(pickle_bytes.getvalue())
            except Exception as model_error:
                return html.P(f"Error saving KMeans model (data was saved but model might be unavailable for future use): {str(model_error)}", style=div_message_style)
        elif method == "rule_based" and rb_bounds and rb_weights:
            try:
                bounds_dict = {
                    "session_name": session_name,
                    "bounds": rb_bounds,
                    "weights": rb_weights
                }
                bounds_dataset = dataiku.Dataset(rule_based_dataset_var)
                bounds_df = bounds_dataset.get_dataframe()
                new_bounds_df = pd.DataFrame([bounds_dict])
                updated_bounds_df = pd.concat([bounds_df, new_bounds_df], ignore_index=True)
                bounds_dataset.write_with_schema(updated_bounds_df)
            except Exception as rb_error:
                return html.P(f"Error saving rule-based parameters (data was saved but parameters might be unavailable for future use): {str(rb_error)}", style=div_message_style)
        
        return html.P(f"{session_name}.csv saved in the output_data_folder and metadata dataset in the project flow is updated with a new record!", style=div_message_style)
    except Exception as e:
        print(f"Error saving data: {e}")
        return html.P(f"Error: Failed to save data. {str(e)}", style=div_message_style)

    
def save_updated_data_to_dataiku(session_name, updated_version, output_data, metadata_record):
    try:
        df = pd.DataFrame(output_data)
        
        session_filename = f"{session_name}.csv"
        output_folder.upload_data(session_filename, df.to_csv(index=False).encode("utf-8"))

        version_filename = f"{session_name}_" + str(updated_version) + ".csv"
        versions_folder.upload_data(version_filename, df.to_csv(index=False).encode("utf-8"))
        metadata_dataset = dataiku.Dataset(metadata_dataset_var)
        metadata_df = metadata_dataset.get_dataframe()
        # update the row with new values
        row_index = metadata_df.loc[(metadata_df['session_name'] == session_name) & (metadata_df['status'] == 'active')].index[0]
        metadata_df.loc[row_index, 'status'] = 'inactive'
        new_metadata_df = pd.DataFrame([metadata_record])
        updated_metadata_df = pd.concat([metadata_df, new_metadata_df], ignore_index=True)
        metadata_dataset.write_with_schema(updated_metadata_df)
        return html.P(f"{session_name}.csv saved in the output_data_folder and metadata dataset is updated with a new record!", style = div_message_style)
    except Exception as e:
        print(f"Error saving data: {e}")
        return html.P("Failed to save data.", style = div_message_style)