import numpy as np
import sklearn

from scipy.sparse import coo_matrix
from sklearn.metrics import confusion_matrix as sk_cm

from dataiku.base.utils import package_is_at_least


# This is extracted from scikit learn confusion matrix computations before sklearn 1.0 to avoid data sanity checks that takes
# a long time but aren't needed since the doctor already does what's needed to have sane data
# Og source : https://github.com/scikit-learn/scikit-learn/blob/0.20.4/sklearn/metrics/classification.py#L187
# Licensed under BSD-3-Clause: https://github.com/scikit-learn/scikit-learn/blob/0.20.4/COPYING
# This was fixed in 1.0 so we prefer using the fixed implementation for versions going forward
def confusion_matrix(y_true, y_pred, sample_weight):
    if package_is_at_least(sklearn, "1.0"):
        return sk_cm(y_true, y_pred, labels=[0, 1], sample_weight=sample_weight)  # In the doctor, we always use internally 0 as negative class and 1 as positive class
    else:
        if sample_weight is None:
            sample_weight = np.ones(y_true.shape[0], dtype=np.int64)
        else:
            sample_weight = np.asarray(sample_weight)
        if sample_weight.dtype.kind in {'i', 'u', 'b'}:
            dtype = np.int64
        else:
            dtype = np.float64

        # Builds a sparse matrix where coordinates are the tuples (y_true, y_pred) and the values are from sample_weight
        # The .toarray() calls collapses all the sparse points at the same coordinates, summing the weights and effectively
        # building the confusion matrix
        return coo_matrix((sample_weight, (y_true, y_pred)),
                          shape=(2, 2), dtype=dtype,
                          ).toarray()
