import json

from dataiku.external_ml.proxy_model.sagemaker.inputformat.abstract_sagemaker_writer import AbstractSagemakerWriter


class SagemakerJSONExtendedWriter(AbstractSagemakerWriter):
    # Implements the "extended" JSON format described at
    # https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-inference.html#common-in-formats
    NAME = "INPUT_SAGEMAKER_JSON_EXTENDED"

    def __init__(self, client):
        super(SagemakerJSONExtendedWriter, self).__init__(client)

    def build_request_payload(self, input_df):
        payload = input_df.values.tolist()
        payload = {"instances": [SagemakerJSONExtendedWriter.get_line(row) for row in payload]}
        payload = json.dumps(payload)
        return payload

    def write(self, input_df):
        # note: content_type is *not* an HTTP header in this case
        return self.client.call_endpoint(self.build_request_payload(input_df), {"ContentType": "application/json"})

    @staticmethod
    def get_line(row):
        return {"data": {"features": {"values": row}}}
