import json
import pandas as pd
import logging

from dataiku.external_ml.proxy_model.common.outputformat import BaseReader

logger = logging.getLogger(__name__)


class AzureMLJSONObjectResultsReader(BaseReader):
    NAME = "OUTPUT_AZUREML_JSON_OBJECT"

    def __init__(self, prediction_type, value_to_class):
        super(AzureMLJSONObjectResultsReader, self).__init__(prediction_type, value_to_class)

    def can_read(self, endpoint_output):
        try:
            as_json = json.loads(endpoint_output)
            if "Results" in as_json:
                logger.info("Predictions are in JSON object 'results' format")
                return True
            return False
        except json.JSONDecodeError as json_exception:
            logger.info("Predictions are not in JSON object 'results' format")
            logger.debug("JSON Parse exception: {}".format(json_exception))
            return False

    def read_binary(self, endpoint_output):
        return self.read_multiclass(endpoint_output)

    def read_multiclass(self, endpoint_output):
        from_json = json.loads(endpoint_output)
        predictions = from_json.get("Results")
        results = []
        if len(predictions) > 0:
            for prediction in predictions:
                # We support inputs like this:
                # { "Results": [ "label", "label", "label" ]
                # or like this: { "Results": [[proba_label1, proba_label2, proba_label3], [proba_label1, proba_label2, proba_label3], ..]
                result = {}
                if isinstance(prediction, list):
                    # proba kind
                    probas = prediction
                    for i, proba in enumerate(probas):
                        result["proba_{}".format(self.value_to_class[i])] = proba
                else:
                    # label list kind
                    result = prediction
                results.append(result)
        return pd.DataFrame(results)

    def read_regression(self, endpoint_output):
        from_json = json.loads(endpoint_output)
        # flatten the list if needed
        predictions = from_json.get("Results")
        if len(predictions) > 0:
            first = predictions[0]
            if isinstance(first, list):
                raise ValueError("Endpoint output was a JSON list of lists, this often happens when selecting the wrong prediction type (for instance, "
                                 "regression instead of multiclass)")
        return pd.DataFrame({"prediction": predictions})
