/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMAllocationTagsUtils;
import com.dataiku.dip.llm.LLMAuditHelper;
import com.dataiku.dip.llm.governance.GuardrailsPipelineSettings;
import com.dataiku.dip.llm.governance.GuardrailsPipelineUtils;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMMeshStreamClient;
import com.dataiku.dip.llm.promptstudio.PromptResponse;
import com.dataiku.dip.llm.promptstudio.PromptStudio;
import com.dataiku.dip.llm.promptstudio.PromptStudiosCRUDService;
import com.dataiku.dip.llm.utils.StreamingChunkEmitter;
import com.dataiku.dip.recipes.nlp.common.LLMCompletionSettings;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.resourceusage.ComputeResourceUsageReportingService;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.audit.AuditTrailService;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.streaming.endpoints.httpsse.MiniSSEEmitter;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.google.gson.JsonObject;
import java.util.ArrayList;
import java.util.Map;
import java.util.Objects;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;

public class PromptChatClient {
    @Autowired
    private AuditTrailService auditTrailService;
    @Autowired
    private ComputeResourceUsageReportingService cruReportingService;
    @Autowired
    private PromptStudiosCRUDService promptStudiosCrudService;
    private final AuthCtx authCtx;
    private final PromptChat promptChat;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.promptchat.client");

    public PromptChatClient(AuthCtx authCtx, PromptChat promptChat) {
        this.authCtx = authCtx;
        this.promptChat = promptChat;
        SpringUtils.getInstance().autowire((Object)this);
    }

    public PromptResponse.SingleInputPromptResponse streamChatResponse(MiniSSEEmitter emitter) throws Exception {
        LLMClient.SimpleCompletionResponseOrError scre;
        this.promptChat.lastUserMessage.runBy = this.authCtx.getIdentifier();
        if (!this.promptChat.messages.containsKey(this.promptChat.lastUserMessage.id)) {
            this.addNewChatMessage(this.promptChat.lastUserMessage);
        }
        LLMClient.SingleCompletionQuery query = this.getStreamableCompletionRequestFromPrompt();
        GuardrailsPipelineSettings connectionGuardrailsPipelineSettings = GuardrailsPipelineUtils.getConnectionAndLLMLevelSettings(this.authCtx, this.promptChat.projectKey, this.promptChat.enrichedLLMRef);
        GuardrailsPipelineSettings guardrailsPipelineSettings = GuardrailsPipelineUtils.mergeEnforcementSettings(connectionGuardrailsPipelineSettings, this.promptChat.guardrailsPipelineSettings);
        AbstractLLMConnection connection = null;
        LLMClient.CompletionSettings completionSettings = this.promptChat.llmSettings.toFullSettings();
        logger.info((Object)("Sending single completion query to LLM " + this.promptChat.enrichedLLMRef.id + ": " + JSON.log((Object)query.getSafeForLoggingCopy())));
        StreamingChunkEmitter streamingChunkEmitter = new StreamingChunkEmitter(emitter);
        try (LLMMeshStreamClient streamClient = new LLMMeshStreamClient(this.authCtx, this.promptChat.projectKey, this.promptChat.enrichedLLMRef, guardrailsPipelineSettings, this.promptChat.streamingDisabled, this.promptChat.useDevKernel, streamingChunkEmitter);){
            streamClient.enableEmulatedStreamingInfoChunk();
            connection = streamClient.getConnection();
            LLMAuditHelper.emitToolValidationAuditsIfNeeded(this.auditTrailService, this.promptChat.enrichedLLMRef, connection, query);
            scre = streamClient.streamComplete(query, completionSettings);
            ComputeResourceUsage cru = streamClient.getTotalCRU();
            if (cru != null) {
                LLMAllocationTagsUtils.addAllocationTagsToCRU(query, cru);
                this.cruReportingService.reportComplete(cru);
            }
        }
        catch (Exception e) {
            scre = LLMClient.SimpleCompletionResponseOrError.fromError(e);
        }
        LLMAuditHelper.emitLLMCompletionAuditFromBackendIfNeeded(this.auditTrailService, this.promptChat.enrichedLLMRef, connection, query, scre);
        if (scre.toolValidationRequests != null && !scre.toolValidationRequests.isEmpty()) {
            PromptStudio.ConversationMessage toolValidationRequestsMessage = this.buildToolValidationRequestsMessage(scre);
            return this.buildSingleInputPromptResponse(toolValidationRequestsMessage, scre);
        }
        PromptStudio.ConversationMessage assistantMessage = this.buildAssistantMessage(scre);
        logger.info((Object)("Returning messages  " + JSON.log(this.promptChat.messages)));
        return this.buildSingleInputPromptResponse(assistantMessage, scre);
    }

    private void addNewChatMessage(PromptStudio.ConversationMessage newMessage) {
        if (this.promptChat.messages.isEmpty()) {
            PromptStudio.ConversationMessage parentMessage = new PromptStudio.ConversationMessage();
            this.promptChat.messages.put(parentMessage.id, parentMessage);
            newMessage.parentId = parentMessage.id;
        }
        newMessage.version = (int)this.promptChat.messages.values().stream().filter(message -> Objects.equals(message.parentId, newMessage.parentId)).count();
        this.promptChat.messages.put(newMessage.id, newMessage);
    }

    private LLMClient.SingleCompletionQuery getStreamableCompletionRequestFromPrompt() {
        LLMClient.SingleCompletionQuery query = new LLMClient.SingleCompletionQuery();
        query.messages = new ArrayList<LLMClient.ChatMessage>();
        query.messages.add(this.promptChat.lastUserMessage.message);
        String parentId = this.promptChat.lastUserMessage.parentId;
        while (parentId != null) {
            PromptStudio.ConversationMessage parentMessage = this.promptChat.messages.get(parentId);
            if (parentMessage.message != null) {
                query.messages.add(0, parentMessage.message);
            }
            parentId = parentMessage.parentId;
        }
        if (this.promptChat.systemMessage != null && !this.promptChat.systemMessage.isBlank()) {
            LLMClient.ChatMessage systemChatMessage = new LLMClient.ChatMessage();
            systemChatMessage.role = "system";
            systemChatMessage.setTextOnly(this.promptChat.systemMessage);
            query.messages.add(0, systemChatMessage);
        }
        query.context = this.promptChat.context;
        return query;
    }

    private PromptStudio.ConversationMessage buildAssistantMessage(LLMClient.SimpleCompletionResponseOrError scre) {
        String parentId = this.promptChat.lastUserMessage.id;
        if (scre.memoryFragment != null) {
            PromptStudio.ConversationMessage memoryFragmentMessage = this.buildMemoryFragmentMessage(scre);
            parentId = memoryFragmentMessage.id;
        }
        PromptStudio.ConversationMessage assistantMessage = new PromptStudio.ConversationMessage();
        assistantMessage.parentId = parentId;
        assistantMessage.message = new LLMClient.ChatMessage("assistant", scre.text);
        if (!scre.ok) {
            assistantMessage.error = true;
            assistantMessage.llmError = scre.errorMessage;
        } else if (StringUtils.isEmpty((String)scre.text)) {
            assistantMessage.error = true;
            assistantMessage.llmError = "LLM response is empty.";
        }
        assistantMessage.completionSettings = this.promptChat.llmSettings;
        assistantMessage.llmStructuredRef = this.promptChat.enrichedLLMRef;
        assistantMessage.systemMessage = this.promptChat.systemMessage;
        assistantMessage.context = this.promptChat.context;
        assistantMessage.fullTrace = scre.trace;
        assistantMessage.artifacts = scre.artifacts;
        assistantMessage.sources = scre.sources;
        assistantMessage.additionalInformation = scre.additionalInformation;
        this.addNewChatMessage(assistantMessage);
        return assistantMessage;
    }

    private PromptStudio.ConversationMessage buildToolValidationRequestsMessage(LLMClient.SimpleCompletionResponseOrError scre) {
        String parentId = this.promptChat.lastUserMessage.id;
        if (scre.memoryFragment != null) {
            PromptStudio.ConversationMessage memoryFragmentMessage = this.buildMemoryFragmentMessage(scre);
            parentId = memoryFragmentMessage.id;
        }
        PromptStudio.ConversationMessage toolValidationRequestsMessage = new PromptStudio.ConversationMessage();
        toolValidationRequestsMessage.parentId = parentId;
        toolValidationRequestsMessage.message = new LLMClient.ChatMessage("toolValidationRequests", scre.text);
        toolValidationRequestsMessage.message.role = "toolValidationRequests";
        toolValidationRequestsMessage.message.toolValidationRequests = scre.toolValidationRequests;
        if (!scre.ok) {
            toolValidationRequestsMessage.error = true;
            toolValidationRequestsMessage.llmError = scre.errorMessage;
        } else if (scre.toolValidationRequests == null || scre.toolValidationRequests.isEmpty()) {
            toolValidationRequestsMessage.error = true;
            toolValidationRequestsMessage.llmError = "LLM response is empty.";
        }
        toolValidationRequestsMessage.completionSettings = this.promptChat.llmSettings;
        toolValidationRequestsMessage.llmStructuredRef = this.promptChat.enrichedLLMRef;
        toolValidationRequestsMessage.systemMessage = this.promptChat.systemMessage;
        toolValidationRequestsMessage.fullTrace = scre.trace;
        toolValidationRequestsMessage.artifacts = scre.artifacts;
        toolValidationRequestsMessage.sources = scre.sources;
        toolValidationRequestsMessage.additionalInformation = scre.additionalInformation;
        this.addNewChatMessage(toolValidationRequestsMessage);
        return toolValidationRequestsMessage;
    }

    private PromptStudio.ConversationMessage buildMemoryFragmentMessage(LLMClient.SimpleCompletionResponseOrError scre) {
        PromptStudio.ConversationMessage memoryFragmentMessage = new PromptStudio.ConversationMessage();
        memoryFragmentMessage.parentId = this.promptChat.lastUserMessage.id;
        memoryFragmentMessage.message = new LLMClient.ChatMessage();
        memoryFragmentMessage.message.role = "memoryFragment";
        memoryFragmentMessage.message.memoryFragment = scre.memoryFragment;
        memoryFragmentMessage.completionSettings = this.promptChat.llmSettings;
        memoryFragmentMessage.llmStructuredRef = this.promptChat.enrichedLLMRef;
        memoryFragmentMessage.systemMessage = this.promptChat.systemMessage;
        this.addNewChatMessage(memoryFragmentMessage);
        return memoryFragmentMessage;
    }

    private PromptResponse.SingleInputPromptResponse buildSingleInputPromptResponse(PromptStudio.ConversationMessage assistantMessage, LLMClient.SimpleCompletionResponseOrError scre) {
        PromptResponse.SingleInputPromptResponse sipr = new PromptResponse.SingleInputPromptResponse();
        sipr.chatMessages = this.promptChat.messages;
        sipr.lastMessageId = assistantMessage.id;
        this.promptStudiosCrudService.fillSingleInputPromptResponse(sipr, scre, true);
        sipr.contextUpsert = scre.contextUpsert;
        return sipr;
    }

    public static class PromptChat {
        public String projectKey;
        public LLMCompletionSettings llmSettings;
        public EnrichedLLMStructuredRef enrichedLLMRef;
        public GuardrailsPipelineSettings guardrailsPipelineSettings;
        public Map<String, PromptStudio.ConversationMessage> messages;
        public PromptStudio.ConversationMessage lastUserMessage;
        public String systemMessage;
        public JsonObject context;
        public boolean streamingDisabled;
        public boolean useDevKernel;
    }
}

