import json
import pandas as pd
import logging

from dataiku.external_ml.proxy_model.common.outputformat import BaseReader

logger = logging.getLogger(__name__)


class DatabricksJSONResultsReader(BaseReader):
    NAME = "OUTPUT_DATABRICKS_JSON"

    def __init__(self, prediction_type, value_to_class):
        super(DatabricksJSONResultsReader, self).__init__(prediction_type, value_to_class)

    def can_read(self, endpoint_output):
        try:
            if "predictions" in endpoint_output:
                logger.info("Predictions are in JSON 'predictions' format")
                return True
            return False
        except json.JSONDecodeError as json_exception:
            logger.info("Predictions are not in JSON 'predictions' format")
            logger.debug("JSON Parse exception: {}".format(json_exception))
            return False

    def read_binary(self, endpoint_output):
        return self.read_multiclass(endpoint_output)

    def read_multiclass(self, endpoint_output):
        predictions = endpoint_output.get("predictions")
        keyed_probas = endpoint_output.get("probability", endpoint_output.get("probabilities"))
        results = []
        for row, prediction in enumerate(predictions):
            result = {}
            if keyed_probas:
                probas = keyed_probas[row]
                # Inputs like this:
                # { "predictions": [ "label", "label", "label" ], "probability": [[0.1,...],[0.32,...]] }
                # { "predictions": [ "label", "label", "label" ], "probabilities": [[0.1,...],[0.32,...]] }
                for i, proba in enumerate(probas):
                    result["proba_{}".format(self.value_to_class[i])] = proba
            else:
                # Inputs like this:
                # { "predictions": [ "label", "label", "label" ] }
                # or like this: { "predictions": [[proba_label1, proba_label2, proba_label3], [proba_label1, proba_label2, proba_label3], ..]
                if isinstance(prediction, list):
                    # proba kind
                    probas = prediction
                    for i, proba in enumerate(probas):
                        result["proba_{}".format(self.value_to_class[i])] = proba
                else:
                    # label list kind
                    result["prediction"] = prediction
            results.append(result)
        return pd.DataFrame(results)

    def read_regression(self, endpoint_output):
        # flatten the list if needed
        predictions = endpoint_output.get("predictions")
        if len(predictions) > 0:
            first = predictions[0]
            if isinstance(first, list):
                raise ValueError("Endpoint output was a JSON list of lists, this often happens when selecting the wrong prediction type (for instance, "
                                 "regression instead of multiclass)")
        return pd.DataFrame({"prediction": predictions})
