from enum import Enum

from scipy import sparse
import scipy
import numpy as np
import logging
import pandas as pd
import sys
from collections import OrderedDict, Counter
from six.moves import xrange

from dataiku.base.utils import safe_exception, safe_unicode_str
from dataiku.doctor import utils
from dataiku.doctor.utils.pandascompat import to_numpy

logger = logging.getLogger(__name__)

def delete_rows_csr(mat, indices):
    """
    Remove the rows denoted by ``indices`` form the CSR sparse matrix ``mat``.
    Taken from http://stackoverflow.com/questions/13077527
    """
    if not isinstance(mat, scipy.sparse.csr_matrix):
        raise ValueError("works only for CSR format -- use .tocsr() first")

    mask = np.ones(mat.shape[0], dtype=bool)
    mask[indices] = False
    return mat[mask]


def array_to_csr(array, unrecorded_value):
    row = np.repeat(np.arange(array.shape[0]), array.shape[1])  # (0,0,1,1,2,2,...) for example
    col = np.tile(np.arange(array.shape[1]), array.shape[0])  # (0,1,0,1,0,1, ...) for example
    data = array.reshape(-1)
    if np.isnan(unrecorded_value):
        # NaN's in a dense Numpy array become unrecorded entries in its sparse matrix conversion.
        # This matches the XGBoost behaviour with sparse matrices, where unrecorded entries are treated as missing.
        # Scikit-learn and LightGBM however treat unrecorded entries in sparse matrices as 0's, so for these
        # two libraries, removing NaN's needs to be avoided for dense/sparse compatibility.
        nan_indices = np.argwhere(np.isnan(data))
        row = np.delete(row, nan_indices)
        col = np.delete(col, nan_indices)
        data = np.delete(data, nan_indices)
    return scipy.sparse.csr_matrix((data, (row, col)), array.shape)


def csr_to_array(csr_matrix, unrecorded_value):
    if np.isnan(unrecorded_value):
        # XGBoost dense/sparse matrices compatibility
        row, col = csr_matrix.nonzero()
        X_arr = np.full(csr_matrix.shape, np.nan)
        for i, j in zip(row, col):
            X_arr[i, j] = csr_matrix[i, j]
    else:
        # Scikit-learn and LightGBM dense/sparse matrices compatibility
        X_arr = csr_matrix.toarray()
    return X_arr


class NamedNPArray(object):
    def __init__(self, array, names):
        self.array = array
        self.names = names

    @property
    def shape(self):
        return self.array.shape

    def to_csr(self, unrecorded_value):
        return array_to_csr(self.array, unrecorded_value)

    def __repr__(self):
        return "NamedNPArray(%s,%s)" % (self.array.shape[0], self.array.shape[1])  # self.names, self.array)


class SparseMatrixWithNames(object):
    def __init__(self, matrix, names):
        if names is not None and len(names) != matrix.shape[1]:
            raise Exception("Invalid matrix: %s names and %s columns" % (len(names), matrix.shape[1]))
        self.matrix = matrix
        self.names = names

    @property
    def shape(self):
        return self.matrix.shape

    def to_csr(self, unrecorded_value):
        return scipy.sparse.csr_matrix(self.matrix)

    def to_array(self, unrecorded_value):
        return csr_to_array(self.matrix, unrecorded_value)

    def __repr__(self):
        return "NamedSM(%s,%s)" % (self.matrix.shape[0], self.matrix.shape[1])

        # , self.names, self.matrix)


class DataFrameWrapper(object):
    def __init__(self, df):
        self.df = df

    @property
    def shape(self):
        return self.df.shape

    def to_csr(self, unrecorded_value):
        return array_to_csr(to_numpy(self.df), unrecorded_value)

    def __repr__(self):
        return "DF(%s,%s)" % (self.df.shape[0], self.df.shape[1])


class DropRowReason(Enum):
    NULL_COLUMN_VALUE = "Null Column Value"
    CLUSTERING_OUTLIERS = "Clustering Outliers"
    NULL_INF_TARGET = "Null or Inf Target"
    NULL_OTHER = "Null Other"


class MultiFrame(object):
    """
    The multiframe agglomerates horizontally several blocks of columns. All blocks
    must have the same number of rows. Each block is named.

    Blocks can be:

    * Pandas DataFrames
    * Numpy arrays
    * Scipy sparse matrices

    The MultiFrame also gives a *single* dataframe builder that allows you to build a dataframe from several
    series.

    The only point of reference for the index in the MultiFrame is its own index. More specifically, all of the
    DataFrames blocks in a MultiFrame have a reset index, independently of the original input DataFrames.

    The unrecorded_value attribute (either 0. of NaN) ensures consistent conversion between dense (Pandas DataFrame and Numpy Array)
    and sparse (Scipy CSR matrix) representations of the data.
    For XGBoost, missing data (NaN) in a dense array and unrecorded data in a sparse matrix are equivalent.
    For Scikit-learn and LightGBM:
     - missing data (NaN) in a dense array cannot be handled as a sparse matrix
     - unrecoreded data in a sparse matrix must be converted to 0 (regular Scipy Sparse csr_matrix.todense behaviour)
    """

    def __init__(self):
        self.block_orders = []
        self.blocks = {}
        self.dataframes = {}
        self.arrays = {}
        self.sparses = {}
        self.keep = {}
        self.df_builders = {}
        self.index = None
        self.total_lifetime_rows_dropped_log = {
            DropRowReason.NULL_COLUMN_VALUE: Counter(),
            DropRowReason.CLUSTERING_OUTLIERS: 0,
            DropRowReason.NULL_INF_TARGET: 0,
            DropRowReason.NULL_OTHER: 0
        }
        self.initial_size = 0
        self.unrecorded_value = 0.

    def __repr__(self):
        s = "MultiFrame (%d blocks):\n" % (len(self.blocks))
        for block_name in self.block_orders:
            block = self.blocks[block_name]
            s += "Block %s (%s)\n" % (block_name, block.__class__)
            s += "----------------------\n"
            s += "%s\n" % block
        return s

    def set_unrecorded_value(self, unrecorded_value):
        self.unrecorded_value = unrecorded_value

    def stats(self):
        s = "MultiFrame (%d blocks):\n" % (len(self.blocks))
        for block_name in self.block_orders:
            block = self.blocks[block_name]
            shape = block.shape
            s += "Block %s (%s) -> (%s,%s)" % (block_name, block.__class__, shape[0], shape[1])
            s += "\n"
        return s

    def set_index_from_df(self, df):
        assert self.index is None
        # use index from dataframe to stay in sync with the target/weight series in case
        # a pipeline is used on a dataframe that already went through a pipeline (for ex
        # in case of scoring recipe or ensembles) 
        self.index = df.index.copy()
        self.initial_size = self.index.size
        logger.info("Set MF index len %s" % len(self.index))

    def drop_rows(self, deletion_mask, reason, column_name=None):
        rows_to_delete = utils.series_nonzero(deletion_mask)

        if self.index is None:
            logger.warning("No index in multiframe, aborting drop")
            return

        if reason == DropRowReason.NULL_COLUMN_VALUE:
            total_lifetime_rows_dropped = self.total_lifetime_rows_dropped_log[reason][column_name]  # Counter returns 0 for missing item
            self.total_lifetime_rows_dropped_log[reason][column_name] = total_lifetime_rows_dropped + rows_to_delete[0].size
        else:
            self.total_lifetime_rows_dropped_log[reason] += rows_to_delete[0].size

        logger.info("MultiFrame, dropping rows: %s" % rows_to_delete)
        self.index = pd.Series(self.index).drop(rows_to_delete[0]).values

        for name in self.block_orders:
            blk = self.blocks[name]
            if isinstance(blk, NamedNPArray):
                blk.array = np.delete(blk.array, rows_to_delete, axis=0)
            elif isinstance(blk, SparseMatrixWithNames):
                blk.matrix = delete_rows_csr(blk.matrix, rows_to_delete)
            elif isinstance(blk, DataFrameWrapper):
                blk.df = blk.df.drop(blk.df.index[rows_to_delete])
                blk.df.reset_index(drop=True, inplace=True)
            else:
                raise Exception("Unknown block")

    def append_df(self, name, df, keep=True, copy=False):
        """
        Append a Pandas DataFrame to the MultiFrame. The resulting DataFrame block will
        have a reset index (since the only point of reference for the index in the
        MultiFrame is its own index).

        :param str name: Block name
        :param pd.DataFrame df: DataFrame to append to the MultiFrame
        :param bool keep: Keep the resulting block when iterating/exporting the MultiFrame
                          (default: True)
        :param bool copy: Append a copy of the dataframe (default: False)
        """
        self._check_not_in_blocks(name)
        if self.index is None:
            # use index from dataframe to stay in sync with the target/weight series in case
            # a pipeline is used on a dataframe that already went through a pipeline (for ex
            # in case of scoring recipe or ensembles) 
            self.index = df.index.copy()
        if len(self.index) != df.shape[0]:
            raise Exception("Unexpected number of rows, index has %s, df has %s" % (len(self.index), df.shape[0]))

        if copy:
            # Append a copy of the dataframe with a reset index
            dfw = DataFrameWrapper(df.reset_index(drop=True))
        else:
            # Reset the index in place
            df.reset_index(drop=True, inplace=True)
            dfw = DataFrameWrapper(df)
        self.dataframes[name] = dfw
        self.blocks[name] = dfw
        self.keep[name] = keep
        self.block_orders.append(name)

    def append_np_block(self, name, array, col_names):
        self._check_not_in_blocks(name)
        if self.index is None:
            self.index = np.array([x for x in xrange(0, array.shape[0])])
        if len(self.index) != array.shape[0]:
            raise Exception("Unexpected number of rows, index has %s, arra has %s" % (len(self.index), array.shape[0]))

        block = NamedNPArray(array, col_names)
        self.arrays[name] = block
        self.blocks[name] = block
        self.block_orders.append(name)

    def append_sparse(self, name, matrix):
        self._check_not_in_blocks(name)
        if self.index is None:
            self.index = np.array([x for x in xrange(0, matrix.shape[0])])
        if len(self.index) != matrix.shape[0]:
            raise Exception(
                "Unexpected number of rows, index has %s, matrix has %s" % (len(self.index), matrix.shape[0]))

        self.sparses[name] = matrix
        self.blocks[name] = matrix
        self.block_orders.append(name)

    def _check_not_in_blocks(self, name):
        if name in self.blocks:
            raise safe_exception(Exception, u"Block {} already exists in multiframe".format(safe_unicode_str(name)))

    def get_block(self, name):
        return self.blocks[name]

    def iter_blocks(self, with_keep_info=False):
        for block_name in self.block_orders:
            block = self.blocks[block_name]
            if with_keep_info:
                yield block_name, block, self.keep.get(block_name, True)
            else:
                yield block_name, block

    def iter_dataframes(self):
        for key, value in self.dataframes.items():
            yield key, value

    def iter_columns(self):
        for block_name, blk in self.iter_blocks():
            # This block is not kept, so don't iterate on it
            if block_name in self.keep and not self.keep[block_name]:
                continue
            if isinstance(blk, NamedNPArray):
                val = blk.array
                names = blk.names
                for i in xrange(len(names)):
                    yield names[i], val[:, i]
            elif isinstance(blk, SparseMatrixWithNames):
                val = blk.matrix
                names = blk.names
                for i in xrange(len(names)):
                    yield names[i], val[:, i]
            elif isinstance(blk, DataFrameWrapper):
                df = blk.df
                for col in df.columns:
                    yield col, df[col]
            else:
                raise Exception("Unknown block type %s" % blk.__class__)

    def col_as_series(self, block, col_name):
        blk = self.blocks[block]
        # logger.info("Return column from block:%s / %s -> %s" % (block, col_name, blk.__class__))

        if isinstance(blk, NamedNPArray) and blk.names is not None:
            col_idx = blk.names.index(col_name)
            return blk.array[:, col_idx]
        elif isinstance(blk, SparseMatrixWithNames) and blk.names is not None:
            col_idx = blk.names.index(col_name)
            return blk.array[:, col_idx]
        elif isinstance(blk, DataFrameWrapper):
            return blk.df[col_name]

    def as_csr_matrix(self):
        if self.shape()[0] == 0:
            return scipy.sparse.csr_matrix(self.shape())
        # logger.info("********** START AS CSR")
        blockvals = []
        for name in self.block_orders:
            blk = self.blocks[name]
            if not hasattr(blk, "to_csr"):
                raise Exception("Block type %s doesn't implement to_csr" % blk.__class__)
            if not name in self.keep or self.keep[name]:
                logger.info("APPEND BLOCK %s shape=%s" % (name, blk.shape))
                if blk.shape[1] != 0:
                    blockvals.append(blk)
        # we have to do this check because of a bug in scipy...
        if len(blockvals) == 1:
            return blockvals[0].to_csr(self.unrecorded_value)
        else:
            # if `blockvals` doesn't contain any `csr_matrix`, then `hstack` will fail with scipy > 1.3.
            # So, we convert individual blocks to `csr_matrix` before hstacking.
            return scipy.sparse.hstack([block.to_csr(self.unrecorded_value) for block in blockvals]).tocsr()

    def as_np_array(self):
        blockvals = []
        for name in self.block_orders:
            blk = self.blocks[name]
            val = MultiFrame.block_as_np_array(blk, self.unrecorded_value)
            if not name in self.keep or self.keep[name]:
                blockvals.append(val)
        return np.hstack(blockvals)

    @staticmethod
    def block_as_np_array(blk, unrecorded_value):
        if isinstance(blk, NamedNPArray):
            val = blk.array
        elif isinstance(blk, SparseMatrixWithNames):
            val = csr_to_array(blk.matrix, unrecorded_value)
        elif isinstance(blk, DataFrameWrapper):
            val = blk.df
        else:
            raise Exception("Unknown block type %s" % blk.__class__)
        return val

    def as_dataframe(self):
        df = pd.DataFrame()
        blockvals = []
        for name in self.block_orders:
            if name in self.keep and not self.keep[name]:
                continue
            blk = self.blocks[name]
            if isinstance(blk, NamedNPArray):
                blkdf = pd.DataFrame(blk.array, columns=blk.names)
            elif isinstance(blk, SparseMatrixWithNames):
                blkdf = pd.DataFrame(blk.matrix.toarray(), columns=blk.names)
            elif isinstance(blk, DataFrameWrapper):
                blkdf = blk.df
            else:
                raise Exception("Unknown block type %s" % blk.__class__)
            blockvals.append(blkdf)
        return pd.concat(blockvals, axis=1)

    def columns(self):
        colnames = []
        # logger.info("****** Get names")
        for blkname in self.block_orders:
            if blkname in self.keep and not self.keep[blkname]:
                continue
            blk = self.blocks[blkname]
            if isinstance(blk, NamedNPArray):
                colnames.extend(blk.names)
            elif isinstance(blk, SparseMatrixWithNames):
                if blk.names is None:
                    blk.names = ["%s:%s" % (blkname, x) for x in xrange(0, blk.matrix.shape[1])]
                colnames.extend(blk.names)
            elif isinstance(blk, DataFrameWrapper):
                colnames.extend(blk.df.columns)
            else:
                raise Exception("Unknown block type %s" % blk.__class__)
        return colnames

    def nnz(self):
        nnz = 0
        for blkname in self.block_orders:
            if blkname in self.keep and not self.keep[blkname]:
                continue
            blk = self.blocks[blkname]
            if isinstance(blk, NamedNPArray):
                nnz += blk.array.shape[0] * blk.array.shape[1]
            elif isinstance(blk, SparseMatrixWithNames):
                nnz += blk.matrix.nnz
            elif isinstance(blk, DataFrameWrapper):
                nnz += blk.df.shape[0] * blk.df.shape[1]
        return nnz

    def shape(self):
        ncols = 0
        nrows = 0
        for blkname in self.block_orders:
            if blkname in self.keep and not self.keep[blkname]:
                continue
            blk = self.blocks[blkname]
            if isinstance(blk, NamedNPArray):
                nrows = blk.array.shape[0]
                ncols += blk.array.shape[1]
            elif isinstance(blk, SparseMatrixWithNames):
                nrows = blk.matrix.shape[0]
                ncols += blk.matrix.shape[1]
            elif isinstance(blk, DataFrameWrapper):
                nrows = blk.df.shape[0]
                ncols += blk.df.shape[1]
        return nrows, ncols

    def select_columns(self, names):
        names = set(names)
        for blkname in self.block_orders:
            blk = self.blocks[blkname]

            if isinstance(blk, NamedNPArray):
                newnames = [name for idx, name in enumerate(blk.names) if name in names]
                colmask = [idx for idx, name in enumerate(blk.names) if name in names]
                blk.names = newnames
                blk.array = blk.array[:, colmask]

            elif isinstance(blk, SparseMatrixWithNames):
                if blk.names is None:
                    blk.names = ["%s:%s" % (blkname, x) for x in xrange(0, blk.matrix.shape[1])]

                newnames = [name for idx, name in enumerate(blk.names) if name in names]
                colmask = [idx for idx, name in enumerate(blk.names) if name in names]

                blk.names = newnames
                blk.matrix = blk.matrix[:, colmask]

            elif isinstance(blk, DataFrameWrapper):
                cols_in_df = set(blk.df.columns)
                cols_kept = cols_in_df & names

                cols_kept_list = list(cols_kept)

                # We observe that `set` does not ensure a deterministic order when converted into a list in py3.
                # Therefore we ensure that this method produces dataframes with the same column order at each run
                # by ordering them lexicographically
                # We do it only for py3 in order not to modify the behaviour of existing and working py2 models.
                if sys.version_info > (3, 0):
                    cols_kept_list = sorted(cols_kept_list)

                blk.df = blk.df[cols_kept_list]

            else:
                raise Exception("Unknown block type %s" % blk.__class__)

    def get_df_builder(self, name):
        """Helper for building a dataframe from series"""
        if not name in self.df_builders:
            self.df_builders[name] = DataFrameBuilder(name)
        return self.df_builders[name]

    def has_df_builder(self, name):
        return name in self.df_builders

    def flush_df_builder(self, name):
        self.append_df(name, self.df_builders[name].to_dataframe())
        del self.df_builders[name]


def is_series_like(series):
    return isinstance(series, pd.Series) or isinstance(series, np.ndarray) or isinstance(series,
                                                                                         scipy.sparse.csr.csr_matrix)


class DataFrameBuilder(object):
    """ A dataframe builder just receives columns
    to ultimately create a dataframe, respecting the
    insertion order.
    """

    __slots__ = ('prefix', 'columns',)

    def __init__(self, prefix=""):
        """ constructor

        prefix -- Prefixes the name of the column of
                  the resulting dataframe.
        """
        # TODO put the prefix in to_dataframe
        self.columns = OrderedDict()
        self.prefix = prefix

    def add_column(self, column_name, column_values):
        assert column_name not in self.columns
        assert is_series_like(column_values)
        # logger.info("ADD COLUMN %s = %s"% (column_name, column_values))
        self.columns[column_name] = column_values

    def to_dataframe(self, ):
        df = pd.DataFrame(self.columns)
        df.columns = [
            self.prefix + ":" + col if col is not None else self.prefix
            for col in df.columns
            ]
        return df
