/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.agents.tools.mcp;

import com.dataiku.dip.agents.tools.AgentTool;
import com.dataiku.dip.agents.tools.AgentToolMeta;
import com.dataiku.dip.agents.tools.AgentToolParams;
import com.dataiku.dip.agents.tools.AgentToolRunner;
import com.dataiku.dip.agents.tools.utils.JsonSchema;
import com.dataiku.dip.connections.ConnectionsDAO;
import com.dataiku.dip.connections.DSSConnection;
import com.dataiku.dip.connections.RemoteMCPConnection;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.transactions.TransactionContext;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JF;
import com.dataiku.dip.utils.JSON;
import io.modelcontextprotocol.spec.McpSchema;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;

public class RemoteMCPClientTool {
    public static final AgentToolMeta META = new AgentToolMeta(false){

        @Override
        public String getType() {
            return "RemoteMCPClient";
        }

        @Override
        public Class<? extends AgentToolParams> paramsClass() {
            return Params.class;
        }

        @Override
        public List<SavedModel.AgentDependency> getDependencies(AgentTool tool) {
            return new ArrayList<SavedModel.AgentDependency>();
        }

        @Override
        public Set<String> listConnectionNames(AgentTool tool) {
            Params p = tool.getParamsCopyAs(Params.class);
            if (!StringUtils.isBlank((String)p.connectionName)) {
                return Set.of(p.connectionName);
            }
            return new HashSet<String>();
        }

        @Override
        public boolean remapConnections(AgentTool tool, Map<String, String> replacements) {
            Params p = tool.getParamsCopyAs(Params.class);
            if (p == null || StringUtils.isBlank((String)p.connectionName)) {
                return false;
            }
            String newConnection = replacements.get(p.connectionName);
            if (newConnection == null) {
                return false;
            }
            p.connectionName = newConnection;
            tool.setParams(p);
            return true;
        }

        @Override
        public AgentToolMeta.ToolDescriptor getResultingDescriptor(AuthCtx authCtx, String projectKey, AgentTool tool) throws Exception {
            TransactionContext.assertNoAttachedTransaction();
            try (Runner runner = new Runner(authCtx, tool);){
                runner.init();
                AgentToolMeta.ToolDescriptor toolDescriptor = runner.getDescriptor();
                return toolDescriptor;
            }
        }

        @Override
        public AgentToolMeta.ToolCallDescription getToolCallDescription_NT(AuthCtx authCtx, String projectKey, AgentTool tool, LLMClient.FunctionTool descriptor, AgentToolRunner.AgentToolInput input) {
            TransactionContext.assertNoAttachedTransaction();
            Params params = tool.getParamsCopyAs(Params.class);
            Object description = String.format("I'm about to call subtool <b>%s</b> of remote MCP server defined in connection <b>%s</b> with the following input.%n", input.subtoolName, params.connectionName);
            description = (String)description + "\n";
            description = (String)description + "Do you want to proceed?";
            return new AgentToolMeta.ToolCallDescription((String)description);
        }

        @Override
        public AgentToolRunner buildRunner(AuthCtx authCtx, String projectKey, AgentTool tool, boolean devKernel) {
            return new Runner(authCtx, tool);
        }
    };
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.genai.agents.tools");

    public static class Runner
    implements AgentToolRunner {
        private final AuthCtx authCtx;
        private final AgentTool tool;
        private RemoteMCPConnection.RemoteMCPClient mcpClient;
        @Autowired
        private ConnectionsDAO connectionsDAO;

        public Runner(AuthCtx authCtx, AgentTool tool) {
            this.authCtx = authCtx;
            this.tool = tool;
        }

        @Override
        public void init() throws Exception {
            SpringUtils.getInstance().autowire((Object)this);
            Params params = this.tool.getParamsCopyAs(Params.class);
            DSSConnection c2 = this.connectionsDAO.getMandatoryConnection(this.authCtx, params.connectionName);
            if (!(c2 instanceof RemoteMCPConnection)) {
                throw new IllegalArgumentException("Invalid connection type for remote MCP tool: " + c2.getType());
            }
            RemoteMCPConnection connection = (RemoteMCPConnection)c2;
            this.mcpClient = connection.getMCPClient(this.authCtx);
        }

        @Override
        public AgentToolRunner.AgentToolOutput run(AgentToolRunner.AgentToolInput input) throws Exception {
            logger.debug((Object)("Running with input " + JSON.log((Object)input)));
            if (StringUtils.isBlank((String)input.subtoolName)) {
                throw new IllegalArgumentException("Subtool name is required, please set the name of the subtool to use in the 'subtoolName' field.");
            }
            Params params = this.tool.getParamsCopyAs(Params.class);
            if (!params.subtoolsStateOverride.getOrDefault(input.subtoolName, false).booleanValue()) {
                throw new IllegalArgumentException("Subtool \"" + input.subtoolName + "\" is invalid or disabled");
            }
            McpSchema.CallToolResult result = this.mcpClient.callTool(input);
            ArrayList<LLMClient.ToolOutputPart> parts = new ArrayList<LLMClient.ToolOutputPart>();
            ArrayList<LLMClient.SourceItem> artifactParts = new ArrayList<LLMClient.SourceItem>();
            Boolean structured_content_found = false;
            if (result.structuredContent() != null) {
                structured_content_found = true;
                LLMClient.ToolOutputPart part = new LLMClient.ToolOutputPart();
                part.withText(JSON.json((Object)result.structuredContent()));
                parts.add(part);
            }
            for (McpSchema.Content content : result.content()) {
                LLMClient.ToolOutputPart part;
                if (content instanceof McpSchema.TextContent) {
                    McpSchema.TextContent textContent = (McpSchema.TextContent)content;
                    if (!structured_content_found.booleanValue()) {
                        part = new LLMClient.ToolOutputPart();
                        part.withText(textContent.text());
                        parts.add(part);
                        continue;
                    }
                }
                if (content instanceof McpSchema.ImageContent) {
                    McpSchema.ImageContent imageContent = (McpSchema.ImageContent)content;
                    if (params.imageHandlingMode == AgentToolRunner.AgentToolOutputImageHandlingMode.ONLY_ADD_AS_ARTIFACT || params.imageHandlingMode == AgentToolRunner.AgentToolOutputImageHandlingMode.ADD_AS_ARTIFACT_AND_SEND_TO_LLM) {
                        LLMClient.SourceItem artifactPart = new LLMClient.SourceItem();
                        artifactPart.type = "DATA_INLINE";
                        artifactPart.dataBase64 = imageContent.data();
                        artifactPart.mimeType = imageContent.mimeType();
                        artifactParts.add(artifactPart);
                    }
                    if (params.imageHandlingMode == AgentToolRunner.AgentToolOutputImageHandlingMode.ONLY_SEND_TO_LLM || params.imageHandlingMode == AgentToolRunner.AgentToolOutputImageHandlingMode.ADD_AS_ARTIFACT_AND_SEND_TO_LLM) {
                        part = new LLMClient.ToolOutputPart();
                        part.withInlineImage(imageContent.data(), imageContent.mimeType());
                        parts.add(part);
                    }
                    if (params.imageHandlingMode != AgentToolRunner.AgentToolOutputImageHandlingMode.IGNORE) continue;
                    logger.debug((Object)"Ignoring image part");
                    continue;
                }
                logger.warn((Object)("Ignoring unsupported response content " + content.type()));
            }
            AgentToolRunner.AgentToolOutput o = new AgentToolRunner.AgentToolOutput();
            if (parts.isEmpty()) {
                throw new UnsupportedOperationException("No supported content returned by tool call, this might be solved by using another image handling mode.");
            }
            if (parts.size() == 1 && ((LLMClient.ToolOutputPart)parts.get((int)0)).type == LLMClient.ChatMessagePartType.TEXT) {
                o.output = JF.obj().with("text", parts.get((int)0).text).get();
            } else {
                o.parts = parts;
            }
            if (!artifactParts.isEmpty()) {
                LLMClient.Artifact artifact = new LLMClient.Artifact();
                artifact.type = "DATA_INLINE";
                artifact.name = "Images from remote MCP tool " + this.tool.getDisplayName();
                artifact.parts = artifactParts;
                o.artifacts = List.of(artifact);
            }
            return o;
        }

        public AgentToolMeta.ToolDescriptor getDescriptor() {
            Params params = this.tool.getParamsCopyAs(Params.class);
            AgentToolMeta.ToolDescriptor descriptor = new AgentToolMeta.ToolDescriptor(this.tool.name);
            descriptor.multiple = true;
            descriptor.description = "";
            if (StringUtils.isNotBlank((String)this.tool.additionalDescriptionForLLM)) {
                descriptor.description = this.tool.additionalDescriptionForLLM;
            }
            McpSchema.ListToolsResult tools = this.mcpClient.listTools();
            descriptor.subtools = new ArrayList<AgentToolMeta.SubtoolDescriptor>();
            for (McpSchema.Tool subtool : tools.tools()) {
                AgentToolMeta.SubtoolDescriptor st2 = new AgentToolMeta.SubtoolDescriptor();
                st2.name = subtool.name();
                st2.description = subtool.description();
                st2.inputSchema = (JsonSchema)JSON.parse((String)JSON.json((Object)subtool.inputSchema()), JsonSchema.class);
                st2.enabled = params.subtoolsStateOverride.getOrDefault(st2.name, false);
                descriptor.subtools.add(st2);
            }
            return descriptor;
        }

        @Override
        public void close() throws Exception {
            if (this.mcpClient != null) {
                this.mcpClient.close();
            }
        }
    }

    public static class Params
    implements AgentToolParams {
        public String connectionName;
        public HashMap<String, Boolean> subtoolsStateOverride = new HashMap();
        public AgentToolRunner.AgentToolOutputImageHandlingMode imageHandlingMode = AgentToolRunner.AgentToolOutputImageHandlingMode.ONLY_ADD_AS_ARTIFACT;
    }
}

