from abc import ABCMeta
from abc import abstractmethod
import logging

import numpy as np
import pandas as pd
from six import add_metaclass

from dataiku.doctor.exploration.emu.feature_domains import FeatureType
from dataiku.doctor.exploration.emu.handlers import DistanceName
from dataiku.doctor.exploration.emu.handlers import FrozenFeatureHandler
from dataiku.doctor.exploration.emu.handlers import BaseNumericalFeatureHandler
from dataiku.doctor.exploration.emu.handlers import BaseCategoryFeatureHandler

logger = logging.getLogger(__name__)


@add_metaclass(ABCMeta)
class BaseGenerator(object):
    """
    Base class for generators. Works with CF explanations and outcome optim.
    """
    def __init__(self, model):
        """
        :param sklearn.base.BaseEstimator model: A trained model.
        """
        self.model = model
        self.algorithm = None
        self.feature_domains = None
        self.distance_name = None
        self.handlers = {}

    def distance(self, x, xt):
        """
        Distance between the first row of the df `x` and the numpy array `xt`

        :param pd.DataFrame x: df that contains one row representing a point
        :param np.ndarray xt: numpy array representing a point
        :rtype float
        """
        xt_df = pd.DataFrame(np.atleast_2d(xt), columns=x.columns)
        feature_distances = np.zeros((x.shape[1],))
        for i, (feature_name, handler) in enumerate(self.handlers.items()):
            feature_distances[i] = handler.distance(x.iloc[0][feature_name], xt_df.iloc[0][feature_name])

        if self.distance_name == DistanceName.EUCLIDEAN:
            return np.sqrt(np.sum(feature_distances ** 2))
        elif self.distance_name == DistanceName.GOWER:
            return np.mean(feature_distances)
        else:
            raise NotImplementedError("Supported distances are euclidean or gower.")

    @abstractmethod
    def _get_new_numerical_handler(self, feature_domain):
        """
        :type feature_domain: NumericalFeatureDomain
        :rtype: BaseNumericalFeatureHandler
        """

    @abstractmethod
    def _get_new_categorical_handler(self, feature_domain):
        """
        :type feature_domain: CategoricalFeatureDomain
        :rtype: BaseCategoryFeatureHandler
        """

    @abstractmethod
    def _get_new_algorithm(self):
        """
        :return: the algorithm that will generate the points
        :rtype: BaseGrowingSphere or EvolutionaryStrategy
        """

    def fit(self, X, y, feature_domains):
        # Set distance name
        if feature_domains.has_categorical_feature():
            self.distance_name = DistanceName.GOWER
        else:
            self.distance_name = DistanceName.EUCLIDEAN

        # Init and fit handlers
        for feature_domain in feature_domains:
            feature_name = feature_domain.feature_name
            if self.feature_domains is not None and self.feature_domains.has_feature(feature_name):
                if self.feature_domains[feature_name].equals(feature_domain):
                    logger.info(u"Handler for '{}' already fitted, not re-fitting it".format(feature_name))
                    continue
                else:
                    logger.info(u"Parameters for '{}' have changed, refitting handler".format(feature_name))
            else:
                logger.info(u"Fitting new handler for feature '{}'".format(feature_name))

            if feature_domain.TYPE == FeatureType.NUMERICAL:
                handler = self._get_new_numerical_handler(feature_domain)
            elif feature_domain.TYPE == FeatureType.CATEGORICAL:
                handler = self._get_new_categorical_handler(feature_domain)
            elif feature_domain.TYPE == FeatureType.FROZEN:
                handler = self.handlers[feature_name] = FrozenFeatureHandler(feature_domain)
            else:
                raise NotImplementedError("Unknown feature type '{}'".format(feature_domain.TYPE))

            handler.fit(X[feature_name].values, y)
            self.handlers[feature_name] = handler

        # Remove outdated handlers
        for feature_name in filter(lambda f: not feature_domains.has_feature(f), self.handlers):
            logger.info(u"Feature '{}' is not in feature_domains anymore, removing existing handler".format(feature_name))
            del self.handlers[feature_name]

        self.feature_domains = feature_domains

        # Init algorithm
        self.algorithm = self._get_new_algorithm()
