from abc import ABCMeta
from abc import abstractmethod
from collections import OrderedDict
from enum import Enum

import numpy as np
from six import add_metaclass


class FeatureType(Enum):
    FROZEN = "FROZEN"
    NUMERICAL = "NUMERICAL"
    CATEGORICAL = "CATEGORICAL"


@add_metaclass(ABCMeta)
class FeatureDomain(object):
    TYPE = None

    def __init__(self, feature_name):
        self.feature_name = feature_name

    @abstractmethod
    def equals(self, feature_domain):
        """
        Checks constraints equality with other feature_domain (ignores feature_name).

        :type feature_domain: FeatureDomain
        :rtype: bool
        """

    @abstractmethod
    def check_validity(self, x):
        """
        Checks if some values respect the constraints.

        :type x: np.ndarray
        :rtype: np.ndarray
        """


class FrozenFeatureDomain(FeatureDomain):
    TYPE = FeatureType.FROZEN

    def __init__(self, feature_name, reference):
        """
        :type feature_name: str or int
        :type reference: str or float or int
        """
        super(FrozenFeatureDomain, self).__init__(feature_name)
        self.reference = reference

    def equals(self, feature_domain):
        """
        Checks if the param is also a FrozenFeatureDomain.
        :param FrozenFeatureDomain feature_domain: other FrozenFeatureDomain
        :rtype: bool
        """
        if not isinstance(feature_domain, FrozenFeatureDomain):
            return False
        return feature_domain.reference == self.reference

    def check_validity(self, x):
        """
        Checks if some values respect the constraints.

        :type x: np.ndarray
        :rtype: np.ndarray
        """
        return self.reference == x

    def __str__(self):
        return "FrozenFeatureDomain(feature_name={}, reference={})".format(self.feature_name, self.reference)


class NumericalFeatureDomain(FeatureDomain):
    TYPE = FeatureType.NUMERICAL

    def __init__(self, feature_name, min_value=None, max_value=None, is_integer=False, n_digits='auto'):
        """
        :type feature_name: str or int
        :type min_value: float or int or None
        :type max_value: float or int or None
        :type is_integer: bool
        :param int or "auto" or None n_digits: Precision of generated float values.
            Higher precision = potentially similar counterfactuals
            - int: number of digits
            - None: as many digits as possible
            - 'auto': use a heuristic to find n_digits
        """
        super(NumericalFeatureDomain, self).__init__(feature_name)
        self.min_value = min_value
        self.max_value = max_value
        self.is_integer = is_integer
        self.n_digits = 0 if is_integer else n_digits
        if max_value is not None and min_value > max_value:
            raise ValueError("First bound must be less than second bound")
        if self.n_digits != 'auto' and self.n_digits is not None and not isinstance(self.n_digits, int):
            raise ValueError("Input n_digits must be 'auto', None or an integer value.")

    def equals(self, feature_domain):
        """
        :param NumericalFeatureDomain feature_domain: compares the min and max values
        :rtype: bool
        """
        if not isinstance(feature_domain, NumericalFeatureDomain):
            return False
        return (self.min_value == feature_domain.min_value
                and self.max_value == feature_domain.max_value
                and self.is_integer == feature_domain.is_integer
                and self.n_digits == feature_domain.n_digits)

    def check_validity(self, x):
        """
        Checks if some values respect the constraints.

        :type x: np.ndarray
        :rtype: np.ndarray
        """
        if self.min_value is None and self.max_value is None:
            return True  # no constraint, all are valid
        elif self.min_value is not None and self.max_value is None:
            return x >= self.min_value
        elif self.min_value is None and self.max_value is not None:
            return x <= self.max_value
        else:
            return (x <= self.max_value) & (x >= self.min_value)

    def __str__(self):
        return "NumericalFeatureDomain(feature_name={}, is_integer={}, min_value={}, max_value={}"\
            .format(self.feature_name, self.is_integer, self.min_value, self.max_value)


class CategoricalFeatureDomain(FeatureDomain):
    TYPE = FeatureType.CATEGORICAL

    def __init__(self, feature_name, categories=None):
        """
        :type feature_name: str or int
        :type categories: set
        """
        super(CategoricalFeatureDomain, self).__init__(feature_name)
        self.categories = None if categories is None else categories

    def equals(self, feature_domain):
        """
        :param CategoricalFeatureDomain feature_domain: compares the enabled categories
        :rtype: bool
        """
        if not isinstance(feature_domain, CategoricalFeatureDomain):
            return False
        return self.categories == feature_domain.categories

    def check_validity(self, x):
        """
        Checks if some values respect the constraints.

        :type x: np.ndarray
        :rtype: np.ndarray
        """
        if self.categories is None:
            return True  # no constraint, all are valid
        else:
            # NB: converting to `list` because `set` would be converted to an object
            #     array with one element, rather than an array of the values.
            return np.isin(x, list(self.categories))

    def __str__(self):
        # NB: set() will be serialized as "set([a, b, ...])" in python 2, and as "{a, b, ...}" in python 3
        return "CategoricalFeatureDomain(feature_name={}, categories={})".format(self.feature_name, self.categories)


class FeatureDomains(object):

    def __init__(self, num_features):
        """
        :param num_features: number of features in data (regardless of the constraints)
        :type num_features: int
        """
        self.num_features = num_features
        self._feature_domains_map = OrderedDict()

    def append(self, feature_domain):
        """
        :type feature_domain: FeatureDomain
        """
        if not isinstance(feature_domain, FeatureDomain):
            raise TypeError("feature_domain should have type: 'FeatureDomain'")
        self._feature_domains_map[feature_domain.feature_name] = feature_domain

    def has_categorical_feature(self):
        for domain in self._feature_domains_map.values():
            if domain.TYPE == FeatureType.CATEGORICAL:
                return True
        return False

    def __len__(self):
        return self.num_features

    def __iter__(self):
        for feature_name in self._feature_domains_map:
            yield self._feature_domains_map[feature_name]

    def __getitem__(self, feature_name):
        if self.has_feature(feature_name):
            return self._feature_domains_map[feature_name]
        else:
            raise ValueError("Unknown feature '%s'" % feature_name)

    def __str__(self):
        return "FeatureDomains({})".format(", ".join(map(str, self._feature_domains_map.values())))

    def has_feature(self, feature_name):
        return feature_name in self._feature_domains_map
