import json
import re
from typing import Any, Dict, List, Optional, Union

import dataiku
from common.backend.constants import (
    ALTERNATIVE_LLM_KEYS,
    DEFAULT_MAX_LLM_TOKENS,
    DEFAULT_TEMPERATURE,
    LOWEST_TEMPERATURE,
    MAIN_LLM_KEY,
)
from common.backend.models.base import (
    ImageGenerationSettings,
    LLMCompletionSettings,
    LlmHistory,
    MediaSummary,
    UploadChainTypes,
)
from common.backend.utils.config_utils import resolve_webapp_param
from common.backend.utils.context_utils import add_llm_step_trace
from common.backend.utils.dataiku_api import dataiku_api
from common.backend.utils.file_utils import file_path_text_parts, file_path_to_image_parts
from common.backend.utils.json_utils import mask_keys_in_json
from common.backend.utils.picture_utils import get_image_bytes
from common.backend.utils.user_profile_utils import (
    get_image_height,
    get_image_quality,
    get_image_width,
    get_nbr_images_to_generate,
)
from common.llm_assist.logging import logger
from dataikuapi.dss.llm import (
    DSSLLMCompletionQuery,
    DSSLLMCompletionQueryMultipartMessage,
    DSSLLMCompletionResponse,
    DSSLLMImageGenerationQuery,
    DSSLLMImageGenerationResponse,
    DSSLLMStreamedCompletionFooter,
)
from dataikuapi.utils import DataikuException

webapp_config: Dict[str, Any] = dataiku_api.webapp_config

FALLBACK_LLM_KEY = "ENABLE_FALLBACK_LLM"

def compare_versions(version1, version2):
    def parse_version(version):
        try:
            return [int(part) for part in version.split(".")]
        except ValueError:
            return None  # Malformed version

    parts1 = parse_version(version1)
    if parts1 is None:
        # If version1 is malformed, treat it as a dev version
        return 1

    # Parse version2 (always well-formed)
    parts2 = [int(part) for part in version2.split(".")]

    # Compare corresponding parts
    for part1, part2 in zip(parts1, parts2):
        if part1 < part2:
            return -1  # version1 is older
        elif part1 > part2:
            return 1  # version1 is newer

    # If all parts compared are equal, compare lengths, eg. 13.0 vs 13.0.1
    if len(parts1) < len(parts2):
        return -1  # version1 is older (shorter version)
    elif len(parts1) > len(parts2):
        return 1  # version1 is newer (longer version)

    # Versions are identical
    return 0


# Image generation is supported with the following providers:

# OpenAI (DALL-E 3)

# Azure OpenAI (DALL-E 3)

# Google Vertex (Imagen 1 and Imagen 2)

# Stability AI (Stable Image Core, Stable Diffusion 3.0, Stable Diffusion 3.0 Turbo)

# Bedrock Titan Image Generator

# Bedrock Stable Diffusion XL 1


def get_llm_capabilities(get_fallback: bool = False) -> Dict[str, bool]:
    config: Dict[str, str] = dataiku_api.webapp_config
    if get_fallback:
        from common.llm_assist.fallback import get_fallback_id
        llm_id = get_fallback_id()
        if llm_id is None:
            return {"multi_modal": False, "streaming": False, "image_generation": False}
    else:
        llm_id = config["llm_id"]
    force_streaming_mode = bool(config.get("force_streaming_mode", False))
    force_multi_modal_mode = bool(config.get("force_multi_modal_mode", False))
    multi_modal, streaming, image_generation = force_multi_modal_mode, force_streaming_mode, False
    client = dataiku.api_client()
    dss_version = client.get_instance_info().raw.get("dssVersion", "0.0.0")
    if dss_version == "0.0.0":
        logger.warn("Could not retrieve DSS version")
    # Split the llm_id to extract the connection type and model
    parts = llm_id.split(":")
    if len(parts) >= 3:
        connexion, _, model = parts[:3]
        streaming = force_streaming_mode or (
            (connexion == "openai" and model.startswith("gpt"))
            or (connexion == "bedrock" and any(prefix in model for prefix in ["amazon.titan-", "anthropic.claude-"]))
            or (connexion == "azureopenai" and compare_versions(dss_version, "12.6.2") >= 0)
        )
        multi_modal = force_multi_modal_mode or (
            (connexion == "openai" and model.startswith("gpt-4"))
            or (connexion == "vertex" and model == "gemini-pro-vision")
            or (
                connexion == "bedrock"
                and model.startswith("anthropic.claude-3")
                and compare_versions(dss_version, "13.0.2") >= 0
            )
        )
        image_generation = (
            False
            if compare_versions(dss_version, "13.0.0") < 0 or not config.get("enable_image_generation", False)
            else config.get("image_generation_llm_id", "") != ""
        )
    if compare_versions(dss_version, "12.5.0") >= 0:
        return {
            "multi_modal": multi_modal,
            "streaming": streaming,
            "image_generation": image_generation,
        }
    else:
        return {"multi_modal": False, "streaming": False, "image_generation": False}


def handle_prompt_media_explanation(system_prompt: str, has_media: bool) -> str:
    if has_media:
        logger.debug("Appending media explanation to system prompt")
        example = """
        -- Start of example --
            [{"role": "user", "content":"hello"},
            {"role": "assistant", "content":"How can I help you today?"},
            {"role": "user", "content":"generate blue circle"},
            {"role": "assistant", "content":'{"generated_media_by_ai": {"images": [{"file_path": "userwx_17117_04RMXS.png", "file_format": "png", "referred_file_path":""}]}}'}،
            {"role": "user", "content": "Thank you"},]
            Expected Answer: You are welcome
        -- End of example --
        """
        system_prompt = f"""{system_prompt}. 
            During the current conversation, some media such as images could have been generated by an image generation agent.
            In that case Chat history could include metadata about the media generated by the image generation agent in the form of a generated_media_by_ai JSON object in the assistant message.
            When responding, do not include generated_media_by_ai object in your answers to the user.
            You do not have access to the media itself. If the user asks about the media, you can inform them that you don't have access to it.
            You can ignore the media in the current conversation and continue with your tasks.
            # Example :"""
        system_prompt = f"{system_prompt}\n{example}"
    return system_prompt


def append_summaries_to_completion_msg(
    media_summaries: List[MediaSummary], msg: DSSLLMCompletionQueryMultipartMessage
) -> DSSLLMCompletionQueryMultipartMessage:
    try:
        logger.debug("Appending media summaries to completion message")
        is_first_text = True
        for summary in media_summaries:
            chain_type: Union[str, None] = summary.get("chain_type")

            if chain_type in [
                UploadChainTypes.IMAGE.value,
                UploadChainTypes.DOCUMENT_AS_IMAGE.value,
            ]:
                msg = file_path_to_image_parts(summary, msg)
            elif chain_type == UploadChainTypes.SHORT_DOCUMENT.value:
                if is_first_text:
                    msg.with_text("""The user uploaded file(s) with extracted text along with their query. Here is the extracted text:
                    """)
                msg = file_path_text_parts(summary, msg)
                is_first_text = False
            else:
                continue
        msg.add()
    except Exception as e:
        logger.exception(f"Error when creating completion query from media summaries: {e}")


def get_llm_completion(completion_settings: LLMCompletionSettings) -> DSSLLMCompletionQuery:
    try:
        llm_id = completion_settings["llm_id"]
    except KeyError: 
        raise Exception("No LLM ID passed to get the completion")
    completion: DSSLLMCompletionQuery = dataiku_api.default_project.get_llm(llm_id).new_completion() # Replacement of llm._llm_handle.new_completion() to prevent the use of private fields
    completion.settings["maxOutputTokens"] = completion_settings.get("max_tokens") if completion_settings.get("max_tokens") else DEFAULT_MAX_LLM_TOKENS
    if temperature:= completion_settings.get("temperature"): # The temperature is passed as a completion parameter only if it has been resolved without a 'None' value
        completion.settings["temperature"] = temperature
        if temperature > LOWEST_TEMPERATURE:
            logger.warn(f"The LLM '{llm_id}' temperature '{temperature}' is '>= {LOWEST_TEMPERATURE if temperature < 1.0 else 1.0 }'.")
    logger.info(f"completion settings for LLM '{llm_id}': {completion.settings})")
    return completion


def extract_response_trace(response: Union[
        DSSLLMCompletionResponse,
        DSSLLMStreamedCompletionFooter,
        DSSLLMImageGenerationResponse,
    ]
) -> Dict[str, Any]:
    try:
        trace_value = {}
        if hasattr(response, "trace"):
            trace_value = getattr(response, "trace")
        elif isinstance(response, dict) and "trace" in response:
            trace_value = response["trace"]
        return trace_value
    except Exception as e:
        logger.exception(f"Error when handling response trace: {e}")
        return {}


def handle_response_trace(
    response: Union[
        DSSLLMCompletionResponse,
        DSSLLMStreamedCompletionFooter,
        DSSLLMImageGenerationResponse,
    ]
) -> None:
    trace_value = extract_response_trace(response)
    if trace_value:
        add_llm_step_trace(trace_value)

def parse_error_messages(error_as_str: Union[str, DataikuException]) -> str:
    dicts = re.findall(r"\{[^{}]*\}", str(error_as_str))
    if len(dicts) > 0:
        message = json.loads(dicts[0])
        if "message" in message:
            return f" Error message: {message['message']}"
    logger.debug(f"Error message from LLM couldn't be parsed: {error_as_str}")
    return ""


def get_main_llm_completion_settings()-> LLMCompletionSettings:
    llm_id = dataiku_api.webapp_config.get("llm_id")
    if not llm_id: 
        raise ValueError("A Dataiku LLM ID must be provided")
    use_advanced_llm_parameters = dataiku_api.webapp_config.get("show_advanced_settings", False) or False
    max_tokens: int = resolve_webapp_param("max_llm_tokens", default_value=DEFAULT_MAX_LLM_TOKENS, advanced_mode_enabled=use_advanced_llm_parameters)
    temperature: Optional[float] = resolve_webapp_param("llm_temperature", default_value=DEFAULT_TEMPERATURE, advanced_mode_enabled=use_advanced_llm_parameters)
    return LLMCompletionSettings(llm_id=llm_id, max_tokens=max_tokens, temperature=temperature)

def resolve_llm_id_from_key(llm_id_key: str)-> str:
    main_llm_id: str = dataiku_api.webapp_config.get(MAIN_LLM_KEY) # type: ignore
    if llm_id_key in ALTERNATIVE_LLM_KEYS:
        llm_id: str = dataiku_api.webapp_config.get(llm_id_key) # type: ignore
        if llm_id:
            return llm_id
        return main_llm_id
    return main_llm_id

def get_alternative_llm_completion_settings(llm_id_key: str) -> LLMCompletionSettings:
    if llm_id := dataiku_api.webapp_config.get(llm_id_key):
        logger.info(f"Using alternative LLM ID: {llm_id}")
        if llm_id_key == "title_llm_id":
            use_advanced_llm_parameters = webapp_config.get("use_advanced_title_llm_settings", False) or False
            temperature = resolve_webapp_param("title_llm_temperature", default_value=DEFAULT_TEMPERATURE, advanced_mode_enabled=use_advanced_llm_parameters)
            max_tokens = resolve_webapp_param("max_title_llm_tokens", default_value=DEFAULT_MAX_LLM_TOKENS, advanced_mode_enabled=use_advanced_llm_parameters)
        elif llm_id_key == "json_decision_llm_id":
            use_advanced_llm_parameters = webapp_config.get("use_advanced_decision_llm_settings", False) or False
            temperature = resolve_webapp_param("decision_llm_temperature", default_value=DEFAULT_TEMPERATURE, advanced_mode_enabled=use_advanced_llm_parameters)
            max_tokens = resolve_webapp_param("max_decision_llm_tokens", default_value=DEFAULT_MAX_LLM_TOKENS, advanced_mode_enabled=use_advanced_llm_parameters)
            if isinstance(temperature, (float, int)) and (temperature > LOWEST_TEMPERATURE):
                logger.warn(f"The 'Decisions LLM' temperature is not set to the minimum value ('{LOWEST_TEMPERATURE}'): current value: '{temperature}'. It must be as close to '{LOWEST_TEMPERATURE}' as possible.")
            if max_tokens == DEFAULT_MAX_LLM_TOKENS:
                logger.warn(f"The 'Decisions LLM' max output tokens is set to the minimum allowed value ('{DEFAULT_MAX_LLM_TOKENS}'): It is recommended to set a high value for accurate results.")
        else:
            logger.info(f"The  alternative LLM ID: '{llm_id}' will be used with the default parameters (temperature={DEFAULT_TEMPERATURE}, max_tokens={DEFAULT_MAX_LLM_TOKENS}).")
            temperature = DEFAULT_TEMPERATURE
            max_tokens = DEFAULT_MAX_LLM_TOKENS
        completion: DSSLLMCompletionQuery = dataiku_api.default_project.get_llm(llm_id).new_completion() # Replacement of llm._llm_handle.new_completion() to prevent the use of private fields
        completion.settings["maxOutputTokens"] = max_tokens if max_tokens else DEFAULT_MAX_LLM_TOKENS
        if temperature is not None:
            completion.settings["temperature"] = temperature
        return LLMCompletionSettings(llm_id=llm_id, max_tokens=max_tokens, temperature=temperature)
    elif llm_id := dataiku_api.webapp_config.get("llm_id"):
        logger.info(f"As the LLM '{llm_id_key}' is not set the 'Main LLM' ('{llm_id}') will be used")
        return get_main_llm_completion_settings()
    raise Exception("No LLM ID found in webapp config")


def get_image_generation(image_generation_settings: ImageGenerationSettings) -> DSSLLMImageGenerationQuery:
    try:
        model_id = image_generation_settings["model_id"]
    except KeyError: 
        raise Exception("No Image Generation model ID passed to get the generation")
    generation: DSSLLMImageGenerationQuery = dataiku_api.default_project.get_llm(model_id).new_images_generation()
    if referred_image := image_generation_settings.get("referred_image"):
        generation.with_original_image(image=referred_image)
    if height := image_generation_settings.get("image_height"):
        logger.debug(f"User set image height: {height}", log_conv_id=True)
        generation.height = height
    if width := image_generation_settings.get("image_width"):
        logger.debug(f"User set image width: {width}", log_conv_id=True)
        generation.width = width
    if images_to_generate := image_generation_settings.get("images_to_generate"):
        logger.debug(f"User set number of images to generate: {images_to_generate}", log_conv_id=True)
        generation.images_to_generate = images_to_generate
    if image_quality := image_generation_settings.get("image_quality"):
        logger.debug(f"User set quality: {image_quality}", log_conv_id=True)
        generation.quality = image_quality
    logger.info(f"Image generation settings for LLM '{model_id}': {mask_keys_in_json(image_generation_settings, {'referred_image'})}", log_conv_id=True)
    return generation


def get_image_generation_settings(model_id: str, referred_image_path: Optional[str], user_profile: Optional[Dict[str, Any]]) -> ImageGenerationSettings:
    image_generation_settings = ImageGenerationSettings(model_id=model_id)
    if referred_image_path:
        image_bytes = get_image_bytes(referred_image_path)
        if image_bytes:
            image_generation_settings["referred_file_path"] = referred_image_path
            image_generation_settings["referred_image"] = image_bytes
    if user_profile and user_profile.get("media") and user_profile.get("media").get("image"): # type: ignore
        if height := get_image_height(user_profile):
            image_generation_settings["image_height"] = height
        if width := get_image_width(user_profile):
            image_generation_settings["image_width"] = width
        if images_to_generate := get_nbr_images_to_generate(user_profile):
            image_generation_settings["images_to_generate"] = images_to_generate
        if image_quality:= get_image_quality(user_profile):
            image_generation_settings["image_quality"] = image_quality
    return image_generation_settings


def add_history_to_completion(
    completion: DSSLLMCompletionQuery,
    chat_history: List[LlmHistory],
) -> DSSLLMCompletionQuery:
    for hist_item in chat_history:
        if input_ := hist_item.get("input"):
            completion.with_message(message=input_, role="user")
        if output := hist_item.get("output"):
            completion.with_message(message=output, role="assistant")
    return completion


def get_llm_friendly_name(llm_id: str, project_key: str) -> str:
    project = dataiku.api_client().get_project(project_key)
    llms: List[Dict[str, str]] = project.list_llms()

    for llm in llms:
        if llm.get('id') == llm_id:
            return llm.get('friendlyName', llm_id)
    return ""