import torch
import logging

logger = logging.getLogger(__name__)


def is_bfloat16_supported_with_cuda():
    return torch.cuda.is_bf16_supported() and (torch.cuda.get_device_capability()[0] >= 8)


def best_supported_dtype():
    if torch.cuda.is_available():
        if is_bfloat16_supported_with_cuda():
            return torch.bfloat16
        else:
            logger.warning("bfloat16 not supported, falling back to float16.")
            return torch.float16
    else:
        return torch.bfloat16


def is_single_24gb_gpu():
    logger.info("Checking if using a single 24GB GPU")

    # Check if GPU0 has no more than 24GB of memory
    # (experimentally, gpu_memory=23633985536 on my NVIDIA L4)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory
    logger.info(f"GPU0 has {gpu_memory} bytes of memory")
    if gpu_memory > 24 * 1000**3:
        logger.info("GPU0 has more than 24GB of memory")
        return False

    # Check if there is only one GPU
    gpu_count = torch.cuda.device_count()
    logger.info(f"Detected {gpu_count} GPUs")
    if gpu_count > 1:
        logger.info("Multiple GPUs detected")
        return False
    return True
