import logging
from abc import ABCMeta
from collections import namedtuple

import numpy as np
import pandas as pd
from six import add_metaclass

from dataiku.base.utils import safe_unicode_str
from dataiku.core import doctor_constants
from dataiku.doctor.prediction.metric import PRED_TYPE_TO_METRICS

logger = logging.getLogger(__name__)


class SignedMetric(object):
    def __init__(self, value, sign):
        self.value = value
        self.sign = sign
        self.value_gib = value*sign


def get_zero_metrics(pred_type):
    """
    Populates a dictionary with all the metrics equal to zero for a given prediction type
    :type pred_type: str
    :rtype: dict[str, float]
    """
    return {metric.get_field(pred_type): 0. for metric in PRED_TYPE_TO_METRICS[pred_type].values()}


@add_metaclass(ABCMeta)
class PerformanceResults(object):

    def __init__(self, typ, perf_data, predicted_df, raw_predictions, prediction_statistics):
        self.perf = perf_data
        self.type = typ
        self.predicted_df = predicted_df
        # predictions with mappedValues (see dataiku.doctor.deephub.deephub_params.TargetRemapping),
        # doesn't contain dropped rows
        self.raw_predictions = raw_predictions
        # dict equivalent to `com.dataiku.dip.analysis.model.prediction.ClassificationModelPredictionInfos`
        self.prediction_statistics = prediction_statistics

    def get_metric(self, metric_name):
        """
        :type metric_name: str
        :rtype: SignedMetric
        """
        metric_definition = PRED_TYPE_TO_METRICS[self.type][metric_name]
        metric_value = self.perf["metrics"][metric_definition.get_field(self.type)]
        if metric_value is None:
            raise ValueError("Evaluation metric '" + metric_name + "' should not be None")
        # By flipping the sign of lower_is_better metrics here we can now compute everything as if they were
        # greater_is_better and then change them back to positive in the UI before displaying them:
        # src/main/platypus/static/dataiku/js/ml/mlcharts.js#1821
        sign = 1 if metric_definition.greater_is_better else -1
        return SignedMetric(metric_value, sign)

    def to_dict(self):
        r = self._to_dict()
        # for PolyJSON serialization / deserialization
        r["type"] = self.type
        return r

    def get_predicted_data(self):
        return self.predicted_df

    def _to_dict(self):
        return self.perf


class ObjectDetectionPerformanceResults(PerformanceResults):

    def __init__(self, perf, predicted_df, raw_predictions):
        super(ObjectDetectionPerformanceResults, self).__init__(doctor_constants.DEEP_HUB_IMAGE_OBJECT_DETECTION,
                                                                perf, predicted_df, raw_predictions, {})

    @staticmethod
    def empty():
        empty_perf = {
            "perIOU": [],
            "metrics": get_zero_metrics(doctor_constants.DEEP_HUB_IMAGE_OBJECT_DETECTION)
        }
        return ObjectDetectionPerformanceResults(empty_perf, pd.DataFrame(columns=["prediction", "pairing"]),
                                                 raw_predictions=np.array([]))


class ImageClassificationPerformanceResults(PerformanceResults):

    def __init__(self, performance_dict, predicted_df, raw_predictions, prediction_statistics):
        super(ImageClassificationPerformanceResults, self).__init__(doctor_constants.DEEP_HUB_IMAGE_CLASSIFICATION,
                                                                    performance_dict, predicted_df, raw_predictions,
                                                                    prediction_statistics)

    @staticmethod
    def empty(categories):
        empty_perf = {
            "metrics": get_zero_metrics(doctor_constants.DEEP_HUB_IMAGE_CLASSIFICATION),
            "oneVsAllRocAUC": {},
            "oneVsAllRocCurves": {},
            "oneVsAllAveragePrecision": {},
            "oneVsAllPrCurves": {},
            "densityData": {}
        }
        columns = ["prediction"] + ["proba_{}".format(safe_unicode_str(category)) for category in categories]
        return ImageClassificationPerformanceResults(empty_perf, predicted_df=pd.DataFrame(columns=columns),
                                                     raw_predictions=np.array([]),
                                                     prediction_statistics={})
