from copy import deepcopy
import os
from logging import Logger
from typing import cast

from dataiku.huggingface.types import ProcessSingleCommandModel, ProcessSingleEmbeddingCommandModel, ProcessSinglePromptCommandModel

# Needs to be called before any import of huggingface_hub otherwise has no effect
# automatically activates hf_transfer
# should be called after connecting the link to the backend (otherwise the log does not show up)
def enable_hf_transfer(logger: Logger) -> None:
    if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
        try:
            # enable hf hub transfer if available
            import hf_transfer
            os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
            logger.info(f"HF_HUB_ENABLE_HF_TRANSFER={os.environ['HF_HUB_ENABLE_HF_TRANSFER']} : package found")
        except:
            os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
            logger.info(f"HF_HUB_ENABLE_HF_TRANSFER={os.environ['HF_HUB_ENABLE_HF_TRANSFER']} : Could not import package 'hf_transfer'. For optimal "
                            f"startup performance, please install it in the code env")
    else:
        logger.info(f"HF_HUB_ENABLE_HF_TRANSFER={os.environ['HF_HUB_ENABLE_HF_TRANSFER']} : keeping value set in environment variable")


def log_hf_debug_info(logger: Logger) -> None:
    try:
        from huggingface_hub.constants import HF_HOME, HF_HUB_OFFLINE, HF_HUB_ENABLE_HF_TRANSFER
        logger.info(f"""
HF_HOME: {HF_HOME},
HF_HUB_OFFLINE: {HF_HUB_OFFLINE}
HF_HUB_ENABLE_HF_TRANSFER: {HF_HUB_ENABLE_HF_TRANSFER}
""")
    except Exception as e:
        logger.warning("could not get hf debug information", e)

def copy_request_for_logging(request: ProcessSingleCommandModel):
    try:
        copied_request = deepcopy(request)
        if "query" in copied_request and "inlineImage" in copied_request["query"]:
            copied_request = cast(ProcessSingleEmbeddingCommandModel, copied_request)
            copied_request["query"]["inlineImage"] = copied_request["query"]["inlineImage"][:25] + "..."
        elif "query" in copied_request and "messages" in copied_request["query"]:
            copied_request = cast(ProcessSinglePromptCommandModel, copied_request)
            for message in copied_request["query"]["messages"]:
                if message.get("parts") is not None:
                    for part in message["parts"]: # type: ignore checks are done just above
                        if part.get("inlineImage") is not None:
                            part["inlineImage"] = part["inlineImage"][:25] + "..." # type: ignore checks are done just above
        return copied_request
    except:
        return request
