from dataiku.base.utils import package_is_at_least
from dataiku.doctor.utils.skcompat.utils import _swap_variables, _replace_value
import sklearn


def gbt_skcompat_hp_space(input_hp_space):
    """
    The HP values explored has changed names between version of sklearn, but we don't want to have our
    frontend depend on the selected code env scikit version.
    So if the frontend send (for example) "deviance" but we're in a sklearn 1.3 env we'll silently remap
    that to "log_loss" and proceed.
    Since we prefer not touching anything, the versions cuts are fuzzy to accommodate the versions where both
    values are working ok.
    See also PyGradientBoostingMeta.java on how we handle the regridification for those cases.
    """
    if package_is_at_least(sklearn, "1.3"):
        _swap_variables(input_hp_space["loss"]["values"], "deviance", "log_loss")
    if package_is_at_least(sklearn, "1.2"):
        _swap_variables(input_hp_space["loss"]["values"], "lad", "absolute_error")
        _swap_variables(input_hp_space["loss"]["values"], "ls", "squared_error")
    return input_hp_space


def gbt_skcompat_actual_params(gbt_params):
    """
    Always resolve actual params to the values we expose in DSS, making the scikit compatibility layer invisible
    """
    _replace_value(gbt_params, "loss", "log_loss", "deviance")
    _replace_value(gbt_params, "loss", "absolute_error", "lad")
    _replace_value(gbt_params, "loss", "squared_error", "ls")


if package_is_at_least(sklearn, "1.2"):
    SQUARED_LOSS_NAME = "squared_error"
else:
    SQUARED_LOSS_NAME = "squared_loss"


def sgd_skcompat_hp_space(input_hp_space):
    if package_is_at_least(sklearn, "1.2"):
        _swap_variables(input_hp_space["loss"]["values"], "log", "log_loss")
        _swap_variables(input_hp_space["loss"]["values"], "squared_loss", "squared_error")
    return input_hp_space


def sgd_skcompat_actual_params(sgd_params):
    _replace_value(sgd_params, "loss", "log_loss", "log")
    _replace_value(sgd_params, "loss", "squared_error", "squared_loss")
