from enum import Enum
from collections import OrderedDict

import albumentations as A

from dataiku.doctor.deephub.data_augmentation.crop import DkuRandomCrop


class AugmentationType(Enum):
    AFFINE = "AFFINE"
    COLOR = "COLOR"
    CROP = "CROP"


class ImageTransformsLists(object):

    def __init__(self, with_augmentation_transforms_list, without_augmentation_transforms_list):
        """
        Simple holder of transforms list:
          * one containing the augmentation transforms (if any)
          * one without
        Needed when some augmentation might move the data out (for instance, remove all the boxes for object detection)

        :type with_augmentation_transforms_list: list
        :type without_augmentation_transforms_list: list
        """
        self.with_augmentation_transforms_list = with_augmentation_transforms_list
        self.without_augmentation_transforms_list = without_augmentation_transforms_list


def color_augmentation_transform(augmentation_params, apply_max_transform=False):
    bc_params = augmentation_params.get("colorJitter", {})
    if not bc_params.get("enabled", False):
        return []

    if apply_max_transform:
        p = 1.
        brightness = (1 + bc_params["brightness"], 1 + bc_params["brightness"])
        hue = (bc_params["hue"], bc_params["hue"])
        contrast = (1 + bc_params["contrast"], 1 + bc_params["contrast"])
    else:
        p = bc_params["probability"]
        brightness = bc_params["brightness"]
        hue = bc_params["hue"]
        contrast = bc_params["contrast"]

    return [A.ColorJitter(p=p, brightness=brightness, hue=hue, contrast=contrast)]


def affine_augmentation_transform(augmentation_params, apply_max_transform=False):
    transforms_list = []
    affine_params = augmentation_params.get("affine", {})

    horizontal_flip_params = affine_params.get("horizontalFlip", {})
    if horizontal_flip_params.get("enabled", False):
        p = 1. if apply_max_transform else horizontal_flip_params["probability"]
        transforms_list.append(A.HorizontalFlip(p=p))

    vertical_flip_params = affine_params.get("verticalFlip", {})
    if vertical_flip_params.get("enabled", False):
        p = 1. if apply_max_transform else vertical_flip_params["probability"]
        transforms_list.append(A.VerticalFlip(p=p))

    rotate_params = affine_params.get("rotate", {})
    if rotate_params["enabled"]:
        if apply_max_transform:
            p = 1
            limit = (rotate_params["maxRotation"], rotate_params["maxRotation"])
        else:
            p = rotate_params["probability"]
            limit = rotate_params["maxRotation"]

        transforms_list.append(A.Rotate(limit=limit, p=p))

    return transforms_list


def crop_augmentation_transform(augmentation_params, apply_max_transform=False):
    transforms_list = []
    crop_params = augmentation_params.get("crop", {})
    if crop_params.get("enabled", False):
        if apply_max_transform:
            p = 1
            min_kept_ratio = crop_params["minKeptRatio"]
            max_kept_ratio = crop_params["minKeptRatio"]
        else:
            p = crop_params["probability"]
            min_kept_ratio = crop_params["minKeptRatio"]
            max_kept_ratio = 1.

        transforms_list.append(DkuRandomCrop(min_kept_ratio=min_kept_ratio,
                                             max_kept_ratio=max_kept_ratio,
                                             preserve_aspect_ratio=crop_params["preserveAspectRatio"],
                                             p=p))
    return transforms_list


def build_augmentation_transforms_list(augmentation_params, augmentation_type=None, apply_max_transform=False):
    if augmentation_params is None:
        return []

    augmentation_transforms = OrderedDict([
        (AugmentationType.AFFINE, affine_augmentation_transform),
        (AugmentationType.COLOR, color_augmentation_transform),
        (AugmentationType.CROP, crop_augmentation_transform)
    ])

    # Requesting augmentation of a single type
    if augmentation_type is not None:
        return augmentation_transforms[augmentation_type](augmentation_params, apply_max_transform=apply_max_transform)

    # Requesting augmentation for all types
    transforms_list = []
    for transform in augmentation_transforms.values():
        transforms_list.extend(transform(augmentation_params, apply_max_transform=apply_max_transform))

    return transforms_list


def build_transforms_lists(augmentation_params=None, model_specific_transforms=None):

    with_augmentation_transforms_list = build_augmentation_transforms_list(augmentation_params,
                                                                           augmentation_type=None,
                                                                           apply_max_transform=False)
    without_augmentation_transforms_list = []

    if model_specific_transforms:
        with_augmentation_transforms_list.extend(model_specific_transforms)
        without_augmentation_transforms_list.extend(model_specific_transforms)

    return ImageTransformsLists(with_augmentation_transforms_list, without_augmentation_transforms_list)


def build_image_classification_transforms_lists(model_image_input_size, model_normalization, augmentation_params=None):
    model_specific_transforms = [
        A.Resize(**model_image_input_size),
        A.Normalize(**model_normalization)
    ]

    return build_transforms_lists(augmentation_params, model_specific_transforms=model_specific_transforms)
