/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.server.a2a;

import com.dataiku.dip.DKUApp;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMAuditHelper;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.governance.GuardrailsPipelineSettings;
import com.dataiku.dip.llm.governance.GuardrailsPipelineUtils;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMClientFactory;
import com.dataiku.dip.llm.online.LLMMeshStreamClient;
import com.dataiku.dip.llm.utils.IStreamingChunkEmitter;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.resourceusage.ComputeResourceUsageReportingService;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.audit.AuditTrailService;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.a2a.DSSA2AUser;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadeliba2a.io.a2a.server.agentexecution.AgentExecutor;
import com.dataiku.dss.shadeliba2a.io.a2a.server.agentexecution.RequestContext;
import com.dataiku.dss.shadeliba2a.io.a2a.server.auth.User;
import com.dataiku.dss.shadeliba2a.io.a2a.server.events.EventQueue;
import com.dataiku.dss.shadeliba2a.io.a2a.server.tasks.TaskUpdater;
import com.dataiku.dss.shadeliba2a.io.a2a.spec.Artifact;
import com.dataiku.dss.shadeliba2a.io.a2a.spec.DataPart;
import com.dataiku.dss.shadeliba2a.io.a2a.spec.Event;
import com.dataiku.dss.shadeliba2a.io.a2a.spec.JSONRPCError;
import com.dataiku.dss.shadeliba2a.io.a2a.spec.Message;
import com.dataiku.dss.shadeliba2a.io.a2a.spec.TaskArtifactUpdateEvent;
import com.dataiku.dss.shadeliba2a.io.a2a.spec.TextPart;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.google.gson.JsonObject;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang.mutable.MutableBoolean;
import org.springframework.beans.factory.annotation.Autowired;

public class DSSAgentExecutor
implements AgentExecutor {
    @Autowired
    private ComputeResourceUsageReportingService cruReportingService;
    @Autowired
    private AuditTrailService auditTrailService;
    static long contextsCacheExpirationInHours = DKUApp.getParams().getLongParam("dku.a2a.contextsCacheExpirationInHours", 12L);
    private static final Cache<ContextCacheKey, DSSAgentContextState> contextsCache = CacheBuilder.newBuilder().maximumSize(1000L).expireAfterAccess(contextsCacheExpirationInHours, TimeUnit.HOURS).build();
    public final String projectKey;
    public String agentId;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.server.a2a.DSSAgentExecutor");

    public DSSAgentExecutor(String projectKey, String agentId) {
        this.projectKey = projectKey;
        this.agentId = agentId;
        SpringUtils.getInstance().autowire((Object)this);
    }

    public void execute(final RequestContext context, final EventQueue queue) throws JSONRPCError {
        logger.info((Object)("Starting execution of DSSAgentExecutor for context " + context.getContextId()));
        try {
            final User user = context.getCallContext().getUser();
            assert (user instanceof DSSA2AUser);
            AuthCtx authCtx = ((DSSA2AUser)user).getAuthCtx();
            final TaskUpdater updater = new TaskUpdater(context, queue);
            updater.submit();
            List<TextPart> parts = List.of(new TextPart("Thinking ..."));
            updater.startWork(updater.newAgentMessage(parts, Collections.emptyMap()));
            String requestId = "agent:" + this.agentId;
            LLMStructuredRef llmRef = LLMStructuredRef.decodeId((String)requestId);
            try (LLMClient llmClient = LLMClientFactory.get((AuthCtx)authCtx, (String)this.projectKey, (LLMStructuredRef)llmRef);){
                llmClient.throwIfBadImageAuditFolder();
                GuardrailsPipelineSettings connectionGuardrailsPipelineSettings = GuardrailsPipelineUtils.getConnectionAndLLMLevelSettings((AuthCtx)authCtx, (String)this.projectKey, (LLMStructuredRef)llmRef);
                EnrichedLLMStructuredRef enrichedLLMRef = llmClient.getEnrichedRef();
                logger.info((Object)("Creating or loading conversation for context " + context.getContextId() + " for user " + user.getUsername()));
                DSSAgentContextState _contextState = (DSSAgentContextState)contextsCache.getIfPresent((Object)new ContextCacheKey(context.getContextId(), user.getUsername()));
                logger.info((Object)("Starting stream from conversation: " + JSON.log((Object)_contextState)));
                if (_contextState == null) {
                    _contextState = new DSSAgentContextState();
                }
                final DSSAgentContextState contextState = _contextState;
                LLMClient.SingleCompletionQuery scq = new LLMClient.SingleCompletionQuery();
                scq.messages.addAll(contextState.messages);
                scq.context = (JsonObject)JSON.deepCopy((Object)contextState.context);
                String input = context.getUserInput("text");
                LLMClient.ChatMessage inputChatMessage = new LLMClient.ChatMessage("user", input);
                contextState.messages.add(inputChatMessage);
                scq.messages.add(inputChatMessage);
                final String RAW_TEXT_ARTIFACT_ID = "raw_answer_text";
                final StringBuilder receivedTextBuffer = new StringBuilder();
                final MutableBoolean firstTextChunkSent = new MutableBoolean(false);
                IStreamingChunkEmitter emitter = new IStreamingChunkEmitter(){

                    public void initSuccess() {
                    }

                    public void emitCompletionChunk(LLMClient.StreamedCompletionResponseChunk chunk) throws Exception {
                        logger.trace((Object)("emitCompletionChunk " + JSON.log((Object)chunk)));
                        if (!StringUtils.isEmpty((String)chunk.text)) {
                            receivedTextBuffer.append(chunk.text);
                            List<TextPart> parts = List.of(new TextPart(chunk.text));
                            Artifact a2aArtifact = new Artifact.Builder().artifactId(RAW_TEXT_ARTIFACT_ID).name("messageStream").parts(parts).build();
                            TaskArtifactUpdateEvent event = new TaskArtifactUpdateEvent.Builder().taskId(updater.getTaskId()).contextId(updater.getContextId()).artifact(a2aArtifact).append(Boolean.valueOf(firstTextChunkSent.booleanValue())).build();
                            firstTextChunkSent.setValue(true);
                            queue.enqueueEvent((Event)event);
                        }
                        if (chunk.memoryFragment != null) {
                            LLMClient.ChatMessage cm = new LLMClient.ChatMessage();
                            cm.role = "memoryFragment";
                            cm.memoryFragment = chunk.memoryFragment;
                            contextState.messages.add(cm);
                        }
                        if (chunk.artifacts != null) {
                            for (LLMClient.Artifact artifact : chunk.artifacts) {
                                ArrayList<Object> a2AParts = new ArrayList<Object>();
                                block9: for (LLMClient.SourceItem part : artifact.parts) {
                                    switch (part.type) {
                                        case "TEXT": {
                                            a2AParts.add(new TextPart(part.text));
                                            continue block9;
                                        }
                                        case "RECORDS": {
                                            HashMap<String, String> recordsMap = new HashMap<String, String>();
                                            recordsMap.put("data", JSON.json((Object)part.records));
                                            a2AParts.add(new DataPart(recordsMap));
                                            continue block9;
                                        }
                                    }
                                    a2AParts.add(new TextPart("Untranslated artifact: " + JSON.json((Object)part)));
                                }
                                Artifact a2aArtifact = new Artifact.Builder().artifactId(UUID.randomUUID().toString()).name(!StringUtils.isEmpty((String)artifact.name) ? artifact.name : artifact.type).description(artifact.description).parts(a2AParts).build();
                                TaskArtifactUpdateEvent event = new TaskArtifactUpdateEvent.Builder().taskId(updater.getTaskId()).contextId(updater.getContextId()).artifact(a2aArtifact).lastChunk(Boolean.valueOf(true)).build();
                                queue.enqueueEvent((Event)event);
                            }
                        }
                    }

                    public void emitCompletionFooter(LLMClient.StreamedCompletionResponseFooter footer) throws Exception {
                        TextPart done = new TextPart("DONE");
                        List<TextPart> artifactParts = List.of(done);
                        Message agentMessage = updater.newAgentMessage(artifactParts, Collections.emptyMap());
                        updater.complete(agentMessage);
                        LLMClient.ChatMessage assistantMessage = new LLMClient.ChatMessage();
                        assistantMessage.role = "assistant";
                        assistantMessage.setTextOnly(receivedTextBuffer.toString());
                        contextState.messages.add(assistantMessage);
                        logger.info((Object)("Storing context state: " + JSON.log((Object)contextState)));
                        contextsCache.put((Object)new ContextCacheKey(context.getContextId(), user.getUsername()), (Object)contextState);
                    }

                    public void emitEmulateStreamingInfoChunk(String message) throws Exception {
                        List<TextPart> parts = List.of(new TextPart(message));
                        Artifact a2aArtifact = new Artifact.Builder().artifactId(RAW_TEXT_ARTIFACT_ID).name("messageStream").parts(parts).build();
                        TaskArtifactUpdateEvent event = new TaskArtifactUpdateEvent.Builder().taskId(updater.getTaskId()).contextId(updater.getContextId()).artifact(a2aArtifact).append(Boolean.valueOf(true)).build();
                        queue.enqueueEvent((Event)event);
                    }

                    public void setInterruptCallback(AutoCloseable autoCloseable) {
                    }

                    public boolean isInterrupted() {
                        return false;
                    }
                };
                try (LLMMeshStreamClient streamClient = new LLMMeshStreamClient(llmClient, authCtx, this.projectKey, enrichedLLMRef, connectionGuardrailsPipelineSettings, emitter);){
                    LLMClient.SimpleCompletionResponseOrError scre = streamClient.streamComplete(scq, new LLMClient.CompletionSettings());
                    ComputeResourceUsage cru = streamClient.getTotalCRU();
                    if (cru != null) {
                        this.cruReportingService.reportComplete(cru);
                    }
                    LLMAuditHelper.emitLLMCompletionAuditFromBackendIfNeeded((AuditTrailService)this.auditTrailService, (LLMStructuredRef)enrichedLLMRef, (AbstractLLMConnection)llmClient.getConnection(), (LLMClient.SingleCompletionQuery)scq, (LLMClient.SimpleCompletionResponseOrError)scre);
                    if (!scre.ok) {
                        if (scre.errorType == LLMClient.LLMErrorType.REFUSAL) {
                            throw new LLMClient.RefusalException(scre.errorMessage);
                        }
                        throw new Exception(scre.errorMessage);
                    }
                }
            }
        }
        catch (JSONRPCError e) {
            logger.error((Object)"Agent execution failed", (Throwable)e);
            throw e;
        }
        catch (Exception e) {
            logger.error((Object)"Agent execution failed", (Throwable)e);
            throw new JSONRPCError(Integer.valueOf(-32000), "Agent execution failed: " + e.getMessage(), null);
        }
    }

    public void cancel(RequestContext context, EventQueue queue) throws JSONRPCError {
        TaskUpdater updater = new TaskUpdater(context, queue);
        updater.cancel();
    }

    public record ContextCacheKey(String contextId, String userName) {
    }

    private static class DSSAgentContextState {
        public JsonObject data = new JsonObject();
        public List<LLMClient.ChatMessage> messages = new ArrayList<LLMClient.ChatMessage>();
        public JsonObject context = new JsonObject();

        private DSSAgentContextState() {
        }
    }
}

