# Citation: 
# Berrios, W., Mittal, G., Thrush, T., Kiela, D., & Singh, A. (2023). Towards Language Models 
# That Can See: Computer Vision Through the LENS of Natural Language. arXiv preprint arXiv:2306.16410.
# The code is from the GitHub repository : https://github.com/ContextualAI/lens
# Minor modifications have been applied and comments have been added

import datetime
import os

import torch
from torch.distributed import init_process_group
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM

default_device = "cuda" if torch.cuda.is_available() else "cpu"  # Set the default device to CUDA if available, otherwise to CPU

# Mapping of CLIP model names to their corresponding abbreviations
MAP_CLIP_NAME = {
    "openai/clip-vit-large-patch14": "ViT-L-14",
    "openai/clip-vit-base-patch16": "ViT-B-16",
    "openai/clip-vit-base-patch32": "ViT-B-32",
    "hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K": "laion-ViT-H-14-2B",
    "hf-hub:laion/CLIP-ViT-bigG-14-laion2B-39B-b160k": "laion-ViT-bigG-14-2B",
}


def ddp_setup():
    """
    Set up the distributed training environment.
    """
    init_process_group(backend="nccl", timeout=datetime.timedelta(seconds=180000))
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))


def create_sampler(dataset, distributed=False):
    """
    Create a data sampler based on the dataset.

    Args:
        dataset (torch.utils.data.Dataset): The dataset for which the sampler is created.
        distributed (bool, optional): Flag indicating whether to create a distributed sampler. 
            Defaults to False.

    Returns:
        torch.utils.data.Sampler: The created sampler object.
    """
    if distributed:
        # Create a distributed sampler that samples elements from the dataset across multiple processes
        sampler = torch.utils.data.DistributedSampler(dataset, shuffle=False)
    else:
        # Create a sequential sampler that samples elements from the dataset in a sequential order
        sampler = torch.utils.data.SequentialSampler(dataset)
    return sampler


def create_dataloader(dataset, sampler, batch_size=8, num_workers=0):
    """
    Create a data loader for the given dataset using the specified sampler.

    Args:
        dataset (torch.utils.data.Dataset): The dataset to be loaded.
        sampler (torch.utils.data.Sampler): The sampler to use for sampling data from the dataset.
        batch_size (int, optional): Number of samples per batch. Defaults to 8.
        num_workers (int, optional): Number of subprocesses to use for data loading. Defaults to 0.

    Returns:
        torch.utils.data.DataLoader: The created data loader object.
    """
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True,
        sampler=sampler,
        shuffle=False,
        drop_last=False,
    )
    return loader


def is_main_process():
    """
    Check if the current process is the main process.

    Returns:
        bool: True if the current process is the main process, False otherwise.
    """
    if int(os.environ["RANK"]) == 0:
        return True
    else:
        return False


def get_llm_model(version, load_8bit, device_map=None):
    """
    Get the Language Language Model (LLM) model based on the specified version and configuration.

    Args:
        version (str): The version of the language model to load.
        load_8bit (bool): Whether to load the model in 8-bit precision.
        device_map (dict): Device mapping for the model.

    Returns:
        torch.nn.Module: The loaded LLM model.

    Raises:
        Exception: If the specified version is not supported or if there is an error loading the model.
    """
    if load_8bit:
        try:
            model = AutoModelForCausalLM.from_pretrained(
                version,
                load_in_8bit=True,
                torch_dtype=torch.float16,
                device_map={"": device_map},
            )
        except:
            model = AutoModelForSeq2SeqLM.from_pretrained(
                version,
                load_in_8bit=True,
                torch_dtype=torch.float16,
                device_map={"": device_map},
            )
    else:
        try:
            model = AutoModelForSeq2SeqLM.from_pretrained(version).to(device_map)
        except:
            model = AutoModelForCausalLM.from_pretrained(version).to(device_map)
    model = model.eval()
    return model


def create_prompt_sample(
    samples,
    idx,
    tags_col="tags",
    attributes_col="attributes",
    caption_col="caption",
    intensive_captions_col="intensive_captions",
    question_col="questions",
    question_prompt=None,
    num_intensive_captions=50,
    mode="all",
):
    """
    Create a prompt based on the specified mode and sample information.

    Args:
        samples (dict): Dictionary containing sample information.
        idx (int): Index of the sample.
        tags_col (str, optional): Column name for tags. Defaults to "tags".
        attributes_col (str, optional): Column name for attributes. Defaults to "attributes".
        caption_col (str, optional): Column name for caption. Defaults to "caption".
        intensive_captions_col (str, optional): Column name for intensive captions. Defaults to "intensive_captions".
        question_col (str, optional): Column name for questions. Defaults to "questions".
        question_prompt (str, optional): Prompt for questions. Defaults to None.
        num_intensive_captions (int, optional): Number of intensive captions. Defaults to 50.
        mode (str, optional): Mode for generating the prompt. Defaults to "all".

    Returns:
        str: Generated prompt based on the specified mode and sample information.

    Raises:
        Exception: If the specified mode is not available.
    """
    prompt = ""

    if mode == "vqa":
        prompt += "Image:\n"
        prompt += "Captions:"
        prompt += ".".join(
            samples[intensive_captions_col][idx][:num_intensive_captions]
        )

    elif mode == "vision":
        prompt += "Tag: "
        prompt += ",".join(samples[tags_col][idx])
        prompt += "\nAttributes: "
        prompt += ",".join(samples[attributes_col][idx])

    elif mode == "hm":
        prompt += "Image:\n"
        prompt += "Caption:"
        prompt += samples[caption_col][idx]
        prompt += "\nAttributes:"
        prompt += ",".join(samples[attributes_col][idx])
        prompt += "\nTags:"
        prompt += ",".join(samples[attributes_col][idx])
        prompt += "\nQuestion: Is the image hateful or not-hateful?"
        prompt += "\nShort Answer:"

    elif mode == "all":
        prompt += "\n-".join(
            samples[intensive_captions_col][idx][:num_intensive_captions]
        )
    else:
        raise Excepton("Mode not available")
    return prompt
