/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.agents.tools.llmmesh;

import com.dataiku.dip.agents.tools.AgentTool;
import com.dataiku.dip.agents.tools.AgentToolMeta;
import com.dataiku.dip.agents.tools.AgentToolParams;
import com.dataiku.dip.agents.tools.AgentToolRunner;
import com.dataiku.dip.agents.tools.utils.JsonSchema;
import com.dataiku.dip.agents.tools.utils.JsonSchemaElement;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.exceptions.CodedException;
import com.dataiku.dip.exceptions.ForbiddenObjectException;
import com.dataiku.dip.futures.FutureProgress;
import com.dataiku.dip.futures.FutureProgressState;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMAuditHelper;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.governance.GuardrailsPipelineSettings;
import com.dataiku.dip.llm.governance.GuardrailsPipelineUtils;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMMeshClient;
import com.dataiku.dip.llm.online.LLMMeshClientFactory;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.resourceusage.ComputeResourceUsageReportingService;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.audit.AuditTrailService;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.services.ITaggingService;
import com.dataiku.dip.server.services.ProjectsService;
import com.dataiku.dip.transactions.TransactionContext;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JF;
import com.dataiku.dip.utils.JSON;
import com.google.common.collect.Lists;
import com.google.gson.JsonElement;
import com.google.gson.JsonSyntaxException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;

public class LLMMeshLLMQueryTool {
    public static final AgentToolMeta META = new AgentToolMeta(false){

        @Override
        public String getType() {
            return "LLMMeshLLMQuery";
        }

        @Override
        public Class<? extends AgentToolParams> paramsClass() {
            return Params.class;
        }

        @Override
        public List<SavedModel.AgentDependency> getDependencies(AgentTool tool) {
            LLMStructuredRef llmStructuredRef;
            Params p = tool.getParamsCopyAs(Params.class);
            if (p != null && StringUtils.isNotBlank((String)p.llmId) && (llmStructuredRef = LLMStructuredRef.decodeId(p.llmId)) != null && llmStructuredRef.savedModelSmartId != null) {
                return Lists.newArrayList((Object[])new SavedModel.AgentDependency[]{new SavedModel.AgentDependency(ITaggingService.TaggableType.SAVED_MODEL, llmStructuredRef.savedModelSmartId)});
            }
            return new ArrayList<SavedModel.AgentDependency>();
        }

        @Override
        public void checkAccessDependency(AuthCtx authCtx, AgentTool tool) throws IOException, ForbiddenObjectException {
            Params params = tool.getParamsCopyAs(Params.class);
            if (StringUtils.isBlank((String)params.llmId)) {
                logger.warn((Object)"No LLM selected. Skipping access check to dependency.");
                return;
            }
            LLMStructuredRef llmRef = LLMStructuredRef.decodeId(params.llmId);
            if (llmRef.savedModelSmartId != null) {
                AnyLoc llmLoc = AnyLoc.resolveSmart(tool.projectKey, llmRef.savedModelSmartId);
                ((ProjectsService)SpringUtils.getBean(ProjectsService.class)).failIfLocNotAvailableInProject(ITaggingService.TaggableType.SAVED_MODEL, llmLoc, tool.projectKey);
            }
        }

        @Override
        public Set<String> listConnectionNames(AgentTool tool) {
            LLMStructuredRef llmStructuredRef;
            Params p = tool.getParamsCopyAs(Params.class);
            if (p != null && StringUtils.isNotBlank((String)p.llmId) && (llmStructuredRef = LLMStructuredRef.decodeId(p.llmId)) != null && llmStructuredRef.connection != null) {
                return Set.of(llmStructuredRef.connection);
            }
            return new HashSet<String>();
        }

        @Override
        public boolean remapConnections(AgentTool tool, Map<String, String> replacements) {
            Params p = tool.getParamsCopyAs(Params.class);
            if (p == null || p.llmId == null) {
                return false;
            }
            LLMStructuredRef ref = LLMStructuredRef.decodeId(p.llmId);
            if (ref == null || ref.connection == null) {
                return false;
            }
            String newConnection = replacements.get(ref.connection);
            if (newConnection == null) {
                return false;
            }
            ref.setConnection(newConnection);
            p.llmId = ref.encodeToId();
            tool.setParams(p);
            return true;
        }

        @Override
        public AgentToolMeta.ToolDescriptor getResultingDescriptor(AuthCtx authCtx, String projectKey, AgentTool tool) throws IOException {
            TransactionContext.assertNoAttachedTransaction();
            Params p = tool.getParamsCopyAs(Params.class);
            AgentToolMeta.ToolDescriptor td = new AgentToolMeta.ToolDescriptor(tool.name);
            td.description = "Asks a question to an agent.";
            if (StringUtils.isNotBlank((String)tool.additionalDescriptionForLLM)) {
                td.description = td.description + "\n\n" + tool.additionalDescriptionForLLM;
            }
            td.inputSchema = JsonSchema.newObject("https://dataiku.com/agents/tools/llm-mesh/llm/query", "Query an LLM using the LLM Mesh");
            td.inputSchema.properties.put("question", JsonSchemaElement.string("the question to ask"));
            return td;
        }

        @Override
        public AgentToolRunner buildRunner(AuthCtx authCtx, String projectKey, AgentTool tool, boolean devKernel) throws CodedException {
            return new Runner(authCtx, projectKey, tool.projectKey, tool.getParamsCopyAs(Params.class));
        }
    };
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.agents.tools.llm");

    public static class Runner
    implements AgentToolRunner {
        @Autowired
        private ComputeResourceUsageReportingService cruReportingService;
        @Autowired
        private AuditTrailService auditTrailService;
        private final AuthCtx authCtx;
        private final String contextProjectKey;
        private final String sourceProjectKey;
        private final Params params;

        public Runner(AuthCtx authCtx, String contextProjectKey, String sourceProjectKey, Params p) {
            this.authCtx = authCtx;
            this.contextProjectKey = contextProjectKey;
            this.sourceProjectKey = sourceProjectKey;
            this.params = p;
        }

        @Override
        public void init() throws IOException {
            SpringUtils.getInstance().autowire((Object)this);
        }

        private LLMStructuredRef getLLMRef() {
            if (this.params.llmId == null) {
                throw new IllegalArgumentException("No LLM/Agent selected");
            }
            LLMStructuredRef llmRef = LLMStructuredRef.decodeId(this.params.llmId);
            if (!Objects.equals(this.contextProjectKey, this.sourceProjectKey) && llmRef.savedModelSmartId != null) {
                AnyLoc llmLoc = AnyLoc.resolveSmart(this.sourceProjectKey, llmRef.savedModelSmartId);
                llmRef.savedModelSmartId = llmLoc.getFullName();
                llmRef.id = llmRef.encodeToId();
            }
            return llmRef;
        }

        @Override
        public AgentToolRunner.AgentToolOutput run(AgentToolRunner.AgentToolInput input) throws Exception {
            logger.debug((Object)("Running with input " + JSON.json((Object)input)));
            String question = this.safeReadStringArgument(input, "question");
            LLMStructuredRef llmRef = this.getLLMRef();
            GuardrailsPipelineSettings connectionGuardrailsPipelineSettings = GuardrailsPipelineUtils.getConnectionAndLLMLevelSettings(this.authCtx, this.sourceProjectKey, llmRef);
            GuardrailsPipelineSettings usageTimeGuardrailsPipelineSettings = null;
            GuardrailsPipelineSettings guardrailsPipelineSettings = GuardrailsPipelineUtils.mergeEnforcementSettings(connectionGuardrailsPipelineSettings, usageTimeGuardrailsPipelineSettings);
            try (LLMMeshClient llmMeshClient = LLMMeshClientFactory.get(this.authCtx, this.contextProjectKey, llmRef, guardrailsPipelineSettings, null, 1);){
                JsonElement sourcesObject;
                EnrichedLLMStructuredRef enrichedLLMRef = llmMeshClient.getEnrichedRef();
                LLMClient.SingleCompletionQuery query = new LLMClient.SingleCompletionQuery();
                if (this.params.forwardContext.booleanValue()) {
                    query.context = input.context;
                }
                if (StringUtils.isNotBlank((String)this.params.systemPromptPrepend)) {
                    query.messages.add(new LLMClient.ChatMessage("system", this.params.systemPromptPrepend));
                }
                query.messages.add(new LLMClient.ChatMessage("user", question));
                List<LLMClient.SimpleCompletionResponseOrError> responses = null;
                try (FutureProgress.AutocloseableFutureProgressState _ignored = FutureProgress.pushAutoCloseableState((String)"Querying LLM", (double)1.0, (FutureProgressState.StateUnit)FutureProgressState.StateUnit.RECORDS);){
                    responses = llmMeshClient.completeQueries(Lists.newArrayList((Object[])new LLMClient.SingleCompletionQuery[]{query}), this.params.completionSettings);
                }
                ComputeResourceUsage cru = llmMeshClient.getTotalCRU(ComputeResourceUsage.LLMUsageType.COMPLETION);
                if (cru != null) {
                    this.cruReportingService.reportComplete(cru);
                }
                assert (responses.size() == 1);
                LLMClient.SimpleCompletionResponseOrError resp = responses.get(0);
                if (!resp.ok) {
                    throw new RuntimeException("LLM query failed: " + resp.errorMessage);
                }
                AgentToolRunner.AgentToolOutput o = new AgentToolRunner.AgentToolOutput();
                o.output = JF.obj().with("response", resp.text).get();
                if (this.params.returnArtifacts.booleanValue() && resp.artifacts != null) {
                    o.artifacts = resp.artifacts;
                }
                if (this.params.returnSources.booleanValue() && resp.additionalInformation != null && (sourcesObject = resp.additionalInformation.get("sources")) != null && sourcesObject.isJsonArray()) {
                    ArrayList<AgentToolRunner.Source> parsedSources = new ArrayList<AgentToolRunner.Source>();
                    sourcesObject.getAsJsonArray().forEach(jsonSource -> {
                        try {
                            parsedSources.add((AgentToolRunner.Source)JSON.parse((String)jsonSource.toString(), AgentToolRunner.Source.class));
                        }
                        catch (JsonSyntaxException e) {
                            logger.warn((Object)"Failed to parse source from LLM call", (Throwable)e);
                        }
                    });
                    o.sources = parsedSources;
                }
                o.trace = resp.trace;
                LLMAuditHelper.emitLLMCompletionAuditFromBackendIfNeeded(this.auditTrailService, enrichedLLMRef, llmMeshClient.getConnection(), query, resp);
                AgentToolRunner.AgentToolOutput agentToolOutput = o;
                return agentToolOutput;
            }
        }

        @Override
        public void close() throws Exception {
        }
    }

    public static class Params
    implements AgentToolParams {
        public String llmId;
        public String systemPromptPrepend;
        public LLMClient.CompletionSettings completionSettings = new LLMClient.CompletionSettings();
        public Boolean forwardContext = false;
        public Boolean returnArtifacts = false;
        public Boolean returnSources = false;
    }
}

