/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.retrieval;

import com.dataiku.dip.DKUApp;
import com.dataiku.dip.agents.tools.filtering.SimpleFilter;
import com.dataiku.dip.code.CodeEnvSelection;
import com.dataiku.dip.code.CodeEnvSelector;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.dataflow.exec.CodeBasedRecipeDatasetInfoHelper;
import com.dataiku.dip.io.JavaBlockLink;
import com.dataiku.dip.io.SimplePythonKernel;
import com.dataiku.dip.io.SimplePythonKernelFactory;
import com.dataiku.dip.license.LicenseStatusService;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.governance.GuardrailsCodes;
import com.dataiku.dip.llm.governance.GuardrailsPipelineRunner;
import com.dataiku.dip.llm.io.PythonRequestUtils;
import com.dataiku.dip.llm.io.commands.ProcessSinglePromptCommand;
import com.dataiku.dip.llm.langchain.PythonLLMServer;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.retrieval.RAGLLMSettings;
import com.dataiku.dip.llm.retrieval.RetrievableKnowledge;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.DSSAuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.services.licensing.AbstractLicenseFeaturesStatusBuilder;
import com.dataiku.dip.server.services.licensing.LicenseFeaturesStatusBuilder;
import com.dataiku.dip.util.SecretKeyGenerator;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.SmartLogTail;
import com.dataiku.dss.shadelib.com.google.common.base.Strings;
import com.google.gson.JsonElement;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import org.apache.log4j.Logger;

public class LangchainBasedRAGServer {
    protected final DSSAuthCtx authCtx;
    protected final String projectKey;
    protected final LLMStructuredRef llmRef;
    protected final RetrievableKnowledge rk;
    protected final String envName;
    protected final File logBaseDir;
    protected final RAGLLMSettings ragSettings;
    protected final LLMClient.CompletionSettings completionSettings;
    private final String containerConfName;
    private SimplePythonKernel kernel;
    private final String kernelId;
    private final boolean devKernel;
    private final LicenseStatusService licenseStatusService = (LicenseStatusService)SpringUtils.getBean(LicenseStatusService.class);
    private static final Logger logger = Logger.getLogger((String)"dku.llm.rag.server");

    public LangchainBasedRAGServer(AuthCtx authCtx, String projectKey, LLMStructuredRef llmRef, RetrievableKnowledge rk, RAGLLMSettings ragSettings, LLMClient.CompletionSettings completionSettings, CodeEnvSelection codeEnvSelection, String containerConfName, File logBaseDir, boolean devKernel) throws IOException {
        this.authCtx = (DSSAuthCtx)authCtx;
        this.projectKey = projectKey;
        this.llmRef = llmRef;
        this.rk = rk;
        this.ragSettings = ragSettings;
        this.completionSettings = completionSettings;
        this.containerConfName = containerConfName;
        this.envName = new CodeEnvSelector().selectForPythonRecipe(projectKey, codeEnvSelection);
        this.logBaseDir = logBaseDir;
        this.kernelId = "rag-" + rk.projectKey + "-" + rk.id + "-" + SecretKeyGenerator.generateSmall();
        this.devKernel = devKernel;
    }

    public void close() {
        if (this.kernel != null) {
            try {
                this.kernel.close();
            }
            catch (Throwable e) {
                logger.error((Object)"Failed to kill kernel", e);
            }
        }
    }

    public boolean isAlive() {
        return this.kernel != null && this.kernel.isAlive();
    }

    public String getKernelId() {
        return this.kernelId;
    }

    public SmartLogTail getKernelLog() {
        if (this.kernel == null) {
            return null;
        }
        return this.kernel.getSmartLogTailBuilder().get();
    }

    public CompletableFuture<LLMClient.SimpleCompletionResponse> processAsync(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings) {
        return this.kernel.getLink().getAsyncLink().asyncSendRequest((Object)new ProcessSinglePromptCommand(query, settings, false), LLMClient.SimpleCompletionResponseOrError.class, this::mapRAGResponse);
    }

    public CompletableFuture<Integer> streamProcessAsync(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings, LLMClient.StreamedCompletionResponseConsumer consumer) {
        return PythonRequestUtils.asyncStreamRequest(this.kernel.getAsyncLink(), query, settings, consumer);
    }

    public void init() throws Exception {
        this.kernel = SimplePythonKernelFactory.prepareKernel(this.authCtx, this.projectKey, GeneralSettingsDAO.CGrouppableProcessType.ML_KERNEL, this.envName, "dataiku.llm.rag.rag_query_server", false, this.containerConfName, this.kernelId);
        StartCommand command = new StartCommand();
        command.defaultCompletionSettings = this.completionSettings;
        command.llmRef = this.llmRef;
        command.knowledgeBankFullId = this.rk.getFullId();
        LicenseStatusService.LicensingStatus ls = this.licenseStatusService.getLicensingStatus();
        AbstractLicenseFeaturesStatusBuilder.LicenseFeaturesStatus licenseFeaturesStatus = LicenseFeaturesStatusBuilder.getFeaturesStatus(ls);
        if (!licenseFeaturesStatus.advancedLLMMeshAllowed && this.ragSettings.hasGuardrailsEnabled()) {
            logger.warn((Object)"Guardrails enabled but the current license does not allow advanced LLM Mesh features. Disabling guardrails.");
            this.ragSettings.disableGuardrails();
        }
        if (this.ragSettings.hasGuardrailsEnabled()) {
            if (this.ragSettings.retrievalSource == RAGLLMSettings.RetrievalSource.MULTIMODAL) {
                logger.info((Object)String.format("Disabling text guardrails for rag model from KB '%s' as its retrieval content can be multimodal.", this.rk.name));
                this.ragSettings.disableTextGuardrails();
            } else {
                logger.info((Object)String.format("Disabling multimodal guardrails for rag model from KB '%s' as its retrieval content is only embeddings.", this.rk.name));
                this.ragSettings.disableMultimodalGuardrails();
            }
        }
        command.ragSettings = this.ragSettings.withDefaultPrompts();
        if (this.ragSettings.filter != null && this.ragSettings.performFiltering) {
            command.filter = SimpleFilter.fromComplexFilter(this.ragSettings.filter, Optional.empty());
        }
        HashMap<String, String> extraEnv = new HashMap<String, String>();
        if (!Strings.isNullOrEmpty((String)this.rk.connection)) {
            extraEnv.put("DKU_KB_CONNECTION_INFO", JSON.json((Object)new CodeBasedRecipeDatasetInfoHelper().getConnectionInfoUnsafe_NT(this.authCtx, this.rk.connection, this.projectKey)));
        }
        this.kernel.withExtraEnv(extraEnv);
        int maxDevKeptFiles = DKUApp.getParams().getIntParam("dku.llm.rag.logs.maxDevKeptFiles", Integer.valueOf(10));
        File logDir = PythonLLMServer.initKernelLogDir(this.logBaseDir, this.kernelId, false, this.devKernel, maxDevKeptFiles);
        int maxRotated = DKUApp.getParams().getIntParam("dku.llm.rag.logs.maxRotatedFiles", Integer.valueOf(2));
        long maxSize = DKUApp.getParams().getLongParam("dku.llm.rag.logs.maxFileSizeKB", 1024L) * 1024L;
        int maxDevTailLines = this.devKernel ? DKUApp.getParams().getIntParam("dku.llm.rag.logs.maxDevTailLines", Integer.valueOf(1000)) : 0;
        PythonLLMServer.streamKernelLogsToRotatingFile(this.kernel, logDir, "rag.log", maxRotated, maxSize, maxDevTailLines);
        this.kernel.start();
        logger.info((Object)("Sending start command to server: " + JSON.json((Object)command)));
        JavaBlockLink.AsyncJavaLink link = this.kernel.getLink().getAsyncLink();
        link.request((Object)command, JsonElement.class);
    }

    private LLMClient.SimpleCompletionResponse mapRAGResponse(LLMClient.SimpleCompletionResponseOrError simpleCompletionResponseOrError) throws Exception {
        if (simpleCompletionResponseOrError.ok) {
            return simpleCompletionResponseOrError;
        }
        if (GuardrailsCodes.ERR_LLM_RESPONSE_GUARDRAIL.toString().equalsIgnoreCase(simpleCompletionResponseOrError.errorCode)) {
            throw new GuardrailsPipelineRunner.LLMUsageEnforcerException(simpleCompletionResponseOrError.errorSource, GuardrailsCodes.valueOf(simpleCompletionResponseOrError.errorCode), "Guardrail rejected query: " + simpleCompletionResponseOrError.errorMessage);
        }
        throw new Exception("Failed processing RAG query: " + simpleCompletionResponseOrError.errorMessage);
    }

    public static class StartCommand {
        public final String type = "start";
        public LLMStructuredRef llmRef;
        public String knowledgeBankFullId;
        public RAGLLMSettings ragSettings;
        public LLMClient.CompletionSettings defaultCompletionSettings;
        public SimpleFilter filter;
        public final String noRetrievalKey = RAGLLMSettings.SearchInputStrategySettings.getNoRetrievalKey();
    }

    public static class RagQueryFilter {
        public SimpleFilter filter;
        public String toolRef;
    }
}

