/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.agents.tools.ml;

import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.agents.tools.AgentTool;
import com.dataiku.dip.agents.tools.AgentToolMeta;
import com.dataiku.dip.agents.tools.AgentToolParams;
import com.dataiku.dip.agents.tools.AgentToolRunner;
import com.dataiku.dip.agents.tools.utils.JsonSchema;
import com.dataiku.dip.agents.tools.utils.JsonSchemaElement;
import com.dataiku.dip.apideployer.deployments.APIServiceDeploymentsService;
import com.dataiku.dip.apideployer.deployments.QueryableEndpoint;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.nodeclients.APIDeployerClientProxyUser;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.openapi.OpenAPIHelper;
import com.dataiku.dip.server.openapi.SwaggerSchemaParser;
import com.dataiku.dip.transactions.TransactionContext;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JF;
import com.dataiku.dss.shadelib.io.swagger.models.Swagger;
import com.dataiku.lambda.client.BaseLambdaAPIClient;
import com.dataiku.lambda.model.studioconfig.ApiEndpointQuery;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;

public class ApiEndpointTool {
    public static final AgentToolMeta META = new AgentToolMeta(false){

        @Override
        public String getType() {
            return "ApiEndpoint";
        }

        @Override
        public Class<? extends AgentToolParams> paramsClass() {
            return Params.class;
        }

        @Override
        public List<SavedModel.AgentDependency> getDependencies(AgentTool tool) {
            return new ArrayList<SavedModel.AgentDependency>();
        }

        @Override
        public AgentToolMeta.ToolDescriptor getResultingDescriptor(AuthCtx authCtx, String projectKey, AgentTool tool) throws IOException, DKUSecurityException {
            Swagger openAPISwagger;
            SwaggerSchemaParser parser;
            JsonSchemaElement inputSchema;
            TransactionContext.assertNoAttachedTransaction();
            Params p = tool.getParamsCopyAs(Params.class);
            GeneralSettingsDAO.GeneralSettings gs = ApplicationConfigurator.getGeneralSettingsUnsafeAutoTXN();
            gs.deployerClientSettings.check();
            GeneralSettingsDAO.DeployerRef deployerRef = gs.deployerClientSettings.getRef();
            AgentToolMeta.ToolDescriptor td = new AgentToolMeta.ToolDescriptor(tool.name);
            logger.info((Object)"Get resulting descriptor of API endpoint %s of deployment %s".formatted(p.endpointId, p.deploymentId));
            td.description = "Calls the API endpoint";
            if (p.deploymentId == null || p.endpointId == null) {
                td.description = td.description + "(error: no endpoint selected).";
            } else {
                td.description = td.description + ": " + p.deploymentId + " > " + p.endpointId + " \n\n ";
                td.description = td.description + "Provide the data to send to the API Endpoint as a single JSON dictionary called \"data\", with the payload necessary for the input of the endpoint.\n ";
            }
            JsonSchemaElement data = JsonSchemaElement.object("The body of the query made to the endpoint.");
            QueryableEndpoint endpoint = this.getEndpointQueryableThroughDeployer(authCtx, deployerRef, p.deploymentId, p.endpointId);
            if (!endpoint.hasOpenAPIDocumentation || endpoint.deploymentOpenApiDocContent == null) {
                throw new IllegalArgumentException("Your published package must have an OpenAPI documentation.");
            }
            if (StringUtils.isNotBlank((String)endpoint.description)) {
                td.description = td.description + endpoint.description + " \n ";
            }
            if ((inputSchema = (parser = new SwaggerSchemaParser(openAPISwagger = OpenAPIHelper.getSwagger(endpoint.deploymentOpenApiDocContent))).parseInputSchema(p.endpointId, endpoint.type)) != null && !inputSchema.properties.isEmpty()) {
                data.properties.putAll(inputSchema.properties);
            } else {
                logger.debugV("No input schema found for endpoint %s", new Object[]{p.endpointId});
            }
            if (StringUtils.isNotBlank((String)tool.additionalDescriptionForLLM)) {
                td.description = td.description + "\n\n" + tool.additionalDescriptionForLLM;
            }
            td.inputSchema = JsonSchema.newObject("https://dataiku.com/agents/tools/ml/api-endpoint/query", "Input query for the API endpoint");
            td.inputSchema.properties.put("data", data);
            return td;
        }

        @Override
        public AgentToolRunner buildRunner(AuthCtx authCtx, String projectKey, AgentTool tool, boolean devKernel) {
            return new Runner(authCtx, projectKey, tool.getParamsCopyAs(Params.class));
        }

        private QueryableEndpoint getEndpointQueryableThroughDeployer(AuthCtx authCtx, GeneralSettingsDAO.DeployerRef deployerRef, String deploymentId, String endpointId) throws IOException, DKUSecurityException {
            APIServiceDeploymentsService deploymentService = (APIServiceDeploymentsService)SpringUtils.getBean(APIServiceDeploymentsService.class);
            if (deployerRef.mode == GeneralSettingsDAO.DeployerMode.LOCAL) {
                return deploymentService.getEndpointQueryableThroughDeployer_Check(authCtx, deploymentId, endpointId);
            }
            try (APIDeployerClientProxyUser client = new APIDeployerClientProxyUser(deployerRef, authCtx);){
                QueryableEndpoint queryableEndpoint = client.getEndpointQueryableThroughDeployer(deploymentId, endpointId);
                return queryableEndpoint;
            }
        }
    };
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.agents.tools.ml.apiendpoint");

    public static class Runner
    implements AgentToolRunner {
        private final Params params;
        private final AuthCtx authCtx;
        private final String projectKey;
        private final GeneralSettingsDAO.DeployerRef deployerRef;
        @Autowired
        private APIServiceDeploymentsService apiServiceDeploymentsService;

        public Runner(AuthCtx authCtx, String projectKey, Params p) {
            this.authCtx = authCtx;
            this.projectKey = projectKey;
            this.params = p;
            GeneralSettingsDAO.GeneralSettings gs = ApplicationConfigurator.getGeneralSettingsUnsafeAutoTXN();
            gs.deployerClientSettings.check();
            this.deployerRef = gs.deployerClientSettings.getRef();
        }

        @Override
        public void init() throws Exception {
            SpringUtils.getInstance().autowire((Object)this);
            if (StringUtils.isBlank((String)this.params.deploymentId) || StringUtils.isBlank((String)this.params.endpointId)) {
                throw new IllegalArgumentException("No api endpoint specified in tool.");
            }
            if (this.deployerRef == null || this.deployerRef.mode == null || GeneralSettingsDAO.DeployerMode.DISABLED.equals((Object)this.deployerRef.mode)) {
                throw new IllegalArgumentException("Deployer is disabled.");
            }
        }

        @Override
        public AgentToolRunner.AgentToolOutput run(AgentToolRunner.AgentToolInput input) throws Exception {
            BaseLambdaAPIClient.ResponseOrError responseOrError;
            block14: {
                JsonObject data = this.safeReadObjectArgument(input, "data");
                ApiEndpointQuery query = new ApiEndpointQuery("query", data);
                try {
                    if (this.deployerRef.mode == GeneralSettingsDAO.DeployerMode.LOCAL) {
                        responseOrError = (BaseLambdaAPIClient.ResponseOrError)this.apiServiceDeploymentsService.runQueries_NT_Check((String)this.params.deploymentId, (String)this.params.endpointId, (AuthCtx)this.authCtx, List.of(query), null, (boolean)false, (boolean)true).responses.stream().findFirst().orElseThrow();
                        break block14;
                    }
                    try (APIDeployerClientProxyUser client = new APIDeployerClientProxyUser(this.deployerRef, this.authCtx);){
                        responseOrError = (BaseLambdaAPIClient.ResponseOrError)client.runQueries((String)this.params.deploymentId, (String)this.params.endpointId, List.of(query)).responses.stream().findFirst().orElseThrow();
                    }
                }
                catch (Exception e) {
                    throw new IllegalArgumentException("An error happened during the execution of queries", e);
                }
            }
            AgentToolRunner.AgentToolOutput ret = new AgentToolRunner.AgentToolOutput();
            if (responseOrError.error != null) {
                ret.error = responseOrError.error.detailedMessage;
            } else if (responseOrError.response != null) {
                if (responseOrError.response instanceof JsonElement) {
                    JsonObject responseInJson = ((JsonElement)responseOrError.response).getAsJsonObject();
                    if (responseInJson.has("result") && responseInJson.get("result").isJsonObject()) {
                        responseInJson = responseInJson.getAsJsonObject("result");
                    }
                    ret.output = responseInJson;
                } else {
                    JF.ObjectBuilder ob = JF.obj();
                    ob.with("result", responseOrError.response.toString());
                    ret.output = ob.get();
                }
            }
            return ret;
        }

        @Override
        public void close() throws Exception {
        }
    }

    public static class Params
    implements AgentToolParams {
        public String deploymentId;
        public String endpointId;
    }
}

