import logging
from typing import AsyncIterator, Union

from vllm.entrypoints.openai.serving_models import OpenAIServingModels

from dataiku.huggingface.vllm_backend.utils import get_vllm_config
from dataiku.base.async_link import FatalException
from dataiku.base.utils import package_is_at_least
from dataiku.huggingface.chat_template import ChatTemplateRenderer
from dataiku.huggingface.pipeline_vllm import ModelPipelineVLLM
from dataiku.huggingface.types import ProcessSinglePromptCommand
from dataiku.huggingface.types import ProcessSinglePromptResponseText
from dataiku.huggingface.types import ToolSettings
from dataiku.huggingface.types import ChatTemplateSettings
from dataiku.huggingface.types import ModelSettings
from dataiku.huggingface.env_collector import extract_weight_quantization_method
from dataiku.huggingface.vllm_backend.model_params import get_model_params_loader_for_completion

logger = logging.getLogger(__name__)


class ModelPipelineTextGenerationVLLM(ModelPipelineVLLM[ProcessSinglePromptCommand, ProcessSinglePromptResponseText]):
    def __init__(
            self,
            hf_handling_mode,
            model_name_or_path,
            base_model_name_or_path,  # will be None if not an adapter model
            model_settings: ModelSettings,
            use_dss_model_cache,
            with_image_input
    ):
        super().__init__(model_name_or_path, base_model_name_or_path, model_settings)
        self.hf_handling_mode = hf_handling_mode

        self.used_engine = "vllm"
        self.model_tracking_data["task"] = "text-generation"
        self.model_tracking_data["adapter"] = "lora" if self.lora_path is not None else "none"
        self.model_tracking_data["weights_quantization"] = extract_weight_quantization_method(self.transformers_model_config)

        model_params_loader = get_model_params_loader_for_completion(self.model_to_load, self.transformers_model_config, model_settings, with_image_input, use_dss_model_cache, self.lora_path)
        load_params = {
            "disable_log_stats": False,  # Enable logging of performance metrics
            "model": self.model_to_load,
            "runner": "generate",
            **model_params_loader.build_params()
        }
        self.load_model(load_params)

    async def build_openai_server(self, openai_serving_models: OpenAIServingModels):
        vllm_config = await get_vllm_config(self.engine_client)

        # -------------------------------
        # CHAT TEMPLATE
        # having tool template here does not mean the model necessarily supports tool call (mistral 7b v0.1 does not for example)
        chat_template_settings: ChatTemplateSettings = self.model_settings.get("chatTemplateSettings", {})
        chat_template_override: str = chat_template_settings.get('chatTemplate') or "" if chat_template_settings.get('overrideChatTemplate', False) else ""

        logger.info("Loading tokenizer")
        tokenizer = await self.engine_client.get_tokenizer()
        logger.info("Tokenizer loaded")

        self.chat_template_renderer = ChatTemplateRenderer(
            tokenizer=tokenizer,
            hf_handling_mode=self.hf_handling_mode,
            supports_message_parts=True,
            chat_template_override=chat_template_override,
            vllm_config=vllm_config
        )
        # -------------------------------

        # -------------------------------
        # TOOLS
        tool_settings: ToolSettings = self.model_settings.get("toolSettings", {})
        enable_tools: bool = tool_settings.get('enableTools', False)
        tool_parser: Union[str, None] = tool_settings.get('toolParser')

        # Make sure empty string are treated as None
        if tool_parser == "":
            tool_parser = None

        if enable_tools:
            if self.chat_template_renderer.supports_tool:
                logger.info("Model chat template supports tools")
            else:
                logger.warning("Model chat template does not seem to support tools (will only work with GPT-OSS models)")
            if tool_parser:
                logger.info(f"Using tool parser {tool_parser}")
            else:
                logger.warning("No tool parser specified (will only work with GPT-OSS models)")
            self.tools_supported = True
        else:
            logger.info("Tools disabled for this model")
            self.tools_supported = False
        # -------------------------------

        # -------------------------------
        # GUIDED GENERATION
        self.enable_json_constraints_in_prompt = self.model_settings.get('enableJsonConstraintsInPrompt') or True
        logger.info(f"Json constraints injected in prompt: {self.enable_json_constraints_in_prompt}")
        # -------------------------------

        # -------------------------------
        # REASONING
        reasoning_settings = self.model_settings.get("reasoningSettings", {})
        parse_reasoning: bool = reasoning_settings.get('parseReasoning', False)
        reasoning_parser: str = reasoning_settings.get('reasoningParser') or ""

        self.reasoning_supported = False
        if not parse_reasoning:
            logger.info("Reasoning disabled for this model")
        elif not reasoning_parser:
            logger.warning("Reasoning enabled without a reasoning parser: it won't work if not a GPT-OSS model")
            self.reasoning_supported = True
        else:
            logger.info("Model supports reasoning")
            logger.info(f"Using reasoning parser {reasoning_parser}")
            self.reasoning_supported = True

        # Create the OpenAI Chat API server of vLLM
        from vllm.entrypoints.openai.serving_chat import OpenAIServingChat

        extra_args = {}
        import vllm
        if package_is_at_least(vllm, "0.11.0"):
            # Required since https://github.com/vllm-project/vllm/pull/25794 (0.11.0) to be able to use custom chat templates
            extra_args['trust_request_chat_template'] = True
            logger.info("vLLM >= 0.11.0 detected, enabling trust_request_chat_template for custom chat templates")

        if self.tools_supported:
            extra_args['enable_auto_tools'] = True
            if tool_parser:
                extra_args['tool_parser'] = tool_parser

        if self.reasoning_supported:
            if reasoning_parser:
                extra_args['reasoning_parser'] = reasoning_parser

        if package_is_at_least(vllm, "0.11.1"):
            self.openai_server = OpenAIServingChat(
                engine_client=self.engine_client,
                models=openai_serving_models,
                response_role="assistant",
                request_logger=None,
                chat_template=self.chat_template_renderer.get_chat_template(),
                chat_template_content_format="auto",
                **extra_args,
            )
        else:
            self.openai_server = OpenAIServingChat(
                engine_client=self.engine_client,
                models=openai_serving_models,
                model_config=vllm_config.model_config,
                response_role="assistant",
                request_logger=None,
                chat_template=self.chat_template_renderer.get_chat_template(),
                chat_template_content_format="auto",
                **extra_args,
            )

    async def run_test_query(self):
        from dataiku.huggingface.vllm_backend.oai_mappers import generate_dummy_dss_completion_request

        prompt = "Explain in simple terms what Generative AI is and why prompts matter for it"
        logger.info(
            "Testing completion engine with basic prompt: {input_text}".format(input_text=prompt)
        )
        request = generate_dummy_dss_completion_request(prompt)

        result = None
        async for result in self.run_single_async(request=request):
            pass

        # Models like OpenAI OSS-GPT won't necessarily return a non-reasoning text response due to the small max token limit
        # => It is more reliable to check presence of completion tokens in general than presence of text output
        if not result or (result.get("usage") or {}).get("completionTokens", 0) == 0:
            raise FatalException(
                "Something went wrong at initialization. Completion engine did not return any result for basic prompt"
            )
        logger.info("Completion engine test executed successfully: {result}".format(result=result["text"]))

    async def run_query(
        self, request: ProcessSinglePromptCommand
    ) -> AsyncIterator[ProcessSinglePromptResponseText]:
        from dataiku.huggingface.vllm_backend.oai_mappers import (
            dss_to_oai_completion_request,
            oai_to_dss_completion_response,
        )
        request_id = request["id"]

        logger.info("Start prompt {request_id}".format(request_id=request_id))
        assert self.openai_server

        oai_request = dss_to_oai_completion_request(
            request,
            self.chat_template_renderer,
            self.enable_json_constraints_in_prompt,
            self.tools_supported,
            self.reasoning_supported
        )
        response = await self.openai_server.create_chat_completion(oai_request)
        async for resp_or_chunk in oai_to_dss_completion_response(response):
            yield resp_or_chunk
        logger.info("Done prompt {request_id}".format(request_id=request_id))