/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.lambda.client;

import com.dataiku.common.server.SerializedError;
import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.DKUApp;
import com.dataiku.dip.apideployer.DeployerUtils;
import com.dataiku.dip.apideployer.datamodel.actual.APIServiceDeploymentHeavyStatus;
import com.dataiku.dip.apideployer.datamodel.config.AbstractFullyManagedAPIDeploymentInfra;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.externalinfras.sagemaker.SageMakerUtils;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.core.SdkBytes;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.sagemakerruntime.SageMakerRuntimeClient;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointResponse;
import com.dataiku.lambda.client.BaseLambdaAPIClient;
import com.dataiku.lambda.model.serverconfig.LambdaEndpointConfig;
import com.dataiku.lambda.model.studioconfig.ApiEndpointQuery;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.apache.log4j.Logger;

public class SageMakerLambdaAPIClient
implements BaseLambdaAPIClient {
    private static final String ACCEPT = "application/json";
    private static final String CONTENT_TYPE = "application/json";
    private static final String EXPLANATIONS_PROPERTY = "explanations";
    private static final String DKU_MISC_INFERENCE_TYPE = "dku_misc_";
    private static final Logger logger = Logger.getLogger((String)"dku.lambda.test.sagemaker");

    public static ApiEndpointQuery updateQueryToSageMakerExpectedFormat(ApiEndpointQuery query, LambdaEndpointConfig.EndpointType endpointType) throws IllegalArgumentException {
        switch (endpointType) {
            case STD_PREDICTION: 
            case STD_CAUSAL_PREDICTION: 
            case STD_CLUSTERING: 
            case CUSTOM_PREDICTION: {
                return SageMakerLambdaAPIClient.createSageMakerPredictionQuery(query);
            }
            case PY_FUNCTION: {
                return query;
            }
        }
        throw new IllegalArgumentException("Unsupported endpoint type " + String.valueOf((Object)endpointType));
    }

    private static ApiEndpointQuery createSageMakerPredictionQuery(ApiEndpointQuery query) {
        ApiEndpointQuery tqCopy = (ApiEndpointQuery)JSON.deepCopy((Object)query);
        SageMakerPredictionQuery sageMakerPredictionQuery = new SageMakerPredictionQuery();
        if (query.q.has(EXPLANATIONS_PROPERTY)) {
            sageMakerPredictionQuery.explanations = tqCopy.q.get(EXPLANATIONS_PROPERTY);
            tqCopy.q.remove(EXPLANATIONS_PROPERTY);
        }
        sageMakerPredictionQuery.items.add((JsonElement)tqCopy.q);
        JsonObject updatedQuery = JSON.toJsonObject((Object)sageMakerPredictionQuery);
        return new ApiEndpointQuery(query.name, updatedQuery);
    }

    public static BaseLambdaAPIClient.ApiEndpointResponses runQueries_NT(AuthCtx authCtx, AbstractFullyManagedAPIDeploymentInfra infra, String awsRegion, String sageMakerEndpointName, APIServiceDeploymentHeavyStatus.EndpointSummary endpoint, Collection<ApiEndpointQuery> queries, boolean forTest) throws IOException, DKUSecurityException {
        BaseLambdaAPIClient.ApiEndpointResponses apiEndpointResponses = new BaseLambdaAPIClient.ApiEndpointResponses();
        SageMakerRuntimeClient sageMakerRuntime = SageMakerUtils.loginSageMakerRuntime_NT(authCtx, infra.authConnection, awsRegion, DeployerUtils.getInfraConnectTimeout(), DeployerUtils.getInfraSocketTimeout());
        for (ApiEndpointQuery tq : queries) {
            BaseLambdaAPIClient.ResponseOrError roe = new BaseLambdaAPIClient.ResponseOrError();
            JsonObject sageMakerQuery = tq.q;
            roe.query = tq;
            try {
                String sageMakerQueryJSON = JSON.json((Object)sageMakerQuery);
                logger.debug((Object)("Submit query with name `" + tq.name + "` and body `" + sageMakerQueryJSON + "` to SageMaker endpoint " + sageMakerEndpointName));
                SdkBytes body = SdkBytes.fromUtf8String((String)sageMakerQueryJSON);
                InvokeEndpointRequest invokeEndpointRequest = (InvokeEndpointRequest)InvokeEndpointRequest.builder().accept("application/json").contentType("application/json").endpointName(sageMakerEndpointName).inferenceId(DKU_MISC_INFERENCE_TYPE + tq.name).body(body).build();
                InvokeEndpointResponse invokeEndpointResult = SageMakerUtils.invokeEndpoint_NT(sageMakerRuntime, invokeEndpointRequest);
                String response = invokeEndpointResult.body().asUtf8String();
                roe.response = JSON.parse((String)response, JsonObject.class);
            }
            catch (Exception e) {
                logger.warn((Object)("Query " + JSON.json((Object)sageMakerQuery) + " sent to SageMaker endpoint " + sageMakerEndpointName + " failed."), (Throwable)e);
                roe.error = new SerializedError((Throwable)e, !ApplicationConfigurator.hideErrorStacks(), !DKUApp.hideErrorStacks(), !ApplicationConfigurator.hideLogTails());
            }
            apiEndpointResponses.responses.add(roe);
        }
        return apiEndpointResponses;
    }

    private static class SageMakerPredictionQuery {
        List<JsonElement> items = new ArrayList<JsonElement>();
        JsonElement explanations;

        private SageMakerPredictionQuery() {
        }
    }
}

