# coding: utf-8
from __future__ import unicode_literals
from __future__ import division

import numpy as np

from dataiku.core import doctor_constants
from dataiku.doctor.diagnostics.diagnostics import DiagnosticCallback
from dataiku.doctor.diagnostics.diagnostics import DiagnosticType
from dataiku.base.utils import safe_unicode_str


MAX_MOST_COMMON_CLASS_PROPORTION = 0.99


class AbnormalPredictionsDetectionDiagnostic(DiagnosticCallback):
    """ See in the documentation machine-learning/diagnostics.html#abnormal-predictions-detection"""
    def __init__(self):
        super(AbnormalPredictionsDetectionDiagnostic, self).__init__(DiagnosticType.ML_DIAGNOSTICS_ABNORMAL_PREDICTIONS_DETECTION)

    def on_scoring_end(self, scoring_params=None, transformed_test=None, transformed_train=None, with_sample_weight=False):
        diagnostics = []
        self.check_prediction_diversity(scoring_params, diagnostics)
        return diagnostics

    @staticmethod
    def check_prediction_diversity(scoring_params, diagnostics):
        if scoring_params is not None and scoring_params.prediction_type in {doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS,
                                                                             doctor_constants.DEEP_HUB_IMAGE_CLASSIFICATION}:
            unique, counts = np.unique(scoring_params.preds, return_counts=True)
            most_common_class_index = counts.argmax()
            most_common_class = unique[most_common_class_index]
            most_common_class_proportion = counts[most_common_class_index] / scoring_params.preds.size

            if most_common_class_proportion > MAX_MOST_COMMON_CLASS_PROPORTION:
                diagnostics.append("The model almost always predicts the class '{}'".format(safe_unicode_str(most_common_class)))
