import pandas as pd
import dataiku
import numpy as np
from copy import deepcopy
import logging

# Configure logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Global factors used for capability and control chart calculations
d2_factors = {2: 1.128,
              3: 1.693,
              4: 2.059,
              5: 2.326,
              6: 2.534,
              7: 2.704,
              8: 2.847,
              9: 2.97,
              10: 3.078,
              15: 3.472,
              20: 3.735,
              25: 3.931}

D_factors = {2: {'A2': 1.88, 'D3': 0.0, 'D4': 3.267},
             3: {'A2': 1.023, 'D3': 0.0, 'D4': 2.575},
             4: {'A2': 0.729, 'D3': 0.0, 'D4': 2.282},
             5: {'A2': 0.577, 'D3': 0.0, 'D4': 2.115},
             6: {'A2': 0.483, 'D3': 0.0, 'D4': 2.004},
             7: {'A2': 0.419, 'D3': 0.076, 'D4': 1.924},
             8: {'A2': 0.373, 'D3': 0.136, 'D4': 1.864},
             9: {'A2': 0.337, 'D3': 0.184, 'D4': 1.816},
             10: {'A2': 0.308, 'D3': 0.223, 'D4': 1.777},
             15: {'A2': 0.223, 'D3': 0.347, 'D4': 1.653},
             25: {'A2': 0.153, 'D3': 0.459, 'D4': 1.541}}


def generate_capability_metrics_data(df, measurement_columns):
    """
    Computes Cp, Cpk, Cpm, Cpkm, mean, and std for each column.
    Saves mean/std to Dataiku project variables.
    Returns a single-row DataFrame with flattened column names.
    """

    project = dataiku.api_client().get_default_project()
    variables = project.get_variables()
    project_vars = variables["standard"]

    result_dict = {}

    for col in measurement_columns:
        df[col] = pd.to_numeric(df[col], errors='coerce')
        numeric_series = df[col].dropna()

        if numeric_series.empty:
            logger.warning(f"Column '{col}' has no valid data.")
            continue

        # Save mean and std dev to project variables
        mean = numeric_series.mean()
        std_dev = numeric_series.std()
        project_vars[f"mean_{col}"] = round(mean, 5)
        project_vars[f"std_{col}"] = round(std_dev, 5)

        # Compute subgroup stats
        if "_subgroup" not in df.columns:
            logger.info(f"'_subgroup' column missing. Skipping capability indices for '{col}'.")
            continue

        subgroup_stats = df.groupby('_subgroup').agg(
            range=(col, lambda x: x.max() - x.min()),
            subgroup_size=(col, 'count')
        ).reset_index()

        subgroup_stats = subgroup_stats[subgroup_stats['subgroup_size'] > 0]
        if subgroup_stats.empty:
            logger.warning(f"No valid subgroups for '{col}'.")
            continue

        R_bar = subgroup_stats['range'].mean()
        avg_subgroup_size = subgroup_stats['subgroup_size'].mean()
        logger.info(f"For variable '{col}', R̄ = {R_bar:.5f} and avg_subgroup_size = {avg_subgroup_size:.2f}")

        if R_bar <= 0:
            logger.warning(f"R̄ = 0 for '{col}' → No within-subgroup variation. Skipping metrics.")
            continue

        if avg_subgroup_size < 2:
            logger.warning(f"Subgroup size < 2 for '{col}' → Can't compute range-based sigma. Skipping metrics.")
            continue

        n = int(round(avg_subgroup_size))
        available_sizes = np.array(list(d2_factors.keys()))
        closest_n = available_sizes[np.abs(available_sizes - n).argmin()]
        d2 = d2_factors.get(closest_n)

        if not d2:
            logger.warning(f"No d2 factor for size {closest_n}.")
            continue

        sigma = R_bar / d2
        if sigma == 0:
            logger.warning(f"Sigma = 0 for '{col}'.")
            continue

        USL = variables["standard"].get(f"USL_{col}")
        LSL = variables["standard"].get(f"LSL_{col}")
        target = variables["standard"].get(f"TARGET_{col}")

        if USL is None or LSL is None:
            logger.warning(f"Missing USL or LSL for '{col}'. Skipping.")
            continue

        Cp = (float(USL) - float(LSL)) / (6 * sigma)
        Cpk = min((float(USL) - mean), (mean - float(LSL))) / (3 * sigma)

        result_dict[f"{col}_cp"] = Cp
        result_dict[f"{col}_cpk"] = Cpk

        if target is not None:
            dev_sq = ((mean - target) / sigma) ** 2
            result_dict[f"{col}_cpm"] = Cp / np.sqrt(1 + dev_sq)
            result_dict[f"{col}_cpkm"] = Cpk / np.sqrt(1 + dev_sq)
        else:
            logger.info(f"Target missing for '{col}'. Skipping Cpm/Cpkm.")

        logger.info(f"Metrics generated for '{col}': Cp = {Cp:.4f}, Cpk = {Cpk:.4f}")

    # Save updated variables
    variables["standard"] = project_vars
    project.set_variables(variables)

    # Return as single-row DataFrame
    return pd.DataFrame([result_dict])

def generate_process_capability_visualization_data(df_input, measurement_columns):
    """
    Prepares data for process capability visualization (distribution charts).
    Adds USL, LSL, and target values as columns to the DataFrame.
    Also computes and stores mean and std dev for each measurement column in project variables.
    """
    df = df_input.copy()
    project = dataiku.api_client().get_default_project()
    all_vars = project.get_variables()
    project_vars = all_vars["standard"]

    columns_to_include = measurement_columns.copy()

    for col in measurement_columns:
        logger.info(f"Handling column '{col}'")

        USL = project_vars.get(f'USL_{col}')
        LSL = project_vars.get(f'LSL_{col}')
        target_value = project_vars.get(f'target_{col}')

        logger.info(f"Column '{col}' → USL = {USL}, LSL = {LSL}, Target = {target_value}")

        if USL is not None:
            df[f'USL_{col}'] = USL
            columns_to_include.append(f'USL_{col}')
        else:
            logger.warning(f"USL not found for column '{col}'. Skipping USL column for visualization.")

        if LSL is not None:
            df[f'LSL_{col}'] = LSL
            columns_to_include.append(f'LSL_{col}')
        else:
            logger.warning(f"LSL not found for column '{col}'. Skipping LSL column for visualization.")

        if target_value is not None:
            df[f'target_{col}'] = target_value
            columns_to_include.append(f'target_{col}')
        else:
            logger.info(f"Target not found for column '{col}'. Skipping target column for visualization.")

        col_mean = df[col].mean()
        col_std = df[col].std()
        project_vars[f'mean_{col}'] = round(col_mean, 5)
        project_vars[f'std_{col}'] = round(col_std, 5)

    all_vars["standard"] = project_vars
    project.set_variables(all_vars)

    df_filtered = df[list(set(columns_to_include).intersection(df.columns))]
    return df_filtered

def generate_selected_study_data_enhanced(df_input, measurement_columns):
    """
    Enhances the study data with rank, subgroup, control chart limits, target values,
    and constant zone/range columns from project variables.
    """
    df = df_input.copy()

    if '_rank' not in df.columns:
        logger.info("'_rank' column not found. Creating it based on index for study data enhancement.")
        df['_rank'] = range(1, len(df) + 1)

    project_vars = dataiku.get_custom_variables(typed=True)
    subgroup_size_str = project_vars.get('subgroup_size', None)

    if subgroup_size_str is None:
        raise ValueError("Project variable 'subgroup_size' is not defined. Cannot compute control chart inputs.")

    try:
        subgroup_size = int(subgroup_size_str)
        if subgroup_size <= 0:
            raise ValueError("Subgroup size must be a positive integer.")
    except (ValueError, TypeError) as e:
        raise ValueError(f"Invalid 'subgroup_size' project variable: {subgroup_size_str}. Error: {e}")

    df['_subgroup'] = ((df['_rank'] - 1) // subgroup_size + 1).astype(str).str.zfill(3)

    date_time_column = None
    for col in df.columns:
        if 'date' in col.lower() or 'time' in col.lower() or 'timestamp' in col.lower():
            try:
                pd.to_datetime(df[col], errors='raise')
                date_time_column = col
                logger.info(f"Identified '{col}' as a date/time column for study data enhancement.")
                break
            except (ValueError, TypeError):
                pass
    if not date_time_column and df.index.name is None:
        try:
            df.index = pd.to_datetime(df.index)
            date_time_column = df.index.name if df.index.name else 'index_timestamp'
            df.index.name = date_time_column
            df = df.reset_index()
            logger.info("Identified index as a date/time for study data enhancement.")
        except (ValueError, TypeError):
            logger.warning("Could not identify index as date/time for study data enhancement.")

    new_columns_to_add = []

    for col in measurement_columns:
        df[col] = pd.to_numeric(df[col], errors='coerce')
        subgroup_stats = df.groupby('_subgroup').agg(
            mean=(col, 'mean'),
            range=(col, lambda x: x.max() - x.min()),
            subgroup_size=(col, 'count')
        ).reset_index()
        subgroup_stats = subgroup_stats[subgroup_stats['subgroup_size'] > 0]

        if subgroup_stats.empty:
            logger.warning(f"No valid subgroups found for column '{col}'. Skipping control chart limits.")
            continue

        R_bar = subgroup_stats['range'].mean()
        X_bar = subgroup_stats['mean'].mean()
        avg_subgroup_size = subgroup_stats['subgroup_size'].mean()

        if avg_subgroup_size < 2:
            logger.warning(f"Average subgroup size for column '{col}' is less than 2. Skipping control chart limits.")
            continue

        n = int(round(avg_subgroup_size))
        available_sizes = np.array(list(D_factors.keys()))
        closest_n = available_sizes[np.abs(available_sizes - n).argmin()]
        factors = D_factors.get(closest_n)
        if not factors:
            logger.warning(f"Control chart factors not found for subgroup size {closest_n}.")
            continue

        A2, D3, D4 = factors['A2'], factors['D3'], factors['D4']

        UCL_Xbar = X_bar + A2 * R_bar
        LCL_Xbar = X_bar - A2 * R_bar
        UCL_R = D4 * R_bar
        LCL_R = D3 * R_bar

        subgroup_mean_col = f'mean_{col}'
        subgroup_range_col = f'range_{col}'
        subgroup_stats_renamed = subgroup_stats[['_subgroup', 'mean', 'range']].rename(
            columns={'mean': subgroup_mean_col, 'range': subgroup_range_col}
        )
        df = pd.merge(df, subgroup_stats_renamed, on='_subgroup', how='left')
        new_columns_to_add.extend([subgroup_mean_col, subgroup_range_col])

        df[f'UCL_Xbar_{col}'] = UCL_Xbar
        df[f'LCL_Xbar_{col}'] = LCL_Xbar
        df[f'UCL_R_{col}'] = UCL_R
        df[f'LCL_R_{col}'] = LCL_R
        new_columns_to_add.extend([f'UCL_Xbar_{col}', f'LCL_Xbar_{col}', f'UCL_R_{col}', f'LCL_R_{col}'])

        target_value = project_vars.get(f'target_{col}')
        if target_value is not None:
            df[f'target_{col}'] = target_value
            new_columns_to_add.append(f'target_{col}')
        else:
            logger.info(f"Target variable 'target_{col}' not found for column '{col}'.")

    for key, value in project_vars.items():
        if key.endswith("_lower_zone") or key.endswith("_upper_zone"):
            df[key] = value
            new_columns_to_add.append(key)

    for key, value in project_vars.items():
        if key.endswith("_selected_range_min") or key.endswith("_selected_range_max"):
            df[key] = value
            new_columns_to_add.append(key)

    columns_to_keep = measurement_columns + ['_rank', '_subgroup'] + list(set(new_columns_to_add))
    if date_time_column and date_time_column not in columns_to_keep:
        columns_to_keep.append(date_time_column)

    final_df_columns = list(set(columns_to_keep).intersection(df.columns))
    return df[final_df_columns]


def insight_builder(project, base_info, variable_name, name_suffix,
                    measures, reference_column=None, custom_colors=None,
                    custom_formula=None, reference_lines=None,
                    additional_full_measures=None):
    info = deepcopy(base_info)
    info['name'] = f"{name_suffix} - {variable_name}"
    chart_def = info['params']['def']
    chart_def['name'] = f"{name_suffix} - {variable_name}"

    # Fill in basic string-based column measures
    for idx, col in enumerate(measures):
        if isinstance(col, str):
            chart_def['genericMeasures'][idx]['column'] = col
        else:
            chart_def['genericMeasures'][idx] = col

    # Add additional full measure blocks (e.g., CPM, CPKM)
    if additional_full_measures:
        chart_def['genericMeasures'].extend(additional_full_measures)

    # ---- Y-axis title: set explicitly with the variable name ----
    # Infer whether this is a Mean or Range chart from provided measures
    mcols = [
        (m if isinstance(m, str) else m.get('column', ''))
        for m in measures
    ]
    y_title = None
    if any(f"mean_{variable_name}" == c for c in mcols):
        y_title = f"Mean of {variable_name}"
    elif any(f"range_{variable_name}" == c for c in mcols):
        y_title = f"Range of {variable_name}"

    if y_title:
        for ax in chart_def.get('yAxesFormatting', []):
            if ax.get('id') == 'y_left_0':
                ax['axisTitle'] = y_title

    # Reference line (aggregated column)
    if reference_column:
        chart_def['referenceLines'][0]['aggregatedColumn']['column'] = reference_column

    # Merge custom colors (don’t overwrite existing mappings)
    if custom_colors:
        chart_def.setdefault("colorOptions", {}).setdefault('customColors', {}).update(custom_colors)

    # Optional custom formula for second measure
    if custom_formula and len(chart_def['genericMeasures']) > 1:
        chart_def['genericMeasures'][1]['customFunction'] = custom_formula

    # Dataset column reference lines (e.g., USL/LSL)
    if reference_lines:
        for idx, col in enumerate(reference_lines):
            chart_def['referenceLines'][idx]['datasetColumn']['column'] = col

    return project.create_insight(info).insight_id


def make_metric_block(column_name, label):
    return {
        'column': column_name,
        'function': 'AVG',
        'type': 'NUMERICAL',
        'displayed': True,
        'isA': 'measure',
        'displayAxis': 'axis1',
        'displayType': 'column',
        'computeMode': 'NORMAL',
        'computeModeDim': 0,
        'multiplier': 'Auto',
        'decimalPlaces': 2,
        'digitGrouping': 'DEFAULT',
        'useParenthesesForNegativeValues': False,
        'shouldFormatInPercentage': False,
        'hideTrailingZeros': False,
        'prefix': '',
        'suffix': '',
        'showValue': True,
        'displayLabel': label,
        'showDisplayLabel': True,
        'labelPosition': 'BOTTOM',
        'labelFontSize': 16,
        'percentile': 50.0,
        'isCustomPercentile': False,
        'kpiTextAlign': 'CENTER',
        'kpiValueFontSizeMode': 'RESPONSIVE',
        'kpiValueFontSize': 32,
        'responsiveTextAreaFill': 100,
        'valueTextFormatting': {'fontSize': 11, 'fontColor': '#333', 'hasBackground': False},
        'labelTextFormatting': {'fontSize': 11, 'fontColor': '#333', 'hasBackground': False},
        'valuesInChartDisplayOptions': {
            'displayValues': True,
            'textFormatting': {
                'fontSize': 11,
                'fontColor': 'AUTO',
                'hasBackground': False,
                'backgroundColor': '#D9D9D9BF'
            },
            'addDetails': False,
            'additionalMeasures': []
        },
        'colorRules': []
    }


def is_under_control(df, parameter, sort_column, ucl_xbar, lcl_xbar, ucl_r, lcl_r):
    """
    Determine if a parameter is statistically under control using:
    - Rule 1: Any subgroup mean outside UCL/LCL for Xbar
    - Rule 2: Any subgroup range outside UCL/LCL for R chart

    Parameters:
    - df: DataFrame with aggregated values (must contain {parameter}_mean and {parameter}_range)
    - parameter: base column name of the parameter (e.g. "pressure")
    - sort_column: column name for ordering (e.g. "_subgroup")
    - ucl_xbar/lcl_xbar: numeric limits for Xbar chart
    - ucl_r/lcl_r: numeric limits for R chart

    Returns:
    - is_controlled (bool): True if under control
    - violations (dict): {"rule1_xbar": bool, "rule1_r": bool}
    """
    df = df.sort_values(by=sort_column).reset_index(drop=True)

    mean_col = f"mean_{parameter}"
    range_col = f"range_{parameter}"

    # Rule 1 (Xbar): subgroup means outside control limits
    out_of_bounds_xbar = (df[mean_col] > ucl_xbar) | (df[mean_col] < lcl_xbar)
    rule1_xbar_violation = out_of_bounds_xbar.any()

    # Rule 2 (R chart): subgroup ranges outside control limits
    out_of_bounds_r = (df[range_col] > ucl_r) | (df[range_col] < lcl_r)
    rule1_r_violation = out_of_bounds_r.any()

    violations = {
        "rule1_xbar": rule1_xbar_violation,
        "rule1_r": rule1_r_violation,
    }

    is_controlled = not any(violations.values())
    return is_controlled, violations
