/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.sagemakergeneric;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.anthropic.api.AnthropicChatChunkResponseAdapter;
import com.dataiku.dip.llm.online.anthropic.api.AnthropicChatQuery;
import com.dataiku.dip.llm.online.anthropic.api.AnthropicChatQueryAdapter;
import com.dataiku.dip.llm.online.anthropic.api.AnthropicChatResponse;
import com.dataiku.dip.llm.online.anthropic.api.AnthropicChatResponseAdapter;
import com.dataiku.dip.llm.online.cohere.RawCohereClient;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettingsValidator;
import com.dataiku.dip.llm.online.marshall.FinishReasonResponseAdapter;
import com.dataiku.dip.llm.online.mistralai.api.MistralAIChatQuery;
import com.dataiku.dip.llm.online.mistralai.api.MistralAIChatQueryAdapter;
import com.dataiku.dip.llm.online.mistralai.api.MistralAIChatResponse;
import com.dataiku.dip.llm.online.mistralai.api.MistralAIChatResponseAdapter;
import com.dataiku.dip.llm.online.sagemakergeneric.GenericLLMHandling;
import com.dataiku.dip.llm.online.sagemakergeneric.GenericTextCompletionLLMMarshall;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelib.com.google.common.base.MoreObjects;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;

public interface GenericChatCompletionLLMMarshall {
    public static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.generic_chat_completion_marshall");

    public JsonObject prepareChatCompletionQuery(List<LLMClient.ChatMessage> var1, CoreCompletionSettings var2);

    public LLMClient.SimpleCompletionResponse parseChatCompletionResponse(Map<String, String> var1, JsonElement var2) throws IOException;

    public GenericTextCompletionLLMMarshall.EnrichedStreamedCompletionResponseChunk parseChatCompletionChunk(JsonObject var1, String var2) throws IOException;

    public static GenericChatCompletionLLMMarshall get(GenericLLMHandling family, AbstractLLMConnection.AbstractLLMConnectionParams params) {
        switch (family) {
            case ANTHROPIC_CLAUDE_CHAT: {
                return new AnthropicClaudeChatLLMMarshall();
            }
            case MISTRAL_AI_CHAT: {
                return new MistralAIChatLLMMarshall();
            }
            case COHERE_COMMAND_CHAT: {
                return new CohereCommandChatLLMMarshall();
            }
        }
        throw new Error("Unknown GenericLLMHandling family for chat completion: " + String.valueOf((Object)family));
    }

    public static class AnthropicClaudeChatLLMMarshall
    implements GenericChatCompletionLLMMarshall {
        private static final String ANTHROPIC_VERSION_BEDROCK = "bedrock-2023-05-31";
        private static final CoreCompletionSettingsValidator validator = new CoreCompletionSettingsValidator("Bedrock Anthropic Claude (chat)").allowMaxTokens().allowTemperature().allowTopK().allowTopP().allowStopSequences().allowTools();

        @Override
        public JsonObject prepareChatCompletionQuery(List<LLMClient.ChatMessage> messages, CoreCompletionSettings ccs) {
            validator.validate(ccs);
            ccs.maxTokens = (Integer)MoreObjects.firstNonNull((Object)ccs.maxTokens, (Object)4096);
            AnthropicChatQuery query = AnthropicChatQueryAdapter.adapt(messages, ccs);
            JsonObject rawQuery = (JsonObject)JSON.toJsonElement((Object)query);
            rawQuery.remove("stream");
            rawQuery.addProperty("anthropic_version", ANTHROPIC_VERSION_BEDROCK);
            return rawQuery;
        }

        @Override
        public LLMClient.SimpleCompletionResponse parseChatCompletionResponse(Map<String, String> headers, JsonElement response) {
            AnthropicChatResponse chatResponse = (AnthropicChatResponse)JSON.parse((JsonElement)response, AnthropicChatResponse.class);
            return AnthropicChatResponseAdapter.adapt(chatResponse);
        }

        @Override
        public GenericTextCompletionLLMMarshall.EnrichedStreamedCompletionResponseChunk parseChatCompletionChunk(JsonObject jo, String uniqueId) throws IOException {
            String type;
            GenericTextCompletionLLMMarshall.EnrichedStreamedCompletionResponseChunk ret = new GenericTextCompletionLLMMarshall.EnrichedStreamedCompletionResponseChunk();
            ret.chunk = AnthropicChatChunkResponseAdapter.adapt(jo, uniqueId);
            LLMClient.FinishReason finishReason = AnthropicChatChunkResponseAdapter.extractFinishReason(jo);
            if (finishReason != null) {
                ret.finishReason = finishReason;
            }
            if ("message_stop".equals(type = jo.get("type").getAsString()) && jo.has("amazon-bedrock-invocationMetrics")) {
                JsonObject metrics = jo.get("amazon-bedrock-invocationMetrics").getAsJsonObject();
                ret.promptTokens = metrics.get("inputTokenCount").getAsInt();
                ret.completionTokens = metrics.get("outputTokenCount").getAsInt();
            }
            return ret;
        }
    }

    public static class MistralAIChatLLMMarshall
    implements GenericChatCompletionLLMMarshall {
        private static final CoreCompletionSettingsValidator validator = new CoreCompletionSettingsValidator("Bedrock MistralAI (chat)").allowMaxTokens().allowTemperature().allowTopP().allowTools();

        @Override
        public JsonObject prepareChatCompletionQuery(List<LLMClient.ChatMessage> messages, CoreCompletionSettings ccs) {
            validator.validate(ccs);
            MistralAIChatQuery query = MistralAIChatQueryAdapter.adapt(messages, ccs);
            return JSON.toJsonElement((Object)query).getAsJsonObject();
        }

        @Override
        public LLMClient.SimpleCompletionResponse parseChatCompletionResponse(Map<String, String> headers, JsonElement response) {
            MistralAIChatResponse chatResponse = (MistralAIChatResponse)JSON.parse((JsonElement)response, MistralAIChatResponse.class);
            LLMClient.SimpleCompletionResponse resp = MistralAIChatResponseAdapter.adapt(chatResponse);
            resp.promptTokens = Integer.parseInt(headers.get("X-Amzn-Bedrock-Input-Token-Count"));
            resp.completionTokens = Integer.parseInt(headers.get("X-Amzn-Bedrock-Output-Token-Count"));
            return resp;
        }

        @Override
        public GenericTextCompletionLLMMarshall.EnrichedStreamedCompletionResponseChunk parseChatCompletionChunk(JsonObject jo, String uniqueId) throws IOException {
            MistralAIChatResponse chunkResponse = (MistralAIChatResponse)JSON.parse((JsonElement)jo, MistralAIChatResponse.class);
            LLMClient.SimpleCompletionResponse scr = MistralAIChatResponseAdapter.adapt(chunkResponse);
            LLMClient.FinishReason finishReason = this.extractFinishReason(chunkResponse);
            GenericTextCompletionLLMMarshall.EnrichedStreamedCompletionResponseChunk ret = new GenericTextCompletionLLMMarshall.EnrichedStreamedCompletionResponseChunk();
            if (scr.text != null) {
                ret.chunk.text = scr.text;
            }
            if (scr.toolCalls != null) {
                ret.chunk.toolCalls = scr.toolCalls;
            }
            if (finishReason != null) {
                ret.finishReason = finishReason;
            }
            if (jo.has("amazon-bedrock-invocationMetrics")) {
                JsonObject metrics = jo.get("amazon-bedrock-invocationMetrics").getAsJsonObject();
                ret.promptTokens = metrics.get("inputTokenCount").getAsInt();
                ret.completionTokens = metrics.get("outputTokenCount").getAsInt();
            }
            return ret;
        }

        @Nullable
        private LLMClient.FinishReason extractFinishReason(MistralAIChatResponse chunk) {
            if (chunk.choices == null || chunk.choices.isEmpty()) {
                return null;
            }
            MistralAIChatResponse.Choice choice = chunk.choices.get(0);
            if (choice.finishReason == null) {
                return null;
            }
            return FinishReasonResponseAdapter.adapt(choice.finishReason);
        }
    }

    public static class CohereCommandChatLLMMarshall
    implements GenericChatCompletionLLMMarshall {
        private static final CoreCompletionSettingsValidator validator = new CoreCompletionSettingsValidator("Bedrock Cohere Command (chat)").allowMaxTokens().allowTemperature().allowTopK().allowTopP().allowFrequencyPenalty().allowPresencePenalty().allowStopSequences();

        @Override
        public JsonObject prepareChatCompletionQuery(List<LLMClient.ChatMessage> messages, CoreCompletionSettings ccs) {
            validator.validate(ccs);
            JsonObject query = RawCohereClient.buildChatQuery(messages, ccs);
            query.remove("stream");
            return query;
        }

        @Override
        public LLMClient.SimpleCompletionResponse parseChatCompletionResponse(Map<String, String> headers, JsonElement response) throws IOException {
            JsonObject jo = (JsonObject)response;
            LLMClient.SimpleCompletionResponse ret = new LLMClient.SimpleCompletionResponse();
            ret.text = jo.get("text").getAsString();
            String finishReason = jo.get("finish_reason").getAsString();
            ret.finishReason = FinishReasonResponseAdapter.adapt(finishReason);
            JsonElement meta = jo.get("meta");
            if (meta != null) {
                ret.promptTokens = meta.getAsJsonObject().get("billed_units").getAsJsonObject().get("input_tokens").getAsInt();
                ret.completionTokens = meta.getAsJsonObject().get("billed_units").getAsJsonObject().get("output_tokens").getAsInt();
            } else {
                ret.promptTokens = Integer.parseInt(headers.get("X-Amzn-Bedrock-Input-Token-Count"));
                ret.completionTokens = Integer.parseInt(headers.get("X-Amzn-Bedrock-Output-Token-Count"));
            }
            return ret;
        }

        @Override
        public GenericTextCompletionLLMMarshall.EnrichedStreamedCompletionResponseChunk parseChatCompletionChunk(JsonObject jo, String uniqueId) throws IOException {
            throw new IllegalArgumentException("Streaming not supported on Cohere");
        }
    }
}

