# coding: utf-8
from __future__ import unicode_literals
import logging

from dataiku.doctor.diagnostics import dataset_sanity_check
from dataiku.doctor.diagnostics import model_check
from dataiku.doctor.diagnostics import leakage
from dataiku.doctor.diagnostics import overfit
from dataiku.doctor.diagnostics import ml_assertions
from dataiku.doctor.diagnostics import modeling_parameters
from dataiku.doctor.diagnostics import abnormal_predictions_detection
from dataiku.doctor.diagnostics.diagnostics import DiagnosticType
from dataiku.doctor.diagnostics.diagnostics import register

logger = logging.getLogger(__name__)


def register_prediction_callbacks(core_params):
    """ Register default callbacks used for prediction """
    settings = _get_settings(core_params)
    register([dataset_sanity_check.DatasetSanityCheckDiagnostic(),
              model_check.ClassifierAccuracyCheckDiagnostic(),
              model_check.RegressionR2CheckDiagnostic(),
              leakage.LeakageDiagnostic(),
              overfit.TreeOverfitDiagnostic(),
              ml_assertions.MLAssertionsDiagnostic(),
              modeling_parameters.ModelCheckDiagnostic(),
              abnormal_predictions_detection.AbnormalPredictionsDetectionDiagnostic()],
             settings)


def register_keras_callbacks(core_params):
    settings = _get_settings(core_params)
    register([dataset_sanity_check.DatasetSanityCheckDiagnostic(),
              leakage.LeakageDiagnostic(),
              abnormal_predictions_detection.AbnormalPredictionsDetectionDiagnostic()],
             settings)


def register_deephub_callbacks(core_params):
    settings = _get_settings(core_params)
    register([dataset_sanity_check.DatasetSanityCheckDiagnostic(),
              leakage.LeakageDiagnostic(),
              abnormal_predictions_detection.AbnormalPredictionsDetectionDiagnostic()],
             settings)


def register_forecasting_callbacks(core_params):
    settings = _get_settings(core_params)
    register([dataset_sanity_check.DatasetSanityCheckDiagnostic(),
              modeling_parameters.ModelCheckDiagnostic()],
            settings)


def register_forecasting_scoring_callbacks():
    """ Enable diagnostics for the forecasting scoring recipe. """
    settings = {
        'enabled': True,
        'settings': [
            {'type': DiagnosticType.ML_DIAGNOSTICS_SCORING_DATASET_SANITY_CHECKS.value, 'enabled': True},
            {'type': DiagnosticType.ML_DIAGNOSTICS_TIMESERIES_RESAMPLING_CHECKS.value, 'enabled': True},
            {'type': DiagnosticType.ML_DIAGNOSTICS_REPRODUCIBILITY.value, 'enabled': True},
        ]
    }
    # the forecasting scoring recipe only creates diagnostics using add_or_update, so no callbacks need to be registered
    callbacks = []
    register(callbacks, settings)


def register_forecasting_evaluation_callbacks():
    """ Enable diagnostics for the forecasting evaluation recipe. """
    settings = {
        'enabled': True,
        'settings': [
            {'type': DiagnosticType.ML_DIAGNOSTICS_EVALUATION_DATASET_SANITY_CHECKS.value, 'enabled': True},
            {'type': DiagnosticType.ML_DIAGNOSTICS_TIMESERIES_RESAMPLING_CHECKS.value, 'enabled': True}
        ]
    }
    # the forecasting evaluation recipe only creates diagnostics using add_or_update, so no callbacks need to be registered
    callbacks = []
    register(callbacks, settings)


def register_evaluation_callbacks():
    """ Enable diagnostics and register callbacks for evaluation (always enabled dataset sanity checks and modeling parameters). """
    settings = {
        'enabled': True,
        'settings': [
            {'type': DiagnosticType.ML_DIAGNOSTICS_DATASET_SANITY_CHECKS.value, 'enabled': True},
            {'type': DiagnosticType.ML_DIAGNOSTICS_MODELING_PARAMETERS.value, 'enabled': True},
            {'type': DiagnosticType.ML_DIAGNOSTICS_RUNTIME.value, 'enabled': True},
            {'type': DiagnosticType.LLM_EVALUATION_COMPUTATION_ERROR.value, 'enabled': True}
        ]
    }
    register([dataset_sanity_check.DatasetSanityCheckDiagnostic()], settings)


def register_causal_predictions_callbacks(core_params):
    settings = _get_settings(core_params)
    register([dataset_sanity_check.DatasetSanityCheckDiagnostic()], settings)


def register_clustering_callbacks(core_params):
    """ Register default callbacks used for clustering """
    settings = _get_settings(core_params)
    register([dataset_sanity_check.DatasetSanityCheckDiagnostic(),
              modeling_parameters.ModelCheckDiagnostic()],
             settings)


def _get_settings(core_params):
    if "diagnosticsSettings" not in core_params:
        logger.info("no 'diagnosticsSettings' found in core_params")
    settings = core_params.get("diagnosticsSettings", {})
    return settings

