/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.vertex.api;

import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.llm.online.vertex.api.GeminiQuery;
import com.dataiku.dip.llm.utils.json_schema.JSONSchemaCompatibilityEnhancer;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelib.com.google.common.base.Strings;
import com.google.gson.JsonObject;
import com.google.gson.JsonSyntaxException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang.StringUtils;

public class GeminiQueryAdapter {
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.vertex.client");

    private GeminiQueryAdapter() {
    }

    public static GeminiQuery adapt(String model, List<LLMClient.ChatMessage> messages, CoreCompletionSettings ccs) {
        GeminiQuery geminiQuery = new GeminiQuery();
        if (ccs.maxTokens != null) {
            geminiQuery.generationConfig.maxOutputTokens = ccs.maxTokens;
        }
        if (ccs.temperature != null) {
            geminiQuery.generationConfig.temperature = ccs.temperature;
        }
        if (ccs.topK != null) {
            geminiQuery.generationConfig.topK = ccs.topK;
        }
        if (ccs.topP != null) {
            geminiQuery.generationConfig.topP = ccs.topP;
        }
        if (ccs.frequencyPenalty != null) {
            geminiQuery.generationConfig.frequencyPenalty = ccs.frequencyPenalty;
        }
        if (ccs.presencePenalty != null) {
            geminiQuery.generationConfig.presencePenalty = ccs.presencePenalty;
        }
        if (ccs.stopSequences != null && !ccs.stopSequences.isEmpty()) {
            geminiQuery.generationConfig.stopSequences = new ArrayList<String>(ccs.stopSequences);
        }
        if (ccs.responseFormat instanceof LLMClient.ResponseFormatJson) {
            LLMClient.ResponseFormatJson responseFormat = (LLMClient.ResponseFormatJson)ccs.responseFormat;
            geminiQuery.generationConfig.responseMimeType = "application/json";
            if (responseFormat.schema != null) {
                boolean compatible = responseFormat.compatible == null || responseFormat.compatible != false;
                geminiQuery.generationConfig.responseSchema = JSONSchemaCompatibilityEnhancer.enhance(responseFormat.schema, compatible ? JSONSchemaCompatibilityEnhancer.Provider.GEMINI : JSONSchemaCompatibilityEnhancer.Provider.PASSTHROUGH);
            }
        }
        if (ccs.toolChoice != null) {
            geminiQuery.toolConfig = GeminiQueryAdapter.adapt(ccs.toolChoice);
        }
        if (ccs.tools != null && !ccs.tools.isEmpty()) {
            GeminiQuery.FunctionTool ft = new GeminiQuery.FunctionTool();
            ft.functionDeclarations = ccs.tools.stream().map(GeminiQueryAdapter::adapt).collect(Collectors.toList());
            geminiQuery.tools.add(ft);
        }
        messages.stream().forEach(message -> {
            if ("system".equals(message.role)) {
                if (geminiQuery.systemInstruction != null) {
                    logger.warn((Object)("Multiple system prompts are not supported for Gemini models, using only the last one: " + JSON.pretty((Object)messages)));
                }
                geminiQuery.systemInstruction = GeminiQueryAdapter.chatMessageToGeminiMessage(message);
            } else {
                geminiQuery.contents.add(GeminiQueryAdapter.chatMessageToGeminiMessage(message));
            }
        });
        return geminiQuery;
    }

    private static GeminiQuery.FunctionDeclaration adapt(LLMClient.AbstractTool tool) {
        if (tool instanceof LLMClient.FunctionTool) {
            LLMClient.FunctionTool fDesc = (LLMClient.FunctionTool)tool;
            GeminiQuery.FunctionDeclaration ft = new GeminiQuery.FunctionDeclaration();
            ft.name = fDesc.function.name;
            ft.description = fDesc.function.description;
            boolean compatible = fDesc.function.compatible == null || fDesc.function.compatible != false;
            ft.parameters = JSONSchemaCompatibilityEnhancer.enhance(fDesc.function.getParameters(), compatible ? JSONSchemaCompatibilityEnhancer.Provider.GEMINI : JSONSchemaCompatibilityEnhancer.Provider.PASSTHROUGH);
            return ft;
        }
        throw new IllegalArgumentException(String.format("Unknown tool: %s", tool.getClass().getSimpleName()));
    }

    private static GeminiQuery.Message chatMessageToGeminiMessage(LLMClient.ChatMessage chatMessage) {
        if (chatMessage.role.equals("user")) {
            GeminiQuery.Message message = GeminiQueryAdapter.adaptMessage(chatMessage);
            message.role = "user";
            return message;
        }
        if (chatMessage.role.equals("assistant")) {
            GeminiQuery.Message message = GeminiQueryAdapter.adaptMessage(chatMessage);
            message.role = "model";
            return message;
        }
        if (chatMessage.role.equals("system")) {
            return GeminiQueryAdapter.adaptSystemMessage(chatMessage);
        }
        if (chatMessage.role.equals("tool")) {
            return GeminiQueryAdapter.adaptToolMessage(chatMessage);
        }
        throw new IllegalArgumentException("Unknown message role for Gemini model: " + chatMessage.role);
    }

    private static GeminiQuery.Message adaptMessage(LLMClient.ChatMessage chatMessage) {
        GeminiQuery.Message message = new GeminiQuery.Message();
        if (chatMessage.toolCalls != null) {
            message.parts = chatMessage.toolCalls.stream().map(GeminiQueryAdapter::adaptToolCall).collect(Collectors.toList());
        } else if (chatMessage.isTextOnly()) {
            GeminiQuery.MessagePart part = new GeminiQuery.MessagePart();
            part.text = chatMessage.getText();
            message.parts.add(part);
        } else {
            assert (chatMessage.parts != null);
            message.parts = GeminiQueryAdapter.adaptMessageParts(chatMessage.parts);
        }
        return message;
    }

    private static GeminiQuery.MessagePart adaptToolCall(LLMClient.AbstractToolCall atc) {
        if (atc instanceof LLMClient.FunctionToolCall) {
            LLMClient.FunctionToolCall ftc = (LLMClient.FunctionToolCall)atc;
            GeminiQuery.MessagePart part = new GeminiQuery.MessagePart();
            GeminiQuery.FunctionCall call = new GeminiQuery.FunctionCall();
            call.name = ftc.function.name;
            call.args = (JsonObject)JSON.parse((String)ftc.function.arguments, JsonObject.class);
            part.functionCall = call;
            return part;
        }
        throw new IllegalArgumentException("AbstractToolCall should be of type FunctionToolCall");
    }

    private static GeminiQuery.Message adaptSystemMessage(LLMClient.ChatMessage message) {
        if (!message.isTextOnly() || message.toolCalls != null) {
            throw new IllegalArgumentException("Chat message with role: system must be text-only");
        }
        GeminiQuery.Message geminiMessage = GeminiQueryAdapter.adaptMessage(message);
        geminiMessage.role = "system";
        return geminiMessage;
    }

    private static List<GeminiQuery.MessagePart> adaptMessageParts(List<LLMClient.ChatMessagePart> chatParts) {
        chatParts = GeminiQueryAdapter.addTextPartIfNotPresent(chatParts);
        ArrayList<GeminiQuery.MessagePart> geminiParts = new ArrayList<GeminiQuery.MessagePart>();
        for (LLMClient.ChatMessagePart part : chatParts) {
            GeminiQuery.MessagePart geminiPart = new GeminiQuery.MessagePart();
            switch (part.type) {
                case TEXT: {
                    if (Strings.isNullOrEmpty((String)part.text)) {
                        geminiPart.text = "image:";
                        break;
                    }
                    geminiPart.text = part.text;
                    break;
                }
                case IMAGE_INLINE: {
                    (geminiPart.inlineData = new GeminiQuery.InlineData()).mimeType = StringUtils.isBlank((String)part.imageMimeType) ? "image/jpeg" : part.imageMimeType;
                    geminiPart.inlineData.data = part.inlineImage;
                    break;
                }
                case IMAGE_URI: {
                    throw new IllegalArgumentException("Image URIs not supported for Vertex models. Use inline images instead.");
                }
            }
            geminiParts.add(geminiPart);
        }
        return geminiParts;
    }

    private static List<LLMClient.ChatMessagePart> addTextPartIfNotPresent(List<LLMClient.ChatMessagePart> chatParts) {
        Set partTypes = chatParts.stream().map(p -> p.type).collect(Collectors.toSet());
        Set onlyImageType = Stream.of(LLMClient.ChatMessagePartType.IMAGE_INLINE, LLMClient.ChatMessagePartType.IMAGE_URI).collect(Collectors.toSet());
        if (onlyImageType.containsAll(partTypes)) {
            LLMClient.ChatMessagePart dummyTextPart = new LLMClient.ChatMessagePart();
            dummyTextPart.type = LLMClient.ChatMessagePartType.TEXT;
            chatParts.add(dummyTextPart);
        }
        return chatParts;
    }

    private static GeminiQuery.Message adaptToolMessage(LLMClient.ChatMessage message) {
        if (!message.isTextOnly()) {
            throw new IllegalArgumentException(String.format("Chat message with role: %s must be text-only", message.role));
        }
        if (message.toolOutputs == null) {
            throw new IllegalArgumentException(String.format("Chat message with role: %s must have tool outputs", message.role));
        }
        GeminiQuery.Message geminiMessage = new GeminiQuery.Message();
        geminiMessage.role = "user";
        geminiMessage.parts = message.toolOutputs.stream().map(GeminiQueryAdapter::adapt).collect(Collectors.toList());
        return geminiMessage;
    }

    private static GeminiQuery.MessagePart adapt(LLMClient.ToolOutput toolOutput) {
        GeminiQuery.MessagePart part = new GeminiQuery.MessagePart();
        GeminiQuery.FunctionResponse resp = new GeminiQuery.FunctionResponse();
        resp.name = toolOutput.callId;
        try {
            resp.response = (JsonObject)JSON.parse((String)toolOutput.output, JsonObject.class);
        }
        catch (JsonSyntaxException e) {
            JsonObject json = new JsonObject();
            json.addProperty("result", toolOutput.output);
            resp.response = json;
        }
        part.functionResponse = resp;
        return part;
    }

    private static GeminiQuery.ToolConfig adapt(LLMClient.ToolChoice choice) {
        GeminiQuery.ToolConfig toolConfig = new GeminiQuery.ToolConfig();
        GeminiQuery.FunctionCallingConfig functionCallingConfig = new GeminiQuery.FunctionCallingConfig();
        if (choice instanceof LLMClient.NoneToolChoice) {
            functionCallingConfig.mode = GeminiQuery.ToolMode.NONE;
        } else if (choice instanceof LLMClient.RequiredToolChoice) {
            functionCallingConfig.mode = GeminiQuery.ToolMode.ANY;
        } else if (choice instanceof LLMClient.AutoToolChoice) {
            functionCallingConfig.mode = GeminiQuery.ToolMode.AUTO;
        } else if (choice instanceof LLMClient.NamedToolChoice) {
            LLMClient.NamedToolChoice ntc = (LLMClient.NamedToolChoice)choice;
            functionCallingConfig.mode = GeminiQuery.ToolMode.ANY;
            functionCallingConfig.allowedFunctions = Arrays.asList(ntc.name);
        }
        toolConfig.functionCallingConfig = functionCallingConfig;
        return toolConfig;
    }
}

