from dataiku.core import doctor_constants


class Metric(object):
    def __init__(self, name, field, display_name, zero_to_one, greater_is_better, prediction_types):
        self.name = name
        self._field = field
        self.display_name = display_name
        self.zero_to_one = zero_to_one
        self.greater_is_better = greater_is_better
        self.prediction_types = prediction_types

    def get_field(self, prediction_type):
        if prediction_type in {doctor_constants.MULTICLASS, doctor_constants.DEEP_HUB_IMAGE_CLASSIFICATION}:
            if self.name == "ROC_AUC":
                return "mrocAUC"
            elif self.name == "CALIBRATION_LOSS":
                return "mcalibrationLoss"

        return self._field


ACCURACY = Metric("ACCURACY", "accuracy", "Accuracy", True, True, [doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS, doctor_constants.DEEP_HUB_IMAGE_CLASSIFICATION])
PRECISION = Metric("PRECISION", "precision", "Precision", True, True, [doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS, doctor_constants.DEEP_HUB_IMAGE_CLASSIFICATION])
RECALL = Metric("RECALL", "recall", "Recall", True, True, [doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS, doctor_constants.DEEP_HUB_IMAGE_CLASSIFICATION])
F1 = Metric("F1", "f1", "F1 Score", True, True, [doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS, doctor_constants.DEEP_HUB_IMAGE_CLASSIFICATION])
COST_MATRIX = Metric("COST_MATRIX", "cmg", "Cost Matrix Gain", False, True, [doctor_constants.BINARY_CLASSIFICATION])
LOG_LOSS = Metric("LOG_LOSS", "logLoss", "Log Loss", False, False, [doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS, doctor_constants.DEEP_HUB_IMAGE_CLASSIFICATION])
ROC_AUC = Metric("ROC_AUC", "auc", "AUC", True, True, [doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS, doctor_constants.DEEP_HUB_IMAGE_CLASSIFICATION])
CUMULATIVE_LIFT = Metric("CUMULATIVE_LIFT", "lift", "Lift", False, True, [doctor_constants.BINARY_CLASSIFICATION])
CUSTOM = Metric("CUSTOM", "customScore", "Custom Score", False, False, [doctor_constants.BINARY_CLASSIFICATION, doctor_constants.REGRESSION, doctor_constants.MULTICLASS])
CALIBRATION_LOSS = Metric("CALIBRATION_LOSS", "calibrationLoss", "Calibration Loss", True, False, [doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS])
AVERAGE_PRECISION = Metric("AVERAGE_PRECISION", "averagePrecision", "Average Precision", True, True, [doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS, doctor_constants.DEEP_HUB_IMAGE_CLASSIFICATION])

EVS = Metric("EVS", "evs", "Explained Var.", True, True, [doctor_constants.REGRESSION])
MAPE = Metric("MAPE", "mape", "MAPE", False, False, [doctor_constants.REGRESSION, doctor_constants.TIMESERIES_FORECAST])
MAE = Metric("MAE", "mae", "MAE", False, False, [doctor_constants.REGRESSION, doctor_constants.TIMESERIES_FORECAST])
MSE = Metric("MSE", "mse", "MSE", False, False, [doctor_constants.REGRESSION, doctor_constants.TIMESERIES_FORECAST])
RMSE = Metric("RMSE", "rmse", "RMSE", False, False, [doctor_constants.REGRESSION, doctor_constants.TIMESERIES_FORECAST])
RMSLE = Metric("RMSLE", "rmsle", "RMSLE", False, False, [doctor_constants.REGRESSION])
R2 = Metric("R2", "r2", "R2 Score", True, True, [doctor_constants.REGRESSION])
PEARSON = Metric("PEARSON", "pearson", "Correlation", False, False, [doctor_constants.REGRESSION])

MASE = Metric("MASE", "mase", "MASE", False, False, [doctor_constants.TIMESERIES_FORECAST])
# TODO @timeseries: harmonize param names (snake vs camel cases)
MEAN_ABSOLUTE_QUANTILE_LOSS = Metric("MEAN_ABSOLUTE_QUANTILE_LOSS", "meanAbsoluteQuantileLoss", "Mean Absolute Quantile Loss", False, False, [doctor_constants.TIMESERIES_FORECAST])
MEAN_WEIGHTED_QUANTILE_LOSS = Metric("MEAN_WEIGHTED_QUANTILE_LOSS", "meanWeightedQuantileLoss", "Mean Weighted Quantile Loss", False, False, [doctor_constants.TIMESERIES_FORECAST])
MSIS = Metric("MSIS", "msis", "MSIS", False, False, [doctor_constants.TIMESERIES_FORECAST])
ND = Metric("ND", "nd", "Normalized Deviation", False, False, [doctor_constants.TIMESERIES_FORECAST])
SMAPE = Metric("SMAPE", "smape", "sMAPE", True, False, [doctor_constants.TIMESERIES_FORECAST])

AVERAGE_PRECISION_IOU50 = Metric("AVERAGE_PRECISION_IOU50",  "averagePrecisionIOU50", "Average Precision (IoU=0.50)", True, True, [doctor_constants.DEEP_HUB_IMAGE_OBJECT_DETECTION])
AVERAGE_PRECISION_IOU75 = Metric("AVERAGE_PRECISION_IOU75",  "averagePrecisionIOU75", "Average Precision (IoU=0.75)", True, True, [doctor_constants.DEEP_HUB_IMAGE_OBJECT_DETECTION])
AVERAGE_PRECISION_ALL_IOU = Metric("AVERAGE_PRECISION_ALL_IOU",  "averagePrecisionAllIOU", "Average Precision (all IoUs)", True, True, [doctor_constants.DEEP_HUB_IMAGE_OBJECT_DETECTION])


BINARY_METRICS_NAME_TO_FIELD_NAME = {}

REGRESSION_METRICS_NAME_TO_FIELD_NAME = {}

MULTICLASS_METRICS_NAME_TO_FIELD_NAME = {}

DEEP_HUB_IMAGE_CLASSIFICATION_METRICS_NAME_TO_FIELD_NAME = {}


# NB: for ObjectDetection thresholdOptimizationMetric are not added here as they are only used to optimize confidence score threshold not the model itself.
DEEP_HUB_IMAGE_OBJECT_DETECTION_METRICS_NAME_TO_FIELD_NAME = {}


TIMESERIES_METRICS_NAME_TO_FIELD_NAME = {}

METRICS_NAMES = {}

METRICS = [
    ACCURACY,
    PRECISION,
    RECALL,
    F1,
    COST_MATRIX,
    LOG_LOSS,
    ROC_AUC,
    AVERAGE_PRECISION,
    CUMULATIVE_LIFT,
    CUSTOM,
    EVS,
    MAPE,
    MAE,
    MSE,
    RMSE,
    RMSLE,
    R2,
    PEARSON,
    CALIBRATION_LOSS,
    MASE,
    MEAN_ABSOLUTE_QUANTILE_LOSS,
    MEAN_WEIGHTED_QUANTILE_LOSS,
    MSIS,
    ND,
    SMAPE,
    AVERAGE_PRECISION_IOU50,
    AVERAGE_PRECISION_IOU75,
    AVERAGE_PRECISION_ALL_IOU
]

PRED_TYPE_TO_METRICS = {
    doctor_constants.BINARY_CLASSIFICATION: BINARY_METRICS_NAME_TO_FIELD_NAME,
    doctor_constants.MULTICLASS: MULTICLASS_METRICS_NAME_TO_FIELD_NAME,
    doctor_constants.REGRESSION: REGRESSION_METRICS_NAME_TO_FIELD_NAME,
    doctor_constants.TIMESERIES_FORECAST: TIMESERIES_METRICS_NAME_TO_FIELD_NAME,
    doctor_constants.DEEP_HUB_IMAGE_CLASSIFICATION: DEEP_HUB_IMAGE_CLASSIFICATION_METRICS_NAME_TO_FIELD_NAME,
    doctor_constants.DEEP_HUB_IMAGE_OBJECT_DETECTION: DEEP_HUB_IMAGE_OBJECT_DETECTION_METRICS_NAME_TO_FIELD_NAME
}

for metric in METRICS:
    for pred_type in metric.prediction_types:
        PRED_TYPE_TO_METRICS[pred_type][metric.name] = metric
    METRICS_NAMES[metric.name] = metric.display_name
