import base64
import io
import os
import shutil
import logging

from dataiku.base.file_utils import SafeFile

logger = logging.getLogger(__name__)


class FilesReader(object):

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        pass

    def open_file(self, file_info):
        raise NotImplementedError()


class Base64FilesReader(FilesReader):

    def open_file(self, file_info):
        file_data = base64.b64decode(file_info)
        return io.BytesIO(file_data)


class ManagedFolderFilesReader(FilesReader):

    def __init__(self, managed_folder):
        """
        :type managed_folder: dataiku.core.managed_folder.Folder
        """
        self.managed_folder = managed_folder

    def open_file(self, file_info):
        return self.managed_folder.get_download_stream(file_info)


class CachedManagedFolderFilesReader(FilesReader):

    def __init__(self, managed_folder, base_folder_path, init_and_clean_base_folder=True):
        """
        Files reader, relying on a managed folder, that caches every retrieved file locally for further reuse
        without downloading them again.

        :type managed_folder: dataiku.core.managed_folder.Folder
        :type base_folder_path: str
        :param init_and_clean_base_folder: whether the files reader is responsible for the creation and
                                           deletion of the cache folder
        :type init_and_clean_base_folder: bool
        """
        self.managed_folder = managed_folder
        self.base_folder_path = base_folder_path
        self.init_and_clean_base_folder = init_and_clean_base_folder

    def __enter__(self):
        if self.init_and_clean_base_folder:
            if os.path.isdir(self.base_folder_path):
                logger.info("Re-using already existing base folder %s, "
                            "for caching managed folder %s" % (self.base_folder_path, self.managed_folder.name))
            else:
                # Create the base folder
                os.mkdir(self.base_folder_path)
                logger.info("Will use folder %s to cache managed folder %s" % (self.base_folder_path,
                                                                               self.managed_folder.name))
        else:
            logger.info("Not creating base folder '{}', expects that it "
                        "exists and is available".format(self.base_folder_path))

        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """
        Removes the local cache.
        """
        if self.init_and_clean_base_folder:
            logger.info("Exiting, deleting cache for managed folder {}".format(self.managed_folder.name))
            shutil.rmtree(self.base_folder_path, ignore_errors=True)
        else:
            logger.info("Exiting, not deleting cache for managed folder {}".format(self.managed_folder.name))

    def open_file(self, file_info):
        return open(self.get_local_path(file_info), 'rb')

    @staticmethod
    def _flatten_path(file_path):
        """
        Remove any sub-folder from path as the cached managed folder is only one folder.
        Separators are replaced by "__dku_sep__"
        """
        custom_separator = "__dku_sep__"
        return file_path.replace(os.path.sep, custom_separator)

    def get_local_path(self, file_path):
        """
        :type file_path: str
        :return: The local path of the cached file
        :rtype: str
        """
        # The 'file_path' is an user input, so it can contain forward slashes as well. On Windows we normalize it. On Unix systems it will stay the same.
        file_path = os.path.normcase(file_path)

        flattened_path = self._flatten_path(file_path)
        cached_path = os.path.join(self.base_folder_path, flattened_path)

        if not os.path.exists(cached_path):
            with SafeFile(cached_path, "wb") as f:
                f.write(self.managed_folder.get_download_stream(file_path).read())

        return cached_path


def img_array_to_base64(img_array, img_format):
    """
    :param img_array: arrays of dimension (width, height, color_channels) and int values between 0 and 255.
    :param img_format: JPEG, PNG.. make sure that it's compatible with the color_channels dimension.
    :rtype: str (image transformed into base64 string)
    """

    # local import to allow unit tests running on builtin env to import this file.
    from PIL import Image

    with io.BytesIO() as f:
        # `compress_level` is used for PNG images, where we are ok to compress less the image a bit but faster
        # (default value is 6)
        Image.fromarray(img_array.astype('uint8')).save(f, img_format, compress_level=3)
        return base64.b64encode(f.getvalue()).decode("utf8")
