import base64
import json
import time
from datetime import datetime
from typing import Any, Dict, Generator, Optional, Union

from common.backend.models.base import (
    ImageGenerationSettings,
    LLMCompletionSettings,
    LLMStep,
    LLMStepDesc,
    RetrievalSummaryJson,
)
from common.backend.utils.auth_utils import get_auth_user
from common.backend.utils.context_utils import TIME_FORMAT, add_llm_step_trace
from common.backend.utils.llm_utils import get_image_generation, handle_response_trace
from common.backend.utils.picture_utils import detect_image_format
from common.backend.utils.user_profile_utils import (
    get_nbr_images_to_generate,
    get_num_images_user_can_generate,
    set_nbr_images_to_generate,
)
from common.llm_assist.logging import logger
from common.solutions.chains.no_retrieval_chain import NoRetrievalChain
from dataiku.llm.tracing import SpanBuilder
from dataikuapi.dss.llm import DSSLLMImageGenerationQuery, DSSLLMImageGenerationResponse


class ImageGenerationChain:
    def __init__(
        self,
        completion_settings: LLMCompletionSettings,
        image_generation_settings: ImageGenerationSettings,
        user_query: str,
        user_profile_sql_manager: Any,
        user_profile: Optional[Dict[str, Any]] = None,
        trace: SpanBuilder = None,
        include_user_profile_in_prompt: Optional[bool] = False,
    ):
        self.user_query = user_query
        self.user_profile = user_profile
        self.completion_settings = completion_settings
        self.image_generation_settings = image_generation_settings
        self.user_profile_sql_manager = user_profile_sql_manager
        self.trace = trace
        self.include_user_profile_in_prompt = include_user_profile_in_prompt

    def __handle_image_generation_trace(self, image_generation_settings: ImageGenerationSettings):
        if height := image_generation_settings.get("image_height"):
            self.trace.attributes["image_height"] = height
        if width := image_generation_settings.get("image_width"):
            self.trace.attributes["image_width"] = width
        if nbr_images := image_generation_settings.get("images_to_generate"):
            self.trace.attributes["images_to_generate"] = nbr_images
        if quality := image_generation_settings.get("image_quality"):
            self.trace.attributes["image_quality"] = quality

    def run_image_generation_query(self, max_images_to_generate: int) -> Generator[Union[LLMStepDesc, RetrievalSummaryJson], Any, None]:
        self.trace.begin(int(time.time() * 1000))
        self.trace.inputs["query"] = self.user_query
        
        if max_images_to_generate > 0 and self.user_profile:
            stored_user_profile = self.user_profile_sql_manager.get_user_profile(get_auth_user())
            if stored_user_profile:
                self.user_profile = stored_user_profile
            else:
                logger.debug("New user profile", log_conv_id=True)
                # If the user profile is not found, add a flag to mark the user profile as new
                # to insert a new user profile in the database
                self.user_profile["new_user_profile"] = True
            num_images = get_num_images_user_can_generate(self.user_profile)
            answer_context = {}
            if num_images == 0:
                message = "Oops! You’ve reached your image generation limit for this week. Check your next reset time in your Settings."
                answer_context["answer"] = message
                yield {"step": LLMStep.STREAMING_END}
                yield NoRetrievalChain(self.completion_settings, include_user_profile_in_prompt=self.include_user_profile_in_prompt).get_as_json(answer_context)
                return
            elif num_images < get_nbr_images_to_generate(self.user_profile):  # type: ignore
                set_nbr_images_to_generate(self.user_profile, num_images)  # type: ignore
        logger.debug("Running image generation query", log_conv_id=True)

        yield {"step": LLMStep.GENERATING_IMAGE}
        try:
            generation: DSSLLMImageGenerationQuery = get_image_generation(self.image_generation_settings)
            start_time: str = datetime.now().strftime(TIME_FORMAT)
            self.__handle_image_generation_trace(self.image_generation_settings)
            generation.with_prompt(self.user_query)
            resp: DSSLLMImageGenerationResponse = generation.execute()
            handle_response_trace(resp)
            answer_context = {}
            trace_response = {}
            trace_response["status"] = "ok" if resp.success else "error"
            if resp.success:
                logger.debug("Image generation was successful", log_conv_id=True)
                images = []
                trace_response["images"] = [] # type: ignore
                for image in resp._raw["images"]:
                    image_data = image["data"]
                    image_format = detect_image_format(base64.b64decode(image_data)) or "png"
                    logger.debug(f"Image type: {image_format}", log_conv_id=True)
                    image_b64 = f"data:image/{image_format};base64," + image_data
                    images.append(
                        {
                            "file_data": image_b64,
                            "file_format": image_format,
                            "referred_file_path": self.image_generation_settings.get("referred_file_path", ""),
                        }
                    )
                    trace_response["images"].append({"file_data": "INLINE_IMAGE", "file_format": image_format}) # type: ignore
                answer_context["images"] = images  # type: ignore
                answer_context["user_profile"] = self.user_profile  # type: ignore
            else:
                logger.error(f"Image generation failed: {resp._raw}", log_conv_id=True)
                trace_response["error_response"] = resp._raw
                resp = resp._raw
                message = "An error occurred while generating the image. If you changed some image generation settings, please try again with the default settings."
                if resp.get("errorMessage"):
                    error_json = {}
                    # Extract error information from the model's response
                    if "response: " in resp["errorMessage"]:
                        error_json_str = resp["errorMessage"].split("response: ", 1)[1]
                        error_json = json.loads(error_json_str)
                        error_json = error_json.get("error", {})
                    else:
                        error_json = json.loads(resp["errorMessage"])
                    if "message" in error_json:
                        message += " Here is the model's message: " + error_json.get("message", "")
                answer_context["answer"] = message
            logger.debug(
                f"Time ===> taken by run_image_generation_query: {round((datetime.now() - datetime.strptime(start_time, TIME_FORMAT)).total_seconds(), 2)} secs",
                log_conv_id=True
            )
            self.trace.outputs["response"] = trace_response
            add_llm_step_trace(self.trace.to_dict())
            yield {"step": LLMStep.STREAMING_END}
            yield NoRetrievalChain(self.completion_settings, include_user_profile_in_prompt=self.include_user_profile_in_prompt).get_as_json(answer_context)
        except Exception as e:
            logger.exception(f"Error in image generation: {e}", log_conv_id=True)
            self.trace.outputs["response"] = {"status": "error", "error_response": f"Error in image generation: {e}"}
            add_llm_step_trace(self.trace.to_dict())
            answer_context["answer"] = "An error occurred while generating the image."
            yield NoRetrievalChain(self.completion_settings, include_user_profile_in_prompt=self.include_user_profile_in_prompt).get_as_json(answer_context)
