from abc import ABCMeta
import logging

from six import add_metaclass

from dataiku.doctor.exploration.emu.algorithms import ActiveSphere
from dataiku.doctor.exploration.emu.generators import BaseGenerator
from dataiku.doctor.exploration.emu.handlers import NumericalDistributionHandler
from dataiku.doctor.exploration.emu.handlers import CategoryTargetEncodingHandler

logger = logging.getLogger(__name__)


@add_metaclass(ABCMeta)
class BaseCFGenerator(BaseGenerator):
    """
    Base class for counterfactual generators, based on Feature Sampling.
    """
    def __init__(self, model, target=None, with_clustering=True):
        """
        Create a Feature Sampling counterfactual generator.

        :param sklearn.base.BaseEstimator model: a trained model
        :param None or int target: desired class for the counterfactuals
        :param bool with_clustering: flag to activate clustering step
        """
        super(BaseCFGenerator, self).__init__(model)
        self.target = target
        self.with_clustering = with_clustering

    def generate_counterfactuals(self, reference, n=10):
        """
        Generate n counterfactual explanations given an input point (the reference).

        :param pd.DataFrame reference: reference point (df with one single row)
        :param int n: maximum number of counterfactual explanations to return
        """
        counterfactuals = self.algorithm.generate_counterfactuals(reference, n)
        if counterfactuals.shape[0] == 0:
            logger.warning("No counterfactual has been found. Try relaxing the constraints.")
        elif counterfactuals.shape[0] < n:
            logger.warning("Could not find the required number of counterfactuals. Try relaxing the constraints.")
        return counterfactuals


class ActiveSphereCFGenerator(BaseCFGenerator):
    """
    Counterfactual explanation generator, based on Active-Spheres.
    """
    def _get_new_algorithm(self):
        """
        :return: the algorithm that will generate the points
        :rtype: BaseGrowingSphere or EvolutionaryStrategy
        """
        return ActiveSphere(feature_domains=self.feature_domains,
                            handlers=self.handlers,
                            model=self.model,
                            measure_distance=self.distance,
                            target=self.target,
                            with_clustering=self.with_clustering)

    def _get_new_categorical_handler(self, feature_domain):
        """
        :type feature_domain: CategoricalFeatureDomain
        :rtype: BaseCategoryFeatureHandler
        """
        return CategoryTargetEncodingHandler(feature_domain=feature_domain,
                                             target=self.target)

    def _get_new_numerical_handler(self, feature_domain):
        """
        :type feature_domain: NumericalFeatureDomain
        :rtype: BaseNumericalFeatureHandler
        """
        return NumericalDistributionHandler(feature_domain=feature_domain,
                                            distance_name=self.distance_name)
