import logging
import traceback
import sys

import numpy as np

from dataiku.base.socket_block_link import JavaLink, parse_javalink_args
from dataiku.base.utils import get_json_friendly_error
from dataiku.base.utils import watch_stdin
from dataiku.core import debugging
from dataiku.core.managed_folder import Folder
from dataiku.doctor.deephub.data_augmentation.image_transformer import build_augmentation_transforms_list
from dataiku.doctor.deephub.data_augmentation.image_transformer import AugmentationType

from PIL import Image, ImageOps
import albumentations as A

from dataiku.doctor.deephub.utils.file_utils import img_array_to_base64

logger = logging.getLogger(__name__)


class ImageAugmentationProtocol(object):

    def __init__(self, link):
        """
        :type link: dataiku.base.socket_block_link.JavaLink
        """
        self.link = link

    def _augment_image(self, params):
        """
        Augment image and send back the result as a list of dictionaries, each corresponding to the java class
        `com.dataiku.dip.analysis.coreservices.DataAugmentationService.AugmentedImage`

        :param params: Params sent by the backend, dictionary corresponding to the java class
               `com.dataiku.dip.analysis.coreservices.DataAugmentationService.DataAugmentationKernelParams`
        """

        logger.info("Received this commands %s" % params)
        folder = Folder(params["managedFolderId"])

        augmentation_type_str = params.get("augmentationType", None)
        augmentation_type = (None if augmentation_type_str is None else AugmentationType[augmentation_type_str])

        transforms_lists = build_augmentation_transforms_list(params["augmentationParams"], augmentation_type,
                                                              params.get("applyMaxTransform", None))
        img_transformer = A.Compose(transforms_lists)

        logger.info("Will apply this transform %s" % img_transformer)
        with folder.get_download_stream(params["imagePath"]) as img_file:
            with Image.open(img_file) as img:
                img_format = img.format
                # if an image has an EXIF Orientation tag, transpose the image accordingly, and remove the orientation data
                img = ImageOps.exif_transpose(img)
                img_array = np.array(img.convert("RGB"))

                image_versions = []
                for _ in range(params["numAugmentedVersions"]):
                    try:
                        transformed = img_transformer(image=img_array)
                    except:
                        logger.exception("Failed to augment image")
                        image_versions.append({
                            "failed": True
                        })
                        continue

                    img_array_transformed = transformed["image"]
                    image_versions.append({
                        "imgDataAsBase64": img_array_to_base64(img_array_transformed, img_format)
                    })

                self.link.send_json(image_versions)

    def start(self):
        try:
            while True:
                params = self.link.read_json()
                self._augment_image(params)
        except Exception as e:
            traceback.print_exc()
            traceback.print_stack()
            logger.error(e)
            error = get_json_friendly_error()
            self.link.send_json({'error': error})
        finally:
            self.link.close()


def serve(port, secret, server_cert=None):
    link = JavaLink(port, secret, server_cert=server_cert)
    link.connect()

    interactive_model = ImageAugmentationProtocol(link)
    interactive_model.start()


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO,
                        format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')
    debugging.install_handler()

    watch_stdin()
    port, secret, server_cert = parse_javalink_args()
    serve(port, secret, server_cert=server_cert)
