import logging
import os
from contextlib import contextmanager

from dataiku.base.folder_context import build_model_cache_folder_context, build_local_hf_folder_context

# Should not import huggingface or transformers here, otherwise would break hf_transfer and transformers_offline
# see https://app.shortcut.com/dataiku/story/190851/make-sure-transformers-offline-is-set-early-enough

logger = logging.getLogger(__name__)


@contextmanager
def _model_path_or_name_manager(model_path, model_name):
    if model_path is None:
        logger.info("No path provided for model, loading from model name directly: {}".format(model_name))
        yield model_name
    else:
        folder_context = build_model_cache_folder_context(model_path)
        with folder_context.get_folder_path_to_read() as model_folder_path:
            logger.info("Path provided, loading model from path {}".format(model_folder_path))
            yield os.path.join(model_folder_path, "model")


@contextmanager
def _saved_model_path_manager(saved_model_version_path, saved_model_project_key, saved_model_id):
    folder_context = build_local_hf_folder_context(saved_model_version_path, saved_model_project_key, saved_model_id)
    with folder_context.get_folder_path_to_read() as saved_model_folder_path:
        logger.info("Saved model provided, loading model from saved model path {}".format(saved_model_folder_path))
        yield saved_model_folder_path


def _get_model_path_or_name_context_manager(command):
    model_origin = command.get("modelOrigin")
    if model_origin == "SAVED_MODEL_VERSION":
        return _saved_model_path_manager(command["savedModelVersionPath"], command["savedModelProjectKey"],
                                         command["savedModelId"])
    elif model_origin == "HUGGINGFACE_MODEL":
        return _model_path_or_name_manager(command.get("hfModelPath"), command["hfModelName"])
    else:
        raise ValueError("Unknown model origin '%s', cannot instantiate pipeline" % model_origin)


@contextmanager
def model_and_base_model_name_or_path_manager(command, model_settings):
    """
    Resolves the model, the base model and the refiner model, yielding either a HuggingFace ID or a path for both.
    base_model_name_or_path is not None if this model is an adapter model
    :param command: The start command from Java
    :yield: model_name_or_path, base_model_name_or_path, refiner_name_or_path
    """
    with _get_model_path_or_name_context_manager(command) as model_name_or_path:
        logger.info("Checking if we need to load a base model for a LoRA adapter")
        try:
            from peft import PeftConfig  # import inside method, so it still works even if `peft` is not in the code-env
            config = PeftConfig.from_pretrained(model_name_or_path)
            config_base_model = config.base_model_name_or_path
        except:
            config_base_model = None
        base_model_path = command.get('baseModelPath')
        if base_model_path is None and config_base_model is None:
            logger.info("This is not a LoRA model, no need to load a base model")
            # at the moment for image gen (with or without refiner) we do not support LoRA adapter
            refiner_name = model_settings.get("refinerId", None)
            refiner_path = model_settings.get("hfRefinerPath", None)
            if refiner_name is None:
                yield model_name_or_path, None, None
            else:
                with _model_path_or_name_manager(refiner_path, refiner_name) as refiner_name_or_path:
                    yield model_name_or_path, None, refiner_name_or_path
        else:
            logger.info("This is a LoRA model, loading base model for LoRA adapter")
            with _model_path_or_name_manager(base_model_path, config_base_model) as base_model_name_or_path:
                yield model_name_or_path, base_model_name_or_path, None
