import logging

from dataiku.base import dku_pickle
from dataiku.doctor.utils import dku_write_mode_for_pickling

logger = logging.getLogger(__name__)

CUSTOM_OBJECTS_PKL_NAME = "keras_model_custom_objects.pkl"

_custom_objects = {}


def register_object(object_name, object_val):
    logger.info("registering new custom_object: {}".format(object_name))
    _custom_objects[object_name] = object_val


def save_current_custom_objects(model_folder_context):

    if not _custom_objects:
        logger.info("No custom objects, not saving them")
        return
    else:
        logger.info("Saving custom objects: {}".format(_custom_objects.keys()))

    with model_folder_context.get_file_path_to_write(CUSTOM_OBJECTS_PKL_NAME) as pkl_file_path:
        with open(pkl_file_path, dku_write_mode_for_pickling()) as pkl_file:
            dku_pickle.dump(_custom_objects, pkl_file)


def load_custom_objects(model_folder_context):
    logger.info("Attempting to load custom_objects from {}".format(model_folder_context))

    if not model_folder_context.isfile(CUSTOM_OBJECTS_PKL_NAME):
        logger.info("No custom objects found, not loading them")
        return None

    else:
        with model_folder_context.get_file_path_to_read(CUSTOM_OBJECTS_PKL_NAME) as pkl_path:
            with open(pkl_path, "rb") as pkl_file:
                custom_objects = dku_pickle.load(pkl_file)
        logger.info("Custom objects loaded: {}".format(custom_objects.keys()))
        return custom_objects
