/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.bedrock.converse;

import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.marshall.FinishReasonResponseAdapter;
import com.dataiku.dip.llm.online.sagemakergeneric.GenericTextCompletionLLMMarshall;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDelta;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDeltaEvent;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStart;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStartEvent;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamMetadataEvent;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.MessageStopEvent;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ReasoningContentBlockDelta;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.StopReason;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.TokenUsage;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlockDelta;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlockStart;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.util.List;
import org.apache.commons.lang.StringUtils;

public class ConverseChatChunkResponseAdapter {
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.converse.api");

    public static GenericTextCompletionLLMMarshall.EnrichedStreamedCompletionResponseChunk adapt(ConverseStreamOutput.EventType eventType, ConverseStreamOutput response, String uniqueId) {
        GenericTextCompletionLLMMarshall.EnrichedStreamedCompletionResponseChunk enrichedChunk = new GenericTextCompletionLLMMarshall.EnrichedStreamedCompletionResponseChunk();
        switch (eventType) {
            case MESSAGE_START: 
            case CONTENT_BLOCK_STOP: {
                logger.trace(() -> String.format("Skipping response chunk with type: %s", eventType));
                enrichedChunk.chunk = LLMClient.StreamedCompletionResponseChunk.empty();
                return enrichedChunk;
            }
            case CONTENT_BLOCK_START: {
                enrichedChunk.chunk = ConverseChatChunkResponseAdapter.adaptContentBlockStart((ContentBlockStartEvent)response);
                return enrichedChunk;
            }
            case CONTENT_BLOCK_DELTA: {
                enrichedChunk.chunk = ConverseChatChunkResponseAdapter.adaptContentBlockDelta((ContentBlockDeltaEvent)response, uniqueId);
                return enrichedChunk;
            }
            case MESSAGE_STOP: {
                enrichedChunk.finishReason = ConverseChatChunkResponseAdapter.extractFinishReason((MessageStopEvent)response);
                return enrichedChunk;
            }
            case METADATA: {
                JsonObject usage = ConverseChatChunkResponseAdapter.extractUsage((ConverseStreamMetadataEvent)response);
                if (usage != null) {
                    enrichedChunk.promptTokens = usage.get("inputTokens").getAsInt();
                    enrichedChunk.completionTokens = usage.get("outputTokens").getAsInt();
                }
                return enrichedChunk;
            }
        }
        logger.warn((Object)String.format("Unknown response chunk: %s", JSON.prettyLog((Object)response)));
        return enrichedChunk;
    }

    private static LLMClient.StreamedCompletionResponseChunk adaptContentBlockStart(ContentBlockStartEvent response) {
        ContentBlockStart start = response.start();
        int index = response.contentBlockIndex();
        if (start != null && start.toolUse() != null) {
            ToolUseBlockStart toolUseBlockStart = start.toolUse();
            LLMClient.FunctionToolCall ftc = ConverseChatChunkResponseAdapter.prebuildToolCall(index);
            ftc.id = toolUseBlockStart.toolUseId();
            ftc.function.name = toolUseBlockStart.name();
            LLMClient.StreamedCompletionResponseChunk chunk = new LLMClient.StreamedCompletionResponseChunk();
            chunk.toolCalls = List.of(ftc);
            return chunk;
        }
        logger.info((Object)"Skipping response chunk with type contentBlockStart and no toolUse");
        return LLMClient.StreamedCompletionResponseChunk.empty();
    }

    private static LLMClient.StreamedCompletionResponseChunk adaptContentBlockDelta(ContentBlockDeltaEvent response, String uniqueId) {
        ContentBlockDelta delta = response.delta();
        int index = response.contentBlockIndex();
        if (delta != null && delta.text() != null) {
            LLMClient.StreamedCompletionResponseChunk chunk = new LLMClient.StreamedCompletionResponseChunk();
            chunk.text = delta.text();
            return chunk;
        }
        if (delta != null && delta.toolUse() != null) {
            ToolUseBlockDelta toolUseBlockDelta = delta.toolUse();
            LLMClient.FunctionToolCall ftc = ConverseChatChunkResponseAdapter.prebuildToolCall(index);
            ftc.function.arguments = toolUseBlockDelta.input();
            LLMClient.StreamedCompletionResponseChunk chunk = new LLMClient.StreamedCompletionResponseChunk();
            chunk.toolCalls = List.of(ftc);
            return chunk;
        }
        if (delta != null && delta.reasoningContent() != null) {
            ReasoningContentBlockDelta reasoningContentBlockDelta = delta.reasoningContent();
            LLMClient.StreamedCompletionResponseChunk chunk = new LLMClient.StreamedCompletionResponseChunk();
            String reasoningText = reasoningContentBlockDelta.text();
            String signature = reasoningContentBlockDelta.signature();
            if (reasoningContentBlockDelta.type() == ReasoningContentBlockDelta.Type.TEXT && reasoningText != null) {
                LLMClient.Artifact artifact = new LLMClient.Artifact();
                artifact.type = "REASONING";
                artifact.id = uniqueId;
                LLMClient.SourceItem sourceItem = new LLMClient.SourceItem();
                sourceItem.type = "TEXT";
                sourceItem.index = index;
                sourceItem.text = reasoningText;
                artifact.parts = List.of(sourceItem);
                chunk.artifacts = List.of(artifact);
            }
            if (reasoningContentBlockDelta.type() == ReasoningContentBlockDelta.Type.SIGNATURE && StringUtils.isNotEmpty((String)signature)) {
                chunk.memoryFragment = new LLMClient.MemoryFragment();
                chunk.memoryFragment.llmReasoning = JSON.toJsonObject((Object)reasoningContentBlockDelta, (String[])new String[0]);
            }
            return chunk;
        }
        logger.info((Object)String.format("Unknown response chunk: %s", JSON.prettyLog((Object)response)));
        return LLMClient.StreamedCompletionResponseChunk.empty();
    }

    private static LLMClient.FinishReason extractFinishReason(MessageStopEvent response) {
        StopReason reason = response.stopReason();
        if (reason == null) {
            return null;
        }
        return FinishReasonResponseAdapter.adapt(reason.toString());
    }

    private static LLMClient.FunctionToolCall prebuildToolCall(int index) {
        LLMClient.FunctionToolCall ftc = new LLMClient.FunctionToolCall();
        ftc.index = index;
        ftc.function = new LLMClient.FunctionToolCallInfo();
        return ftc;
    }

    private static JsonObject extractUsage(ConverseStreamMetadataEvent response) {
        TokenUsage usage = response.usage();
        JsonObject usageJson = null;
        if (usage != null) {
            usageJson = new JsonObject();
            usageJson.addProperty("inputTokens", (Number)usage.inputTokens());
            usageJson.addProperty("outputTokens", (Number)usage.outputTokens());
        }
        return usageJson;
    }

    public static GenericTextCompletionLLMMarshall.EnrichedStreamedCompletionResponseChunk adapt(String eventType, JsonObject response) {
        GenericTextCompletionLLMMarshall.EnrichedStreamedCompletionResponseChunk enrichedChunk = new GenericTextCompletionLLMMarshall.EnrichedStreamedCompletionResponseChunk();
        switch (eventType) {
            case "messageStart": 
            case "contentBlockStop": {
                logger.info((Object)String.format("Skipping response chunk with type: %s", eventType));
                enrichedChunk.chunk = LLMClient.StreamedCompletionResponseChunk.empty();
                return enrichedChunk;
            }
            case "contentBlockStart": {
                enrichedChunk.chunk = ConverseChatChunkResponseAdapter.adaptContentBlockStart(response);
                return enrichedChunk;
            }
            case "contentBlockDelta": {
                enrichedChunk.chunk = ConverseChatChunkResponseAdapter.adaptContentBlockDelta(response);
                return enrichedChunk;
            }
            case "messageStop": {
                enrichedChunk.finishReason = ConverseChatChunkResponseAdapter.extractFinishReason(response);
                return enrichedChunk;
            }
            case "metadata": {
                JsonObject usage = ConverseChatChunkResponseAdapter.extractUsage(response);
                if (usage != null) {
                    enrichedChunk.promptTokens = usage.get("inputTokens").getAsInt();
                    enrichedChunk.completionTokens = usage.get("outputTokens").getAsInt();
                }
                return enrichedChunk;
            }
        }
        logger.info((Object)String.format("Unknown response chunk: %s", JSON.prettyLog((Object)response)));
        return enrichedChunk;
    }

    private static LLMClient.StreamedCompletionResponseChunk adaptContentBlockStart(JsonObject response) {
        JsonObject start = (JsonObject)JSON.parse((JsonElement)response.get("start"), JsonObject.class);
        int index = response.get("contentBlockIndex").getAsInt();
        if (start.has("text")) {
            logger.info((Object)"Skipping response chunk with type: (contentBlockStart, text)");
            return LLMClient.StreamedCompletionResponseChunk.empty();
        }
        if (start.has("toolUse")) {
            JsonObject toolUse = (JsonObject)JSON.parse((JsonElement)start.get("toolUse"), JsonObject.class);
            LLMClient.FunctionToolCall ftc = new LLMClient.FunctionToolCall();
            ftc.index = index;
            ftc.id = toolUse.get("toolUseId").getAsString();
            ftc.function = new LLMClient.FunctionToolCallInfo();
            ftc.function.name = toolUse.get("name").getAsString();
            LLMClient.StreamedCompletionResponseChunk chunk = new LLMClient.StreamedCompletionResponseChunk();
            chunk.toolCalls = List.of(ftc);
            return chunk;
        }
        logger.info((Object)String.format("Unknown response chunk: %s", JSON.prettyLog((Object)response)));
        return LLMClient.StreamedCompletionResponseChunk.empty();
    }

    private static LLMClient.StreamedCompletionResponseChunk adaptContentBlockDelta(JsonObject response) {
        JsonObject delta = (JsonObject)JSON.parse((JsonElement)response.get("delta"), JsonObject.class);
        int index = response.get("contentBlockIndex").getAsInt();
        if (delta.has("text")) {
            LLMClient.StreamedCompletionResponseChunk chunk = new LLMClient.StreamedCompletionResponseChunk();
            chunk.text = delta.get("text").getAsString();
            return chunk;
        }
        if (delta.has("toolUse")) {
            JsonObject toolUse = (JsonObject)JSON.parse((JsonElement)delta.get("toolUse"), JsonObject.class);
            LLMClient.FunctionToolCall ftc = new LLMClient.FunctionToolCall();
            ftc.index = index;
            ftc.function = new LLMClient.FunctionToolCallInfo();
            ftc.function.arguments = toolUse.get("input").getAsString();
            LLMClient.StreamedCompletionResponseChunk chunk = new LLMClient.StreamedCompletionResponseChunk();
            chunk.toolCalls = List.of(ftc);
            return chunk;
        }
        logger.info((Object)String.format("Unknown response chunk: %s", JSON.prettyLog((Object)response)));
        return LLMClient.StreamedCompletionResponseChunk.empty();
    }

    private static LLMClient.FinishReason extractFinishReason(JsonObject response) {
        String reason = response.get("stopReason").getAsString();
        if (reason == null) {
            return null;
        }
        return FinishReasonResponseAdapter.adapt(reason);
    }

    private static JsonObject extractUsage(JsonObject response) {
        return (JsonObject)JSON.parse((JsonElement)response.get("usage"), JsonObject.class);
    }
}

