/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.agents.tools.mcp;

import com.dataiku.dip.agents.tools.AgentTool;
import com.dataiku.dip.agents.tools.AgentToolMeta;
import com.dataiku.dip.agents.tools.AgentToolParams;
import com.dataiku.dip.agents.tools.AgentToolRunner;
import com.dataiku.dip.agents.tools.utils.JsonSchema;
import com.dataiku.dip.connections.ConnectionsDAO;
import com.dataiku.dip.connections.DSSConnection;
import com.dataiku.dip.connections.RemoteMCPConnection;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.transactions.TransactionContext;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JF;
import com.dataiku.dip.utils.JSON;
import io.modelcontextprotocol.spec.McpSchema;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;

public class RemoteMCPClientTool {
    public static final AgentToolMeta META = new AgentToolMeta(false){

        @Override
        public String getType() {
            return "RemoteMCPClient";
        }

        @Override
        public Class<? extends AgentToolParams> paramsClass() {
            return Params.class;
        }

        @Override
        public List<SavedModel.AgentDependency> getDependencies(AgentTool tool) {
            return new ArrayList<SavedModel.AgentDependency>();
        }

        @Override
        public Set<String> listConnectionNames(AgentTool tool) {
            Params p = tool.getParamsCopyAs(Params.class);
            if (!StringUtils.isBlank((String)p.connectionName)) {
                return Set.of(p.connectionName);
            }
            return new HashSet<String>();
        }

        @Override
        public boolean remapConnections(AgentTool tool, Map<String, String> replacements) {
            Params p = tool.getParamsCopyAs(Params.class);
            if (p == null || StringUtils.isBlank((String)p.connectionName)) {
                return false;
            }
            String newConnection = replacements.get(p.connectionName);
            if (newConnection == null) {
                return false;
            }
            p.connectionName = newConnection;
            tool.setParams(p);
            return true;
        }

        @Override
        public AgentToolMeta.ToolDescriptor getResultingDescriptor(AuthCtx authCtx, String projectKey, AgentTool tool) throws Exception {
            TransactionContext.assertNoAttachedTransaction();
            try (Runner runner = new Runner(authCtx, tool);){
                runner.init();
                AgentToolMeta.ToolDescriptor toolDescriptor = runner.getDescriptor();
                return toolDescriptor;
            }
        }

        @Override
        public AgentToolRunner buildRunner(AuthCtx authCtx, String projectKey, AgentTool tool, boolean devKernel) {
            return new Runner(authCtx, tool);
        }
    };
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.genai.agents.tools");

    public static class Runner
    implements AgentToolRunner {
        private final AuthCtx authCtx;
        private final AgentTool tool;
        private RemoteMCPConnection.RemoteMCPClient mcpClient;
        @Autowired
        private ConnectionsDAO connectionsDAO;

        public Runner(AuthCtx authCtx, AgentTool tool) {
            this.authCtx = authCtx;
            this.tool = tool;
        }

        @Override
        public void init() throws Exception {
            SpringUtils.getInstance().autowire((Object)this);
            Params params = this.tool.getParamsCopyAs(Params.class);
            DSSConnection c2 = this.connectionsDAO.getMandatoryConnection(this.authCtx, params.connectionName);
            if (!(c2 instanceof RemoteMCPConnection)) {
                throw new IllegalArgumentException("Invalid connection type for remote MCP tool: " + c2.getType());
            }
            RemoteMCPConnection connection = (RemoteMCPConnection)c2;
            this.mcpClient = connection.getMCPClient(this.authCtx);
        }

        @Override
        public AgentToolRunner.AgentToolOutput run(AgentToolRunner.AgentToolInput input) throws Exception {
            logger.debug((Object)("Running with input " + JSON.json((Object)input)));
            if (StringUtils.isBlank((String)input.subtoolName)) {
                throw new IllegalArgumentException("Subtool name is required, please set the name of the subtool to use in the 'subtoolName' field.");
            }
            Params params = this.tool.getParamsCopyAs(Params.class);
            if (!params.subtoolsStateOverride.getOrDefault(input.subtoolName, false).booleanValue()) {
                throw new IllegalArgumentException("Subtool \"" + input.subtoolName + "\" is invalid or disabled");
            }
            McpSchema.CallToolResult result = this.mcpClient.callTool(input);
            ArrayList<String> outputs = new ArrayList<String>();
            for (McpSchema.Content content : result.content()) {
                if (!(content instanceof McpSchema.TextContent)) continue;
                McpSchema.TextContent textContext = (McpSchema.TextContent)content;
                outputs.add(textContext.text());
            }
            AgentToolRunner.AgentToolOutput o = new AgentToolRunner.AgentToolOutput();
            if (result.content().size() == 1) {
                if (outputs.isEmpty()) {
                    throw new UnsupportedOperationException("Cannot handle this type of content: " + ((McpSchema.Content)result.content().get(0)).type());
                }
                o.output = JF.obj().with("text", (String)outputs.get(0)).get();
            } else {
                o.output = JF.obj().with("text", outputs).get();
            }
            o.sources = List.of();
            o.artifacts = List.of();
            return o;
        }

        public AgentToolMeta.ToolDescriptor getDescriptor() {
            Params params = this.tool.getParamsCopyAs(Params.class);
            AgentToolMeta.ToolDescriptor descriptor = new AgentToolMeta.ToolDescriptor(this.tool.name);
            descriptor.multiple = true;
            descriptor.description = "";
            if (StringUtils.isNotBlank((String)this.tool.additionalDescriptionForLLM)) {
                descriptor.description = this.tool.additionalDescriptionForLLM;
            }
            McpSchema.ListToolsResult tools = this.mcpClient.listTools();
            descriptor.subtools = new ArrayList<AgentToolMeta.SubtoolDescriptor>();
            for (McpSchema.Tool subtool : tools.tools()) {
                AgentToolMeta.SubtoolDescriptor st2 = new AgentToolMeta.SubtoolDescriptor();
                st2.name = subtool.name();
                st2.description = subtool.description();
                st2.inputSchema = (JsonSchema)JSON.parse((String)JSON.json((Object)subtool.inputSchema()), JsonSchema.class);
                st2.enabled = params.subtoolsStateOverride.getOrDefault(st2.name, false);
                descriptor.subtools.add(st2);
            }
            return descriptor;
        }

        @Override
        public void close() throws Exception {
            if (this.mcpClient != null) {
                this.mcpClient.close();
            }
        }
    }

    public static class Params
    implements AgentToolParams {
        public String connectionName;
        public HashMap<String, Boolean> subtoolsStateOverride = new HashMap();
    }
}

