import json
import logging
import os
import sys
import textwrap
import torch
from abc import ABC
from typing import Optional, Union

from transformers import PretrainedConfig

from dataiku.huggingface.pipeline import QuantizationMode
from dataiku.huggingface.torch_utils import is_bfloat16_supported_with_cuda
from dataiku.huggingface.types import ModelSettings
from dataiku.huggingface.torch_utils import is_single_24gb_gpu

logger = logging.getLogger(__name__)

IS_MAC = sys.platform == "darwin"


class ModelParamsLoader:
    """
    ENABLE CHUNKED PREFILL
        See https://docs.vllm.ai/en/latest/models/performance.html#chunked-prefill

    ENABLE PREFIX CACHING
        See https://docs.vllm.ai/en/latest/automatic_prefix_caching/apc.html

    GPU MEMORY UTILIZATION
        vLLM tries to pre-allocate as much memory as available for the KV cache at startup.

        It uses this formula: `KV cache size = total_gpu_memory * gpu_memory_utilization - peak_memory`

        'peak_memory' is an estimation of the peak memory usage of the model, determined at runtime by profiling.
        'gpu_memory_utilization' is this configuration parameter. It defaults to 0.9.
        'total_gpu_memory' is the total memory of the GPU

        In practice, the profiling method does not seem accurate. In some cases, vLLM does not allow itself
        to use enough memory for KV cache, and this may cause a crash at startup
        if 'KV cache len < model context len' (e.g. it breaks Mistral 7B on a 24 GB GPU).

        Here are some pointers:
        - https://github.com/vllm-project/vllm/pull/2031
        - https://github.com/vllm-project/vllm/issues/2248

        This parameter can be useful for debugging.

    MAX MODEL LEN
        Limiting the context length at load time can help reduce memory usage because vLLM pre-allocates memory
        for the KV cache based on the maximum context length.

        Typical use case: Mistral 7B (32k context) will not fit on a 24GB GPU when using vLLM because the KV cache
        for 32k tokens is too large to fit in memory. Reducing the context length to ~13k tokens allows it to fit.

    KV CACHE DTYPE
        Quantization of the KV cache significantly reduces memory usage, especially for long sequences.

        Typical use case: as said above, Mistral 7B (32k context) will not fit on a 24GB GPU when using vLLM OOTB
        because the KV cache for 32k tokens is too large to fit in memory. However, with "fp8_e5m2" quantization,
        it fits and the model can be used with full 32k context length.

        Set this to "fp8_e5m2" (8 bits float) to enable quantization.

    ENFORCE EAGER
        CUDA graph is a performance optimization which is enabled by default in vLLM.

        This optimization has been observed to sometimes cause problems based on issues reported on vLLM repo
        (higher memory usage, memory leaks, slower startup time, incompatible with some other vLLM features, ...).

        Disabling this optimization (enforce_eager=True) can help mitigate these issues.

    MAX NUM SEQS
        For models with a big context size and additional block tables for cross attention layers,
        the default setting max_num_seqs = 128 does not work.
        Lowering the value (e.g. 16) may allow the model to run on some GPUs.
    """

    def __init__(self, model_config: PretrainedConfig, model_settings: ModelSettings, with_image_input: bool, lora_path: Optional[str]):
        self.is_24gb_gpu = (not IS_MAC) and is_single_24gb_gpu()
        self.with_image_input = with_image_input
        self.model_config = model_config
        self.lora_path = lora_path

        self.with_quantization = (QuantizationMode[model_settings["quantizationMode"]] == QuantizationMode.Q_4BIT)
        self.enable_chunked_prefill = model_settings.get('enableChunkedPrefill')
        self.enable_prefix_caching = model_settings.get('enablePrefixCaching')
        self.gpu_memory_utilization = model_settings.get('gpuMemoryUtilization')
        self.max_model_len = model_settings.get('maxModelLen')
        self.kv_cache_dtype = model_settings.get('kvCacheDType')
        self.enforce_eager = model_settings.get('enforceEager')
        self.limit_images_per_prompt = model_settings.get('limitImagesPerPrompt') if self.with_image_input else None
        self.max_num_seqs = model_settings.get('maxNumSeqs')
        self.trust_remote_code = model_settings.get("trustRemoteCode")
        self.guided_decoding_backend = model_settings.get('guidedDecodingBackend', 'auto')
        self.torch_dtype = model_settings.get("dtype")
        self.tensor_parallel_size = model_settings.get("tensorParallelSize")
        self.pipeline_parallel_size = model_settings.get("pipelineParallelSize")
        self.enable_expert_parallelism = model_settings.get("enableExpertParallelism")
        self.config_format = model_settings.get("configFormat")
        self.tokenizer_mode = model_settings.get("tokenizerMode")
        self.load_format = model_settings.get("loadFormat")
        self.ignore_patterns = model_settings.get("ignorePatterns")

        self._apply_overrides()

    def build_params(self):
        params = {}

        params["enable_lora"] = False
        if self.lora_path is not None:
            params["enable_lora"] = True
            max_lora_rank = self._max_lora_rank()
            if max_lora_rank is not None:
                params["max_lora_rank"] = max_lora_rank

        if self.with_quantization:
            params["quantization"] = "bitsandbytes"
        if self.enable_chunked_prefill is not None:
            params["enable_chunked_prefill"] = self.enable_chunked_prefill
        if self.enable_prefix_caching is not None:
            params["enable_prefix_caching"] = self.enable_prefix_caching
        if self.gpu_memory_utilization is not None:
            params["gpu_memory_utilization"] = self.gpu_memory_utilization
        if self.max_model_len is not None:
            params["max_model_len"] = self.max_model_len
        if self.kv_cache_dtype and self.kv_cache_dtype.strip():
            params["kv_cache_dtype"] = self.kv_cache_dtype
        if self.enforce_eager is not None:
            params["enforce_eager"] = self.enforce_eager
        if self.with_image_input and self.limit_images_per_prompt is not None:
            params["limit_mm_per_prompt"] = {"image": self.limit_images_per_prompt}
        if self.max_num_seqs is not None:
            params["max_num_seqs"] = self.max_num_seqs
        if self.trust_remote_code is not None:
            params["trust_remote_code"] = self.trust_remote_code
        if self.guided_decoding_backend and self.guided_decoding_backend.strip():
            params["guided_decoding_backend"] = self.guided_decoding_backend
        if self.torch_dtype and (isinstance(self.torch_dtype, torch.dtype) or self.torch_dtype.strip()):
            params["dtype"] = self.torch_dtype
        if self.tensor_parallel_size is not None:
            logger.info(
                "Tensor parallelism: {tensor_parallel_size}".format(
                    tensor_parallel_size=self.tensor_parallel_size
                )
            )
            params["tensor_parallel_size"] = self.tensor_parallel_size

            # - When the process tree is killed by the Java backend, Ray leaves Dashboard and RuntimeAgent processes alive for one minute before they die on their own
            #   Multiprocessing behaves better: all processes are killed on the spot
            # - Ray processes can be killed when engine V0 is started in background thread, if the thread pool executor is shutdown at the end, so multiprocessing seems more robust
            # More details in https://github.com/dataiku/dip/pull/38752
            if self.tensor_parallel_size > 1:
                params["distributed_executor_backend"] = "mp"
        if self.pipeline_parallel_size is not None:
            logger.info(
                "Pipeline parallel: {pipeline_parallel_size}".format(
                    pipeline_parallel_size=self.pipeline_parallel_size
                )
            )
            params["pipeline_parallel_size"] = self.pipeline_parallel_size

            # Bug with vllm: engine V1 + PP requires distributed_executor_backend to be explicitly set (either to "ray" or "mp")
            # - When the process tree is killed by the Java backend, Ray leaves Dashboard and RuntimeAgent processes alive for one minute before they die on their own
            #   Multiprocessing behaves better: all processes are killed on the spot
            # - Ray processes can be killed when engine V0 is started in background thread, if the thread pool executor is shutdown at the end, so multiprocessing seems more robust
            # More details in https://github.com/dataiku/dip/pull/38752
            if self.pipeline_parallel_size > 1:
                params["distributed_executor_backend"] = "mp"
        if self.enable_expert_parallelism is not None:
            params["enable_expert_parallel"] = self.enable_expert_parallelism
        if self.config_format is not None:
            params["config_format"] = self.config_format
        if self.load_format is not None:
            params["load_format"] = self.load_format
        if self.tokenizer_mode is not None:
            params["tokenizer_mode"] = self.tokenizer_mode
        if self.ignore_patterns is not None:
            params["ignore_patterns"] = self.ignore_patterns

        return params

    def _apply_overrides(self):
        if self.max_model_len is None:
            self.max_model_len = self._max_model_len()
        if self.with_image_input and self.limit_images_per_prompt is None:
            self.limit_images_per_prompt = self._limit_images_per_prompt()
        if self.max_num_seqs is None:
            # For Llama 3.2 11B/90B Vision models, this override MUST run after max_model_len and limit_images_per_prompt
            self.max_num_seqs = self._max_num_seqs()
        if self.enforce_eager is None:
            self.enforce_eager = self._enforce_eager()
        if self.trust_remote_code is None:
            self.trust_remote_code = self._trust_remote_code()
        if not (self.torch_dtype and (isinstance(self.torch_dtype, torch.dtype) or self.torch_dtype.strip())):
            self.torch_dtype = self._dtype()
        if self.tensor_parallel_size is None:
            self.tensor_parallel_size = self._find_best_tensor_parallel_size()
        if self.pipeline_parallel_size is None:
            # We set PP after TP because pipeline parallelism should be used after we've already maxed out efficient tensor parallelism but need to distribute the model further
            # cf. https://docs.vllm.ai/en/v0.9.0.1/configuration/optimization.html#pipeline-parallelism-pp
            self.pipeline_parallel_size = self._find_best_pipeline_parallel_size()
        if self.tokenizer_mode is None:
            self.tokenizer_mode = self._tokenizer_mode()
        if self.config_format is None:
            self.config_format = self._config_format()
        if self.load_format is None:
            self.load_format = self._load_format()
        if self.ignore_patterns is None:
            self.ignore_patterns = self._ignore_patterns()
        self._log_parallelism_strategies()

    def _dtype(self) -> Optional[Union[torch.dtype, str]]:
        if (
                torch.cuda.is_available()
                and hasattr(self.model_config, "torch_dtype")
                and self.model_config.torch_dtype == torch.bfloat16
                and not is_bfloat16_supported_with_cuda()
        ):
            # Fix compatibility with some old GPUs that don't support bfloat16 (e.g. V100, T4, ...)
            # Example: "mistralai/Mistral-7B-Instruct-v0.2" on T4 or V100 (CC < 8) does not work OOTB in vLLM
            logger.warning(
                "Model is configured to use bfloat16 but the GPU device does not support it. "
                "Using float16 instead. This may degrade the quality of the model."
            )
            return torch.float16

    def _limit_images_per_prompt(self) -> Optional[int]:
        return 2

    def _ignore_patterns(self) -> list[str]:
        return [
            "original/**/*",  # avoid repeated downloading of llama's checkpoint
            "consolidated*.safetensors"  # filter out Mistral-format weights
        ]

    def _max_model_len(self) -> Optional[int]:
        return None

    def _trust_remote_code(self) -> Optional[bool]:
        return None

    def _config_format(self) -> Optional[str]:
        return None

    def _load_format(self) -> Optional[str]:
        return None

    def _tokenizer_mode(self) -> Optional[str]:
        return None

    def _max_num_seqs(self) -> Optional[int]:
        return None

    def _enforce_eager(self) -> Optional[bool]:
        if self.with_quantization:
            logger.warning(textwrap.dedent(
                """
                Inflight Bitsandbytes quantization with tensor parallelism is currently unstable with CUDA graph in VLLM.
                DSS automatically setting enforce_eager=True.
                """
            ))
            return True

    def _find_best_tensor_parallel_size(self):
        """
        VLLM can leverage multiple GPUs by setting tensor parallelism > 1. Ideally, we would like to use as many GPUs as possible.
        However, VLLM does not allow to use an arbitrary number of GPUs and there are constraints on the "tensor parallel size".

        For instance, it requires "tensor parallel size" to be a divisor of the number of attention heads and in reality the actual constraints
        are implemented in a model-dependent manner.
        Example for llama2: https://github.com/vllm-project/vllm/blob/5265631d15d59735152c8b72b38d960110987f10/vllm/model_executor/models/llama.py#L105

        The number of GPUs used is equal to the product (tensor parallelism * pipeline parallellism). This product must not exceed the number of available GPUs

        This method tries to use as many GPUs as possible while respecting these constraints and should be good enough in most cases
        If it gets too complicated to maintain, we could:
        - Consider "just trying" all possible tensor parallel sizes and picking the greatest one that works
        - Delegate the choice of tensor parallel size to the user via a config parameter

        These limitations may evolve in the future: https://github.com/vllm-project/vllm/pull/5367
        """
        if IS_MAC:
            return 1
        tensor_parallel_size = 1
        n_gpus = torch.cuda.device_count()
        pipeline_parallel = 1
        if self.pipeline_parallel_size is not None and isinstance(self.pipeline_parallel_size, int) and self.pipeline_parallel_size > 0:
            pipeline_parallel = self.pipeline_parallel_size

        # The nb of attention heads is not always in the same field depending on the model architecture (this is hacky)
        if hasattr(self.model_config, "num_attention_heads"):
            nb_attention_heads = self.model_config.num_attention_heads
        elif hasattr(self.model_config, "n_head"):
            nb_attention_heads = self.model_config.n_head
        else:
            # If we don't know we'll use all gpus
            logger.warning(
                "Could not determine the number of attention heads"
                "Setting tensor parallelism to the highest possible value."
            )
            return max(n_gpus // pipeline_parallel, 1)
        for candidate_tp in reversed(range(1, n_gpus + 1)):
            if (
                    nb_attention_heads % candidate_tp == 0
                    or candidate_tp % nb_attention_heads == 0
            ) and (candidate_tp * pipeline_parallel <= n_gpus):
                tensor_parallel_size = candidate_tp
                break

        return tensor_parallel_size

    def _find_best_pipeline_parallel_size(self):
        """
        VLLM can leverage multiple GPUs by setting pipeline parallelism > 1.
        It is advised to use this parallelism after leveraging tensor parallelism
        cf. https://docs.vllm.ai/en/v0.9.0.1/configuration/optimization.html#pipeline-parallelism-pp

        This PP value doesn't have a contraint similar to TP.
        We still need to not exceed the number of GPU available so the following formula must be true : TP * PP <= n_GPU
        """
        n_gpus = torch.cuda.device_count()
        tensor_parallel_size = 1
        if self.tensor_parallel_size is not None and isinstance(self.tensor_parallel_size, int) and self.tensor_parallel_size > 0:
            tensor_parallel_size = self.tensor_parallel_size
        return max(n_gpus // tensor_parallel_size, 1)

    def _log_parallelism_strategies(self):
        n_gpus = torch.cuda.device_count()
        assert isinstance(self.tensor_parallel_size, int)
        assert isinstance(self.pipeline_parallel_size, int)
        used_gpus = self.tensor_parallel_size * self.pipeline_parallel_size
        logger.info(f"""Parallelism strategies with {n_gpus} GPUs available:
    - Tensor Parallelism: {self.tensor_parallel_size}
    - Pipeline Parallelism: {self.pipeline_parallel_size}
The total number of GPU used should be {self.tensor_parallel_size} * {self.pipeline_parallel_size} = {used_gpus}""")

        if self.tensor_parallel_size * self.pipeline_parallel_size < n_gpus:
            logger.warning(f"Using less GPUs than available: {used_gpus} GPUs used out of {n_gpus}")
        if self.tensor_parallel_size * self.pipeline_parallel_size > n_gpus:
            logger.warning(f"Using more GPUs than available: {used_gpus} GPUs requested out of {n_gpus} available")

    def _max_lora_rank(self) -> Optional[int]:
        if self.lora_path is not None:
            try:
                adapter_config_path = os.path.join(self.lora_path, "adapter_config.json")
                if os.path.isfile(adapter_config_path):
                    with open(adapter_config_path, "r") as f:
                        adapter_config = json.load(f)
                        if "r" in adapter_config:
                            adapter_lora_rank = int(adapter_config["r"])
                            # vLLM does not allow any value to be used as 'max_lora_rank'
                            # See https://github.com/vllm-project/vllm/blob/v0.8.4/vllm/config.py#L2555
                            possible_max_ranks = (8, 16, 32, 64, 128, 256, 320, 512)
                            max_lora_rank = next((r for r in possible_max_ranks if r >= adapter_lora_rank), 512)
                            logger.info(
                                f"Detected LoRA rank: {adapter_lora_rank}, setting max_lora_rank to {max_lora_rank}")
                            return max_lora_rank
            except Exception:
                logger.exception("Error while reading adapter config, couldn't detect LoRA rank")


class Mistral7BModelParamsLoader(ModelParamsLoader):
    model_display_name = "Mistral 7B"

    def _max_model_len(self) -> Optional[int]:
        if self.is_24gb_gpu and not self.with_quantization:
            reduced_context_len = 10000
            logger.warning(
                textwrap.dedent(
                    f"""
                    {self.model_display_name} (or alike) cannot run on a single 24GB GPU, due to memory requirements.
                    DSS automatically reducing context length to {reduced_context_len} to prevent out-of-memory errors.
                    """
                )
            )
            return reduced_context_len


class Llama318BModelParamsLoader(ModelParamsLoader):
    model_display_name = "Llama 3.1 8B"

    def _max_model_len(self) -> Optional[int]:
        if self.is_24gb_gpu:
            if self.with_quantization:
                reduced_context_len = 64000
            else:
                reduced_context_len = 15000

            logger.warning(
                textwrap.dedent(
                    f"""
                    {self.model_display_name} (or alike) cannot run on a single 24GB GPU, due to memory requirements.
                    DSS automatically reducing context length to {reduced_context_len} to prevent out-of-memory errors.
                    """
                )
            )
            return reduced_context_len


class Llama323BModelParamsLoader(ModelParamsLoader):
    model_display_name = "Llama 3.2 3B"

    def _max_model_len(self) -> Optional[int]:
        if self.is_24gb_gpu and not self.with_quantization:
            reduced_context_len = 15000
            logger.warning(
                textwrap.dedent(
                    f"""
                    {self.model_display_name} (or alike) cannot run on a single 24GB GPU, due to memory requirements.
                    DSS automatically reducing context length to {reduced_context_len} to prevent out-of-memory errors.
                    """
                )
            )
            return reduced_context_len


class Phi35VisionModelParamsLoader(ModelParamsLoader):
    model_display_name = "Phi 3.5 Vision"

    def _max_model_len(self) -> Optional[int]:
        if self.is_24gb_gpu:
            reduced_context_len = 10000
            logger.warning(
                textwrap.dedent(
                    f"""
                    {self.model_display_name} (or alike) cannot run on a single 24GB GPU, due to memory requirements.
                    DSS automatically reducing context length to {reduced_context_len} to prevent out-of-memory errors.
                    """
                )
            )
            return reduced_context_len

    def _enforce_eager(self) -> Optional[bool]:
        enforce_eager = True
        logger.warning(textwrap.dedent(
            f"""
            {self.model_display_name} (or alike) requires extra settings in order to run properly.
            DSS automatically setting enforce_eager={enforce_eager}.
            """
        ))
        return True

    def _trust_remote_code(self) -> Optional[bool]:
        # Phi 3.5 Vision Instruct requires trust_remote_code to work
        return True


class Llava16Mistral7BVisionModelParamsLoader(ModelParamsLoader):
    model_display_name = "Llava 1.6 Mistral 7B Vision"

    def _max_model_len(self) -> Optional[int]:
        if self.is_24gb_gpu and not self.with_quantization:
            reduced_context_len = 10000
            logger.warning(
                textwrap.dedent(
                    f"""
                    {self.model_display_name} (or alike) cannot run on a single 24GB GPU, due to memory requirements.
                    DSS automatically reducing context length to {reduced_context_len} to prevent out-of-memory errors.
                    """
                )
            )
            return reduced_context_len


class MistralNemo12BModelParamsLoader(ModelParamsLoader):
    model_display_name = "Mistral Nemo 12B"

    def _max_model_len(self) -> Optional[int]:
        # Reduction of context length for Mistral Nemo 12B 1M -> 64k. This is a very conservative
        # value than should work by default with inexpensive setup like:
        # - 2x24GB GPUs
        # - 1x24GB GPU with 4bit quantization.
        if self.is_24gb_gpu or not self.with_quantization:
            reduced_context_len = 32000
            logger.warning(
                textwrap.dedent(
                    f"""
                    {self.model_display_name} (or alike) has significant memory requirements due to its context length.
                    DSS automatically reducing context length to {reduced_context_len} to prevent out-of-memory errors.
                    """
                )
            )
            return reduced_context_len


class MistralFormatParamsLoader(ModelParamsLoader):
    def _load_format(self) -> Optional[str]:
        return "mistral"

    def _config_format(self) -> Optional[str]:
        return "mistral"

    def _tokenizer_mode(self) -> Optional[str]:
        return "mistral"

    def _ignore_patterns(self) -> list[str]:
        return [
            "model*.safetensors"  # This corresponds to hugging face weight that won't be used since the model is loaded in mistral format
        ]


class _GenericLlama32VisionModelParamsLoader(ABC, ModelParamsLoader):
    model_display_name = "Llama 3.2 Vision"

    def _max_model_len(self) -> Optional[int]:
        # Unconditional reduction of context length for Llama 3.2 11B/90B Vision Instruct 128k -> 32k.
        # This is a very conservative value than should work by default with inexpensive setup like:
        # - 2x24GB GPUs for the 11B model
        # - 1x24GB GPU for the 11B model with 4bit quantization
        # - 4x24GB GPU for the 90B model with 4bit quantization
        #
        # Careful not to set a value too low, otherwise it may not be possible to set max_num_seq >= 1
        # (also depending on limit_images_per_prompt). See comments below about max_num_seq.
        reduced_context_len = 32000
        logger.warning(
            textwrap.dedent(
                f"""
                {self.model_display_name} (or alike) has significant memory requirements due to its context length.
                DSS automatically reducing context length to {reduced_context_len} to prevent out-of-memory errors.
                """
            )
        )
        return reduced_context_len

    def _enforce_eager(self) -> Optional[bool]:
        enforce_eager = True
        logger.warning(textwrap.dedent(
            f"""
            {self.model_display_name} (or alike) requires extra settings in order to run properly.
            DSS automatically setting enforce_eager={enforce_eager}.
            """
        ))
        return True

    def _max_num_seqs(self) -> Optional[int]:
        # For Llama 3.2 11B/90B Vision Instruct, it is required that limit_images_per_prompt * max_num_seq < max_model_len * 6404
        # for the profiler to be able to run (and for the model to run without issues afterwards)
        # see https://github.com/vllm-project/vllm/blob/v0.8.4/vllm/multimodal/profiling.py#L217
        #
        # The 6404 number corresponds to ((model_config.vision_config.image_size/14)^2 +1) * model_config.vision_config.max_num_tiles,
        # with model_config.vision_config.image_size = 560 and model_config.vision_config.max_num_tiles = 4
        # see https://github.com/vllm-project/vllm/blob/v0.8.4/vllm/model_executor/models/mllama.py#L315

        actual_max_model_len = 131072  # default context length for these models is 131072
        if self.max_model_len is not None:
            actual_max_model_len = self.max_model_len

        actual_limit_images_per_prompt = 1
        if self.with_image_input and self.limit_images_per_prompt is not None:
            actual_limit_images_per_prompt = max(1, self.limit_images_per_prompt)

        reduced_max_num_seqs = max(1, actual_max_model_len // (actual_limit_images_per_prompt * 6404))

        logger.warning(textwrap.dedent(
            f"""
            {self.model_display_name} (or alike) requires extra settings in order to run properly.
            DSS automatically setting max_num_seqs={reduced_max_num_seqs}.
            """
        ))
        return reduced_max_num_seqs


class Llama3211BVisionModelParamsLoader(_GenericLlama32VisionModelParamsLoader):
    model_display_name = "Llama 3.2 11B Vision"


class Llama3290BVisionModelParamsLoader(_GenericLlama32VisionModelParamsLoader):
    model_display_name = "Llama 3.2 90B Vision"

    def _limit_images_per_prompt(self) -> Optional[int]:
        limit_images_per_prompt = 1
        logger.warning(textwrap.dedent(
            f"""
            {self.model_display_name} (or alike) requires extra settings in order to run properly.
            DSS automatically setting limit_mm_per_prompt={{ "image": {limit_images_per_prompt} }}.
            """
        ))
        return limit_images_per_prompt
