/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online;

import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.lang.StringUtils;

public class LLMChatMessageUtils {
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.chat_message_utils");

    public static List<LLMClient.ChatMessage> collapseAdjacentSameRoleMessages(List<LLMClient.ChatMessage> messages) {
        ArrayList<LLMClient.ChatMessage> newMessages = new ArrayList<LLMClient.ChatMessage>();
        LLMClient.ChatMessage currentMessage = null;
        for (LLMClient.ChatMessage message : messages) {
            if (currentMessage == null) {
                currentMessage = new LLMClient.ChatMessage(message);
                continue;
            }
            if (!currentMessage.role.equals(message.role) || currentMessage.role.equals("memoryFragment")) {
                newMessages.add(currentMessage);
                currentMessage = new LLMClient.ChatMessage(message);
                continue;
            }
            List<LLMClient.ChatMessagePart> messageParts = LLMChatMessageUtils.chatMessageAsParts(message);
            if (!messageParts.isEmpty()) {
                if (currentMessage.parts == null) {
                    List<LLMClient.ChatMessagePart> currentMessageParts = LLMChatMessageUtils.chatMessageAsParts(currentMessage);
                    currentMessage.parts = new ArrayList<LLMClient.ChatMessagePart>(currentMessageParts);
                }
                currentMessage.parts.addAll(messageParts);
            }
            if (message.toolCalls != null) {
                if (currentMessage.toolCalls == null) {
                    currentMessage.toolCalls = new ArrayList<LLMClient.AbstractToolCall>();
                }
                currentMessage.toolCalls.addAll(message.toolCalls);
            }
            if (message.toolOutputs != null) {
                if (currentMessage.toolOutputs == null) {
                    currentMessage.toolOutputs = new ArrayList<LLMClient.ToolOutput>();
                }
                currentMessage.toolOutputs.addAll(message.toolOutputs);
            }
            if (message.toolValidationRequests != null) {
                if (currentMessage.toolValidationRequests == null) {
                    currentMessage.toolValidationRequests = new ArrayList<LLMClient.ToolValidationRequest>();
                }
                currentMessage.toolValidationRequests.addAll(message.toolValidationRequests);
            }
            if (message.toolValidationResponses == null) continue;
            if (currentMessage.toolValidationResponses == null) {
                currentMessage.toolValidationResponses = new ArrayList<LLMClient.ToolValidationResponse>();
            }
            currentMessage.toolValidationResponses.addAll(message.toolValidationResponses);
        }
        if (currentMessage != null) {
            newMessages.add(currentMessage);
        }
        return newMessages;
    }

    public static List<LLMClient.ChatMessage> convertMessageRole(List<LLMClient.ChatMessage> messages, String fromRole, String toRole) {
        return messages.stream().map(LLMClient.ChatMessage::new).map(message -> {
            if (message.role.equals(fromRole)) {
                message.role = toRole;
            }
            return message;
        }).collect(Collectors.toList());
    }

    public static List<LLMClient.ChatMessage> convertExtraSystemMessageToUser(List<LLMClient.ChatMessage> messages) {
        ArrayList<LLMClient.ChatMessage> newMessages = new ArrayList<LLMClient.ChatMessage>();
        boolean firstSystemMessage = true;
        for (LLMClient.ChatMessage message : messages) {
            LLMClient.ChatMessage newMessage = new LLMClient.ChatMessage(message);
            if (!message.role.equals("system")) {
                firstSystemMessage = false;
            }
            if (message.role.equals("system") && !firstSystemMessage) {
                newMessage.role = "user";
            }
            newMessages.add(newMessage);
        }
        return newMessages;
    }

    public static List<LLMClient.ChatMessage> convertPartsToContentMessagesIfPossible(List<LLMClient.ChatMessage> messages) {
        boolean messagesAreTextOnly = messages.stream().allMatch(message -> message.isTextOnly());
        return messages.stream().map(message -> messagesAreTextOnly ? new LLMClient.ChatMessage(message.role, message.getText()) : message).collect(Collectors.toList());
    }

    private static List<LLMClient.ChatMessagePart> chatMessageAsParts(LLMClient.ChatMessage message) {
        if (message.parts != null) {
            return List.copyOf(message.parts);
        }
        String text = message.getText();
        if (text == null) {
            return Collections.emptyList();
        }
        LLMClient.ChatMessagePart part = new LLMClient.ChatMessagePart();
        part.text = text;
        return List.of(part);
    }

    public static JsonObject completionQueryToLightJsonObject(LLMClient.SingleCompletionQuery query) {
        JsonObject jsonQuery = new JsonObject();
        JsonArray jsonMessages = new JsonArray();
        jsonQuery.add("messages", (JsonElement)jsonMessages);
        for (LLMClient.ChatMessage message : query.messages) {
            if (message.isTextOnly()) {
                jsonMessages.add((JsonElement)JSON.toJsonObject((Object)message, (String[])new String[0]));
                continue;
            }
            JsonObject jsonMessage = new JsonObject();
            jsonMessages.add((JsonElement)jsonMessage);
            JsonArray jsonParts = new JsonArray();
            jsonMessage.add("parts", (JsonElement)jsonParts);
            for (LLMClient.ChatMessagePart part : message.parts) {
                JsonObject jsonPart = new JsonObject();
                jsonParts.add((JsonElement)jsonPart);
                switch (part.type) {
                    case TEXT: {
                        jsonPart.addProperty("text", part.text);
                        break;
                    }
                    case IMAGE_INLINE: {
                        jsonPart.addProperty("image", "<inline-image>");
                        jsonPart.addProperty("imageMimeType", part.imageMimeType);
                        break;
                    }
                    case IMAGE_URI: {
                        jsonPart.addProperty("imageUrl", part.imageUrl);
                    }
                }
            }
            jsonMessage.addProperty("role", message.role);
            jsonMessage.add("toolOutputs", (JsonElement)JSON.toJsonObject(message.toolOutputs, (String[])new String[0]));
            jsonMessage.add("toolCalls", (JsonElement)JSON.toJsonObject(message.toolCalls, (String[])new String[0]));
        }
        return jsonQuery;
    }

    public static List<String> completionQueryToStringsOfTexts(LLMClient.SingleCompletionQuery query) {
        ArrayList<String> ret = new ArrayList<String>();
        for (LLMClient.ChatMessage message : query.messages) {
            String txt = message.getTextEvenIfNotTextOnly();
            if (!StringUtils.isNotBlank((String)txt)) continue;
            ret.add(txt);
        }
        return ret;
    }

    public static JsonObject embeddingQueryToLightJsonObject(LLMClient.EmbeddingQuery query) {
        JsonObject jsonQuery = new JsonObject();
        if (query.hasText()) {
            jsonQuery.addProperty("text", query.text);
        }
        if (query.hasImage()) {
            jsonQuery.addProperty("image", "<inline-image>");
        }
        return jsonQuery;
    }

    public static JsonObject rerankingQueryToLightJsonObject(LLMClient.RerankingQuery query) {
        if (query.queryParts.stream().anyMatch(part -> !part.isText())) {
            throw new UnsupportedOperationException("Only text supported for reranking");
        }
        JsonObject jsonQuery = new JsonObject();
        jsonQuery.add("query", (JsonElement)JSON.toJsonObject((Object)query, (String[])new String[0]));
        return jsonQuery;
    }

    public static void throwIfUnsupportedToolOutputParts(List<LLMClient.ChatMessage> chatMessages) throws UnsupportedOperationException {
        if (chatMessages.isEmpty()) {
            return;
        }
        for (LLMClient.ChatMessage chatMessage : chatMessages) {
            if (chatMessage.toolOutputs == null) continue;
            for (LLMClient.ToolOutput toolOutput : chatMessage.toolOutputs) {
                if (toolOutput.parts == null || toolOutput.parts.isEmpty()) continue;
                HashSet<String> partTypes = new HashSet<String>();
                for (LLMClient.ToolOutputPart part : toolOutput.parts) {
                    partTypes.add(part.type.toString());
                }
                throw new UnsupportedOperationException(toolOutput.parts.size() + " part(s) with types " + ((Object)partTypes).toString() + " from tool output " + toolOutput.callId + " are not supported by this LLM");
            }
        }
    }

    public static List<LLMClient.ChatMessage> addToolOutputPartsAsMultiPartUserMessages(AuthCtx authCtx, List<LLMClient.ChatMessage> chatMessages, boolean includeImages) {
        if (chatMessages.isEmpty()) {
            return chatMessages;
        }
        ArrayList<LLMClient.ChatMessage> chatMessagesOutput = new ArrayList<LLMClient.ChatMessage>(chatMessages.size());
        for (LLMClient.ChatMessage chatMessage : chatMessages) {
            chatMessagesOutput.add(chatMessage);
            if (!StringUtils.equals((String)chatMessage.role, (String)"tool") || chatMessage.toolOutputs == null || chatMessage.toolOutputs.isEmpty()) continue;
            ArrayList<LLMClient.ChatMessagePart> parts = new ArrayList<LLMClient.ChatMessagePart>();
            for (LLMClient.ToolOutput toolOutput : chatMessage.toolOutputs) {
                parts.addAll(LLMChatMessageUtils.getPartsFromToolOutput(authCtx, toolOutput, includeImages));
            }
            if (parts.isEmpty()) continue;
            logger.debug((Object)"Adding tool output parts as multi-part user message");
            parts.add(0, new LLMClient.ChatMessagePart().withText("Here are text and images returned by tool calls"));
            chatMessagesOutput.add(new LLMClient.ChatMessage("user", parts));
        }
        return chatMessagesOutput;
    }

    public static List<LLMClient.ChatMessagePart> getPartsFromToolOutput(AuthCtx authCtx, LLMClient.ToolOutput toolOutput, boolean includeImages) {
        ArrayList<LLMClient.ChatMessagePart> parts = new ArrayList<LLMClient.ChatMessagePart>();
        if (toolOutput.parts == null) {
            return parts;
        }
        int index = 1;
        for (LLMClient.ToolOutputPart part : toolOutput.parts) {
            switch (part.type) {
                case TEXT: {
                    parts.add(new LLMClient.ChatMessagePart().withText("Text returned by tool call " + toolOutput.callId + ": " + part.text));
                    break;
                }
                case IMAGE_INLINE: 
                case IMAGE_URI: {
                    LLMClient.ChatMessagePart imagePart = new LLMClient.ChatMessagePart(part);
                    if (!includeImages) {
                        logger.warn((Object)("Ignoring image from tool output " + toolOutput.callId + " as it's not supported by this LLM"));
                        break;
                    }
                    if (!imagePart.containsImageData()) {
                        logger.warn((Object)("Ignoring image from tool output " + toolOutput.callId + " because its data can't be loaded"));
                        break;
                    }
                    parts.add(new LLMClient.ChatMessagePart().withText("The next image, at index " + index++ + ", was returned by tool call " + toolOutput.callId));
                    parts.add(imagePart);
                    break;
                }
                case IMAGE_REF: {
                    LLMClient.ChatMessagePart inlineImagePart;
                    if (!includeImages) {
                        logger.warn((Object)("Ignoring image from tool output " + toolOutput.callId + " as it's not supported by this LLM"));
                        break;
                    }
                    try {
                        inlineImagePart = part.toInlineImage(authCtx);
                    }
                    catch (Exception e) {
                        logger.warn((Object)("failed to load image data from image ref part at " + part.folderId + part.path + "%s"), (Throwable)e);
                        break;
                    }
                    if (!inlineImagePart.containsImageData()) {
                        logger.warn((Object)("ignoring image ref part without content at " + part.folderId + part.path));
                        break;
                    }
                    parts.add(new LLMClient.ChatMessagePart().withText("The next image from file " + part.path + " in folder " + part.folderId + ", was returned by tool call " + toolOutput.callId));
                    parts.add(new LLMClient.ChatMessagePart(inlineImagePart));
                }
            }
        }
        return parts;
    }
}

