import logging

import numpy as np
import pandas as pd

from dataiku.base.utils import safe_unicode_str
from dataiku.doctor.prediction.explanations.engine import SimpleExplanationResult
from dataiku.doctor.prediction.explanations.engine import ExplainingEngine

logger = logging.getLogger(__name__)


class ICEExplainingEngine(ExplainingEngine):

    class ICEColumnInfo:
        """ Class to hold information about a column of the dataset to explain.
        :type name: str
        :type index: int
        :type modalities: np.ndarray
        :type distribution: np.ndarray
        """
        def __init__(self, name, index, modalities, distribution):
            self.name = name
            self.index = index
            self.modalities = modalities
            self.distribution = distribution

    def __init__(self, features_distribution, columns_to_compute, columns_dtypes, score_computer):
        """
        :type features_distribution: dict
        :type score_computer: ScoreComputer
        """
        self._columns_to_compute = columns_to_compute
        self._columns_dtypes = columns_dtypes
        self._columns_info = self._build_columns_info(features_distribution)
        self._score_computer = score_computer
        self._num_columns = self._columns_dtypes.shape[0]
        self._num_duplicates_per_row = self._get_num_duplicates_row_per_row_explained()

    def get_estimated_peak_number_cells_generated_per_row_explained(self):
        return self._num_columns * self._num_duplicates_per_row

    def _get_num_duplicates_row_per_row_explained(self):
        num_produced_rows = 0
        col_summary_list = []
        for col_info in self._columns_info:
            num_produced_rows_col = col_info.modalities.shape[0]
            num_produced_rows += num_produced_rows_col
            col_summary_list.append(u"{col_name}={num_rows}".format(col_name=safe_unicode_str(col_info.name),
                                                                    num_rows=num_produced_rows_col))
        summary = u"For each row explained, the " \
                  u"explaining engine will produce about {} cells ({})".format(num_produced_rows * self._num_columns,
                                                                               ", ".join(col_summary_list))
        logger.info(summary)
        return num_produced_rows

    def _build_columns_info(self, features_distribution):
        column_infos = []
        for col_name in self._columns_to_compute:
            modalities = features_distribution[col_name]["scale"].astype(self._columns_dtypes[col_name])
            distribution = features_distribution[col_name]["distribution"]
            col_idx = self._columns_dtypes.index.get_loc(col_name)
            column_infos.append(ICEExplainingEngine.ICEColumnInfo(col_name, col_idx, modalities, distribution))
        return column_infos

    @staticmethod
    def _fast_repeat_df(df, n_times):
        repeat_df = df.loc[np.tile(df.index, n_times)]
        repeat_df.reset_index(inplace=True, drop=True)
        return repeat_df

    def _create_ice_frankenstein_scores(self, observations_df):
        """
        :return the score of the observation and the frankenstein score
        :rtype: (ScoreToExplain, ScoreToExplain)
        """
        observations_score = self._score_computer(observations_df)
        frankenstein_df = self._fast_repeat_df(observations_df, self._num_duplicates_per_row)
        curr_index = 0
        for col_info in self._columns_info:
            col_modalities_repeated = np.repeat(col_info.modalities, observations_df.shape[0])
            num_replacements = col_modalities_repeated.shape[0]
            frankenstein_df.iloc[curr_index: curr_index + num_replacements, col_info.index] = col_modalities_repeated
            curr_index += num_replacements
        logger.info("Built frankenstein of shape {} for this batch".format(frankenstein_df.shape))

        matching_indices_in_observations = np.tile(np.arange(observations_df.shape[0]), self._num_duplicates_per_row)
        frankenstein_score = self._score_computer(frankenstein_df, observations_score, matching_indices_in_observations)
        return observations_score, frankenstein_score

    def _extract_ice_explanations(self, frankenstein_score, observations_score, observations_df):
        """
        :type frankenstein_score: np.ndarray
        :type observations_df: pd.DataFrame
        :type observations_score: np.ndarray
        :rtype: pd.DataFrame
        """
        # Extract per-column predicted frankensteins and compute ice explanations
        explanations = pd.DataFrame(columns=[col_info.name for col_info in self._columns_info],
                                    index=observations_df.index, dtype=np.float64)

        curr_index = 0
        for col_info in self._columns_info:
            # Compute explanations for one column out of the per-column scored frankenstein

            # number of rows in frankenstein corresponding to the per_column frankenstein
            nb_rows_for_col = col_info.modalities.shape[0] * observations_df.shape[0]

            # shape is (nb_modalities, nb_rows_to_explain)
            col_scores_arr = frankenstein_score[curr_index: curr_index + nb_rows_for_col].reshape(-1, observations_df.shape[0])
            weighted_scores_arr = col_scores_arr * col_info.distribution[:, np.newaxis]

            # With overrides declined predictions now the model can output nan predictions, therefore we need to ensure
            # that the sum is computed only on the valid predictions.
            explanations[col_info.name] = observations_score - np.nansum(weighted_scores_arr, axis=0)

            curr_index += nb_rows_for_col

        return explanations

    def explain(self, df):
        observations_score, frankenstein_scores = self._create_ice_frankenstein_scores(df)
        explanations_df = self._extract_ice_explanations(frankenstein_scores.score, observations_score.score, df)
        return SimpleExplanationResult(explanations_df)
