import logging
import numpy as np
import pandas as pd
from math import sqrt

from sklearn.metrics import mean_squared_error
from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score
from sklearn.preprocessing import label_binarize

from dataiku.base.utils import package_is_at_least
from dataiku.doctor.utils import dku_nonaninf
from dataiku.doctor.utils.calibration import dku_calibration_curve, dku_calibration_loss
from dataiku.doctor.utils.pandascompat import to_numpy
# Rename import to avoid a conflict with sklearn.metrics confusion_matrix() function
from dataiku.doctor.utils.skcompat import confusion_matrix as compat_confusion_matrix

logger = logging.getLogger(__name__)


def handle_failure(metric_lambda, treat_failure_as_error):
    """
    Allows to either raise metrics errors or only log them as warnings
    (used for now in evaluation recipe where we don't want to stop the process for optional metrics)
    """
    try:
        return metric_lambda()
    except Exception as e:
        if treat_failure_as_error:
            raise e
        else:
            logger.warning(str(e), exc_info=True)
            return None

##################
# Classification #
##################

def log_odds(array, clip_min=0., clip_max=1.):
    """ Compute the log odd of each elements of a array
    logodd = p / (1-p) with p a probability
    :param array: numpy array or pandas Series
    :param clip_min: (float) minimum value
    :param clip_max: (float) maximum value
    :return: a numpy array with the same dimension as input array
    """
    a = array.astype(float)
    a = np.clip(a, clip_min, clip_max)
    return np.log(a / (1 - a))


def check_test_set_ok_for_classification(y_true, treat_failure_as_error=True):
    classes = np.unique(y_true)
    if len(classes) < 2:
        if treat_failure_as_error:
            raise ValueError("Ended up with only one class in the test set. Cannot proceed")
        else:
            logger.warning("Ended up with only one class in the test set. trying to proceed anyway")


def log_loss(y_true, y_pred, normalize=True, sample_weight=None):
    """Log loss, aka logistic loss or cross-entropy loss.

    sk-learn version is bugged when a class
    never appears in the predictions.
    Check test_cases_binary_extreme in test_metrics for examples reproducing - the stance of sklearn on those cases
    is that the user is expected to provide class labels, but for us that's not practicable.
    """
    import sklearn
    if package_is_at_least(sklearn, "1.3"):
        eps = np.finfo(y_pred.dtype).eps
    else:
        eps = 1e-15
    (nb_rows, nb_classes) = y_pred.shape
    assert y_true.shape == (nb_rows,)
    assert y_true.max() <= nb_classes - 1
    T = np.zeros((nb_rows, nb_classes))
    if sample_weight is not None:
        renorm = np.sum(sample_weight)
        sample_weight = to_numpy(sample_weight)
    else:
        renorm = nb_rows
        sample_weight = np.ones(nb_rows)
    np.put_along_axis(T, to_numpy(y_true.astype(int))[:, np.newaxis], sample_weight[:, np.newaxis], axis=1)

    Y = np.clip(y_pred, eps, 1 - eps)
    Y /= Y.sum(axis=1)[:, np.newaxis]
    loss = -(T * np.log(Y)).sum()
    return loss / renorm if normalize else loss


def mroc_auc_score(y_true, y_predictions, sample_weight=None, average="macro"):
    """ Returns a auc score. Handles multi-class

    For multi-class, the AUC score is in fact the MAUC
    score described in


    David J. Hand and Robert J. Till. 2001.
    A Simple Generalisation of the Area Under the ROC Curve
    for Multiple Class Classification Problems.
    Mach. Learn. 45, 2 (October 2001), 171-186.
    DOI=10.1023/A:1010920819831

    http://dx.doi.org/10.1023/A:1010920819831
    """
    (nb_rows, max_nb_classes) = y_predictions.shape
    # Today, it may happen that if a class appears only once in a dataset
    # it can appear in the train and not in the validation set.
    # In this case it will not be in y_true and
    # y_predictions.nb_cols is not exactly the number of class
    # to consider when computing the mroc_auc_score.
    classes = np.unique(y_true)
    nb_classes = len(classes)
    if nb_classes > max_nb_classes or nb_classes < 2:
        raise ValueError("Could not compute AUC: {reason}"
                         .format(reason="Your test set contained more classes than the train set. Check your dataset or try a different split." \
                             if nb_classes > max_nb_classes else "Ended up with less than two-classes in the validation set."))

    if max_nb_classes == 2:  # assuming binary classification
        assert nb_classes == 2  # otherwise a ValueError should have been raised earlier
        classes = classes.tolist()
        y_true = y_true.map(lambda c: classes.index(c))  # ensure classes are [0 1]
        return roc_auc_score(y_true, y_predictions[:, 1], sample_weight=sample_weight)

    def A(i, j):
        """
        Returns an asymmetric proximity metric, written A(i | j)
        in the paper.

        The sum of all (i, j) with  i != j
        will give us the symmetry.
        """
        mask_i_j = np.in1d(y_true, np.array([i, j]))
        y_true_i_j = y_true[mask_i_j] == i
        proba_i_j = y_predictions[mask_i_j][:, i]
        if sample_weight is not None:
            sample_weight_i_j = sample_weight[mask_i_j]
        else:
            sample_weight_i_j = None
        return roc_auc_score(y_true_i_j, proba_i_j, sample_weight=sample_weight_i_j)

    if average == "macro":
        weigth_matrix = {i: {j: 1 for j in classes} for i in classes}
        C = 1.0 / (nb_classes * (nb_classes - 1))
    elif average == "weighted":
        weigth_matrix = {c: {} for c in classes}
        for i in classes:
            for j in classes:
                if i < j:
                    mask_i_j = np.in1d(y_true, np.array([i, j]))
                    if sample_weight is None:
                        weigth_matrix[i][j] = sum(mask_i_j)
                    else:
                        weigth_matrix[i][j] = sum(sample_weight[mask_i_j])
                elif j == i:
                    continue
                else:
                    weigth_matrix[i][j] = weigth_matrix[j][i]
        C = 1. / sum(weigth_matrix[i][j] for i in classes for j in classes if i != j)
    else:
        raise ValueError("Unknown multiclass averaging method: {}".format(average))
    return C * sum(
        A(i, j) * weigth_matrix[i][j]
        for i in classes
        for j in classes
        if i != j)


def m_average_precision_score(y_true, probas, sample_weight=None, average="macro"):
    """
    Returns an Average Precision score. Handles multi-class.
    For multi-class, the average precision score is the (macro) averaged
    Average Precision score of all OvsR classes.

    :param pd.Series y_true: Observed classes, shape = [n_samples]
    :param np.ndarray probas: Predicted class probabilities, shape = [n_samples, n_classes]
    :param numpy.ndarray or pd.Series or list sample_weight: Optional sample weights, shape = [n_samples]
    :param average: "macro" or "weighted", if set to "macro" the unweighted mean for calibration loss will be returned. If set to "weighted" , the weighted mean
        by support (number of true instances for each class, weighted by their weight is sample_weight is not None).
    :return float: the average precision score
    """
    (nb_rows, max_nb_classes) = probas.shape

    classes = np.unique(y_true)
    nb_classes = len(classes)
    if nb_classes > max_nb_classes or nb_classes < 2:
        raise ValueError("Could not compute Average Precision: : {reason}"
                         .format(reason="Your test set contained more classes than the train set. Check your dataset or try a different split."
                         if nb_classes > max_nb_classes else "Ended up with less than two-classes in the validation set."))

    if max_nb_classes == 2:  # assuming binary classification
        assert nb_classes == 2  # otherwise a ValueError should have been raised earlier
        classes = classes.tolist()
        y_true = y_true.map(lambda c: classes.index(c))  # ensure classes are [0 1]
        return average_precision_score(y_true, probas[:, 1], sample_weight=sample_weight)

    if nb_classes < max_nb_classes:
        logger.warning("Some classes are missing in the test set. Average classes scores will be computed using non missing classes")

    all_classes = np.arange(max_nb_classes)
    y_true_binarized = label_binarize(y_true, classes=all_classes)
    if sample_weight is not None and not isinstance(sample_weight, np.ndarray):
        sample_weight = to_numpy(sample_weight)  # Sk-learn scoring functions requires array-like function
    if average == "weighted":
        return average_precision_score(y_true_binarized, probas, average="weighted", sample_weight=sample_weight)
    else:
        # In case of missing classes (nb_classes < max_nb_classes), we can't rely on average='macro'
        # because missing classes scores will be computed as 'nan', and the returned average will be 'nan' too.
        # This is why, we are requesting none-averaged classes scores (average=None), filtering out 'nan' values,
        # and computing the average manually on available classes scores.
        classes_scores = average_precision_score(y_true_binarized, probas, average=None, sample_weight=sample_weight)  # will return an array because average=None
        classes_scores = classes_scores[~np.isnan(classes_scores)]
        return np.average(classes_scores)


def calibration_loss_binary(y_true, probas, sample_weight=None):
    # binary case only. for multiclass see MulticlassModelScorer.get_calibration_metrics_and_curves
    classes = np.unique(y_true)
    nb_classes = len(classes)

    if nb_classes != 2:
        raise ValueError("Could not compute calibration loss: Ended up with {} classes in the validation set, expected 2 classes for binary classification"
                         .format(nb_classes))

    classes = classes.tolist()
    y_true = y_true.map(lambda c: classes.index(c))  # ensure classes are [0 1]
    probas = probas[:, 1]
    freqs, avg_preds, weights = dku_calibration_curve(y_true.values, probas, sample_weight=sample_weight)
    return dku_calibration_loss(freqs, avg_preds, weights)


def confusion_matrix_derived_metrics(y_true, y_pred, sample_weights):
    conf_matrix = compat_confusion_matrix(y_true, y_pred, sample_weights)
    tp = conf_matrix[1, 1]
    tn = conf_matrix[0, 0]
    fp = conf_matrix[0, 1]
    fn = conf_matrix[1, 0]
    precision = 0. if (tp + fp) == 0 else float(tp) / (tp + fp)
    recall = 0. if (tp + fn) == 0 else float(tp) / (tp + fn)
    accuracy = 0. if (tp + tn + fp + fn) == 0 else float(tp + tn) / (tp + tn + fp + fn)
    f1 = 0. if (precision + recall) == 0 else 2 * (precision * recall) / (precision + recall)
    mcc = 0. if float(tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) == 0 else (float(tp * tn) - (fp * fn)) / sqrt(
        float(tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
    hamming = 1. - accuracy

    return {"tp": tp, "tn": tn, "fp": fp, "fn": fn, "precision": precision, "recall": recall, "accuracy": accuracy,
            "f1": f1, "mcc": mcc, "hamming": hamming}


##############
# Regression #
##############


def rmse_score(y, y_pred, sample_weight=None):
    """Root Mean Square Error, more readable than MSE"""
    return sqrt(mean_squared_error(y, y_pred, sample_weight=sample_weight))


def rmsle_score(y, y_pred, sample_weight=None):
    """Root Mean Square Logarithmic Error
    https://www.kaggle.com/wiki/RootMeanSquaredLogarithmicError
    """
    if (y<0).sum() > 0 or (y_pred<0).sum() > 0:
        logger.info("Negative values, not computing RMSLE")
        return 0

    if sample_weight is None:
        rmsle = sqrt(np.power(np.log(y + 1) - np.log(y_pred + 1), 2).sum(0) / y.shape[0])
    else:
        rmsle = sqrt((sample_weight * np.power(np.log(y + 1) - np.log(y_pred + 1), 2)).sum(0) / np.sum(sample_weight))

    if np.isinf(rmsle) or np.isnan(rmsle):
        logger.warning("Unexpected RMSLE: %s - ignoring", rmsle)
        return 0

    return rmsle


def mean_absolute_percentage_error(y_true, y_pred, sample_weight=None):
    if sample_weight is None:
        df = pd.DataFrame({"y_true" : y_true, "y_pred" : y_pred})
        df = df[df["y_true"].abs() > 1e-12]
        y_true = df["y_true"]
        y_pred = df["y_pred"]
        return np.mean(np.abs((y_true - y_pred) / (y_true)))
    else:
        df = pd.DataFrame({"y_true" : y_true, "y_pred" : y_pred, "sample_weight": sample_weight})
        df = df[df["y_true"].abs() > 1e-12]
        y_true = df["y_true"]
        y_pred = df["y_pred"]
        sample_weight = df["sample_weight"]
        return np.sum(sample_weight * np.abs((y_true - y_pred) / (y_true))) / np.sum(sample_weight)


def mean_absolute_scaled_error(y_true, y_pred, naive_error=None):
    error = np.mean(np.abs(y_true - y_pred))
    return error / naive_error if naive_error > 0 else np.nan


def mean_scaled_interval_error(y_true, lower_quantile, upper_quantile, alpha, naive_error=None):
    """MSIS was used in the M4 competition, alpha=0.05 corresponds to the 95% confidence interval"""
    error = np.mean(
        upper_quantile
        - lower_quantile
        + 2.0 / alpha * (lower_quantile - y_true) * (y_true < lower_quantile)
        + 2.0 / alpha * (y_true - upper_quantile) * (y_true > upper_quantile)
    )
    return error / naive_error if naive_error > 0 else np.nan


def symetric_mean_absolute_percentage_error(y_true, y_pred):
    df = pd.DataFrame({"y_true" : y_true, "y_pred" : y_pred})
    df = df[np.abs(df["y_true"]) + np.abs(df["y_pred"]) != 0.]
    y_true = df["y_true"]
    y_pred = df["y_pred"]
    return 2 * np.mean(
        np.abs(y_true - y_pred) / (np.abs(y_true) + np.abs(y_pred))
    )


def quantile_loss(y_true, forecast_quantile, quantile):
    return 2 * np.sum(np.abs((forecast_quantile - y_true) * ((y_true <= forecast_quantile) - quantile)))


def abs_error(y_true, y_pred):
    return np.sum(np.abs(y_true - y_pred))


def abs_target_sum(y_true):
    return np.sum(np.abs(y_true))

# TODO @timeseries note that newer versions of scikit-learn may contain these metrics
