import logging

import numpy as np
import pandas as pd
from dataiku.external_ml.proxy_model.common.outputformat import BaseReader

logger = logging.getLogger(__name__)


class VertexAIDefaultReader(BaseReader):
    NAME = "OUTPUT_VERTEX_DEFAULT"

    def __init__(self, prediction_type, value_to_class):
        super(VertexAIDefaultReader, self).__init__(prediction_type, value_to_class)

    def can_read(self, predictions):
        self.read(predictions)
        return True

    def _read_classification(self, predictions):
        results = []
        for prediction in predictions:
            if "classes" in prediction and "scores" in prediction:
                result = {"prediction": prediction["classes"][np.argmax(prediction["scores"])]}
                for class_value, score in zip(prediction["classes"], prediction["scores"]):
                    result["proba_{}".format(class_value)] = score
            else:
                # This field is not in VertexAI specification but, in Deploy Anywhere Vertex, we have added a custom field to handle classification without probas
                result = {"prediction": prediction["prediction"]}
            results.append(result)
        return pd.DataFrame(results)

    def read_binary(self, predictions):
        return self._read_classification(predictions)

    def read_multiclass(self, predictions):
        return self._read_classification(predictions)

    def read_regression(self, predictions):
        results = [{"prediction": prediction["value"]} for prediction in predictions]
        return pd.DataFrame(results)
