/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.langchain;

import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.DKUApp;
import com.dataiku.dip.connections.AbstractSQLConnection;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.dao.SavedModelsDAO;
import com.dataiku.dip.dataflow.jobrunner.JobContext;
import com.dataiku.dip.kernel.DSSKernelUtils;
import com.dataiku.dip.kernel.KernelPool;
import com.dataiku.dip.llm.langchain.DevKernelDesc;
import com.dataiku.dip.llm.langchain.PythonLLMServer;
import com.dataiku.dip.llm.langchain.PythonLLMServerAPI;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.resourceusage.ComputeResourceUsageContext;
import com.dataiku.dip.security.DSSAuthCtx;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.ifaces.IsolationLevel;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.utils.DKUCompletableFuture;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.Params;
import com.dataiku.dip.utils.SmartLogTail;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.google.gson.JsonObject;
import java.io.File;
import java.io.IOException;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import javax.annotation.Nonnull;
import org.apache.commons.codec.digest.DigestUtils;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class PythonLLMServerKernelPool {
    @Autowired
    private SavedModelsDAO savedModelsDAO;
    @Autowired
    private TransactionService transactionService;
    private final KernelPool<PythonLLMServer, KernelDesc, KernelDesc> manager;
    private final ExecutorService executorService = Executors.newCachedThreadPool(new ThreadFactoryBuilder().setNameFormat("python-llm-poolmgr-%d").build());
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.python.pool");

    public void invalidateKernels(String projectKey, String savedModelId) {
        this.manager.clearKernels(kd -> projectKey.equals(kd.projectKey) && savedModelId.equals(kd.savedModelId), KernelPool.DeathReason.OUTDATED);
    }

    public boolean stopDevKernel(DSSAuthCtx authCtx, String projectKey, String savedModelId, String savedModelVersionId) {
        KernelDesc kernelDesc = new KernelDesc();
        kernelDesc.authCtx = authCtx;
        kernelDesc.projectKey = projectKey;
        kernelDesc.savedModelId = savedModelId;
        kernelDesc.savedModelVersionId = savedModelVersionId;
        kernelDesc.isDevKernel = true;
        Optional kernelID = this.manager.getKernelIdFiltered(kd -> kd.isSameDevKernel(kernelDesc));
        if (kernelID.isEmpty()) {
            return false;
        }
        return this.manager.forceStopKernel((String)kernelID.get(), KernelPool.DeathReason.USER_REQUEST);
    }

    public PythonLLMServerAPI getServerAPI(DSSAuthCtx authCtx, String projectKey, String savedModelId, SavedModel.SavedModelInlineVersion savedModelVersion, String pyClazz, String code, JsonObject config, JsonObject pluginConfig, String envName, String containerConfName, String pluginId, String libFolder, boolean loadPythonLibs, String clusterId, SavedModel.AgentSettings settings, boolean devKernel) {
        final KernelDesc kernelDesc = new KernelDesc();
        kernelDesc.authCtx = authCtx;
        kernelDesc.projectKey = projectKey;
        kernelDesc.savedModelId = savedModelId;
        kernelDesc.isDevKernel = devKernel;
        kernelDesc.cruContext = ComputeResourceUsageContext.forPythonTool((String)projectKey);
        kernelDesc.savedModelVersionId = savedModelVersion.versionId;
        kernelDesc.savedModelVersionVersionNumber = savedModelVersion.versionTag != null ? savedModelVersion.versionTag.versionNumber : 0L;
        kernelDesc.pyClazz = pyClazz;
        kernelDesc.code = code;
        kernelDesc.config = config;
        kernelDesc.pluginConfig = pluginConfig;
        kernelDesc.envName = envName;
        kernelDesc.containerConfName = containerConfName;
        kernelDesc.pluginId = pluginId;
        kernelDesc.libFolder = libFolder;
        kernelDesc.loadPythonLibs = loadPythonLibs;
        kernelDesc.clusterId = clusterId;
        kernelDesc.settings = settings;
        kernelDesc.poolKey = authCtx.getIdentifier() + "-" + projectKey + "-" + savedModelId + "-" + savedModelVersion.versionId + "-" + DigestUtils.sha1Hex((String)StringUtils.join((Object[])new String[]{pyClazz, code, JSON.json((Object)config), JSON.json((Object)pluginConfig), envName, containerConfName, pluginId, libFolder, clusterId, JSON.json((Object)settings), "" + devKernel}, (String)"__DKU__"));
        return new PythonLLMServerAPI(){

            @Override
            public CompletableFuture<LLMClient.SimpleCompletionResponse> processAsync(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings) {
                return PythonLLMServerKernelPool.this.manager.handle(kernel -> kernel.processAsync(query, settings), (Object)kernelDesc, kernelDesc.poolKey, (Object)query);
            }

            @Override
            public CompletableFuture<Integer> streamProcessAsync(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings, LLMClient.StreamedCompletionResponseConsumer consumer) {
                return PythonLLMServerKernelPool.this.manager.handle(kernel -> kernel.streamProcessAsync(query, settings, consumer), (Object)kernelDesc, kernelDesc.poolKey, (Object)query);
            }

            @Override
            public SmartLogTail getKernelLog() {
                try {
                    if (!kernelDesc.isDevKernel) {
                        throw new IllegalArgumentException("Production kernels can't return logs");
                    }
                    String kernelID = (String)PythonLLMServerKernelPool.this.manager.getKernelIdFiltered(kd -> kd.isSameDevKernel(kernelDesc)).orElseThrow(() -> new IllegalArgumentException("Dev kernel not found: " + kernelDesc.getDevKernelKey()));
                    return (SmartLogTail)PythonLLMServerKernelPool.this.manager.getKernelLogs(kernelID).orElseThrow(() -> new IllegalArgumentException("Logs not found for dev kernel: " + kernelDesc.getDevKernelKey()));
                }
                catch (Exception e) {
                    SmartLogTail fakeSLT = new SmartLogTail();
                    fakeSLT.appendLine(e.getMessage());
                    return fakeSLT;
                }
            }

            @Override
            public void close() throws IOException {
            }
        };
    }

    public KernelPool.PoolDump dump(boolean full) {
        return this.manager.dump(full);
    }

    public PythonLLMServerKernelPool() {
        this.manager = new KernelPool((KernelPool.KernelController)new KernelPool.KernelController<PythonLLMServer, KernelDesc, KernelDesc>(){

            @Nonnull
            public PythonLLMServer createKernel(KernelDesc kernelDesc) {
                File logBaseDir = DKUApp.getFile((String[])new String[]{"saved_models", kernelDesc.projectKey, kernelDesc.savedModelId, "versions", kernelDesc.savedModelVersionId, "logs"});
                return new PythonLLMServer(kernelDesc.authCtx, kernelDesc.projectKey, kernelDesc.savedModelId, kernelDesc.savedModelVersionId, kernelDesc.pyClazz, kernelDesc.code, kernelDesc.envName, kernelDesc.containerConfName, kernelDesc.pluginId, kernelDesc.libFolder, logBaseDir, kernelDesc.config, kernelDesc.pluginConfig, false, kernelDesc.isDevKernel, kernelDesc.loadPythonLibs);
            }

            @Nonnull
            public CompletableFuture<Void> startKernel(PythonLLMServer kernel, KernelDesc kernelDesc) {
                return DKUCompletableFuture.runAsync(() -> {
                    DSSKernelUtils.setKernelContext(kernelDesc.cruContext, kernelDesc.jobContext, logger);
                    kernel.start();
                }, (Executor)PythonLLMServerKernelPool.this.executorService);
            }

            public Long getQueuedRequestTimeoutInNs() {
                return ApplicationConfigurator.getParams().getLongParam("dku.llm.python.queuedRequestTimeoutInS", 1800L) * 1000L * 1000L * 1000L;
            }

            public int getGlobalMaxKernelCount() {
                return ApplicationConfigurator.getParams().getIntParam("dku.llm.python.maxKernels", Integer.valueOf(50));
            }

            public Integer getMaxKernelCount(KernelDesc kernelDesc) {
                Params localParams = AbstractSQLConnection.CustomDatabaseProperty.toParams(kernelDesc.settings.dkuProperties);
                return DSSKernelUtils.getMaxKernelCount(localParams, "dku.llm.python.maxKernelsPerAgent", "dku.llm.python.maxKernelProportionPerAgent", kernelDesc.settings.singleInstance, this.getGlobalMaxKernelCount());
            }

            public int getAutoscaleTimeWindowSeconds(KernelDesc kernelDesc) {
                return ApplicationConfigurator.getParams().getIntParam("dku.llm.python.autoscaleWindowSeconds", Integer.valueOf(600));
            }

            public int getHardMaxParallelRequests(KernelDesc kernelDesc) {
                Params localParams = AbstractSQLConnection.CustomDatabaseProperty.toParams(kernelDesc.settings.dkuProperties);
                return DSSKernelUtils.getIntParamWithFallback(localParams, "dku.llm.python.hardMaxRequestsPerKernel", kernelDesc.settings.maxParallelRequestsPerProcess);
            }

            public int getSoftMaxParallelRequests(KernelDesc kernelDesc) {
                Params localParams = AbstractSQLConnection.CustomDatabaseProperty.toParams(kernelDesc.settings.dkuProperties);
                return DSSKernelUtils.getIntParamWithFallback(localParams, "dku.llm.python.softMaxRequestsPerKernel", kernelDesc.settings.maxParallelRequestsPerProcess);
            }

            @Nonnull
            public CompletableFuture<Void> killKernel(PythonLLMServer kernel) {
                return CompletableFuture.runAsync(() -> {
                    try {
                        kernel.close();
                    }
                    catch (Exception e) {
                        logger.error((Object)"Error while closing kernel", (Throwable)e);
                    }
                }, PythonLLMServerKernelPool.this.executorService);
            }

            public boolean isAlive(PythonLLMServer kernel) {
                return kernel.isAlive();
            }

            public String getKernelId(PythonLLMServer kernel) {
                return kernel.getKernelId();
            }

            public SmartLogTail getKernelLog(PythonLLMServer kernel) {
                return kernel.getKernelLog();
            }

            /*
             * Enabled aggressive block sorting
             * Enabled unnecessary exception pruning
             * Enabled aggressive exception aggregation
             */
            public boolean isOutdated(KernelDesc kernelDesc) {
                if (kernelDesc.savedModelVersionId == null) {
                    return false;
                }
                try (Transaction t = PythonLLMServerKernelPool.this.transactionService.retrieveOrBeginRead(IsolationLevel.YOLO);){
                    SavedModel sm = (SavedModel)PythonLLMServerKernelPool.this.savedModelsDAO.getOrNullUnsafe(kernelDesc.projectKey, kernelDesc.savedModelId);
                    if (sm == null) {
                        boolean bl2 = true;
                        return bl2;
                    }
                    Optional<SavedModel.SavedModelInlineVersion> smiv = sm.getVersion(kernelDesc.savedModelVersionId);
                    boolean bl = smiv.isEmpty() || smiv.get().versionTag != null && smiv.get().versionTag.versionNumber != kernelDesc.savedModelVersionVersionNumber;
                    return bl;
                }
                catch (IOException e) {
                    logger.error((Object)String.format("Failed to check if agent %s.%s version %s is outdated", kernelDesc.projectKey, kernelDesc.savedModelId, kernelDesc.savedModelVersionId), (Throwable)e);
                    return false;
                }
            }
        }, "python-llm", logger);
    }

    public void killAllRequests() {
        this.manager.killAllRequests();
    }

    public void killAllKernels(KernelPool.DeathReason reason) {
        this.manager.killAllKernels(reason);
    }

    static class KernelDesc
    extends DevKernelDesc {
        DSSAuthCtx authCtx;
        String projectKey;
        String savedModelId;
        String savedModelVersionId;
        long savedModelVersionVersionNumber;
        String pyClazz;
        String code;
        JsonObject config;
        JsonObject pluginConfig;
        String envName;
        String containerConfName;
        String pluginId;
        String libFolder;
        boolean loadPythonLibs;
        String clusterId;
        String poolKey;
        SavedModel.AgentSettings settings;
        ComputeResourceUsageContext cruContext;
        JobContext jobContext;

        KernelDesc() {
        }

        @Override
        public String getDevKernelKey() {
            assert (this.isDevKernel);
            return this.authCtx.getIdentifier() + "-" + this.projectKey + "-" + this.savedModelId + "-" + this.savedModelVersionId;
        }
    }
}

