/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.utils;

import com.dataiku.dip.agents.tools.AgentTool;
import com.dataiku.dip.agents.tools.AgentToolMeta;
import com.dataiku.dip.agents.tools.AgentToolsCRUDService;
import com.dataiku.dip.agents.tools.AgentToolsRegistry;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.dao.SavedModelsDAO;
import com.dataiku.dip.llm.AgentTrajectoryExtractor;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.governance.GuardrailsPipelineSettings;
import com.dataiku.dip.llm.governance.GuardrailsPipelineUtils;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.Privileges;
import com.dataiku.dip.server.services.ProjectsService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.util.JsonUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.Pair;
import com.dataiku.dip.utils.polyjson.Mapping;
import com.dataiku.dip.utils.polyjson.PolyJSON;
import com.dataiku.j2ts.annotations.UIModel;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class AgentTrajectoryService {
    private final AgentToolsCRUDService agentToolsCRUDService;
    private final TransactionService transactionService;
    private final ProjectsService projectsService;
    private final SavedModelsDAO savedModelsDAO;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.trajectory");

    @Autowired
    public AgentTrajectoryService(AgentToolsCRUDService agentToolsCRUDService, TransactionService transactionService, ProjectsService projectsService, SavedModelsDAO savedModelsDAO) {
        this.agentToolsCRUDService = agentToolsCRUDService;
        this.transactionService = transactionService;
        this.projectsService = projectsService;
        this.savedModelsDAO = savedModelsDAO;
    }

    public List<EnrichedTrajectory> enrichTrajectories(AuthCtx authCtx, String contextProjectKey, @Nonnull LLMStructuredRef agentRef, List<AgentTrajectoryExtractor.RawAgentTrajectory> rawTrajectories) {
        GuardrailsPipelineSettings llmGuardrails;
        Pair projectAndLLMId;
        AgentTool tool;
        String projectKey;
        ArrayList<EnrichedTrajectory> result = new ArrayList<EnrichedTrajectory>();
        HashSet<Pair> projectAndToolIdToFetch = new HashSet<Pair>();
        HashSet<Pair> projectAndAgentRefToFetch = new HashSet<Pair>();
        HashSet<LLMStructuredRef> llmRefToFetch = new HashSet<LLMStructuredRef>();
        for (AgentTrajectoryExtractor.RawAgentTrajectory rowTrajectory : rawTrajectories) {
            for (AgentTrajectoryExtractor.TrajectoryElement element : rowTrajectory.agentLoop()) {
                try {
                    if (element instanceof AgentTrajectoryExtractor.ToolCallTrajectoryElement) {
                        AgentTrajectoryExtractor.ToolCallTrajectoryElement toolCall = (AgentTrajectoryExtractor.ToolCallTrajectoryElement)element;
                        if (toolCall.attributes == null || toolCall.attributes.getAsJsonPrimitive("toolId") == null) continue;
                        projectKey = toolCall.attributes.get("toolProjectKey").getAsString();
                        String toolId = toolCall.attributes.get("toolId").getAsString();
                        projectAndToolIdToFetch.add(new Pair((Object)projectKey, (Object)toolId));
                        continue;
                    }
                    if (!(element instanceof AgentTrajectoryExtractor.GuardrailTrajectoryElement)) continue;
                    AgentTrajectoryExtractor.GuardrailTrajectoryElement guardrail = (AgentTrajectoryExtractor.GuardrailTrajectoryElement)element;
                    if (guardrail.llmId == null) continue;
                    String llmId = guardrail.llmId;
                    LLMStructuredRef llmRef = LLMStructuredRef.decodeId(llmId);
                    if (llmRef.isProjectBound()) {
                        String projectKey2 = AnyLoc.resolveSmart(contextProjectKey, llmRef.savedModelSmartId).getProjectKey();
                        projectAndAgentRefToFetch.add(new Pair((Object)projectKey2, (Object)llmRef));
                        continue;
                    }
                    llmRefToFetch.add(llmRef);
                }
                catch (Exception e) {
                    logger.error((Object)("Can't parse trajectory element \"" + element.name + "\", it won't be present in the enriched trajectory"), (Throwable)e);
                }
            }
        }
        HashMap<Pair, AgentTool> toolByProjectAndToolId = new HashMap<Pair, AgentTool>();
        HashMap<Pair, GuardrailsPipelineSettings> guardrailsByProjectAndAgentId = new HashMap<Pair, GuardrailsPipelineSettings>();
        HashMap<String, GuardrailsPipelineSettings> guardrailsByLlmId = new HashMap<String, GuardrailsPipelineSettings>();
        try (Transaction t = this.transactionService.beginRead();){
            for (Pair projectAndToolId : projectAndToolIdToFetch) {
                try {
                    this.projectsService.checkPerm(authCtx, (String)projectAndToolId.first, Privileges.ProjectLevelPrivilegeType.READ_CONF);
                    tool = this.agentToolsCRUDService.getOrNullUnsafe((String)projectAndToolId.first, (String)projectAndToolId.second);
                    toolByProjectAndToolId.put(projectAndToolId, tool);
                }
                catch (Exception e) {
                    logger.error((Object)("Can't read tool \"" + (String)projectAndToolId.second + "\", it won't be present in the enriched trajectory"), (Throwable)e);
                }
            }
            for (Pair projectAndAgentRef : projectAndAgentRefToFetch) {
                projectKey = (String)projectAndAgentRef.first;
                try {
                    this.projectsService.checkPerm(authCtx, projectKey, Privileges.ProjectLevelPrivilegeType.READ_CONF);
                    GuardrailsPipelineSettings guardrails = GuardrailsPipelineUtils.getConnectionAndLLMLevelSettings(authCtx, (String)projectAndAgentRef.first, (LLMStructuredRef)projectAndAgentRef.second);
                    projectAndLLMId = new Pair((Object)((String)projectAndAgentRef.first), (Object)((LLMStructuredRef)projectAndAgentRef.second).id);
                    guardrailsByProjectAndAgentId.put(projectAndLLMId, guardrails);
                }
                catch (Exception e) {
                    logger.error((Object)("Can't read agent \"" + ((LLMStructuredRef)projectAndAgentRef.second).id + "\", it won't be present in the enriched trajectory"), (Throwable)e);
                }
            }
            for (LLMStructuredRef llmRef : llmRefToFetch) {
                try {
                    GuardrailsPipelineSettings guardrails = GuardrailsPipelineUtils.getConnectionAndLLMLevelSettings(authCtx, null, llmRef);
                    guardrailsByLlmId.put(llmRef.id, guardrails);
                }
                catch (Exception e) {
                    logger.error((Object)("Can't read llm \"" + llmRef.id + "\", it won't be present in the enriched trajectory"), (Throwable)e);
                }
            }
            if (agentRef.savedModelVersionId == null) {
                try {
                    SavedModel sm = (SavedModel)this.savedModelsDAO.getOrNullUnsafe(AnyLoc.resolveSmart(contextProjectKey, agentRef.savedModelSmartId));
                    agentRef.savedModelVersionId = sm.activeVersion;
                }
                catch (Exception e) {
                    logger.error((Object)("Can't read agent \"" + agentRef.id + "\", it won't be present in the enriched trajectory"), (Throwable)e);
                }
            }
        }
        HashMap<Pair, AgentToolMeta.ToolDescriptor> descriptorByProjectAndToolId = new HashMap<Pair, AgentToolMeta.ToolDescriptor>();
        for (Pair projectAndToolId : projectAndToolIdToFetch) {
            if (!toolByProjectAndToolId.containsKey(projectAndToolId)) continue;
            try {
                tool = (AgentTool)toolByProjectAndToolId.get(projectAndToolId);
                AgentToolMeta meta = AgentToolsRegistry.getMeta(tool.type);
                AgentToolMeta.ToolDescriptor descriptor = meta.getResultingDescriptor(authCtx, (String)projectAndToolId.first, tool);
                descriptorByProjectAndToolId.put(projectAndToolId, descriptor);
            }
            catch (Exception e) {
                logger.error((Object)("Can't read tool \"" + (String)projectAndToolId.second + "\" descriptor, it won't be present in the enriched trajectory"), (Throwable)e);
            }
        }
        HashMap<Pair, QueryAndResponseGuardrailDefinitions> guardrailDefinitionsByProjectAndAgentId = new HashMap<Pair, QueryAndResponseGuardrailDefinitions>();
        HashMap<String, QueryAndResponseGuardrailDefinitions> guardrailDefinitionsByLlmId = new HashMap<String, QueryAndResponseGuardrailDefinitions>();
        for (Map.Entry guardrailByProjectAndAgentId : guardrailsByProjectAndAgentId.entrySet()) {
            projectAndLLMId = (Pair)guardrailByProjectAndAgentId.getKey();
            llmGuardrails = (GuardrailsPipelineSettings)guardrailByProjectAndAgentId.getValue();
            guardrailDefinitionsByProjectAndAgentId.put(projectAndLLMId, QueryAndResponseGuardrailDefinitions.from(llmGuardrails));
        }
        for (Map.Entry guardrailByLlmId : guardrailsByLlmId.entrySet()) {
            String agentId = (String)guardrailByLlmId.getKey();
            llmGuardrails = (GuardrailsPipelineSettings)guardrailByLlmId.getValue();
            guardrailDefinitionsByLlmId.put(agentId, QueryAndResponseGuardrailDefinitions.from(llmGuardrails));
        }
        for (AgentTrajectoryExtractor.RawAgentTrajectory rawTrajectory : rawTrajectories) {
            EnrichedTrajectory enrichedTrajectory = new EnrichedTrajectory(agentRef.id, agentRef.savedModelSmartId, agentRef.savedModelVersionId, rawTrajectory.begin(), rawTrajectory.end(), rawTrajectory.duration(), (JsonElement)rawTrajectory.inputs(), (JsonElement)rawTrajectory.outputs(), rawTrajectory.error(), new ArrayList<EnrichedTrajectory.Element>());
            for (AgentTrajectoryExtractor.TrajectoryElement element : rawTrajectory.agentLoop()) {
                QueryAndResponseGuardrailDefinitions guardrailDefinitions;
                if (element instanceof AgentTrajectoryExtractor.ToolCallTrajectoryElement) {
                    AgentToolMeta.ToolDescriptor descriptor;
                    String toolId;
                    AgentTrajectoryExtractor.ToolCallTrajectoryElement toolCall = (AgentTrajectoryExtractor.ToolCallTrajectoryElement)element;
                    if (toolCall.attributes == null || toolCall.attributes.getAsJsonPrimitive("toolId") == null) {
                        enrichedTrajectory.agentLoop.add(new EnrichedTrajectory.ToolCall(toolCall, toolCall.name));
                        continue;
                    }
                    String projectKey3 = toolCall.attributes.get("toolProjectKey").getAsString();
                    Pair projectAndToolId = new Pair((Object)projectKey3, (Object)(toolId = toolCall.attributes.get("toolId").getAsString()));
                    AgentTool tool2 = toolByProjectAndToolId.getOrDefault(projectAndToolId, null);
                    if (tool2 == null || (descriptor = (AgentToolMeta.ToolDescriptor)descriptorByProjectAndToolId.getOrDefault(projectAndToolId, null)) == null) continue;
                    enrichedTrajectory.agentLoop.add(new EnrichedTrajectory.ToolCall(toolCall, tool2, descriptor));
                    continue;
                }
                if (!(element instanceof AgentTrajectoryExtractor.GuardrailTrajectoryElement)) continue;
                AgentTrajectoryExtractor.GuardrailTrajectoryElement guardrail = (AgentTrajectoryExtractor.GuardrailTrajectoryElement)element;
                String llmId = guardrail.llmId;
                LLMStructuredRef llmRef = LLMStructuredRef.decodeId(llmId);
                if (llmRef.isProjectBound()) {
                    String projectKey4 = AnyLoc.resolveSmart(contextProjectKey, llmRef.savedModelSmartId).getProjectKey();
                    Pair projectAndToolId = new Pair((Object)projectKey4, (Object)llmId);
                    guardrailDefinitions = guardrailDefinitionsByProjectAndAgentId.getOrDefault(projectAndToolId, null);
                } else {
                    guardrailDefinitions = guardrailDefinitionsByLlmId.getOrDefault(llmId, null);
                }
                if (guardrailDefinitions != null) {
                    enrichedTrajectory.agentLoop.add(new EnrichedTrajectory.TriggeredGuardrail(guardrail, guardrailDefinitions));
                    continue;
                }
                enrichedTrajectory.agentLoop.add(new EnrichedTrajectory.TriggeredGuardrail(guardrail));
            }
            result.add(enrichedTrajectory);
        }
        return result;
    }

    public LLMClient.SimpleCompletionResponseOrError enrichResponseWithTrajectoryIfNeeded(AuthCtx authCtx, String contextProjectKey, LLMClient.SimpleCompletionResponseOrError result, LLMClient.LLMMeshTraceSpan trace, LLMStructuredRef llmRef) {
        AgentTrajectoryExtractor.RawAgentTrajectory rawTrajectory;
        if (!llmRef.id.startsWith("agent:") || trace == null) {
            return result;
        }
        try {
            rawTrajectory = AgentTrajectoryExtractor.toTrajectory(result, trace);
        }
        catch (Exception e) {
            logger.warnV("Error while extracting agent trajectory from trace: %s", new Object[]{e.getMessage()});
            return result;
        }
        EnrichedTrajectory trajectory = this.enrichTrajectories(authCtx, contextProjectKey, llmRef, List.of(rawTrajectory)).get(0);
        if (result.additionalInformation == null) {
            result.additionalInformation = new JsonObject();
        }
        result.additionalInformation.add("trajectory", JSON.toJsonElement((Object)trajectory));
        return result;
    }

    private record QueryAndResponseGuardrailDefinitions(List<GuardrailsPipelineSettings.GuardrailsPipelineElement> llmEnabledQueryGuardrails, List<GuardrailsPipelineSettings.GuardrailsPipelineElement> llmEnabledResponseGuardrails) {
        public static QueryAndResponseGuardrailDefinitions from(GuardrailsPipelineSettings llmGuardrails) {
            ArrayList<GuardrailsPipelineSettings.GuardrailsPipelineElement> llmEnabledQueryGuardrails = new ArrayList<GuardrailsPipelineSettings.GuardrailsPipelineElement>();
            ArrayList<GuardrailsPipelineSettings.GuardrailsPipelineElement> llmEnabledResponseGuardrails = new ArrayList<GuardrailsPipelineSettings.GuardrailsPipelineElement>();
            for (GuardrailsPipelineSettings.GuardrailsPipelineElement guardrailElement : llmGuardrails.guardrails) {
                if (!guardrailElement.enabled) continue;
                if (JsonUtils.getAsBoolean(guardrailElement.params, false, "filterQueries")) {
                    llmEnabledQueryGuardrails.add(guardrailElement);
                }
                if (!JsonUtils.getAsBoolean(guardrailElement.params, false, "filterResponses")) continue;
                llmEnabledResponseGuardrails.add(guardrailElement);
            }
            return new QueryAndResponseGuardrailDefinitions(llmEnabledQueryGuardrails, llmEnabledResponseGuardrails);
        }
    }

    @UIModel
    public static class EnrichedTrajectory {
        @Nonnull
        public String agentId;
        @Nonnull
        public String agentSavedModelSmartId;
        @Nonnull
        public String agentSavedModelVersion;
        @Nonnull
        public String begin;
        @Nonnull
        public String end;
        @Nonnull
        public Long durationMs;
        @Nonnull
        public JsonElement input;
        @Nullable
        public JsonElement output;
        @Nullable
        public AgentTrajectoryExtractor.TrajectoryError error;
        @Nonnull
        public List<Element> agentLoop;

        public EnrichedTrajectory() {
            this.agentLoop = new ArrayList<Element>();
        }

        public EnrichedTrajectory(@Nonnull String agentId, @Nonnull String agentSavedModelSmartId, @Nonnull String agentSavedModelVersion, @Nonnull String begin, @Nonnull String end, @Nonnull Long durationMs, @Nonnull JsonElement input, @Nullable JsonElement output, @Nullable AgentTrajectoryExtractor.TrajectoryError error, @Nonnull List<Element> agentLoop) {
            this.agentId = agentId;
            this.agentSavedModelSmartId = agentSavedModelSmartId;
            this.agentSavedModelVersion = agentSavedModelVersion;
            this.begin = begin;
            this.end = end;
            this.durationMs = durationMs;
            this.input = input;
            this.output = output;
            this.error = error;
            this.agentLoop = agentLoop;
        }

        public static class TriggeredGuardrail
        extends Element {
            @Nullable
            public String llmId;
            @Nonnull
            public AgentTrajectoryExtractor.GuardrailEnforcementType guardrailEnforcementType;
            @Nonnull
            public AgentTrajectoryExtractor.TrajectoryLevel trajectoryLevel;
            @Nullable
            public GuardrailsPipelineSettings.GuardrailsPipelineElement guardrailDefinition;
            @Nonnull
            public JsonObject inputs;
            @Nonnull
            public JsonObject outputs;

            protected TriggeredGuardrail() {
            }

            TriggeredGuardrail(AgentTrajectoryExtractor.GuardrailTrajectoryElement guardrail, QueryAndResponseGuardrailDefinitions guardrailDefinitions) {
                this.fillFromGuardrailTrajectoryElement(guardrail);
                String[] guardrailRef = guardrail.name.split("_");
                try {
                    int guardrailIndex = Integer.parseInt(guardrailRef[2]);
                    this.guardrailDefinition = guardrail.guardrailEnforcementType == AgentTrajectoryExtractor.GuardrailEnforcementType.QUERY ? guardrailDefinitions.llmEnabledQueryGuardrails.get(guardrailIndex) : guardrailDefinitions.llmEnabledResponseGuardrails.get(guardrailIndex);
                }
                catch (IndexOutOfBoundsException | NumberFormatException e) {
                    logger.warnV((Throwable)e, "Error while getting definition of type %s guardrail named %s.", new Object[]{guardrail.guardrailEnforcementType.name(), guardrail.name});
                    this.guardrailDefinition = null;
                }
            }

            public TriggeredGuardrail(AgentTrajectoryExtractor.GuardrailTrajectoryElement guardrail) {
                this.fillFromGuardrailTrajectoryElement(guardrail);
            }

            private void fillFromGuardrailTrajectoryElement(AgentTrajectoryExtractor.GuardrailTrajectoryElement guardrail) {
                this.begin = guardrail.begin;
                this.end = guardrail.end;
                this.durationMs = guardrail.duration;
                this.llmId = guardrail.llmId;
                this.guardrailEnforcementType = guardrail.guardrailEnforcementType;
                this.trajectoryLevel = guardrail.guardrailLevel;
                this.inputs = guardrail.inputs;
                this.outputs = guardrail.outputs;
            }
        }

        public static class ToolCall
        extends Element {
            @Nullable
            public String toolId;
            @Nonnull
            public String toolName;
            @Nonnull
            public String toolType;
            @Nullable
            public String subToolName;
            @Nullable
            public String toolDescriptorName;
            @Nullable
            public String toolDescriptorDescription;
            @Nullable
            public JsonElement inputContent;
            @Nullable
            public JsonElement outputContent;
            @Nullable
            public String error;

            protected ToolCall() {
            }

            ToolCall(AgentTrajectoryExtractor.ToolCallTrajectoryElement toolCall, AgentTool tool, AgentToolMeta.ToolDescriptor descriptor) {
                this.fillFromToolCallTrajectoryElement(toolCall);
                this.toolId = tool.id;
                this.toolName = tool.name;
                this.toolType = tool.type;
                this.toolDescriptorName = descriptor.name;
                this.toolDescriptorDescription = descriptor.description;
                this.subToolName = JsonUtils.getOrNullStr(toolCall.attributes, "subtoolName");
            }

            ToolCall(AgentTrajectoryExtractor.ToolCallTrajectoryElement toolCall, String toolCallName) {
                this.fillFromToolCallTrajectoryElement(toolCall);
                this.toolName = toolCallName;
            }

            private void fillFromToolCallTrajectoryElement(AgentTrajectoryExtractor.ToolCallTrajectoryElement toolCall) {
                JsonElement input;
                this.inputContent = toolCall.inputs;
                if (toolCall.inputs != null && toolCall.inputs.isJsonObject() && (input = toolCall.inputs.get("input")) != null) {
                    try {
                        this.inputContent = JsonParser.parseString((String)input.getAsString());
                    }
                    catch (Exception e) {
                        this.inputContent = input;
                    }
                }
                this.outputContent = toolCall.outputs;
                if (toolCall.outputs != null) {
                    JsonElement error = toolCall.outputs.get("error");
                    if (error != null && error.isJsonPrimitive()) {
                        this.error = error.getAsString();
                        this.outputContent = null;
                    } else {
                        JsonElement output = toolCall.outputs.get("output");
                        if (output != null && output.isJsonObject()) {
                            JsonElement content = output.getAsJsonObject().get("content");
                            if (content != null) {
                                try {
                                    this.outputContent = JsonParser.parseString((String)content.getAsString());
                                }
                                catch (Exception e) {
                                    this.outputContent = content;
                                }
                            } else {
                                this.outputContent = output;
                            }
                        } else if (output != null) {
                            this.outputContent = output;
                        }
                    }
                }
                this.begin = toolCall.begin;
                this.end = toolCall.end;
                this.durationMs = toolCall.duration;
            }
        }

        @PolyJSON(value={@Mapping(value=ToolCall.class, type="TOOLCALL"), @Mapping(value=TriggeredGuardrail.class, type="GUARDRAIL")})
        public static abstract class Element {
            String begin;
            String end;
            Long durationMs;
        }
    }
}

