import base64
import logging
import json
import torch

from enum import Enum
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple

from compel import Compel, ReturnedEmbeddingsType
from diffusers import DiffusionPipeline

from dataiku.base.async_link import FatalException, call_with_fatal_exception
from dataiku.huggingface.pipeline_batching import ModelPipelineBatching
from dataiku.huggingface.torch_utils import best_supported_dtype
from dataiku.huggingface.types import DeviceStrategy
from dataiku.huggingface.types import ModelSettings
from dataiku.huggingface.types import ProcessSingleImageGenerationCommand
from dataiku.huggingface.types import ProcessSingleImageGenerationResponse

logger = logging.getLogger(__name__)

class _DeviceStrategy(str, Enum):
    NONE = 'NONE',
    MODEL_CPU_OFFLOAD = "MODEL_CPU_OFFLOAD"
    SEQUENTIAL_CPU_OFFLOAD = "SEQUENTIAL_CPU_OFFLOAD"

class ModelPipelineImageGenerationDiffusion(ModelPipelineBatching[ProcessSingleImageGenerationCommand, ProcessSingleImageGenerationResponse]):
    _device_strategy: DeviceStrategy = "NONE"
    _default_height: Optional[int] = None
    _default_width: Optional[int] = None
    _default_num_inference_steps: Optional[int] = None
    _default_guidance_scale: Optional[float] = None
    _device_map: Optional[str]
    _enable_vae_slicing: bool = False
    _enable_vae_tiling: bool = False
    _model: DiffusionPipeline
    _model_weight_variant: Optional[str] = "fp16"
    _torch_dtype: torch.dtype = best_supported_dtype()

    def __init__(self, model: str, model_settings: ModelSettings, batch_size: int):
        super().__init__(batch_size=batch_size)
        self._default_height = model_settings.get("defaultHeight", self._default_height)
        self._default_width = model_settings.get("defaultWidth", self._default_width)
        self._default_num_inference_steps = model_settings.get("defaultNumInferenceSteps", self._default_num_inference_steps)
        self._default_guidance_scale = model_settings.get("defaultGuidanceScale", self._default_guidance_scale)
        self._device_strategy = model_settings.get("deviceStrategy", self._device_strategy)
        self._enable_vae_slicing = model_settings.get("enableVaeSlicing", self._enable_vae_slicing)
        self._enable_vae_tiling = model_settings.get("enableVaeTiling", self._enable_vae_tiling)
        self._device_map = None # use self._get_device_map() instead when https://github.com/huggingface/accelerate/issues/3000 is fixed in torch 2.5 to allow multi-gpu support
        self._model = DiffusionPipeline.from_pretrained(model,
                                                       torch_dtype=self._torch_dtype,
                                                       variant=self._model_weight_variant,
                                                       use_safetensors=True,
                                                       device_map=self._device_map)

        if self._enable_vae_slicing:
            self._model.vae.enable_slicing() # useful to save some memory and allow larger batch sizes
        if self._enable_vae_tiling:
            self._model.vae.enable_tiling() # useful for saving a large amount of memory and to allow processing larger images
        self._handle_device_loading(self._model)

        self.model_tracking_data["used_engine"] = "diffusers"
        self.model_tracking_data["task"] = "image-generation"

    def _handle_device_loading(self, model: DiffusionPipeline):
        if _DeviceStrategy.MODEL_CPU_OFFLOAD == self._device_strategy:
            logger.info("Enabling CPU model offload")
            # offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance
            model.enable_model_cpu_offload()
        elif _DeviceStrategy.SEQUENTIAL_CPU_OFFLOAD == self._device_strategy:
            logger.info("Enabling CPU sequential offload")
            # offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU, significantly reducing memory usage. Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower.
            model.enable_sequential_cpu_offload()
        else:
            if self._device_strategy is not _DeviceStrategy.NONE:
                logger.warning(f"Unknown cpu_offloading_mode: '{self._device_strategy}'")
            # TODO: remove this and rely on device_map when bug is fixed in torch 2.5: https://github.com/huggingface/accelerate/issues/3000
            #       this would allow multi gpu support
            logger.info("Moving model to GPU device")
            model.to("cuda")

    def _get_device_map(self) -> Optional[str]:
        device_map: Optional[str]
        if _DeviceStrategy.MODEL_CPU_OFFLOAD == self._device_strategy or _DeviceStrategy.SEQUENTIAL_CPU_OFFLOAD == self._device_strategy:
            device_map = None
        else:
            if self._device_strategy is not _DeviceStrategy.NONE:
                logger.warning(f"Unknown _device_strategy: '{self._device_strategy}'")
            device_map = "balanced"
        logger.info(f"Device map set to {device_map}")
        return device_map

    @staticmethod
    def create_pipeline(model_name_or_path: str, model_settings: ModelSettings, use_dss_model_cache: bool, batch_size: int) -> 'ModelPipelineImageGenerationDiffusion':
        config = DiffusionPipeline.load_config(model_name_or_path)
        if isinstance(config, tuple):
            config = config[0]
        config_class_name = config["_class_name"]
        logger.info(f"Config class name: '{config_class_name}'.")
        if config_class_name == "StableDiffusionPipeline":
            pipeline = ModelPipelineImageGenerationStableDiffusion(model_name_or_path, model_settings, batch_size)
        elif config_class_name == "StableDiffusionXLPipeline":
            if use_dss_model_cache:
                refiner = model_settings.get("hfRefinerPath", None)
            else:
                refiner = model_settings.get("refinerId", None)
            pipeline = ModelPipelineImageGenerationStableDiffusionXL(model_name_or_path, refiner, model_settings, batch_size)
        elif config_class_name == "StableDiffusion3Pipeline":
            pipeline = ModelPipelineImageGenerationStableDiffusion3(model_name_or_path, model_settings, batch_size)
        elif config_class_name == "FluxPipeline":
            pipeline = ModelPipelineImageGenerationFlux(model_name_or_path, model_settings, batch_size)
        else:
            logger.warning(f"Unknown config class name: '{config_class_name}'. Using StableDiffusionPipeline.")
            pipeline = ModelPipelineImageGenerationStableDiffusion(model_name_or_path, model_settings, batch_size)
        pipeline.model_tracking_data["model_architecture"] = config_class_name
        return pipeline

    @staticmethod
    def _group_images_per_request(batch: List, images: List, num_images_per_prompt: int) -> List[List[str]]:
        images_grouped_by_request = []
        for i in range(0, len(batch)):
            i_start = i * num_images_per_prompt
            i_end = i_start + num_images_per_prompt
            images_grouped_by_request.append(images[i_start:i_end])

        return images_grouped_by_request

    @staticmethod
    def _call_diffusion_pipeline_inference(model: DiffusionPipeline, params: Dict) -> Any:
        logger.info(f"Running image inference with params '{json.dumps(params, indent=4, default=str)}'")
        try:
            return call_with_fatal_exception(lambda: model.__call__(**params).images, torch.cuda.OutOfMemoryError)
        except RuntimeError as err:
            str_err = str(err)
            # Can be out of memory errors mishandled by some cuda setups: https://stackoverflow.com/a/62073916
            if 'cuDNN_STATUS_NOT_INITIALIZED' in str_err: # might happen if an image too large for the GPU is requested
                raise FatalException("Fatal exception: {0}".format(str_err)) from err
            raise err

    @staticmethod
    def _set_param_if_none(pipeline_params: Dict, request_params: Dict, pipeline_key: str, request_key: str):
        # we only want to set this param if we have it otherwise we omit it so the `DiffusionPipeline` will its subclass default value
        if request_key in request_params and request_params[request_key] is not None:
            pipeline_params[pipeline_key] = request_params[request_key]

    def _are_prompts_weighted(self, prompts: List):
        for prompt in prompts:
            if prompt.get("weight", None) is not None:
                return True
        return False

    def _get_diffusion_pipeline_params(self, request_params: Dict) -> Dict:
        pipeline_params = {
            "height": request_params["height"],
            "width": request_params["width"],
            "num_images_per_prompt": request_params["num_images_per_prompt"],
            "generator": torch.Generator("cpu").manual_seed(request_params["seed"]) if request_params["seed"] is not None else None,
        }

        self._set_param_if_none(pipeline_params, request_params, pipeline_key="num_inference_steps", request_key="num_inference_steps")
        self._set_param_if_none(pipeline_params, request_params, pipeline_key="guidance_scale", request_key="guidance_scale")
        return pipeline_params

    def _concat_prompts(self, gen_prompts: List):
        if gen_prompts is None or len(gen_prompts) == 0:
            return ""
        return " ".join(gen_prompt.get("prompt", "") for gen_prompt in gen_prompts)

    def _blend_weighted_prompts(self, gen_prompts: List):
        # See https://huggingface.co/docs/diffusers/en/using-diffusers/weighted_prompts#blending
        prompts: List[str] = []
        weights: List[float] = []
        for gen_prompt in gen_prompts:
            if not gen_prompt.get("prompt", None) or gen_prompt.get("weight") == 0.0:
                continue
            prompts.append(gen_prompt.get("prompt"))
            weights.append(gen_prompt.get("weight", 1.0))
        # Expected output string format: ("Stunning sunset over a futuristic city.", "Forest").blend(1.0, 0.8)'
        blended_prompt = f"""({", ".join([f'"{prompt}"' for prompt in prompts])}).blend({", ".join([str(weight) for weight in weights])})"""
        logger.info(f"Blended prompts into: `{blended_prompt}`")
        return blended_prompt

    def _get_prompts_inputs(self, batch: List, params: Dict) -> Tuple[List[str], List[str]]:
        prompt_weighted = params["prompt_weighted"]
        negative_prompt_weighted = params["negative_prompt_weighted"]
        prompts = []
        negative_prompts = []
        prompt_concat_func = self._blend_weighted_prompts if prompt_weighted else self._concat_prompts
        negative_prompt_concat_func = self._blend_weighted_prompts if negative_prompt_weighted else self._concat_prompts
        for single_img_gen_query in batch:
            prompts.append(prompt_concat_func(single_img_gen_query["prompt_texts"]))
            negative_prompts.append(negative_prompt_concat_func(single_img_gen_query["negative_prompt_texts"]))
        return prompts, negative_prompts

    def _get_inputs(self, requests: List[ProcessSingleImageGenerationCommand]) -> List:
        inputs = []
        for request in requests:
            inputs.append({ "prompt_texts": request.get("promptTexts"), "negative_prompt_texts": request.get("negativePromptTexts") })
        return inputs

    def _get_params(self, request: ProcessSingleImageGenerationCommand) -> Dict:
        fidelity: Optional[float] = request.get("fidelity", None)
        if fidelity is None:
            guidance_scale = self._default_guidance_scale
        else:
            guidance_scale = fidelity * 35.0
        return {
            # if prompts are weighted they will be handled differently in the inference function
            "prompt_weighted": self._are_prompts_weighted(request.get("promptTexts")),
            "negative_prompt_weighted": self._are_prompts_weighted(request.get("negativePromptTexts")),
            "height": request.get("height", self._default_height),
            "width": request.get("width", self._default_width),
            "num_inference_steps": request.get("numInferenceSteps", self._default_num_inference_steps),
            "seed": request.get("seed", None),
            "guidance_scale": guidance_scale,
            "num_images_per_prompt": request.get("numImagesPerPrompt", 1),
        }

    def _parse_response(self, response, request: ProcessSingleImageGenerationCommand) -> ProcessSingleImageGenerationResponse:
        images = []
        for image in response:
            buffered = BytesIO()
            image.save(buffered, format="PNG")
            img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
            images.append(img_str)
        # TODO: add more infos?
        return { "images": images }

class ModelPipelineImageGenerationStableDiffusion(ModelPipelineImageGenerationDiffusion):
    # See https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__

    def __init__(self, base_model: str, model_settings: ModelSettings, batch_size: int):
        super().__init__(base_model, model_settings, batch_size=batch_size)
        logger.info("Initializing compel for weighting")
        self._compel = Compel(tokenizer=self._model.tokenizer, text_encoder=self._model.text_encoder)

    def _run_inference(self, inputs: List, params: Dict) -> List:
        core_params = self._get_diffusion_pipeline_params(params)
        prompts, negative_prompts = self._get_prompts_inputs(inputs, params)

        # Base model generation
        prompt_weighted = params["prompt_weighted"]
        negative_prompt_weighted = params["negative_prompt_weighted"]
        base_params = {
            **core_params,
            **{
                "prompt": prompts if not prompt_weighted else None,
                "negative_prompt": negative_prompts if not negative_prompt_weighted else None,
            }
        }
        if prompt_weighted:
            base_params = {
                **core_params,
                **{
                    "prompt_embeds": self._compel(prompts),
                }
            }
        if negative_prompt_weighted:
            base_params = {
                **base_params,
                **{
                    "negative_prompt_embeds": self._compel(negative_prompts),
                }
            }

        images = self._call_diffusion_pipeline_inference(self._model, base_params)
        return self._group_images_per_request(inputs, images, params["num_images_per_prompt"])

class ModelPipelineImageGenerationStableDiffusionXL(ModelPipelineImageGenerationDiffusion):
    # See https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLPipeline.__call__
    _default_refiner_strength: Optional[float] = 0.3
    _refiner: Any

    def __init__(self, base_model: str, refiner_name_or_path: str, model_settings: ModelSettings, batch_size: int):
        super().__init__(base_model, model_settings, batch_size=batch_size)
        self._default_refiner_strength = model_settings.get("defaultStrength", self._default_refiner_strength)
        logger.info("Initializing compel for weighting")
        self._compel = Compel(
            tokenizer=[self._model.tokenizer, self._model.tokenizer_2],
            text_encoder=[self._model.text_encoder, self._model.text_encoder_2],
            returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
            requires_pooled=[False, True]
        )

        if refiner_name_or_path:
            logger.info("Initializing refiner")
            self._refiner = DiffusionPipeline.from_pretrained(
                refiner_name_or_path,
                text_encoder_2=self._model.text_encoder_2,
                vae=self._model.vae,
                torch_dtype=self._torch_dtype,
                use_safetensors=True,
                variant=self._model_weight_variant,
                device_map=self._device_map
            )
            self._refiner.unet = torch.compile(self._refiner.unet, mode="reduce-overhead", fullgraph=True)
            logger.info("Initializing compel for refiner weighting")
            self._refiner_compel = Compel(
                tokenizer=[self._refiner.tokenizer_2],
                text_encoder=[self._refiner.text_encoder_2],
                returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
                requires_pooled=[True],
            )

            # No need to call the below functions since it's the same vae instance and these optmization were already enabled in the parent ctor
            # self._refiner.vae.enable_slicing()
            # self._refiner.vae.enable_tiling()
            self._handle_device_loading(self._refiner)
        else:
            logger.info("No refiner to initialize")
            self._refiner = None

    def _get_params(self, request: ProcessSingleImageGenerationCommand) -> Dict:
        return {
            **super()._get_params(request),
            **{
                "refiner_strength": request.get("refinerStrength", self._default_refiner_strength) if self._refiner is not None else 1,
            }
        }

    def _run_inference(self, inputs: List, params: Dict) -> List:
        prompt_weighted = params["prompt_weighted"]
        negative_prompt_weighted = params["negative_prompt_weighted"]
        refiner_strength = params["refiner_strength"]

        core_params = self._get_diffusion_pipeline_params(params)
        prompts, negative_prompts = self._get_prompts_inputs(inputs, params)

        # Base model generation
        base_params = {
            **core_params,
            **{
                "prompt": prompts if not prompt_weighted else None,
                "negative_prompt": negative_prompts if not negative_prompt_weighted else None,
            }
        }
        if prompt_weighted:
            base_conditioning, base_pooled = self._compel(prompts)
            base_params = {
                **core_params,
                **{
                    "prompt_embeds": base_conditioning,
                    "pooled_prompt_embeds": base_pooled,
                }
            }
        if negative_prompt_weighted:
            base_negative_conditioning, base_negative_pooled = self._compel(negative_prompts)
            base_params = {
                **base_params,
                **{
                    "negative_prompt_embeds": base_negative_conditioning,
                    "base_negative_pooled": base_negative_pooled,
                }
            }
        images = self._call_diffusion_pipeline_inference(self._model, base_params)

        # Optional refinement step
        # See https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#diffusers.StableDiffusionXLImg2ImgPipeline.__call__
        if self._refiner and (refiner_strength is not None and refiner_strength > 0):
            refiner_params = {
                **core_params,
                **{
                    "image": images,
                    "prompt": prompts if not prompt_weighted else None,
                    "negative_prompt": negative_prompts if not negative_prompt_weighted else None,
                    "strength": refiner_strength if refiner_strength is not None and refiner_strength > 0 else None
                }
            }
            if prompt_weighted:
                refiner_conditioning, refiner_pooled = self._refiner_compel(prompts)
                refiner_params = {
                    **refiner_params,
                    **{
                        "prompt_embeds": refiner_conditioning,
                        "pooled_prompt_embeds": refiner_pooled,
                    }
                }
            if negative_prompt_weighted:
                refiner_negative_conditioning, refiner_negative_pooled = self._refiner_compel(negative_prompts)
                refiner_params = {
                    **refiner_params,
                    **{
                        "negative_prompt_embeds": refiner_negative_conditioning,
                        "base_negative_pooled": refiner_negative_pooled,
                    }
                }
            logger.info("Running refinement step")
            images = self._call_diffusion_pipeline_inference(self._refiner, refiner_params)

        return self._group_images_per_request(inputs, images, params["num_images_per_prompt"])

class ModelPipelineImageGenerationStableDiffusion3(ModelPipelineImageGenerationDiffusion):
    # See https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_3#diffusers.StableDiffusion3Pipeline.__call__
    def _run_inference(self, inputs: List, params: Dict) -> List:
        prompt_weighted = params["prompt_weighted"]
        negative_prompt_weighted = params["negative_prompt_weighted"]
        if prompt_weighted or negative_prompt_weighted:
            # TODO: wait for support https://github.com/damian0815/compel/issues/92
            raise Exception('prompt weights are not supported on SD3 yet')

        core_params = self._get_diffusion_pipeline_params(params)
        prompts, negative_prompts = self._get_prompts_inputs(inputs, params)
        # Base model generation
        base_params = {
            **core_params,
            **{
                "prompt": prompts,
                "negative_prompt": negative_prompts
            }
        }
        images = self._call_diffusion_pipeline_inference(self._model, base_params)
        return self._group_images_per_request(inputs, images, params["num_images_per_prompt"])

class ModelPipelineImageGenerationFlux(ModelPipelineImageGenerationDiffusion):
    # See https://huggingface.co/docs/diffusers/api/pipelines/flux#diffusers.FluxPipeline.__call__
    _default_num_inference_steps = 4
    _default_guidance_scale = 0.
    _default_max_sequence_length: int = 256
    _model_weight_variant: Optional[str] = None

    def __init__(self, model: str, model_settings: ModelSettings, batch_size: int):
        super().__init__(model, model_settings=model_settings, batch_size=batch_size)
        logger.info("Initializing compel for weighting")
        self._compel = Compel(
            tokenizer=[self._model.tokenizer, self._model.tokenizer_2] ,
            text_encoder=[self._model.text_encoder, self._model.text_encoder_2],
            returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
            requires_pooled=[False, True]
        )
        self._default_max_sequence_length = model_settings.get("maxSequenceLength", self._default_max_sequence_length)

    def _get_params(self, request: ProcessSingleImageGenerationCommand) -> Dict:
        return {
            **super()._get_params(request),
            **{
                "max_sequence_length": request.get("max_sequence_length", self._default_max_sequence_length),
            }
        }

    def _get_diffusion_pipeline_params(self, request_params: Dict) -> Dict:
        return {
            **super()._get_diffusion_pipeline_params(request_params),
            **{
                "max_sequence_length": request_params["max_sequence_length"],
            }
        }

    def _run_inference(self, inputs: List, params: Dict) -> List:
        prompt_weighted = params["prompt_weighted"]
        core_params = self._get_diffusion_pipeline_params(params)
        prompts, _ = self._get_prompts_inputs(inputs, params)

        # Base model generation
        base_params = {
            **core_params,
            **{
                "prompt": prompts if not prompt_weighted else None,
            }
        }
        if prompt_weighted:
            base_conditioning, base_pooled = self._compel(prompts)
            base_params = {
                **core_params,
                **{
                    "prompt_embeds": base_conditioning,
                    "pooled_prompt_embeds": base_pooled,
                }
            }

        images = self._call_diffusion_pipeline_inference(self._model, base_params)
        return self._group_images_per_request(inputs, images, params["num_images_per_prompt"])
