import numpy as np
import logging
from enum import Enum

from dataiku.doctor.multiframe import SparseMatrixWithNames, NamedNPArray, DataFrameWrapper

logger = logging.getLogger(__name__)


def prepare_multiframe_with_sparse_support(multiframe, algorithm_sparse_support):
    """ Transform the MultiFrame either into a ndarray or a sparse matrix
    :param algorithm_sparse_support: an object implementing logic and values specific to an algorithm's support of sparse matrices for a given dataset
    :type algorithm_sparse_support: AlgorithmSparseSupport
    :param multiframe: the Multiframe to transform
    :return: the array and a bool to know if the array is sparsed or not
    :rtype: (np.ndarray | scipy.sparse.csr_matrix, bool)
    """
    as_sparse = algorithm_sparse_support.should_use_csr(multiframe)
    return prepare_multiframe_as_sparse_if_needed(multiframe, as_sparse), as_sparse

def prepare_multiframe_as_sparse_if_needed(multiframe, as_sparse):
    if as_sparse:
        return multiframe.as_csr_matrix()
    else:
        return multiframe.as_np_array()


class CSRSupport(Enum):
    UNSUPPORTED = 0
    NON_SETTABLE = 1
    REQUESTED = 2
    DISABLED = 3


class AlgorithmSparseSupport(object):
    ALGORITHMS_WITH_SETTABLE_CSR_SUPPORT = {}
    ALGORITHMS_WITH_NON_SETTABLE_CSR_SUPPORT = {}
    GRID_NAMES = {}  # mapping from algorithm to its modeling_params key
    MAX_FILL_RATIO = 0.5  # only use sparse matrices if data is really sparse
    MIN_CELL_COUNT = 50 * 1000 * 1000  # only use sparse matrices if the dataframe has a lot of cells

    def __init__(self, modeling_params):
        from dataiku.doctor.prediction.common import PredictionAlgorithmNaNSupport
        self.modeling_params = modeling_params
        self.algorithm = self.modeling_params["algorithm"]
        self.nan_support = PredictionAlgorithmNaNSupport(modeling_params)

    def _algorithm_supports_csr(self):
        """ Returns whether this algorithm can use sparse matrices, either at all or with the specific modeling params

        :return: one of CSRSupport enum values
        """
        if self.algorithm in self.ALGORITHMS_WITH_NON_SETTABLE_CSR_SUPPORT:
            return CSRSupport.NON_SETTABLE
        elif self.algorithm in self.ALGORITHMS_WITH_SETTABLE_CSR_SUPPORT:
            if self.should_allow_sparse_matrices():
                return CSRSupport.REQUESTED
            else:
                return CSRSupport.DISABLED
        return CSRSupport.UNSUPPORTED

    def should_allow_sparse_matrices(self):
        return self.modeling_params[self.GRID_NAMES[self.algorithm]].get('allow_sparse_matrices', False)

    def _is_dataframe_sparse_enough(self, train_X):
        (nrows, ncols) = train_X.shape()
        if(nrows == 0) :
            return False # Algos expecting sparse matrices can handle empty arrays all fine too.
        total_recorded_entries = 0
        total_nan = 0
        for name, block in train_X.blocks.items():
            # Ignore blocks that won't be selected in the feature matrix
            if name in train_X.keep and not train_X.keep[name]:
                continue
            if isinstance(block, SparseMatrixWithNames):
                total_recorded_entries += len(block.matrix.data)
            elif isinstance(block, NamedNPArray):
                n_nan = np.sum(np.isnan(block.array))
                total_nan += n_nan
                n_zero = np.sum(block.array == 0.)
                if np.isnan(self.nan_support.unrecorded_value):
                    total_recorded_entries += (block.array.size - n_nan)
                else:
                    total_recorded_entries += (block.array.size - n_zero)
            elif isinstance(block, DataFrameWrapper):
                n_nan = np.sum(np.isnan(block.df.values))
                total_nan += n_nan
                n_zero = np.sum(block.df.values == 0.)
                if np.isnan(self.nan_support.unrecorded_value):
                    total_recorded_entries += (block.df.size - n_nan)
                else:
                    total_recorded_entries += (block.df.size - n_zero)
            if total_nan > 0 and not np.isnan(self.nan_support.unrecorded_value):
                logger.info("Multiframe contains NaN but algorithm does not accept CSR matrix with NaN")
                return False
        fill_ratio = total_recorded_entries/(nrows * ncols)
        logger.info("prepare multiframe shape=(%s,%s) fill_ratio=%.2f" %
                    (nrows, ncols, fill_ratio))
        return fill_ratio <= self.MAX_FILL_RATIO

    @classmethod
    def _is_dataframe_large_enough(cls, multiframe):
        (n_rows, n_cols) = multiframe.shape()
        n_cells = n_rows * n_cols
        return n_cells > cls.MIN_CELL_COUNT

    def should_use_csr(self, multiframe):

        csr_support = self._algorithm_supports_csr()
        if csr_support == CSRSupport.UNSUPPORTED:
            logger.info(u"Algorithm {} doesn't support sparse matrices, using NPA".format(self.algorithm))
            return False
        elif csr_support == CSRSupport.DISABLED:
            logger.info("Sparse matrices support is explicitly disabled, using NPA")
            return False
        elif csr_support == csr_support.NON_SETTABLE:
            if not self._is_dataframe_large_enough(multiframe):
                logger.info("Multiframe is not large enough, using NPA")
                return False
        elif csr_support == CSRSupport.REQUESTED:
            # When the user allows sparse matrices explicitly, no size restriction https://app.shortcut.com/dataiku/story/156650
            pass

        if not self._is_dataframe_sparse_enough(multiframe):
            logger.info("Multiframe is not sparse enough, using NPA")
            return False

        logger.info("Using CSR (sparse matrix)")
        return True
