# coding: utf-8
from __future__ import unicode_literals

import logging

from dataiku.doctor.diagnostics import diagnostics
from dataiku.doctor.multiframe import DropRowReason

logger = logging.getLogger(__name__)

DROPPED_ROWS_THRESHOLD = 0.5  # to be kept consistent with DROPPED_ROW_WARNING_THRESHOLD in predsettings.js


class ModelCheckDiagnostic(diagnostics.DiagnosticCallback):
    """ See in the documentation machine-learning/diagnostics.html#modeling-parameters """
    def __init__(self):
        super(ModelCheckDiagnostic, self).__init__(diagnostics.DiagnosticType.ML_DIAGNOSTICS_MODEL_CHECK)

    def on_preprocess_train_dataset_end(self, multiframe=None):
        return self.check_feature_handling_dropped_row_percentage(multiframe, "train")

    def on_preprocess_test_dataset_end(self, multiframe=None):
        return self.check_feature_handling_dropped_row_percentage(multiframe, "test")

    def on_kfold_step_preprocess_global_end(self, multiframe=None):
        return self.check_feature_handling_dropped_row_percentage(multiframe, "input")

    @staticmethod
    def check_feature_handling_dropped_row_percentage(multiframe, action_type):
        diagnostic_messages = []

        if multiframe.initial_size == 0:
            return diagnostic_messages

        column_dropped_info = multiframe.total_lifetime_rows_dropped_log[DropRowReason.NULL_COLUMN_VALUE]
        total_dropped = sum(list(column_dropped_info.values()))

        logger.info('{} multiframe initial size: {}'.format(action_type, multiframe.initial_size))

        for column, count in column_dropped_info.items():
            if count != 0:
                logger.info('{} rows dropped due to "drop null rows" selected on column "{}"'.format(count, column))

        ratio_dropped = float(total_dropped) / multiframe.initial_size
        if ratio_dropped > DROPPED_ROWS_THRESHOLD:

            columns_sorted = [col for col, nb in column_dropped_info.most_common() if nb > 0]
            column_string = "'" + columns_sorted[0] + "'"

            n_columns_drop = len(columns_sorted)
            if n_columns_drop == 2:
                column_string += " and '" + columns_sorted[1] + "'"
            elif n_columns_drop == 3:
                column_string += ", '" + columns_sorted[1] + "' and '" + columns_sorted[2] + "'"
            elif n_columns_drop == 4:
                column_string += ", '" + columns_sorted[1] + "', '" + columns_sorted[2] + "' and 1 other"
            elif n_columns_drop > 4:
                column_string += ", '" + columns_sorted[1] + "', '" + columns_sorted[2] + "' and " + str(n_columns_drop - 3) + " others"

            diagnostic_messages.append("{:.2f}% of the {} dataset was dropped due to feature handling on {}".format(ratio_dropped * 100, action_type, column_string))

        return diagnostic_messages
