from dataiku.base.utils import detect_usable_cpu_count
from typing_extensions import NotRequired, TypedDict, Union
import logging

logger = logging.getLogger(__name__)


class EnvData(TypedDict):
    cloud_provider: NotRequired[str]

    gpu_count: NotRequired[int]
    gpu_type: NotRequired[str]
    gpu_memory_per_device: NotRequired[int]
    cuda: NotRequired[str]

    platform: NotRequired[str]
    cpu_count: NotRequired[Union[float, int]]
    cpu_type: NotRequired[str]
    cpu_memory: NotRequired[int]

    vllm: NotRequired[str]
    torch: NotRequired[str]
    transformers: NotRequired[str]


def collect_env():
    """
    Capture characteristics of the HF environment for analytics

    Partially extracted https://github.com/vllm-project/vllm/blob/main/vllm/usage/usage_lib.py so that it can be used
    with non-vLLM backends.
    """
    env_data: EnvData = {}

    try:
        from vllm.usage.usage_lib import _detect_cloud_provider

        env_data["cloud_provider"] = _detect_cloud_provider()
    except Exception:
        logger.exception("Failed to detect cloud provider")

    try:
        import torch

        if torch.cuda.is_available():
            env_data["cuda"] = torch.version.cuda
            env_data["gpu_count"] = torch.cuda.device_count()
            if env_data["gpu_count"] > 0:
                device_property = torch.cuda.get_device_properties(0)
                env_data["gpu_memory_per_device"] = device_property.total_memory
                env_data["gpu_type"] = device_property.name
        else:
            env_data["gpu_count"] = 0
            env_data["cuda"] = "not available"
    except Exception:
        logger.exception("Failed to detect GPU")

    try:
        import cpuinfo

        info = cpuinfo.get_cpu_info()
        env_data["cpu_type"] = info.get("brand_raw", "")
    except Exception:
        logger.exception("Failed to detect CPU type")

    try:
        env_data["cpu_count"] = detect_usable_cpu_count()
    except Exception:
        logger.exception("Failed to detect CPU count")

    try:
        env_data["cpu_memory"] = detect_usable_memory_bytes()
    except Exception:
        logger.exception("Failed to detect available memory")

    try:
        import vllm

        env_data["vllm"] = vllm.__version__
    except Exception:
        logger.exception("Failed to detect vllm version")

    try:
        import torch

        env_data["torch"] = torch.__version__
    except Exception:
        logger.exception("Failed to detect torch version")

    try:
        import transformers

        env_data["transformers"] = transformers.__version__
    except Exception:
        logger.exception("Failed to detect transformers version")

    try:
        import platform

        env_data["platform"] = platform.platform()
    except Exception:
        logger.exception("Failed to detect platform")

    return env_data


def detect_usable_memory_bytes():
    """
    Detect available memory while taking into account K8S limits via cgroups
    """

    import psutil
    import os.path as osp

    limit = psutil.virtual_memory().total
    # Only consider the hard limits
    cgroup_v1_limit = "/sys/fs/cgroup/memory/memory.limit_in_bytes"
    cgroup_v2_limit = "/sys/fs/cgroup/memory.max"
    for path in [cgroup_v1_limit, cgroup_v2_limit]:
        if osp.exists(path):
            try:
                with open(path) as f:
                    cgroups_limit_str = f.read().strip()
                    if cgroups_limit_str.isnumeric():
                        cgroups_limit = int(cgroups_limit_str)
                        if cgroups_limit > 0:
                            limit = min(limit, cgroups_limit)
            except Exception:
                logger.exception("Error occurred while reading cgroup quotas")
    return limit


def extract_weight_quantization_method(pretrained_config) -> str:
    """
    Safely extract the quantization method from a transformers's PretrainedConfig for analytics
    """
    try:
        if hasattr(pretrained_config, "quantization_config"):
            if isinstance(pretrained_config.quantization_config, dict):
                # Sometimes it's a dict...
                return f'{pretrained_config.quantization_config.get("quant_method", "unknown")}'
            elif hasattr(pretrained_config.quantization_config, "quant_method"):
                # ... and sometimes it's a data class ¯\_(ツ)_/¯
                return f'{pretrained_config.quantization_config.quant_method}'
        return "none"
    except:
        logger.exception("Failed to get model quantization config")
        return "unknown"


def extract_model_architecture(pretrained_config) -> str:
    """
    Safely extract the model architecture from a transformers's PretrainedConfig for analytics
    """
    try:
        return ",".join(pretrained_config.architectures)
    except:
        logger.exception("Failed to get model architecture from config")
        return "unknown"
