import sklearn
from sklearn.linear_model import SGDClassifier

from dataiku.base.utils import package_is_at_least
from dataiku.doctor.utils.skcompat.utils import _replace_value


def update_sgd_model_state(d):
    if package_is_at_least(sklearn, "1.2"):
        _replace_value(d, "loss", "log", "log_loss")
        _replace_value(d, "loss", "squared_loss", "squared_error")


class UnpicklableSGDClassifier(SGDClassifier, object):

    def __setstate__(self, d):
        update_sgd_model_state(d)
        super(UnpicklableSGDClassifier, self).__setstate__(d)
