import numpy as np
import sklearn

from sklearn.metrics import roc_curve as sk_roc_curve
from dataiku.base.utils import package_is_at_least


def roc_curve(y_true, y_score, pos_label=None, sample_weight=None, drop_intermediate=True):
    """
    Ensures that with sklearn >= 1.3.0, the thresholds output behaves as before (thresholds[0] = max(y_score) + 1).

    The contract changed as follows:
    * sklearn <  1.3.0 : thresholds[0] = max(y_score) + 1
    * sklearn >= 1.3.0 : thresholds[0] = np.inf
    """
    fpr, tpr, thresholds = sk_roc_curve(y_true, y_score, pos_label=pos_label, sample_weight=sample_weight, drop_intermediate=drop_intermediate)

    if package_is_at_least(sklearn, "1.3"):
        if thresholds.size > 0 and thresholds[0] == np.inf:
            thresholds[0] = max(y_score) + 1

    return fpr, tpr, thresholds
