from dataiku.base import remoterun
from dataiku.doctor.utils.gpu_execution import KerasGPUCapability


def load_gpu_options(gpu_list, per_process_gpu_memory_fraction=1, allow_growth=False):
    from dataiku.doctor.deep_learning.tfcompat import ConfigProto, Session, set_session
    gpu_options = {"gpu_list": gpu_list, "n_gpu": len(gpu_list)}

    KerasGPUCapability.init_cuda_visible_devices(gpu_list)

    config_tf = ConfigProto()
    config_tf.gpu_options.per_process_gpu_memory_fraction = per_process_gpu_memory_fraction
    config_tf.gpu_options.allow_growth = allow_growth
    config_tf.allow_soft_placement = True
    session = Session(config=config_tf)
    set_session(session)

    return gpu_options


def load_gpu_options_only_allow_growth():
    from dataiku.doctor.deep_learning.tfcompat import set_session, Session, ConfigProto
    config_tf = ConfigProto()
    config_tf.gpu_options.allow_growth = True
    session = Session(config=config_tf)
    set_session(session)


def deactivate_gpu():
    KerasGPUCapability.disable_all_cuda_devices()


def get_num_gpu_used():
    if not remoterun.get_env_var("CUDA_VISIBLE_DEVICES", False) or remoterun.get_env_var("CUDA_VISIBLE_DEVICES") == "-1":
        return 0

    else:
        return len(remoterun.get_env_var("CUDA_VISIBLE_DEVICES").split(','))
