/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.bedrock;

import com.dataiku.dip.ProxySettings;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.BedrockConnection;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.bedrock.converse.ConverseAPILLMMarshall;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.llm.online.sagemakergeneric.GenericChatCompletionLLMMarshall;
import com.dataiku.dip.llm.online.sagemakergeneric.GenericEmbeddingLLMMarshall;
import com.dataiku.dip.llm.online.sagemakergeneric.GenericImageGenerationLLMMarshall;
import com.dataiku.dip.llm.online.sagemakergeneric.GenericLLMHandling;
import com.dataiku.dip.llm.online.sagemakergeneric.GenericTextCompletionLLMMarshall;
import com.dataiku.dip.recipes.nlp.finetuning.FineTuningRecipePayloadParams;
import com.dataiku.dip.security.aws.AWSClientBrokerService;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.core.SdkBytes;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.http.SdkHttpClient;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.http.async.SdkAsyncHttpClient;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.regions.Region;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrock.BedrockClient;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrock.BedrockClientBuilder;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrock.model.BedrockException;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrock.model.CreateModelCustomizationJobRequest;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrock.model.CreateModelCustomizationJobResponse;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrock.model.CreateProvisionedModelThroughputResponse;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrock.model.GetModelCustomizationJobResponse;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrock.model.GetProvisionedModelThroughputResponse;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrock.model.ListFoundationModelsResponse;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrock.model.ModelCustomization;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrock.model.Validator;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClientBuilder;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClientBuilder;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.BedrockRuntimeResponse;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ConverseRequest;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ConverseResponse;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetadataEvent;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamRequest;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamResponseHandler;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.PayloadPart;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

public class RawBedrockClient {
    private final BedrockConnection.BedrockFineTuningSettings fineTuningSettings;
    private final BedrockClient awsClient;
    private final BedrockRuntimeClient awsRuntimeClient;
    private final BedrockRuntimeAsyncClient awsRuntimeAsyncClient;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.bedrock.client");
    private static final int MAX_JSON_SAMPLE_LOG_LENGTH = 120;

    public String getFineTuningBucketName() {
        return this.fineTuningSettings.bucketName;
    }

    public RawBedrockClient(String region, AwsCredentialsProvider credentialsProvider, AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings, @Nonnull ProxySettings proxySettings, BedrockConnection.BedrockFineTuningSettings fineTuningSettings) {
        Region awsRegion = Region.of((String)region);
        AWSClientBrokerService awsClientBrokerService = (AWSClientBrokerService)SpringUtils.getBean(AWSClientBrokerService.class);
        SdkHttpClient.Builder httpClientBuilder = awsClientBrokerService.getHttpClientBuilder(proxySettings, networkSettings.queryTimeoutMS, networkSettings.queryTimeoutMS);
        this.awsClient = (BedrockClient)((BedrockClientBuilder)awsClientBrokerService.createBedrockClientBuilder(httpClientBuilder, networkSettings.createRetryStrategy(), awsRegion).credentialsProvider(credentialsProvider)).build();
        httpClientBuilder = awsClientBrokerService.getHttpClientBuilder(proxySettings, networkSettings.queryTimeoutMS, networkSettings.queryTimeoutMS);
        this.awsRuntimeClient = (BedrockRuntimeClient)((BedrockRuntimeClientBuilder)awsClientBrokerService.createBedrockRuntimeClientBuilder(httpClientBuilder, networkSettings.createRetryStrategy(), awsRegion).credentialsProvider(credentialsProvider)).build();
        SdkAsyncHttpClient.Builder asyncHttpClientBuilder = awsClientBrokerService.getAsyncHttpClientBuilder(proxySettings, networkSettings.queryTimeoutMS, networkSettings.queryTimeoutMS);
        this.awsRuntimeAsyncClient = (BedrockRuntimeAsyncClient)((BedrockRuntimeAsyncClientBuilder)awsClientBrokerService.createBedrockRuntimeAsyncClientBuilder(asyncHttpClientBuilder, networkSettings.createRetryStrategy(), awsRegion).credentialsProvider(credentialsProvider)).build();
        this.fineTuningSettings = fineTuningSettings;
    }

    public void close() {
        this.awsClient.close();
        this.awsRuntimeClient.close();
        this.awsRuntimeAsyncClient.close();
    }

    private InvokeModelResponse executeInvokeRequest(JsonObject requestBody, String modelId) {
        return this.awsRuntimeClient.invokeModel(request -> request.body(SdkBytes.fromUtf8String((String)JSON.json((Object)requestBody))).modelId(modelId));
    }

    private static JsonObject getBodyFromInvokeResponse(InvokeModelResponse response) {
        return (JsonObject)JSON.parse((String)response.body().asUtf8String(), JsonObject.class);
    }

    private static Map<String, String> getTokenCountHeaders(BedrockRuntimeResponse response) {
        HashMap<String, String> headers = new HashMap<String, String>();
        Map responseHeaders = response.sdkHttpResponse().headers();
        String input_token_key = "X-Amzn-Bedrock-Input-Token-Count";
        headers.put(input_token_key, (String)((List)responseHeaders.get(input_token_key)).get(0));
        String output_token_key = "X-Amzn-Bedrock-Output-Token-Count";
        headers.put(output_token_key, (String)((List)responseHeaders.get(output_token_key)).get(0));
        return headers;
    }

    public LLMClient.SimpleCompletionResponse chatCompleteInvokeAPI(String modelId, GenericLLMHandling handling, List<LLMClient.ChatMessage> messages, CoreCompletionSettings ccs) throws IOException {
        GenericChatCompletionLLMMarshall marshall = GenericChatCompletionLLMMarshall.get(handling, null);
        JsonObject requestBody = marshall.prepareChatCompletionQuery(messages, ccs);
        logger.trace(() -> String.format("Bedrock raw chat query: %s", JSON.pretty((Object)requestBody)));
        InvokeModelResponse response = this.executeInvokeRequest(requestBody, modelId);
        JsonObject responseBody = RawBedrockClient.getBodyFromInvokeResponse(response);
        logger.trace(() -> String.format("Bedrock raw chat response: %s", JSON.pretty((Object)responseBody)));
        return marshall.parseChatCompletionResponse(RawBedrockClient.getTokenCountHeaders((BedrockRuntimeResponse)response), (JsonElement)responseBody);
    }

    public LLMClient.SimpleCompletionResponse completeInvokeAPI(String modelId, GenericLLMHandling handling, String prompt, CoreCompletionSettings ccs) throws IOException {
        GenericTextCompletionLLMMarshall marshall = GenericTextCompletionLLMMarshall.get(handling, null);
        JsonObject requestBody = marshall.prepareTextCompletionQuery(prompt, ccs);
        logger.infoV("Bedrock raw completion request: %s", new Object[]{JSON.pretty((Object)requestBody)});
        InvokeModelResponse response = this.executeInvokeRequest(requestBody, modelId);
        JsonObject responseBody = RawBedrockClient.getBodyFromInvokeResponse(response);
        logger.trace(() -> String.format("Bedrock raw completion response: %s", JSON.pretty((Object)responseBody)));
        return marshall.parseTextCompletionResponse((JsonElement)responseBody, prompt);
    }

    public LLMClient.SimpleCompletionResponse chatCompleteConverseAPI(String modelId, GenericLLMHandling handling, List<LLMClient.ChatMessage> messages, CoreCompletionSettings ccs) {
        ConverseAPILLMMarshall marshall = new ConverseAPILLMMarshall(handling);
        ConverseRequest request = (ConverseRequest)marshall.prepareChatCompletionQuery(messages, ccs).toBuilder().modelId(modelId).build();
        logger.trace(() -> String.format("Bedrock Converse raw chat query: %s", JSON.pretty((Object)request)));
        ConverseResponse response = this.awsRuntimeClient.converse(request);
        logger.trace(() -> String.format("Bedrock Converse raw chat response: %s", JSON.pretty((Object)response.toString())));
        return marshall.parseChatCompletionResponse(response);
    }

    public LLMClient.SimpleEmbeddingResponse embed(String modelId, GenericLLMHandling handling, LLMClient.EmbeddingQuery query) throws IOException, IllegalArgumentException {
        GenericEmbeddingLLMMarshall marshall = GenericEmbeddingLLMMarshall.get(handling, null);
        JsonObject requestBody = marshall.prepareInputsEmbedding(query);
        InvokeModelResponse response = this.executeInvokeRequest(requestBody, modelId);
        JsonObject responseBody = RawBedrockClient.getBodyFromInvokeResponse(response);
        return marshall.parseEmbeddingResponse((JsonElement)responseBody);
    }

    public LLMClient.ImageGenerationResponse generateImage(GenericLLMHandling handling, String modelId, LLMClient.ImageGenerationQuery query) throws IOException {
        GenericImageGenerationLLMMarshall marshall = GenericImageGenerationLLMMarshall.get(handling);
        JsonObject requestBody = marshall.prepareImageGenerationQuery(query);
        if (logger.isTraceEnabled()) {
            logger.trace((Object)("Raw image generation query on model " + modelId + ": " + JSON.json((Object)requestBody)));
        }
        InvokeModelResponse response = this.executeInvokeRequest(requestBody, modelId);
        JsonObject responseBody = RawBedrockClient.getBodyFromInvokeResponse(response);
        return marshall.parseImageGenerationResponse((JsonElement)responseBody, query);
    }

    public void streamCompleteInvokeAPI(LLMClient.StreamedCompletionResponseConsumer consumer, BedrockConnection.BedrockModel model, GenericLLMHandling handling, @Nullable List<LLMClient.ChatMessage> messages, @Nullable String prompt, CoreCompletionSettings ccs) throws Exception {
        InvokeStreamChunkHandler chunkHandler;
        JsonObject requestBody;
        String modelId = model.getInferenceModelId();
        if (handling.isChatModel) {
            if (messages == null) {
                throw new IllegalArgumentException("The provided messages must not be null for chat models");
            }
            GenericChatCompletionLLMMarshall chatMarshall = GenericChatCompletionLLMMarshall.get(handling, null);
            requestBody = chatMarshall.prepareChatCompletionQuery(messages, ccs);
            logger.trace(() -> String.format("Bedrock raw chat streaming query: %s", JSON.pretty((Object)requestBody)));
            chunkHandler = new InvokeStreamChunkHandler(chatMarshall, consumer, model);
        } else {
            if (prompt == null) {
                throw new IllegalArgumentException("The provided prompt must not be null for completion models");
            }
            GenericTextCompletionLLMMarshall completionMarshall = GenericTextCompletionLLMMarshall.get(handling, null);
            requestBody = completionMarshall.prepareTextCompletionQuery(prompt, ccs);
            logger.trace(() -> String.format("Bedrock raw completion streaming query: %s", JSON.pretty((Object)requestBody)));
            chunkHandler = new InvokeStreamChunkHandler(completionMarshall, consumer, model);
        }
        InvokeModelWithResponseStreamRequest request = (InvokeModelWithResponseStreamRequest)InvokeModelWithResponseStreamRequest.builder().body(SdkBytes.fromUtf8String((String)JSON.json((Object)requestBody))).modelId(modelId).build();
        InvokeModelWithResponseStreamResponseHandler responseStreamHandler = InvokeModelWithResponseStreamResponseHandler.builder().subscriber(InvokeModelWithResponseStreamResponseHandler.Visitor.builder().onChunk(chunkHandler::onChunk).build()).build();
        try {
            this.awsRuntimeAsyncClient.invokeModelWithResponseStream(request, responseStreamHandler).get();
            chunkHandler.throwChunkParsingExceptionIfAny();
        }
        catch (Exception e) {
            logger.error((Object)String.format("Can't call invoke streaming '%s': %s", modelId, e.getMessage()));
            if (e instanceof ExecutionException) {
                throw (Exception)e.getCause();
            }
            throw e;
        }
    }

    public void streamCompleteConverseAPI(LLMClient.StreamedCompletionResponseConsumer consumer, BedrockConnection.BedrockModel model, GenericLLMHandling handling, @Nullable List<LLMClient.ChatMessage> messages, CoreCompletionSettings ccs) throws Exception {
        ConverseAPILLMMarshall marshall = new ConverseAPILLMMarshall(handling);
        ConverseStreamRequest request = marshall.prepareChatCompletionStreamingQuery(messages, ccs);
        String modelId = model.getInferenceModelId();
        request = (ConverseStreamRequest)request.toBuilder().modelId(modelId).build();
        ConverseStreamChunkHandler chunkHandler = new ConverseStreamChunkHandler(marshall, consumer, model);
        ConverseStreamResponseHandler streamHandler = ConverseStreamResponseHandler.builder().subscriber(ConverseStreamResponseHandler.Visitor.builder().onDefault(chunkHandler::onDefaultChunk).onMetadata(chunkHandler::onMetadataChunk).build()).build();
        try {
            this.awsRuntimeAsyncClient.converseStream(request, streamHandler).get();
            chunkHandler.throwChunkParsingExceptionIfAny();
        }
        catch (Exception e) {
            logger.error((Object)String.format("Can't call converse streaming with model '%s': %s", modelId, e.getMessage()));
            if (e instanceof ExecutionException) {
                throw (Exception)e.getCause();
            }
            throw e;
        }
    }

    public String fineTuneStart(String baseModelIdentifier, String customModelName, String jobName, String trainingDatasetURI, Optional<String> validationDatasetURI, String outputDataURI, FineTuningRecipePayloadParams.FineTuningHyperparameters hyperparameters, boolean useDefaults) {
        logger.debug((Object)("fineTuneStart: " + baseModelIdentifier + " " + customModelName));
        CreateModelCustomizationJobRequest.Builder requestBuilder = CreateModelCustomizationJobRequest.builder().roleArn(this.fineTuningSettings.roleARN).baseModelIdentifier(baseModelIdentifier).customModelName(customModelName).jobName(jobName).trainingDataConfig(tdc -> tdc.s3Uri(trainingDatasetURI)).outputDataConfig(odc -> odc.s3Uri(outputDataURI));
        HashMap<String, String> hp = new HashMap<String, String>();
        if (!useDefaults) {
            if (hyperparameters.nbEpochs != null) {
                hp.put("epochCount", hyperparameters.nbEpochs.toString());
            }
            if (hyperparameters.remoteHyperparameters.batchSize != null) {
                hp.put("batchSize", hyperparameters.remoteHyperparameters.batchSize.toString());
            }
            if (hyperparameters.remoteHyperparameters.learningRateMultiplier != null) {
                hp.put("learningRate", hyperparameters.remoteHyperparameters.learningRateMultiplier.toString());
            }
        }
        requestBuilder.hyperParameters(hp);
        if (validationDatasetURI.isPresent()) {
            Validator v = (Validator)Validator.builder().s3Uri(validationDatasetURI.get()).build();
            requestBuilder.validationDataConfig(vdc -> vdc.validators(new Validator[]{v}));
        }
        CreateModelCustomizationJobRequest request = (CreateModelCustomizationJobRequest)requestBuilder.build();
        logger.debug((Object)("fineTuneStart raw request: " + String.valueOf(request)));
        CreateModelCustomizationJobResponse response = this.awsClient.createModelCustomizationJob(request);
        return response.jobArn();
    }

    public JsonObject fineTuneGet(String ftJobId) throws IOException {
        logger.debug((Object)("fineTuneGet: " + ftJobId));
        GetModelCustomizationJobResponse response = this.awsClient.getModelCustomizationJob(request -> request.jobIdentifier(ftJobId));
        JsonObject responseObj = new JsonObject();
        responseObj.addProperty("jobName", response.jobName());
        responseObj.addProperty("status", response.statusAsString());
        responseObj.addProperty("failureMessage", response.failureMessage());
        responseObj.addProperty("outputModelArn", response.outputModelArn());
        responseObj.addProperty("outputModelName", response.outputModelName());
        responseObj.add("hyperParameters", (JsonElement)JSON.toJsonObject((Object)response.hyperParameters()));
        responseObj.add("outputDataConfig", (JsonElement)JSON.toJsonObject((Object)response.outputDataConfig()));
        return responseObj;
    }

    public void fineTuneCancel(String ftJobId) {
        logger.debug((Object)("fineTuneCancel: " + ftJobId));
        this.awsClient.stopModelCustomizationJob(request -> request.jobIdentifier(ftJobId));
    }

    public void deleteFinetunedModel(String ftModelId) {
        logger.debug((Object)("deleteFinetunedModel: " + ftModelId));
        try {
            this.awsClient.deleteCustomModel(request -> request.modelIdentifier(ftModelId));
        }
        catch (BedrockException e) {
            logger.error((Object)("Could not delete finetuned model: " + ftModelId), (Throwable)e);
            throw e;
        }
    }

    public JsonArray listFoundationModelsForFineTuning() {
        logger.debug((Object)"listFoundationModelsForFineTuning");
        ListFoundationModelsResponse response = this.awsClient.listFoundationModels(request -> request.byCustomizationType(ModelCustomization.FINE_TUNING));
        List summaries = response.modelSummaries();
        return JSON.toJsonArray((Object)summaries);
    }

    public Optional<BedrockProvisionedThroughput> getProvisionedThroughput_NT(String provisionedModelId) {
        try {
            GetProvisionedModelThroughputResponse response = this.awsClient.getProvisionedModelThroughput(request -> request.provisionedModelId(provisionedModelId));
            BedrockProvisionedThroughput parsedResponse = new BedrockProvisionedThroughput();
            if (response.commitmentDuration() != null) {
                parsedResponse.commitmentDuration = response.commitmentDuration().toString();
            }
            if (response.commitmentExpirationTime() != null) {
                parsedResponse.commitmentExpirationTime = response.commitmentExpirationTime().toString();
            }
            if (response.creationTime() != null) {
                parsedResponse.creationTime = response.creationTime().toString();
            }
            parsedResponse.desiredModelArn = response.desiredModelArn();
            if (response.desiredModelUnits() != null) {
                parsedResponse.desiredModelUnits = response.desiredModelUnits();
            }
            parsedResponse.failureMessage = response.failureMessage();
            parsedResponse.foundationModelArn = response.foundationModelArn();
            if (response.lastModifiedTime() != null) {
                parsedResponse.lastModifiedTime = response.lastModifiedTime().toString();
            }
            parsedResponse.modelArn = response.modelArn();
            if (response.modelUnits() != null) {
                parsedResponse.modelUnits = response.modelUnits();
            }
            parsedResponse.provisionedModelArn = response.provisionedModelArn();
            parsedResponse.provisionedModelName = response.provisionedModelName();
            if (response.status() != null) {
                parsedResponse.status = response.status().toString();
            }
            return Optional.of(parsedResponse);
        }
        catch (BedrockException e) {
            if (e.statusCode() == 404) {
                logger.warn((Object)("Provisioned throughput " + provisionedModelId + " not found. Ignoring it."));
                return Optional.empty();
            }
            logger.error((Object)("Could not get provisioned throughput " + provisionedModelId), (Throwable)e);
            throw e;
        }
    }

    public String createProvisionedThroughput_NT(String provisionedModelId, String fineTunedModelId) {
        CreateProvisionedModelThroughputResponse response = this.awsClient.createProvisionedModelThroughput(request -> request.modelId(fineTunedModelId).modelUnits(Integer.valueOf(1)).provisionedModelName(provisionedModelId));
        return response.provisionedModelArn();
    }

    public void deleteProvisionedThroughput_NT(String provisionedModelId) {
        try {
            this.awsClient.deleteProvisionedModelThroughput(request -> request.provisionedModelId(provisionedModelId));
        }
        catch (BedrockException e) {
            if (e.statusCode() == 404) {
                logger.warn((Object)("Provisioned throughput " + provisionedModelId + " not found. Ignoring it."));
            }
            logger.error((Object)("Could not delete provisioned throughput " + provisionedModelId), (Throwable)e);
            throw e;
        }
    }

    private static class InvokeStreamChunkHandler
    extends AbstractStreamChunkHandler {
        private final GenericTextCompletionLLMMarshall completionMarshall;
        private final GenericChatCompletionLLMMarshall chatMarshall;

        InvokeStreamChunkHandler(GenericTextCompletionLLMMarshall completionMarshall, LLMClient.StreamedCompletionResponseConsumer consumer, BedrockConnection.BedrockModel model) throws Exception {
            super(consumer, model);
            this.completionMarshall = Objects.requireNonNull(completionMarshall);
            this.chatMarshall = null;
        }

        InvokeStreamChunkHandler(GenericChatCompletionLLMMarshall chatMarshall, LLMClient.StreamedCompletionResponseConsumer consumer, BedrockConnection.BedrockModel model) throws Exception {
            super(consumer, model);
            this.completionMarshall = null;
            this.chatMarshall = Objects.requireNonNull(chatMarshall);
        }

        public void onChunk(PayloadPart response) {
            try {
                JsonObject chunk = (JsonObject)JSON.parse((String)response.bytes().asUtf8String(), JsonObject.class);
                logger.trace(() -> String.format("Bedrock Invoke raw streamed response chunk: %s", JSON.pretty((Object)chunk)));
                GenericTextCompletionLLMMarshall.EnrichedStreamedCompletionResponseChunk parsedChunk = null;
                if (this.chatMarshall != null) {
                    parsedChunk = this.chatMarshall.parseChatCompletionChunk(chunk);
                } else {
                    assert (this.completionMarshall != null);
                    parsedChunk = this.completionMarshall.parseTextCompletionChunk(chunk);
                }
                if (parsedChunk.finishReason != null) {
                    this.finishReason = parsedChunk.finishReason;
                }
                if (!parsedChunk.chunk.isEmpty()) {
                    this.consumeChunk(parsedChunk.chunk);
                }
                if (parsedChunk.promptTokens != null) {
                    this.buildFooter(parsedChunk);
                }
            }
            catch (Exception e) {
                this.setChunkParsingException(e);
                throw new RuntimeException(e);
            }
        }
    }

    private static class ConverseStreamChunkHandler
    extends AbstractStreamChunkHandler {
        ConverseAPILLMMarshall marshall;

        public ConverseStreamChunkHandler(ConverseAPILLMMarshall marshall, LLMClient.StreamedCompletionResponseConsumer consumer, BedrockConnection.BedrockModel model) throws Exception {
            super(consumer, model);
            this.marshall = marshall;
        }

        public void onDefaultChunk(ConverseStreamOutput chunk) {
            try {
                logger.trace(() -> String.format("Bedrock Converse raw streamed response chunk: %s", JSON.pretty((Object)chunk)));
                GenericTextCompletionLLMMarshall.EnrichedStreamedCompletionResponseChunk parsedChunk = this.marshall.parseChatCompletionChunk(chunk);
                if (parsedChunk.finishReason != null) {
                    this.finishReason = parsedChunk.finishReason;
                }
                if (!parsedChunk.chunk.isEmpty()) {
                    this.consumeChunk(parsedChunk.chunk);
                }
            }
            catch (Exception e) {
                this.setChunkParsingException(e);
                throw new RuntimeException(e);
            }
        }

        public void onMetadataChunk(ConverseStreamMetadataEvent chunk) {
            try {
                GenericTextCompletionLLMMarshall.EnrichedStreamedCompletionResponseChunk parsedChunk = this.marshall.parseChatCompletionChunk((ConverseStreamOutput)chunk);
                this.buildFooter(parsedChunk);
            }
            catch (Exception e) {
                this.setChunkParsingException(e);
                throw new RuntimeException(e);
            }
        }
    }

    public static class BedrockProvisionedThroughput {
        public String commitmentDuration;
        public String commitmentExpirationTime;
        public String creationTime;
        public String desiredModelArn;
        public int desiredModelUnits;
        public String failureMessage;
        public String foundationModelArn;
        public String lastModifiedTime;
        public String modelArn;
        public int modelUnits;
        public String provisionedModelArn;
        public String provisionedModelName;
        public String status;
    }

    private static abstract class AbstractStreamChunkHandler {
        final LLMClient.StreamedCompletionResponseConsumer consumer;
        final BedrockConnection.BedrockModel model;
        LLMClient.FinishReason finishReason;
        protected Exception hookException;

        protected AbstractStreamChunkHandler(LLMClient.StreamedCompletionResponseConsumer consumer, BedrockConnection.BedrockModel model) throws Exception {
            this.consumer = consumer;
            this.model = model;
            this.consumer.onStreamStarted();
        }

        protected void consumeChunk(LLMClient.StreamedCompletionResponseChunk chunk) throws Exception {
            this.consumer.onStreamChunk(chunk);
        }

        protected void completeStream(LLMClient.StreamedCompletionResponseFooter footer) throws Exception {
            logger.info((Object)("Footer: " + JSON.json((Object)footer)));
            this.consumer.onStreamComplete(footer);
        }

        protected void buildFooter(GenericTextCompletionLLMMarshall.EnrichedStreamedCompletionResponseChunk parsedChunk) throws Exception {
            LLMClient.StreamedCompletionResponseFooter footer = new LLMClient.StreamedCompletionResponseFooter();
            if (this.finishReason != null) {
                footer.finishReason = this.finishReason;
            }
            if (parsedChunk != null) {
                footer.promptTokens = parsedChunk.promptTokens;
                footer.completionTokens = parsedChunk.completionTokens;
                footer.totalTokens = footer.promptTokens + footer.completionTokens;
                footer.estimatedCost = this.model.getEstimatedCompletionCost(footer.promptTokens, footer.completionTokens);
            }
            this.completeStream(footer);
        }

        protected void setChunkParsingException(Exception e) {
            if (this.hookException == null) {
                this.hookException = e;
            }
        }

        public void throwChunkParsingExceptionIfAny() throws Exception {
            if (this.hookException != null) {
                throw this.hookException;
            }
        }
    }
}

