/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.utils;

import com.dataiku.dip.agents.tools.AgentTool;
import com.dataiku.dip.agents.tools.AgentToolMeta;
import com.dataiku.dip.agents.tools.AgentToolsCRUDService;
import com.dataiku.dip.agents.tools.AgentToolsRegistry;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.dao.SavedModelsDAO;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.RawAgentTrajectoryExtractor;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.utils.EnrichedTrajectory;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.Privileges;
import com.dataiku.dip.server.services.ProjectsService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.Pair;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import javax.annotation.Nonnull;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class AgentTrajectoryService {
    private final AgentToolsCRUDService agentToolsCRUDService;
    private final TransactionService transactionService;
    private final ProjectsService projectsService;
    private final SavedModelsDAO savedModelsDAO;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.trajectory");

    @Autowired
    public AgentTrajectoryService(AgentToolsCRUDService agentToolsCRUDService, TransactionService transactionService, ProjectsService projectsService, SavedModelsDAO savedModelsDAO) {
        this.agentToolsCRUDService = agentToolsCRUDService;
        this.transactionService = transactionService;
        this.projectsService = projectsService;
        this.savedModelsDAO = savedModelsDAO;
    }

    public List<EnrichedTrajectory> enrichTrajectories(AuthCtx authCtx, String contextProjectKey, @Nonnull LLMStructuredRef agentRef, List<RawAgentTrajectoryExtractor.RawAgentTrajectory> rawTrajectories) {
        AgentTool tool;
        ArrayList<EnrichedTrajectory> result = new ArrayList<EnrichedTrajectory>();
        HashSet<Pair> projectAndToolIdToFetch = new HashSet<Pair>();
        for (RawAgentTrajectoryExtractor.RawAgentTrajectory rowTrajectory : rawTrajectories) {
            for (RawAgentTrajectoryExtractor.TrajectoryElement element : rowTrajectory.agentLoop()) {
                try {
                    if (!(element instanceof RawAgentTrajectoryExtractor.ToolCallTrajectoryElement)) continue;
                    RawAgentTrajectoryExtractor.ToolCallTrajectoryElement toolCall = (RawAgentTrajectoryExtractor.ToolCallTrajectoryElement)element;
                    if (toolCall.attributes == null || toolCall.attributes.getAsJsonPrimitive("toolId") == null) continue;
                    String projectKey = toolCall.attributes.get("toolProjectKey").getAsString();
                    String toolId = toolCall.attributes.get("toolId").getAsString();
                    projectAndToolIdToFetch.add(new Pair((Object)projectKey, (Object)toolId));
                }
                catch (Exception e) {
                    logger.errorV((Throwable)e, "Can't parse trajectory element \"%s\", it won't be present in the enriched trajectory", new Object[]{element.name});
                }
            }
        }
        HashMap<Pair, AgentTool> toolByProjectAndToolId = new HashMap<Pair, AgentTool>();
        try (Transaction t = this.transactionService.beginRead();){
            for (Pair projectAndToolId : projectAndToolIdToFetch) {
                try {
                    this.projectsService.checkPerm(authCtx, (String)projectAndToolId.first, Privileges.ProjectLevelPrivilegeType.READ_CONF);
                    tool = this.agentToolsCRUDService.getOrNullUnsafe((String)projectAndToolId.first, (String)projectAndToolId.second);
                    toolByProjectAndToolId.put(projectAndToolId, tool);
                }
                catch (Exception e) {
                    logger.error((Object)("Can't read tool \"" + (String)projectAndToolId.second + "\", it won't be present in the enriched trajectory"), (Throwable)e);
                }
            }
            if (agentRef.savedModelVersionId == null) {
                try {
                    SavedModel sm = (SavedModel)this.savedModelsDAO.getOrNullUnsafe(AnyLoc.resolveSmart(contextProjectKey, agentRef.savedModelSmartId));
                    agentRef.savedModelVersionId = sm.activeVersion;
                }
                catch (Exception e) {
                    logger.errorV((Throwable)e, "Can't read agent \"%s\", it won't be present in the enriched trajectory", new Object[]{agentRef.id});
                }
            }
        }
        HashMap<Pair, AgentToolMeta.ToolDescriptor> descriptorByProjectAndToolId = new HashMap<Pair, AgentToolMeta.ToolDescriptor>();
        for (Pair projectAndToolId : projectAndToolIdToFetch) {
            if (!toolByProjectAndToolId.containsKey(projectAndToolId)) continue;
            try {
                tool = (AgentTool)toolByProjectAndToolId.get(projectAndToolId);
                AgentToolMeta meta = AgentToolsRegistry.getMeta(tool.type);
                AgentToolMeta.ToolDescriptor descriptor = meta.getResultingDescriptor(authCtx, (String)projectAndToolId.first, tool);
                descriptorByProjectAndToolId.put(projectAndToolId, descriptor);
            }
            catch (Exception e) {
                logger.error((Object)("Can't read tool \"" + (String)projectAndToolId.second + "\" descriptor, it won't be present in the enriched trajectory"), (Throwable)e);
            }
        }
        for (RawAgentTrajectoryExtractor.RawAgentTrajectory rawTrajectory : rawTrajectories) {
            ArrayList<EnrichedTrajectory.Element> agentLoop = new ArrayList<EnrichedTrajectory.Element>();
            for (RawAgentTrajectoryExtractor.TrajectoryElement element : rawTrajectory.agentLoop()) {
                if (element instanceof RawAgentTrajectoryExtractor.ToolCallTrajectoryElement) {
                    AgentToolMeta.ToolDescriptor descriptor;
                    String toolId;
                    RawAgentTrajectoryExtractor.ToolCallTrajectoryElement toolCall = (RawAgentTrajectoryExtractor.ToolCallTrajectoryElement)element;
                    if (toolCall.attributes == null || toolCall.attributes.getAsJsonPrimitive("toolId") == null) {
                        agentLoop.add(new EnrichedTrajectory.ToolCall(toolCall, toolCall.name));
                        continue;
                    }
                    String projectKey = toolCall.attributes.get("toolProjectKey").getAsString();
                    Pair projectAndToolId = new Pair((Object)projectKey, (Object)(toolId = toolCall.attributes.get("toolId").getAsString()));
                    AgentTool tool2 = toolByProjectAndToolId.getOrDefault(projectAndToolId, null);
                    if (tool2 == null || (descriptor = (AgentToolMeta.ToolDescriptor)descriptorByProjectAndToolId.getOrDefault(projectAndToolId, null)) == null) continue;
                    agentLoop.add(new EnrichedTrajectory.ToolCall(toolCall, tool2, descriptor));
                    continue;
                }
                if (!(element instanceof RawAgentTrajectoryExtractor.GuardrailTrajectoryElement)) continue;
                RawAgentTrajectoryExtractor.GuardrailTrajectoryElement guardrail = (RawAgentTrajectoryExtractor.GuardrailTrajectoryElement)element;
                agentLoop.add(new EnrichedTrajectory.TriggeredGuardrail(guardrail));
            }
            EnrichedTrajectory enrichedTrajectory = new EnrichedTrajectory(agentRef.id, agentRef.savedModelSmartId, agentRef.savedModelVersionId, rawTrajectory.begin(), rawTrajectory.end(), rawTrajectory.duration(), (JsonElement)rawTrajectory.inputs(), (JsonElement)rawTrajectory.outputs(), rawTrajectory.error(), agentLoop);
            result.add(enrichedTrajectory);
        }
        return result;
    }

    public LLMClient.SimpleCompletionResponseOrError enrichResponseWithTrajectoryIfNeeded(AuthCtx authCtx, String contextProjectKey, LLMClient.SimpleCompletionResponseOrError result, LLMClient.LLMMeshTraceSpan trace, LLMStructuredRef llmRef) {
        RawAgentTrajectoryExtractor.RawAgentTrajectory rawTrajectory;
        if (llmRef.type != LLMStructuredRef.LLMType.SAVED_MODEL_AGENT || trace == null) {
            return result;
        }
        try {
            rawTrajectory = RawAgentTrajectoryExtractor.toTrajectory(result, trace);
        }
        catch (Exception e) {
            logger.warnV("Error while extracting agent trajectory from trace: %s", new Object[]{e.getMessage()});
            return result;
        }
        EnrichedTrajectory trajectory = this.enrichTrajectories(authCtx, contextProjectKey, llmRef, List.of(rawTrajectory)).get(0);
        if (result.additionalInformation == null) {
            result.additionalInformation = new JsonObject();
        }
        result.additionalInformation.add("trajectory", JSON.toJsonElement((Object)trajectory));
        return result;
    }

    public EnrichedTrajectory trajectoryFromTrace(AuthCtx authCtx, String contextProjectKey, LLMClient.LLMMeshTraceSpan trace, LLMStructuredRef llmRef) {
        LLMClient.SimpleCompletionResponseOrError screForTrajectory = LLMClient.SimpleCompletionResponseOrError.blank();
        screForTrajectory.ok = true;
        RawAgentTrajectoryExtractor.RawAgentTrajectory rawTrajectory = RawAgentTrajectoryExtractor.toTrajectory(screForTrajectory, trace);
        return this.enrichTrajectories(authCtx, contextProjectKey, llmRef, List.of(rawTrajectory)).get(0);
    }
}

