import os
import logging
import sys
import asyncio
import torch
import transformers

from typing import AsyncIterator

from dataiku.base.async_link import FatalException
from dataiku.base.utils import package_is_exactly
from dataiku.huggingface.chat_template import ChatTemplateRenderer
from dataiku.huggingface.pipeline import ModelPipeline, QuantizationMode
from dataiku.huggingface.types import ProcessSinglePromptCommand
from dataiku.huggingface.types import ProcessSinglePromptResponseText
from dataiku.huggingface.types import ToolSettings
from dataiku.huggingface.types import ChatTemplateSettings
from dataiku.huggingface.types import ModelSettings
from dataiku.huggingface.env_collector import extract_model_architecture, extract_weight_quantization_method
from dataiku.huggingface.vllm_backend.model_params import get_model_params_loader

logger = logging.getLogger(__name__)

IS_MAC = sys.platform == "darwin"


class ModelPipelineTextGenerationVLLM(ModelPipeline[ProcessSinglePromptCommand, ProcessSinglePromptResponseText]):
    def __init__(
            self,
            hf_handling_mode,
            model_name_or_path,
            base_model_name_or_path,  # will be None if not an adapter model
            model_settings: ModelSettings,
            with_image_input
    ):
        super().__init__()
        self.log_stat_task = None
        self.model_settings = model_settings
        self.hf_handling_mode = hf_handling_mode
        self.model_name_or_path = model_name_or_path

        try:
            import vllm._moe_C
        except Exception:
            logger.exception("Some VLLM features like MoE models may not work on this OS, make sure to use Almalinux 9.")

        if base_model_name_or_path is not None:
            # LoRa adapter model
            model_to_load = base_model_name_or_path
            self.lora_path = self._resolve_lora_path(model_name_or_path)
            logger.info(f"Loading LoRA adapter from local path: {self.lora_path}")
        else:
            # Not an adapter model
            model_to_load = model_name_or_path
            self.lora_path = None

        logger.info("Loading model config")
        transformers_model_config = transformers.PretrainedConfig.from_pretrained(model_to_load)
        logger.info("Model config loaded")

        self.model_tracking_data["model_architecture"] = extract_model_architecture(transformers_model_config)
        self.model_tracking_data["used_engine"] = "vllm"
        self.model_tracking_data["task"] = "text-generation"
        self.model_tracking_data["adapter"] = "lora" if base_model_name_or_path is not None else "none"
        self.model_tracking_data["weights_quantization"] = extract_weight_quantization_method(transformers_model_config)

        model_params_loader = get_model_params_loader(transformers_model_config, model_settings, with_image_input, self.lora_path)

        load_params = {
            "disable_log_stats": False,  # Enable logging of performance metrics
            "model": model_to_load,
            "ignore_patterns": [
                "original/**/*",  # avoid repeated downloading of llama's checkpoint
                "consolidated*.safetensors"  # filter out Mistral-format weights
            ],
            **model_params_loader.build_params()
        }
        logger.info(f"Loading model with args {load_params}")

        # remove this hack after next vllm bump (not necessary anymore after https://github.com/vllm-project/vllm/pull/23298)
        from vllm.engine.arg_utils import _warn_or_fallback as _original_warn_or_fallback

        def _patched_warn_or_fallback(feature_name: str):
            if feature_name == "Engine in background thread":
                logger.info("Overriding vllm _is_v1_supported_oracle to allow running engine V1 in background thread")
                return False
            return _original_warn_or_fallback(feature_name)
        import vllm.engine.arg_utils
        vllm.engine.arg_utils._warn_or_fallback = _patched_warn_or_fallback

        from vllm import AsyncEngineArgs, AsyncLLMEngine
        self.engine_client = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**load_params))
        logger.info("Model loaded")

    async def initialize_model(self):
        await self.engine_client.reset_mm_cache()  # inspired from https://github.com/vllm-project/vllm/blob/5fbbfe9a4c13094ad72ed3d6b4ef208a7ddc0fd7/vllm/entrypoints/openai/api_server.py#L192
        vllm_model_config = await self.engine_client.get_model_config()

        # -------------------------------
        # CHAT TEMPLATE
        # having tool template here does not mean the model necessarily supports tool call (mistral 7b v0.1 does not for example)
        chat_template_settings: ChatTemplateSettings = self.model_settings.get("chatTemplateSettings", {})
        chat_template_override: str = chat_template_settings.get('chatTemplate') or "" if chat_template_settings.get('overrideChatTemplate', False) else ""

        logger.info("Loading tokenizer")
        tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_name_or_path, trust_remote_code=vllm_model_config.trust_remote_code or False)
        logger.info("Tokenizer loaded")

        self.chat_template_renderer = ChatTemplateRenderer(
            tokenizer=tokenizer,
            hf_handling_mode=self.hf_handling_mode,
            supports_message_parts=True,
            chat_template_override=chat_template_override,
            vllm_model_config=vllm_model_config
        )
        # -------------------------------

        # -------------------------------
        # TOOLS
        tool_settings: ToolSettings = self.model_settings.get("toolSettings", {})
        enable_tools: bool = tool_settings.get('enableTools', False)
        tool_parser: str = tool_settings.get('toolParser') or ""

        if enable_tools:
            if not tool_parser:
                logger.warning("Tools disabled: please specify a tool parser")
                self.tools_supported = False
            elif self.chat_template_renderer.supports_tool:
                logger.info("Model supports tools")
                logger.info(f"Using tool parser {tool_parser}")
                self.tools_supported = True
            else:
                logger.warning("Model does not support tools")
                self.tools_supported = False
        else:
            logger.info("Tools disabled for this model")
            self.tools_supported = False
        # -------------------------------

        # -------------------------------
        # GUIDED GENERATION
        self.enable_json_constraints_in_prompt = self.model_settings.get('enableJsonConstraintsInPrompt') or True
        logger.info(f"Json constraints injected in prompt: {self.enable_json_constraints_in_prompt}")
        # -------------------------------

        # -------------------------------
        # OPENAI SERVER
        from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
        from vllm.entrypoints.openai.serving_models import OpenAIServingModels, LoRAModulePath, BaseModelPath
        from dataiku.huggingface.vllm_backend.oai_mappers import generate_dummy_dss_request

        if self.lora_path:
            # 'base_model' is directly the base model, but it's not queried
            base_model_path = BaseModelPath(name="base_model", model_path="fake_path")
        else:
            # 'model' is directly the ID of the queried model
            base_model_path = BaseModelPath(name="model", model_path="fake_path")

        openai_serving_models = OpenAIServingModels(
            self.engine_client,
            vllm_model_config,
            base_model_paths=[base_model_path],
            lora_modules=(
                # Add a model ID 'model' that will be queried when LoRA is enabled
                [LoRAModulePath(name="model", path=self.lora_path)]
                if self.lora_path
                else []
            )
        )

        if self.lora_path:
            await openai_serving_models.init_static_loras()

        # Create the OpenAI Chat API server of vLLM
        self.openai_server = OpenAIServingChat(
            engine_client=self.engine_client,
            models=openai_serving_models,
            model_config=vllm_model_config,
            response_role="assistant",
            request_logger=None,
            chat_template=self.chat_template_renderer.get_chat_template(),
            chat_template_content_format="auto",
            **({
                'enable_auto_tools': True,
                'tool_parser': tool_parser,
            } if self.tools_supported else {})
        )
        # -------------------------------
        # START THE STAT LOGGER (V1 engine doesn't log anything by default)
        from vllm.v1.engine.async_llm import AsyncLLM as V1AsyncLLMEngine

        if isinstance(self.engine_client, V1AsyncLLMEngine):
            logger.info("VLLM V1 is being used, starting stat logger")
            async def _force_log():
                    from vllm import envs
                    while True:
                        await asyncio.sleep(envs.VLLM_LOG_STATS_INTERVAL)
                        await self.engine_client.do_log_stats()

            self.log_stat_task = asyncio.create_task(_force_log())
        else:
            logger.info("VLLM V0 is being used")

        # -------------------------------
        # TEST QUERY
        prompt = "Explain in simple terms what Generative AI is and why prompts matter for it"
        logger.info(
            "Testing engine with basic prompt: {input_text}".format(input_text=prompt)
        )
        request = generate_dummy_dss_request(prompt)

        result = None
        async for result in self.run_single_async(request=request):
            pass

        # Models like OpenAI OSS-GPT won't necessarily return a non-reasoning text response due to the small max token limit
        # => It is more reliable to check presence of completion tokens in general than presence of text output
        if not result or (result.get("usage") or {}).get("completionTokens", 0) == 0:
            raise FatalException(
                "Something went wrong at initialization. Engine did not return any result for basic prompt"
            )
        logger.info("Test prompt executed successfully: {result}".format(result=result["text"]))
        # -------------------------------

    @staticmethod
    def supports_model(hf_handling_mode, model_name_or_path, base_model_name_or_path, model_settings, expected_vllm_version):
        logger.info("Checking if VLLM is supported")
        try:
            from peft import LoraConfig
            adapter_config = LoraConfig.from_pretrained(model_name_or_path)
            logger.info(f"Adapter config with rslora {adapter_config.use_rslora}")
            if adapter_config.use_rslora:
                # Bug-fix in VLLM - SC-219662
                logger.info("RS-LoRa is not supported in VLLM")
                return False
        except ValueError:
            logger.info("Model is not an adapter model")

        if not hf_handling_mode.startswith("TEXT_GENERATION_"):
            logger.info("Handling mode is not TEXT_GENERATION, vLLM not supported")
            return False

        try:
            from vllm.model_executor.models import ModelRegistry
            import vllm
            vllm_version = vllm.__version__
            logger.info("VLLM version: " + vllm_version)
        except ImportError:
            logger.info("VLLM is not installed")
            return False

        if not package_is_exactly(vllm, expected_vllm_version):
            raise ValueError(f"Installed version of 'vllm' (version={vllm_version}) is incompatible, "
                             f"please use version {expected_vllm_version}")

        try:
            quantization_mode = QuantizationMode[model_settings["quantizationMode"]]
            if quantization_mode == QuantizationMode.Q_8BIT:
                logger.info("Quantization mode {mode} not supported by VLLM".format(mode=quantization_mode))
                return False

            transformers_model_config = transformers.PretrainedConfig.from_pretrained(
                base_model_name_or_path if base_model_name_or_path is not None else model_name_or_path
            )

            if not IS_MAC:
                if not torch.cuda.is_available():
                    # vllm does not support CPU inference yet: https://github.com/vllm-project/vllm/pull/1028
                    logger.info("CUDA is not available, vLLM not supported")
                    return False

                if torch.cuda.device_count() == 0:
                    logger.info("No CUDA device found, vLLM not supported")
                    return False

                # Check compute capability
                for i in range(torch.cuda.device_count()):
                    capability_level = torch.cuda.get_device_capability(i)
                    device_name = torch.cuda.get_device_name(i)
                    if capability_level[0] < 7:
                        logger.info(
                            f"CUDA device {i} ({device_name}) has compute capability {capability_level}: VLLM is not supported"
                        )
                        return False
                    else:
                        logger.info(
                            f"CUDA device {i} ({device_name}) has compute capability {capability_level}: VLLM is supported"
                        )

            supported_architectures = ModelRegistry.get_supported_archs()
            architecture = transformers_model_config.architectures[0]
            if architecture not in supported_architectures:
                logger.info(
                    "Model architecture {architecture} not supported by VLLM".format(
                        architecture=architecture
                    )
                )
                return False

        except Exception:
            logger.exception(
                "Error while checking if VLLM is supported, assuming it is not supported"
            )
            return False

        logger.info("VLLM is supported")
        return True

    @staticmethod
    def _resolve_lora_path(lora_path):
        """
        Can be removed when the minimum vLLM requirement is >=0.5.3

        Prior to vLLM 0.5.3, LoRA adapters were only supported if downloaded to the local machine (not HF model IDs).
        This method replicates the behaviour of vLLM 0.5.3, downloading the LoRA adapter from HF if needed
        https://github.com/vllm-project/vllm/pull/6234/files#diff-7c04dc096fc35387b6759bdc036f747fd9d8cf21bb7b9f2f69b2d57492b59ba1R114-R154
        """
        from huggingface_hub import snapshot_download

        if os.path.isabs(lora_path):
            return lora_path

        if lora_path.startswith('~'):
            return os.path.expanduser(lora_path)

        if os.path.exists(lora_path):
            return os.path.abspath(lora_path)

        # If the path doesn't exist locally, assume it's a Hugging Face repo.
        logger.info("Downloading LoRA adapter from HuggingFace, model id: {hf_id}".format(hf_id=lora_path))
        return snapshot_download(repo_id=lora_path)

    async def run_single_async(
            self, request: ProcessSinglePromptCommand
    ) -> AsyncIterator[ProcessSinglePromptResponseText]:
        from vllm.engine.async_llm_engine import AsyncEngineDeadError  # engine V0
        from vllm.v1.engine.exceptions import EngineDeadError  # engine V1
        from dataiku.huggingface.vllm_backend.oai_mappers import (
            dss_to_oai_request,
            oai_to_dss_response,
        )
        request_id = request["id"]
        try:
            logger.info("Start prompt {request_id}".format(request_id=request_id))
            assert self.openai_server

            oai_request = dss_to_oai_request(
                request,
                self.chat_template_renderer,
                self.enable_json_constraints_in_prompt,
                self.tools_supported,
            )
            response = await self.openai_server.create_chat_completion(oai_request)
            async for resp_or_chunk in oai_to_dss_response(response):
                yield resp_or_chunk
            logger.info("Done prompt {request_id}".format(request_id=request_id))
        except (EngineDeadError, AsyncEngineDeadError) as err:
            raise FatalException("Fatal exception: {0}".format(str(err))) from err
        except Exception:
            # inspired from https://github.com/vllm-project/vllm/blob/5fbbfe9a4c13094ad72ed3d6b4ef208a7ddc0fd7/vllm/entrypoints/launcher.py#L97-L107
            # in some cases, the engine may die while handling a request but the fatal error is not properly propagated
            # so when a query fails, we check on the engine
            if self.engine_client.errored and not self.engine_client.is_running:
                raise FatalException("Engine client failed")
            raise
