import sklearn

from dataiku.base.utils import package_is_at_least
from sklearn.calibration import _CalibratedClassifier, CalibratedClassifierCV


def get_calibrators(calibrated_classifier_cv):
    """
    Handles different ways of getting the calibrators of calibrated classifier before and after scikit-learn 1.1
    """
    n_calibrators = 1 if len(calibrated_classifier_cv.classes_) == 2 else len(calibrated_classifier_cv.classes_)
    if package_is_at_least(sklearn, "1.1"):
        return [calibrated_classifier_cv.calibrated_classifiers_[0].calibrators[c] for c in range(n_calibrators)]
    else:
        return [calibrated_classifier_cv.calibrated_classifiers_[0].calibrators_[c] for c in range(n_calibrators)]


def get_base_estimator(calibrated_classifier):
    """
    Handles different ways of getting the base estimator of calibrated classifier in scikit-learn,
    The parameter has known a deprecation cycle between 1.0 and 1.2 so the function isn't version specific but try
    to handle all cases.
    """
    if hasattr(calibrated_classifier, 'base_estimator') and calibrated_classifier.base_estimator != 'deprecated':
        return calibrated_classifier.base_estimator
    elif package_is_at_least(sklearn, "1.6"):
        from sklearn.frozen import FrozenEstimator
        if isinstance(calibrated_classifier.estimator, FrozenEstimator):
            return calibrated_classifier.estimator.estimator
    return calibrated_classifier.estimator


def dku_calibrated_classifier_cv(clf, **kwargs):
    if package_is_at_least(sklearn, "1.6"):
        cv = kwargs.get("cv")
        # deprecated in 1.6, removed in 1.8
        if cv is not None and cv == "prefit":
            from sklearn.frozen import FrozenEstimator
            return CalibratedClassifierCV(FrozenEstimator(clf), **kwargs)
    return CalibratedClassifierCV(clf, **kwargs)


class UnpicklableCalibratedClassifier(_CalibratedClassifier, object):

    def __setstate__(self, state):
        if package_is_at_least(sklearn, "0.24"):
            if 'calibrators_' in state:
                state['calibrators'] = state.pop('calibrators_')
            if 'classes_' in state:
                state['classes'] = state.pop('classes_')
        if package_is_at_least(sklearn, "1.1"):
            if 'base_estimator' in state:
                state['estimator'] = state['base_estimator']
        # pickle.py standard logic for setting an object state. _CalibratedClassifier is one of the very rare
        # sklearn classes that doesn't have a __setstate__ method itself, so we can't just super() it
        inst_dict = self.__dict__
        from six.moves import intern  # Python2/3 compat import
        for k, v in state.items():
            if type(k) is str:
                inst_dict[intern(k)] = v
            else:
                inst_dict[k] = v
