import logging
import numpy as np
import pandas as pd

from dataiku.base.utils import safe_unicode_str, get_argspec
from dataiku.doctor.utils import dku_indexing, dku_nonaninf
import numbers
from dataiku.doctor.diagnostics import diagnostics

logger = logging.getLogger(__name__)

# python2 complains when you want to compile code that contains in the same function
# a subfunction and a exec() statement
def python2_friendly_exec(code, ctx_global, ctx_local):
    exec(code, ctx_global, ctx_local)

def get_custom_scorefunc(code, unprocessed, indices=None, allow_naninf=False):
    dic = {}
    python2_friendly_exec(code, dic, dic)
    if "score" not in dic:
        raise ValueError("Custom evaluation function not defined")
    fn = dic["score"]

    def _wrapped(y_valid, y_pred, sample_weight=None):
        try:
            argspec = get_argspec(fn)

            if 'X_valid' in argspec[0] and unprocessed is None:
                raise ValueError("Cannot use the X_valid parameter in a custom metric with this model")

            if 'X_valid' in argspec[0]:
                if indices is not None:
                    X_valid = dku_indexing(unprocessed, indices)
                else:
                    X_valid = unprocessed
                if 'sample_weight' in argspec[0]:
                    val = fn(y_valid, y_pred, sample_weight=sample_weight, X_valid=X_valid)
                else:
                    val = fn(y_valid, y_pred, X_valid=X_valid)
            else:
                if 'sample_weight' in argspec[0]:
                    val = fn(y_valid, y_pred, sample_weight=sample_weight)
                else:
                    val = fn(y_valid, y_pred)
        except Exception as e:
            logger.exception("Custom scoring function failed")
            raise ValueError("Custom scoring function failed: %s" % (e))
        check_customscore(val, allow_naninf=allow_naninf)
        return val

    def type_wrapper(y_valid, y_pred, sample_weight=None):
        # https://app.shortcut.com/dataiku/story/138579
        # Useful for custom metric evaluation backwards compatibility
        cast_ypred = None
        if isinstance(y_pred, np.ndarray):
            def cast_ypred(arr): return pd.Series(arr, indices)
        try:
            val = _wrapped(y_valid, y_pred, sample_weight)
        except ValueError as e:
            if cast_ypred:
                logger.warning("Evaluation failed with y_pred as np.array. Try as pd.Series")
                try:
                    val = _wrapped(y_valid, cast_ypred(y_pred), sample_weight)
                    logger.warning("Evaluation of custom metric recovered from failure with y_pred as pd.Series successful. Please consider updating your metric definition.")
                except Exception as pd_series_exception:
                    logger.error("Evaluation failed with fallback y_pred as pd.Series: %s" % (pd_series_exception))
                    raise e
            else:
                raise e
        except TypeError as e:
            logger.error("Unexpected type for score: %s" % (e))
            raise e
        return val
    return type_wrapper

def get_custom_evaluation_metric(metrics_params):
    for m in metrics_params["customMetrics"]:
        if m["name"] == metrics_params["customEvaluationMetricName"]:
            return m
    raise ValueError("The selected custom optimisation metric does not exist in the defined custom metrics")

def get_custom_metric_scorefunc(code, unprocessed, indices=None, allow_naninf=False):
    if not code:
        raise ValueError("You must write the custom metric code")
    return get_custom_scorefunc(code, unprocessed, indices, allow_naninf)

def get_custom_evaluation_metric_scorefunc(custom_evaluation_metric, unprocessed, indices=None):
    if custom_evaluation_metric is None:
        raise ValueError("The selected custom optimisation metric does not exist in the defined custom metrics")
    if custom_evaluation_metric.get("metricCode", None) is None:
        raise ValueError("The selected custom optimisation metric appears not to contain any custom code, please ensure that the metric is correctly defined.")

    code = custom_evaluation_metric["metricCode"]

    return get_custom_scorefunc(code, unprocessed, indices)

def get_custom_score_from_custom_metrics_results(custom_metrics_results, custom_metric_name, per_cut_data=False):
    custom_score_results_gen = (
        custom_metric_result for custom_metric_result in custom_metrics_results if custom_metric_result["metric"]["name"] == custom_metric_name
    )
    return next(custom_score_results_gen)["values" if per_cut_data else "value"]

def check_customscore(score, allow_naninf=False):
    _, error = get_custom_score_or_None_and_error(score, allow_naninf)
    if error:
        raise error

def get_custom_score_or_None_and_error(score, allow_naninf=False):
    if score is None:
        return None, ValueError("Custom evaluation function returned None. Illegal value")
    elif not isinstance(score, numbers.Number):
        message = 'Result value {} of type {} did not match expected type of Number'.format(score, type(score))
        return None, TypeError(message)
    elif not allow_naninf and np.isnan(score):
        return None, ValueError("Custom evaluation function returned NaN. Illegal value")
    elif not allow_naninf and np.isinf(score):
        return None, ValueError("Custom evaluation function returned Infinity. Illegal value")
    else:
        return score, None

def calculate_overall_classification_custom_metrics(metric_params, preds, probas, sample_weight, unprocessed, valid_y):
    """
    Calculates custom metric values for classification models. Does not perform any per-cut calculations.
    """
    custom_metrics_results = []
    for custom_metric in metric_params["customMetrics"]:
        if custom_metric["needsProbability"]:
            custom_metric_result = parse_and_compute_custom_metric(custom_metric=custom_metric,
                                                                   valid_unprocessed=unprocessed,
                                                                   valid_y=valid_y,
                                                                   preds_or_probas=probas,
                                                                   sample_weight=sample_weight)
        else:
            custom_metric_result = parse_and_compute_custom_metric(custom_metric=custom_metric,
                                                                   valid_unprocessed=unprocessed,
                                                                   valid_y=valid_y,
                                                                   preds_or_probas=preds,
                                                                   sample_weight=sample_weight)

        custom_metrics_results.append(custom_metric_result)
    return custom_metrics_results

def calculate_regression_custom_metrics(metric_params, valid_unprocessed, valid_y, preds, sample_weight):
    custom_metric_results = []
    for custom_metric in metric_params["customMetrics"]:
        custom_metric_result = parse_and_compute_custom_metric(custom_metric=custom_metric,
                                                               valid_unprocessed=valid_unprocessed,
                                                               valid_y=valid_y,
                                                               preds_or_probas=preds,
                                                               sample_weight=sample_weight)

        custom_metric_results.append(custom_metric_result)
    return custom_metric_results

def parse_and_compute_custom_metric(custom_metric, valid_unprocessed, valid_y, preds_or_probas, sample_weight):
    custom_metric_result = {
        'metric': custom_metric
    }

    try:
        custom_metric_function = get_custom_metric_scorefunc(custom_metric["metricCode"], valid_unprocessed)
    except Exception as e:
        custom_metric_result["didSucceed"] = False
        custom_metric_result['error'] = safe_unicode_str(e)
        logger.warning("Custom metric function '{}' failed to parse".format(custom_metric['name']), exc_info=True)

        diagnostics.add_or_update(
            diagnostics.DiagnosticType.ML_DIAGNOSTICS_MODELING_PARAMETERS,
            "Calculation of '{}' failed: unable to parse metric code".format(custom_metric_result['metric']['name'])
        )

        return custom_metric_result

    return execute_parsed_custom_metric_function(custom_metric_function, custom_metric_result, valid_y, preds_or_probas, sample_weight=sample_weight)

def execute_parsed_custom_metric_function(custom_metric_function, custom_metric_result, valid_y, preds_or_probas, sample_weight):
    custom_metric = custom_metric_result['metric']
    try:
        res = custom_metric_function(valid_y, preds_or_probas, sample_weight=sample_weight)

        custom_metric_result["value"] = res
        custom_metric_result["didSucceed"] = True
    except Exception as e:
        custom_metric_result["didSucceed"] = False
        custom_metric_result['error'] = safe_unicode_str(e)
        logger.warning("Custom metric function '{}' failed to execute".format(custom_metric['name']), exc_info=True)

        diagnostics.add_or_update(
            diagnostics.DiagnosticType.ML_DIAGNOSTICS_MODELING_PARAMETERS,
            "Calculation of '{}' failed".format(custom_metric_result['metric']['name'])
        )

    return custom_metric_result

class CustomMetricAggregator:
    def __init__(self, metric):
        self.has_failure = False
        self.metric = metric
        self.values = []

    def append_value(self, value):
        self.values.append(value)

def aggregate_custom_metrics_for_cross_val_model(custom_metric_data_per_cut):
    values_per_metric = {}
    for custom_metric_result in custom_metric_data_per_cut[0]:
        metric = custom_metric_result['metric']
        values_per_metric[metric["name"]] = CustomMetricAggregator(metric)

    for fold in custom_metric_data_per_cut:
        for item in fold:
            metric = item['metric']
            did_succeed = item['didSucceed']
            if did_succeed:
                values_per_metric[metric["name"]].append_value(item["value"])
            else:
                values_per_metric[metric["name"]].has_failure = True

    aggregate_custom_metrics_results = []
    for metric_name in values_per_metric:
        aggregator = values_per_metric[metric_name]

        result = {
            "metric": aggregator.metric,
            "didSucceed": not aggregator.has_failure
        }

        if aggregator.has_failure:
            result['error'] = 'One or more folds failed to calculate the metric'
        else:
            values = np.array(aggregator.values)
            result["value"] = dku_nonaninf(np.nanmean(values))
            result["valuestd"] = dku_nonaninf(np.nanstd(values))

        aggregate_custom_metrics_results.append(result)
    return aggregate_custom_metrics_results

def get_custom_metric_functions_binary_classif(metric_params, valid_unprocessed):
    threshold_dependent = []
    threshold_independent = []

    if "customMetrics" in metric_params:
        for custom_metric in metric_params["customMetrics"]:
            custom_metric_function_container = {
                "metric": custom_metric,
                "values": [],
            }

            try:
                custom_metric_function = get_custom_metric_scorefunc(custom_metric["metricCode"], valid_unprocessed)
                custom_metric_function_container["function"] = custom_metric_function
                custom_metric_function_container['hasFailure'] = False
            except Exception as e:
                custom_metric_function_container['hasFailure'] = True
                custom_metric_function_container['error'] = safe_unicode_str(e)
                logger.warning("Custom metric function '{}' failed to parse".format(custom_metric['name']), exc_info=True)

                diagnostics.add_or_update(
                    diagnostics.DiagnosticType.ML_DIAGNOSTICS_MODELING_PARAMETERS,
                    "Calculation of '{}' failed: unable to parse metric code".format(custom_metric_function_container['metric']['name'])
                )

            if custom_metric["needsProbability"]:
                threshold_independent.append(custom_metric_function_container)
            else:
                threshold_dependent.append(custom_metric_function_container)

    return threshold_dependent, threshold_independent

def compute_custom_metrics_for_cut(custom_metric_list, test_y, decision_with_valid_index, test_sample_weight, cut):
    for metric_result_object in custom_metric_list:
        has_failure = metric_result_object['hasFailure']
        metric = metric_result_object['metric']
        if has_failure:
            continue
        else:
            try:
                custom_metric_function = metric_result_object["function"]
                res = custom_metric_function(test_y, decision_with_valid_index, sample_weight=test_sample_weight)

                metric_result_object["values"].append(res)
            except Exception as e:
                metric_result_object['hasFailure'] = True
                metric_result_object['error'] = safe_unicode_str(e)
                logger.warning("Multi-cut custom metric function '{}' failed to calculate for cut '{}', "
                               "the function will not be run for any further cut(s).".format(metric['name'], cut), exc_info=True)

                diagnostics.add_or_update(
                    diagnostics.DiagnosticType.ML_DIAGNOSTICS_MODELING_PARAMETERS,
                    "Calculation of '{}' failed".format(metric['name'])
                )

def build_cv_per_cut_custom_metrics(r1, perfdatas):
    def find_metrics():
        custom_metrics = []
        for metric_result in r1["perCutData"]["customMetricsResults"]:
            custom_metrics.append(metric_result["metric"])
        return custom_metrics

    # for each metric id, collapse fold information
    def collapse_for_name(metric_name, cut):
        accumulated_cut_values = []
        for fold in perfdatas:
            metric_in_fold = next(x for x in fold["perCutData"]["customMetricsResults"] if x["metric"]["name"] == metric_name)
            cut_value = metric_in_fold['values'][cut]
            accumulated_cut_values.append(cut_value)

        values_np = np.array(accumulated_cut_values)
        return values_np.mean(), values_np.std()

    def check_if_failure(metric_name):
        for fold in perfdatas:
            metric_in_fold = next(x for x in fold["perCutData"]["customMetricsResults"] if x["metric"]["name"] == metric_name)

            if not metric_in_fold['didSucceed']:
                return True

        return False

    cut_count = len(r1["perCutData"]["cut"])

    custom_metric_results = []
    for custom_metric in find_metrics():
        cut_means = []
        cut_stds = []
        current_metric_name = custom_metric["name"]

        contains_failure = check_if_failure(current_metric_name)

        if not contains_failure:
            for i in range(cut_count):
                mean, std = collapse_for_name(current_metric_name, i)
                cut_means.append(mean)
                cut_stds.append(std)

            custom_metric_results.append({
                "metric": custom_metric,
                "didSucceed": True,
                "values": cut_means,
                "valuesstd": cut_stds
            })
        else:
            custom_metric_results.append({
                "metric": custom_metric,
                'error': 'One or more folds failed to calculate the metric',
                "didSucceed": False
            })
    return custom_metric_results
