import json
import logging

import pandas as pd
from dataiku.external_ml.proxy_model.common.outputformat import BaseReader

logger = logging.getLogger(__name__)


class SagemakerArrayAsStringReader(BaseReader):
    NAME = "OUTPUT_SAGEMAKER_ARRAY_AS_STRING"

    def __init__(self, prediction_type, value_to_class):
        super(SagemakerArrayAsStringReader, self).__init__(prediction_type, value_to_class)

    def can_read(self, endpoint_output):
        try:
            predictions_dec = endpoint_output.decode("utf-8")
            if isinstance(predictions_dec, str):
                if predictions_dec[0] == "[" and predictions_dec[-1] == "]":
                    logger.info("Predictions are in a one line array")
                    return True
                # Make sure we are not dealing with CSV
                if "\n" in predictions_dec:
                    return False
                if "," not in predictions_dec:
                    return False
                else:
                    try:
                        # Case where the string does not have square brackets, make sure we are not dealing with JSON first.
                        json.loads(predictions_dec)
                        return False
                    except ValueError:
                        # Not json, not CSV, but commas in it... should be ok
                        logger.info("Predictions are in a one line string, with commas-delimited values.")
                        return True
        except Exception as read_str_exception:
            logger.info("Predictions are not in a one-line string format")
            logger.debug("Parse exception: {}".format(read_str_exception))
            return False

    def _read_classification(self, endpoint_output):
        res_df = pd.DataFrame({"prediction": endpoint_output.decode("utf-8").strip('][').split(',')})
        try:
            res_df["prediction"] = pd.to_numeric(res_df["prediction"], downcast="integer")
        except:
            try:
                logger.info("Predictions were not integers, casting as string instead.")
                res_df["prediction"] = res_df["prediction"].astype(str)
            except:
                pass
        return res_df

    def read_binary(self, endpoint_output):
        return self._read_classification(endpoint_output)

    def read_multiclass(self, endpoint_output):
        return self._read_classification(endpoint_output)

    def read_regression(self, endpoint_output):
        res_df = pd.DataFrame({"prediction": endpoint_output.decode("utf-8").strip('][').split(',')})
        try:
            res_df["prediction"] = res_df["prediction"].astype(float)
        except:
            pass
        return res_df
