/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.agents.tools;

import com.dataiku.common.server.SerializedError;
import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.DKUApp;
import com.dataiku.dip.DSSMetrics;
import com.dataiku.dip.agents.tools.AgentTool;
import com.dataiku.dip.agents.tools.AgentToolMeta;
import com.dataiku.dip.agents.tools.AgentToolRunner;
import com.dataiku.dip.agents.tools.AgentToolsDAO;
import com.dataiku.dip.agents.tools.AgentToolsRegistry;
import com.dataiku.dip.agents.tools.inlinepython.InlinePythonTool;
import com.dataiku.dip.agents.tools.mcp.GenericStdioMCPClientTool;
import com.dataiku.dip.agents.tools.vectorstore.VectorStoreQueryTool;
import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.coremodel.VersionTag;
import com.dataiku.dip.cuspol.CustomFieldsService;
import com.dataiku.dip.cuspol.CustomPolicyHooksRegistry;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.dataflow.FlowGraphService;
import com.dataiku.dip.exceptions.CodedException;
import com.dataiku.dip.exceptions.ExceptionWithLogTail;
import com.dataiku.dip.exceptions.UnauthorizedException;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.retrieval.RetrievableKnowledge;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.PasswordEncryptionService;
import com.dataiku.dip.security.audit.AuditTrailService;
import com.dataiku.dip.server.llm.AgentToolsController;
import com.dataiku.dip.server.notifications.backend.TaggableObjectChangedEvent;
import com.dataiku.dip.server.services.ITaggingService;
import com.dataiku.dip.server.services.LockablePythonLogsService;
import com.dataiku.dip.server.services.LogsService;
import com.dataiku.dip.server.services.PubSubService;
import com.dataiku.dip.server.services.TaggableObjectDiffService;
import com.dataiku.dip.server.services.TaggableObjectsService;
import com.dataiku.dip.server.services.TaggingService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.TransactionContext;
import com.dataiku.dip.transactions.ifaces.RWTransactionRef;
import com.dataiku.dip.util.SecretKeyGenerator;
import com.dataiku.dip.utils.DKUFileUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JF;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.SmartLogTail;
import com.dataiku.dip.utils.StringTransmogrifier;
import com.google.common.base.Preconditions;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.io.File;
import java.io.IOException;
import java.io.OutputStream;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class AgentToolsCRUDService {
    @Autowired
    private AgentToolsDAO dao;
    @Autowired
    private TransactionService transactionService;
    @Autowired
    private TaggableObjectsService taggableObjectsService;
    @Autowired
    private CustomPolicyHooksRegistry customPolicyHooksRegistry;
    @Autowired
    private PubSubService pubSub;
    @Autowired
    private TaggingService taggingService;
    @Autowired
    private TaggableObjectDiffService taggableObjectDiffService;
    @Autowired
    private PubSubService pubSubService;
    @Autowired
    private CustomFieldsService customFieldsService;
    @Autowired
    private FlowGraphService flowGraphService;
    @Autowired
    private AuditTrailService auditTrailService;
    @Autowired
    private PasswordEncryptionService passwordEncryptionService;
    @Autowired
    private LockablePythonLogsService pythonLogsService;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.agents.tools.service");

    public AgentTool getOrNull(String projectKey, String id) throws IOException {
        return (AgentTool)this.dao.getOrNull(projectKey, id);
    }

    public AgentTool getOrNullUnsafe(String projectKey, String id) throws IOException {
        return (AgentTool)this.dao.getOrNullUnsafe(projectKey, id);
    }

    public AgentTool getMandatory(String projectKey, String id) throws IOException {
        return (AgentTool)this.dao.getMandatory(projectKey, id);
    }

    public AgentTool getMandatoryUnsafe(String projectKey, String id) throws IOException {
        return (AgentTool)this.dao.getMandatoryUnsafe(projectKey, id);
    }

    protected String getAgentToolId(String projectKey, AgentToolsController.ProtoAgentTool proto) throws IllegalArgumentException, IOException {
        if (StringUtils.isNotBlank((String)proto.id)) {
            if (this.dao.getOrNull(projectKey, proto.id) != null) {
                throw new IllegalArgumentException("Tool with id " + proto.id + " already exists");
            }
            return proto.id;
        }
        return SecretKeyGenerator.generateSmall();
    }

    public String create(AuthCtx u, String projectKey, AgentToolsController.ProtoAgentTool proto) throws IOException, CodedException, UnauthorizedException {
        String id = this.getAgentToolId(projectKey, proto);
        StringTransmogrifier transmogrifier = new StringTransmogrifier(" ");
        for (AgentTool head : this.dao.list(projectKey)) {
            transmogrifier.addAlreadyTransmogrifiedAcceptDupes(head.name);
        }
        AgentTool at = new AgentTool();
        at.projectKey = projectKey;
        at.id = id;
        at.creationTag = new VersionTag(u.getIdentifier());
        at.versionTag = new VersionTag(u.getIdentifier());
        at.name = transmogrifier.transmogrify(proto.name);
        at.type = proto.type;
        at.quickTestQuery = proto.quickTestQuery;
        AgentToolMeta meta = AgentToolsRegistry.getMeta(at.type);
        at.params = JSON.toJsonObject((Object)meta.newParams());
        if (proto.creationParams != null) {
            for (Map.Entry entry : proto.creationParams.entrySet()) {
                at.params.add((String)entry.getKey(), (JsonElement)entry.getValue());
            }
        }
        this.checkWritePermission(u, at);
        this.customFieldsService.enrichWithDefaultCustomFieldsForTaggableObject(at);
        this.customPolicyHooksRegistry.onPreObjectSave(u, null, at);
        this.dao.save(at);
        JsonObject details = new JsonObject();
        details.addProperty("objectDisplayName", at.name);
        this.pubSub.publishAfterTransaction(new TaggableObjectChangedEvent(ITaggingService.TaggableType.AGENT_TOOL, projectKey, id, u, TaggableObjectChangedEvent.ActionType.AGENT_TOOL_CREATE).withDetails(details));
        return id;
    }

    public String createFromKB(AuthCtx u, String projectKey, AgentToolsController.ProtoAgentTool proto, String knowledgeBankRef) throws Exception {
        proto.creationParams = JF.obj().with("knowledgeBankRef", knowledgeBankRef).get();
        return this.create(u, projectKey, proto);
    }

    public AgentTool save(AuthCtx u, AgentTool at, boolean summaryOnly) throws IOException, CodedException, UnauthorizedException {
        TaggableObjectChangedEvent.ActionType action;
        Preconditions.checkNotNull((Object)at.projectKey);
        Preconditions.checkNotNull((Object)at.id);
        this.checkWritePermission(u, at);
        RWTransactionRef t = TransactionContext.retrieveWrite();
        AgentTool preExisting = (AgentTool)this.dao.getOrNullUnsafe(at.projectKey, at.id);
        this.taggableObjectsService.handleCreationVersionTagOnObjectUpdateNullAllowed(at, preExisting);
        TaggableObjectDiffService.TaggableObjectsDiff diff = new TaggableObjectDiffService.TaggableObjectsDiff();
        JsonObject details = new JsonObject();
        details.addProperty("objectDisplayName", at.name);
        if (at.name != null && preExisting != null && !at.name.equals(preExisting.name)) {
            action = TaggableObjectChangedEvent.ActionType.AGENT_TOOL_RENAME;
            details.addProperty("newName", at.name);
            details.addProperty("oldName", preExisting.name);
        } else {
            action = TaggableObjectChangedEvent.ActionType.AGENT_TOOL_EDIT;
            diff = this.taggableObjectDiffService.diff(preExisting, at, t.getUser().getIdentifier());
        }
        AgentToolMeta meta = AgentToolsRegistry.getMeta(at.type);
        if (at.params == null) {
            at.params = JSON.toJsonObject((Object)meta.newParams());
        }
        meta.encryptPasswords(at, this.passwordEncryptionService);
        this.customPolicyHooksRegistry.onPreObjectSave(t.getUser(), (TaggableObjectsService.TaggableObject)this.dao.getOrNull(at.projectKey, at.id), at);
        this.dao.save(at);
        if (diff.metadataChanged()) {
            this.taggableObjectDiffService.publishAfterTransaction(diff);
        }
        if (!summaryOnly) {
            this.pubSubService.publishAfterTransaction(new TaggableObjectChangedEvent(ITaggingService.TaggableType.AGENT_TOOL, at.projectKey, at.id, t.getUser(), action).withDetails(details));
        }
        this.taggingService.onObjectSaved(at.projectKey, at.tags);
        this.flowGraphService.invalidateCache(at.projectKey);
        return at;
    }

    public String copy(AuthCtx user, AgentTool at, String projectKey) throws IOException, CodedException, UnauthorizedException {
        this.checkWritePermission(user, at);
        String id = SecretKeyGenerator.generateSmall();
        AgentTool copy = (AgentTool)JSON.deepCopy((Object)at);
        copy.projectKey = projectKey;
        copy.id = id;
        copy.versionTag = copy.creationTag = new VersionTag(user.getIdentifier());
        StringTransmogrifier transmogrifier = new StringTransmogrifier(" ");
        for (AgentTool head : this.dao.list(projectKey)) {
            transmogrifier.addAlreadyTransmogrifiedAcceptDupes(head.name);
        }
        copy.name = transmogrifier.transmogrify("Copy of " + at.name);
        this.customPolicyHooksRegistry.onPreObjectSave(user, null, copy);
        this.dao.save(at);
        JsonObject details = new JsonObject();
        details.addProperty("objectDisplayName", copy.name);
        details.addProperty("copy", Boolean.valueOf(true));
        details.addProperty("originalObjectDisplayName", at.name);
        details.addProperty("originalObjectId", at.id);
        this.pubSub.publishAfterTransaction(new TaggableObjectChangedEvent(ITaggingService.TaggableType.AGENT_TOOL, projectKey, id, user, TaggableObjectChangedEvent.ActionType.AGENT_TOOL_CREATE).withDetails(details));
        return id;
    }

    public void delete(AuthCtx liu, String projectKey, String id) throws Exception {
        AgentTool at = (AgentTool)this.dao.getOrNull(projectKey, id);
        if (at == null) {
            logger.debug((Object)("Tried to delete an agent tool that does not exist: " + projectKey + "/" + id));
        } else {
            this.checkWritePermission(liu, at);
            this.customPolicyHooksRegistry.onPreObjectDelete(liu, at);
            JsonObject details = new JsonObject();
            details.addProperty("objectDisplayName", at.name);
            this.dao.delete(projectKey, id);
            this.pubSub.publishAfterTransaction(new TaggableObjectChangedEvent(ITaggingService.TaggableType.AGENT_TOOL, projectKey, id, liu, TaggableObjectChangedEvent.ActionType.AGENT_TOOL_DELETE).withDetails(details));
        }
    }

    public boolean stopDevKernel(AuthCtx authCtx, AgentTool tool, String projectKey) throws Exception {
        AgentToolMeta meta = AgentToolsRegistry.getMeta(tool.type);
        try (AgentToolRunner runner = meta.buildRunner(authCtx, projectKey, tool, true);){
            boolean bl = runner.stopDevKernel();
            return bl;
        }
    }

    public void checkWritePermission(AuthCtx user, AgentTool at) throws UnauthorizedException {
        if (StringUtils.equals((String)at.type, (String)GenericStdioMCPClientTool.META.getType())) {
            GeneralSettingsDAO.GeneralSettings gs = ApplicationConfigurator.getGeneralSettingsUnsafeAutoTXN();
            GeneralSettingsDAO.LocalMCPServersRestrictions restrictions = gs.generativeAISettings.agentsToolsSettings.localMCPServersRestrictions;
            if (restrictions == GeneralSettingsDAO.LocalMCPServersRestrictions.FORBIDDEN_FOR_ALL) {
                throw new UnauthorizedException("Local MCP tools are not allowed", "global-permission-denied");
            }
            if (restrictions == GeneralSettingsDAO.LocalMCPServersRestrictions.ALLOWED_FOR_ADMINS && !user.isAdmin()) {
                throw new UnauthorizedException("Local MCP tools are allowed for admins only", "global-permission-denied");
            }
        }
        if (StringUtils.equals((String)at.type, (String)InlinePythonTool.META.getType()) || StringUtils.equals((String)at.type, (String)GenericStdioMCPClientTool.META.getType())) {
            user.failIfNoSafeCode("you are not allowed to write Python code based agent tools");
        }
    }

    public ToolQuickTestResponse test(AuthCtx authCtx, AgentTool tool, String projectKey, String query, String traceName, boolean devKernel) throws Exception {
        ToolQuickTestResponse ret;
        block32: {
            ret = new ToolQuickTestResponse();
            AgentToolRunner.AgentToolInput input = (AgentToolRunner.AgentToolInput)JSON.parse((String)query, AgentToolRunner.AgentToolInput.class);
            GeneralSettingsDAO.GeneralSettings gs = ApplicationConfigurator.getGeneralSettingsUnsafeAutoTXN();
            AuditTrailService.EmittableAuditObj auditObj = this.auditTrailService.generic("agent-tool-run");
            auditObj.with("projectKey", projectKey).with("agentToolId", tool.id);
            try (LLMClient.LLMMeshTraceSpan span = LLMClient.LLMMeshTraceSpan.start(traceName);){
                AgentToolMeta meta = AgentToolsRegistry.getMeta(tool.type);
                try (AgentToolRunner runner = meta.buildRunner(authCtx, projectKey, tool, devKernel);){
                    AgentToolRunner.AgentToolOutput output;
                    SmartLogTail slt = new SmartLogTail();
                    slt.appendLine("Initializing tool runner");
                    long beforeInit = System.currentTimeMillis();
                    try (DSSMetrics.TimeCtx tctx = DSSMetrics.timeCtx((String)("dku.agents.tools.invoke.initRunner.byType." + tool.type));){
                        runner.init();
                    }
                    long afterInit = System.currentTimeMillis();
                    long initTimeMS = afterInit - beforeInit;
                    slt.appendLine("Runner initialized in " + initTimeMS + "ms");
                    logger.debug((Object)("Tool input " + JSON.pretty((Object)input)));
                    if (gs.generativeAISettings.agentsToolsSettings.auditToolsInputs) {
                        auditObj.with("toolInput", JSON.toJsonObject((Object)input.input));
                    }
                    auditObj.with("initTimeMS", (Number)initTimeMS);
                    slt.appendLine("Running tool with input " + JSON.pretty((Object)input));
                    try (DSSMetrics.TimeCtx tctx = DSSMetrics.timeCtx((String)("dku.agents.tools.invoke.run.byType." + tool.type));){
                        output = runner.run(input);
                    }
                    logger.debug((Object)("Tool output " + JSON.pretty((Object)output)));
                    long afterRun = System.currentTimeMillis();
                    long runTimeMS = afterRun - afterInit;
                    auditObj.with("outcome", "success").with("runTimeMS", (Number)runTimeMS);
                    if (gs.generativeAISettings.agentsToolsSettings.auditToolsOutputs) {
                        auditObj.with("toolOutput", JSON.toJsonObject((Object)output));
                    }
                    if (output.trace != null) {
                        span.children.add(output.trace);
                    }
                    output.trace = span;
                    ret.response = output;
                    ret.fullTrace = span;
                    ret.response.trace = null;
                    SmartLogTail kernelSlt = runner.getKernelLog();
                    if (kernelSlt != null) {
                        slt.append(kernelSlt);
                    }
                    slt.appendLine("Tool ran in " + runTimeMS + "ms");
                    ret.log = slt;
                }
            }
            catch (Exception e) {
                logger.error((Object)"Test failed", (Throwable)e);
                ret.error = new SerializedError((Throwable)e, true);
                if (!(e instanceof ExceptionWithLogTail)) break block32;
                ret.error.logTail = ((ExceptionWithLogTail)e).getLogTail();
            }
        }
        return ret;
    }

    private JsonObject populateRetrievalColumnsIfNotPresent(JsonObject params, RetrievableKnowledge rk) {
        JsonObject out = params.deepCopy();
        if (out.has("retrievalColumns")) {
            JsonArray cols = out.getAsJsonArray("retrievalColumns");
            boolean hasEmbedding = false;
            for (JsonElement el : cols) {
                if (!"DKU_TEXT_EMBEDDING_COLUMN".equals(el.getAsString())) continue;
                hasEmbedding = true;
                break;
            }
            if (!hasEmbedding) {
                cols.add("DKU_TEXT_EMBEDDING_COLUMN");
            }
            return out;
        }
        JsonArray cols = new JsonArray();
        cols.add("DKU_TEXT_EMBEDDING_COLUMN");
        if (rk.metadataColumnsSchema != null) {
            for (SchemaColumn c2 : rk.metadataColumnsSchema) {
                cols.add(c2.getName());
            }
        }
        out.add("retrievalColumns", (JsonElement)cols);
        return out;
    }

    public ToolQuickTestResponse searchKnowledgeBank(AuthCtx authCtx, RetrievableKnowledge rk, String query, JsonObject params) throws Exception {
        AgentTool tool = new AgentTool();
        tool.setProjectKey(rk.projectKey);
        tool.setId("KB_SEARCH_" + rk.id);
        tool.type = VectorStoreQueryTool.META.getType();
        tool.params = this.populateRetrievalColumnsIfNotPresent(params, rk);
        tool.params.addProperty("knowledgeBankRef", rk.id);
        AgentToolRunner.AgentToolInput toolInput = new AgentToolRunner.AgentToolInput();
        JsonObject toolQuery = new JsonObject();
        toolQuery.addProperty("searchQuery", query);
        toolInput.input = toolQuery;
        return this.test(authCtx, tool, rk.projectKey, JSON.json((Object)toolInput), "DKU_INTERNAL_TOOL_CALL", false);
    }

    public void deleteVirtualKBSearchTool(RetrievableKnowledge rk) throws IOException {
        File logsBaseDir = DKUApp.getFile((String[])new String[]{"agent-tools", rk.projectKey, "KB_SEARCH_" + rk.id});
        DKUFileUtils.deleteDirectory((File)logsBaseDir);
    }

    public List<LogsService.LogDesc> listLogs(String projectKey, String id) throws IOException {
        File logsDir = DKUApp.getFile((String[])new String[]{"agent-tools", projectKey, id, "logs"});
        return this.pythonLogsService.listLogs(logsDir);
    }

    public SmartLogTail getLog(String projectKey, String id, String logName) throws IOException {
        File logFile = DKUApp.getFile((String[])new String[]{"agent-tools", projectKey, id, "logs", logName});
        return this.pythonLogsService.getLog(logFile);
    }

    public void streamLog(String projectKey, String id, String logName, OutputStream os) throws IOException {
        File logFile = DKUApp.getFile((String[])new String[]{"agent-tools", projectKey, id, "logs", logName});
        this.pythonLogsService.streamLog(logFile, os);
    }

    public void deleteLog(String projectKey, String id, String logName) throws IOException {
        File logFile = DKUApp.getFile((String[])new String[]{"agent-tools", projectKey, id, "logs", logName});
        this.pythonLogsService.deleteLog(logFile);
    }

    public void clearLogs(String projectKey, String id) throws IOException {
        File logDir = DKUApp.getFile((String[])new String[]{"agent-tools", projectKey, id, "logs"});
        this.pythonLogsService.clearLogs(logDir);
    }

    public static class ToolQuickTestResponse {
        public SerializedError error;
        public AgentToolRunner.AgentToolOutput response;
        public LLMClient.LLMMeshTraceSpan fullTrace;
        public LLMClient.LLMMeshTraceSpan traceOfPython;
        public SmartLogTail log;
    }
}

