import json
import logging
import pandas as pd
from io import StringIO

from dataiku.external_ml.proxy_model.common.outputformat import BaseReader

logger = logging.getLogger(__name__)


class ProbasFormat:
    JSON = "JSON"
    CSV = "CSV"


class SagemakerCSVReader(BaseReader):
    NAME = "OUTPUT_SAGEMAKER_CSV"

    def __init__(self, prediction_type, value_to_class):
        super(SagemakerCSVReader, self).__init__(prediction_type, value_to_class)
        self.probas_format = None

    def can_read(self, endpoint_output):
        try:
            pd.read_csv(StringIO(endpoint_output.decode("utf-8")), header=None)
            logger.info("Predictions are in CSV format")
            return True
        except Exception as read_csv_exception:
            logger.info("Predictions are not in CSV format")
            logger.debug("CSV Parse exception: {}".format(read_csv_exception))
            return False

    def read_classification(self, endpoint_output):
        results = []
        decoded_endpoint_output = StringIO(endpoint_output.decode("utf-8"))
        output_df = pd.read_csv(decoded_endpoint_output, header=None)
        for prediction in output_df.to_dict(orient="records"):
            result = {"prediction": str(prediction[0])}
            if prediction.get(1) is not None:
                probas = self._get_probas(prediction)
                if len(probas) == 1:
                    # only proba of first class
                    result["prediction"] = probas[0]
                else:
                    for i, proba in enumerate(probas):
                        result["proba_{}".format(self.value_to_class[i])] = proba
            results.append(result)
        return pd.DataFrame(results)

    def _get_probas(self, row):
        assert len(row) >= 2, "the row doesn't have two columns: {}".format(row)
        if self.probas_format is None:
            try:
                # Case where probas are in a JSON-like array as the second column of the CSV row.
                json.loads(row[1])
                self.probas_format = ProbasFormat.JSON
            except Exception:
                self.probas_format = ProbasFormat.CSV
        if self.probas_format == ProbasFormat.JSON:
            return json.loads(row[1])
        else:
            return list(row.values())[1:]

    def read_binary(self, endpoint_output):
        return self.read_classification(endpoint_output)

    def read_multiclass(self, endpoint_output):
        return self.read_classification(endpoint_output)

    def read_regression(self, endpoint_output):
        results = []
        for prediction in pd.read_csv(StringIO(endpoint_output.decode("utf-8")), header=None).to_dict(orient="records"):
            results.append({"prediction": prediction[0]})
        return pd.DataFrame(results)
