import json
import logging

import pandas as pd

from dataiku.external_ml.proxy_model.common.outputformat import BaseReader
from dataiku.external_ml.proxy_model.common.utils import decode_base64
from dataiku.external_ml.utils import convert_prediction_response_to_dataframe


class DeployAnywhereReader(BaseReader):
    NAME = "OUTPUT_DEPLOY_ANYWHERE_JSON"

    def __init__(self, prediction_type, value_to_class):
        super(DeployAnywhereReader, self).__init__(prediction_type, value_to_class)

    def can_read(self, endpoint_output):
        self.read(endpoint_output)
        return True

    def read_binary(self, endpoint_output):
        return self._read_all(endpoint_output)

    def read_multiclass(self, endpoint_output):
        return self._read_all(endpoint_output)

    def read_regression(self, endpoint_output):
        return self._read_all(endpoint_output)

    def _read_all(self, endpoint_output):
        from_json = json.loads(endpoint_output)
        return convert_prediction_response_to_dataframe(from_json)

    def decode_output_logs(self, logs_output_data):
        decoded_logs_output_data = logs_output_data.apply(decode_base64)
        list_dfs = []
        for index, raw_row in enumerate(decoded_logs_output_data):
            try:
                inference_query_predictions_df = self.read(raw_row).reset_index(drop=True)
                if inference_query_predictions_df.empty:
                    logging.warning("A empty prediction logs row is found : %s" % str(raw_row))
                    continue  # wrong input format leading to no prediction
                inference_query_predictions_df.index = [(index, i) for i in range(len(inference_query_predictions_df))]
                list_dfs.append(inference_query_predictions_df.dropna())
            except Exception as e:
                logging.info("Logs output data cannot be read for row %s" % str(raw_row))
                logging.debug("Logs output reading exception: {}".format(e))
                continue

        if not list_dfs:
            raise ValueError("No prediction data in decoded logs")

        return pd.concat(list_dfs, ignore_index=False)
