/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.local;

import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.llm.LLMSMMgmtService;
import com.dataiku.dip.analysis.ml.llm.SavedLLMModelHandle;
import com.dataiku.dip.analysis.model.llm.LLMModelSnippetData;
import com.dataiku.dip.cluster.Cluster;
import com.dataiku.dip.cluster.ClusterSelector;
import com.dataiku.dip.code.CodeEnvModel;
import com.dataiku.dip.code.CodeEnvResolutionService;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.ConnectionsDAO;
import com.dataiku.dip.connections.HuggingFaceLocalConnection;
import com.dataiku.dip.containers.exec.ContainerExecSelection;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.dao.SavedModelsDAO;
import com.dataiku.dip.dao.UsersDAO;
import com.dataiku.dip.dataflow.jobrunner.JobContext;
import com.dataiku.dip.kernel.DSSKernelUtils;
import com.dataiku.dip.kernel.KernelPool;
import com.dataiku.dip.kernel.KernelPoolThreadFactory;
import com.dataiku.dip.kernel.KernelScalingStrategyBuilder;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMModelHandle;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.local.HuggingFaceKernelBuilder;
import com.dataiku.dip.llm.local.HuggingFaceKernelClient;
import com.dataiku.dip.llm.local.HuggingFaceLocalClient;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.logging.MainLoggingConfigurator;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.resourceusage.ComputeResourceUsageContext;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.DSSAuthCtx;
import com.dataiku.dip.security.model.GlobalScopePublicAPIKey;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.api.auth.PublicAPIKeysService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.ifaces.IsolationLevel;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.utils.DKUCompletableFuture;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.ExceptionUtils;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.Pair;
import com.dataiku.dip.utils.SmartLogTail;
import com.dataiku.dss.shadelib.com.google.common.cache.Cache;
import com.dataiku.dss.shadelib.com.google.common.cache.CacheBuilder;
import com.dataiku.dss.shadelib.com.google.common.util.concurrent.ThreadFactoryBuilder;
import com.google.common.base.Stopwatch;
import java.io.IOException;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang3.tuple.ImmutableTriple;
import org.apache.log4j.NDC;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class HuggingFaceKernelPool {
    @Autowired
    private CodeEnvResolutionService codeEnvResolutionService;
    @Autowired
    private TransactionService transactionService;
    private final KernelPool<HuggingFaceLocalClient, KernelGroup, KernelDesc> manager;
    private final ExecutorService executorService = Executors.newCachedThreadPool(new ThreadFactoryBuilder().setNameFormat("hf-kernel-pool-startstop-%d").build());
    private final HuggingFaceKernelBuilder kernelBuilder = new HuggingFaceKernelBuilder();
    private final int HARD_MAX_PARALLEL_REQUESTS = 256;
    private final int SOFT_MAX_PARALLEL_REQUESTS = 64;
    private final Cache<ImmutableTriple<String, String, String>, Boolean> logCache = CacheBuilder.newBuilder().expireAfterWrite(Duration.ofSeconds(60L)).build();
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.hfpool");

    public LLMClient getClient(AuthCtx authCtx, final HuggingFaceLocalConnection connection, String projectKey, final LLMModelHandle<HuggingFaceLocalConnection.HFLocalModel> modelHandle) throws IOException {
        final KernelDesc kernelDesc = this.getKernelDesc(authCtx, connection, projectKey, modelHandle);
        final ComputeResourceUsage.InternalLLMUsageData usageData = new ComputeResourceUsage.InternalLLMUsageData();
        return new LLMClient(){
            private final DKUCompletableFuture.FutureCancellationTracker futureCancellationTracker = new DKUCompletableFuture.FutureCancellationTracker();

            @Override
            public void close() {
                this.futureCancellationTracker.cancelAll("LLM client closed");
            }

            @Override
            public boolean supportNativeBatch() {
                return false;
            }

            @Override
            public boolean requiresCostLimiting() {
                return false;
            }

            @Override
            public String getProviderId() {
                return "HuggingFaceLocal";
            }

            @Override
            public HuggingFaceLocalConnection getConnection() {
                return connection;
            }

            @Override
            public int getBatchSize(AbstractLLMConnection.QueryType queryType, LLMStructuredRef llmRef) {
                return 1;
            }

            @Override
            public int getMaxParallelism() {
                return ((HuggingFaceLocalConnection.HFLocalModel)modelHandle.getModel()).getParallelism().orElse(64);
            }

            CompletableFuture<LLMClient.SimpleCompletionResponse> complete(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings) {
                return this.futureCancellationTracker.track(() -> HuggingFaceKernelPool.this.manager.handle(kernel -> kernel.asyncComplete(query, settings), (Object)kernelDesc, kernelDesc.descHash(), (Object)kernelDesc, kernelDesc.groupHash(), (Object)query));
            }

            CompletableFuture<LLMClient.SimpleEmbeddingResponse> embed(LLMClient.EmbeddingQuery query, LLMClient.EmbeddingSettings settings) {
                return this.futureCancellationTracker.track(() -> HuggingFaceKernelPool.this.manager.handle(kernel -> kernel.asyncEmbed(query, settings), (Object)kernelDesc, kernelDesc.descHash(), (Object)kernelDesc, kernelDesc.groupHash(), (Object)query));
            }

            @Override
            public void streamComplete(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings, LLMClient.StreamedCompletionResponseConsumer consumer) throws Exception {
                Stopwatch sw = Stopwatch.createStarted();
                LLMClient.StreamedCompletionResponseConsumerProxy wrappedConsumer = new LLMClient.StreamedCompletionResponseConsumerProxy(consumer, (ExceptionUtils.ThrowingConsumer<LLMClient.StreamedCompletionResponseFooter, Exception>)((ExceptionUtils.ThrowingConsumer)footer -> {
                    footer.estimatedCost = ((HuggingFaceLocalConnection.HFLocalModel)modelHandle.getModel()).getEstimatedCompletionCost(footer.promptTokens, footer.completionTokens);
                    footer.includeInUsageData(usageData, sw.elapsed(TimeUnit.MILLISECONDS));
                }));
                CompletableFuture future = this.futureCancellationTracker.track(() -> HuggingFaceKernelPool.this.manager.handle(kernel -> kernel.asyncStreamComplete(query, settings, wrappedConsumer), (Object)kernelDesc, kernelDesc.descHash(), (Object)kernelDesc, kernelDesc.groupHash(), (Object)query));
                DKUCompletableFuture.collectResponse((CompletableFuture)future);
            }

            @Override
            public boolean supportsStream() {
                return true;
            }

            @Override
            public List<LLMClient.SimpleCompletionResponse> completeBatch(List<LLMClient.SingleCompletionQuery> queries, LLMClient.CompletionSettings settings) {
                List futures = queries.stream().map(q -> {
                    Stopwatch sw = Stopwatch.createStarted();
                    CompletableFuture<LLMClient.SimpleCompletionResponse> future = this.complete((LLMClient.SingleCompletionQuery)q, settings);
                    future.thenAccept(response -> {
                        response.estimatedCost = ((HuggingFaceLocalConnection.HFLocalModel)modelHandle.getModel()).getEstimatedCompletionCost(response.promptTokens, response.completionTokens);
                        response.includeInUsageData(usageData, sw.elapsed(TimeUnit.MILLISECONDS));
                    });
                    return future;
                }).collect(Collectors.toList());
                return DKUCompletableFuture.collectResponsesNoException(futures);
            }

            @Override
            public List<LLMClient.SimpleEmbeddingResponse> embedBatch(List<LLMClient.EmbeddingQuery> queries, LLMClient.EmbeddingSettings settings) {
                List futures = queries.stream().map(q -> {
                    Stopwatch sw = Stopwatch.createStarted();
                    CompletableFuture<LLMClient.SimpleEmbeddingResponse> future = this.embed((LLMClient.EmbeddingQuery)q, settings);
                    future.thenAccept(response -> {
                        response.estimatedCost = ((HuggingFaceLocalConnection.HFLocalModel)modelHandle.getModel()).getEstimatedEmbeddingCost(response.promptTokens, q.hasImage() ? 1 : 0);
                        usageData.incrementTotalComputationTimeMS(Long.valueOf(sw.elapsed(TimeUnit.MILLISECONDS)));
                        response.includeInUsageData(usageData);
                    });
                    return future;
                }).collect(Collectors.toList());
                return DKUCompletableFuture.collectResponsesNoException(futures);
            }

            @Override
            public LLMClient.ImageGenerationResponse generateImages(LLMClient.ImageGenerationQuery query) throws Exception {
                Stopwatch sw = Stopwatch.createStarted();
                CompletableFuture future = this.futureCancellationTracker.track(() -> HuggingFaceKernelPool.this.manager.handle(kernel -> kernel.asyncGenerateImages(query), (Object)kernelDesc, kernelDesc.descHash(), (Object)kernelDesc, kernelDesc.groupHash(), (Object)query));
                LLMClient.ImageGenerationResponse response = (LLMClient.ImageGenerationResponse)DKUCompletableFuture.collectResponse((CompletableFuture)future);
                response.estimatedCost = ((HuggingFaceLocalConnection.HFLocalModel)modelHandle.getModel()).getEstimatedImageGenerationCost(query);
                usageData.incrementTotalComputationTimeMS(Long.valueOf(sw.elapsed(TimeUnit.MILLISECONDS)));
                usageData.incrementEstimatedCostUSD(Double.valueOf(response.estimatedCost));
                return response;
            }

            @Override
            public ComputeResourceUsage getTotalCRU(ComputeResourceUsage.LLMUsageType usageType, LLMStructuredRef llmRef) {
                ComputeResourceUsage cru = new ComputeResourceUsage();
                cru.setupLLMUsage(usageType, llmRef.connection, llmRef.type.toString(), llmRef.id);
                cru.llmUsage.setFromInternal(usageData);
                return cru;
            }

            @Override
            public EnrichedLLMStructuredRef getEnrichedRef() {
                return kernelDesc.modelHandle.getEnrichedRef();
            }

            @Override
            public List<LLMClient.ChatMessage> getFormattedPrompt(List<LLMClient.ChatMessage> chatMessages) {
                return HuggingFaceLocalClient.getFormattedPrompt(chatMessages, kernelDesc.modelHandle.getModel().handlingMode);
            }
        };
    }

    private static ContainerExecSelection getContainerExecSelection(ContainerExecSelection modelContainerExecSelection, ContainerExecSelection connectionContainerExecSelection) {
        if (modelContainerExecSelection.containerMode == ContainerExecSelection.ContainerExecMode.INHERIT) {
            if (connectionContainerExecSelection.containerMode == ContainerExecSelection.ContainerExecMode.INHERIT) {
                return new ContainerExecSelection(ContainerExecSelection.ContainerExecMode.NONE);
            }
            return connectionContainerExecSelection;
        }
        return modelContainerExecSelection;
    }

    private boolean useFakeLLMServer() {
        boolean fakeLLMServerEnv = "1".equals(System.getenv("FAKE_LLM_SERVER"));
        boolean fakeLLMServerProp = ApplicationConfigurator.getParams().getBoolParam("dku.llm.hf.fakeLLMServer", false);
        return fakeLLMServerEnv || fakeLLMServerProp;
    }

    private HuggingFaceKernelClient.KernelConfig getKernelConfig(HuggingFaceLocalConnection connection, LLMModelHandle<HuggingFaceLocalConnection.HFLocalModel> modelHandle) throws IOException {
        String containerConfName = null;
        ContainerExecSelection containerExecSelection = HuggingFaceKernelPool.getContainerExecSelection(modelHandle.getModel().containerSelection, connection.params.containerSelection);
        if (containerExecSelection.containerMode == ContainerExecSelection.ContainerExecMode.EXPLICIT_CONTAINER) {
            containerConfName = containerExecSelection.containerConf;
        }
        HashMap<String, String> extraEnv = new HashMap<String, String>();
        if (!StringUtils.isBlank((String)modelHandle.getModel().cudaVisibleDevices)) {
            extraEnv.put("CUDA_VISIBLE_DEVICES", modelHandle.getModel().cudaVisibleDevices);
        }
        String clusterId = connection.params.clusterId == null ? new ClusterSelector().getBuiltinOrDefaultClusterId(Cluster.ClusterArchitecture.KUBERNETES) : connection.params.clusterId;
        String envName = connection.params.getCodeEnvName();
        this.codeEnvResolutionService.checkEnvExists(CodeEnvModel.EnvLang.PYTHON, envName);
        HuggingFaceLocalConnection.HFLocalModel model = modelHandle.getModel();
        HuggingFaceKernelClient.StartCommand sc = new HuggingFaceKernelClient.StartCommand(modelHandle.getModel().handlingMode, connection, modelHandle.getModel().inferenceSettings, HuggingFaceLocalClient.getBatchSize(modelHandle.getModel(), AbstractLLMConnection.QueryType.fromModel(model)), modelHandle.getModel().getModelCapabilities().supportsImageInputs, this.useFakeLLMServer());
        if (modelHandle.getModel().isDKUFineTuned) {
            FullModelId fmi = FullModelId.parse(modelHandle.getModel().getId());
            sc.setSavedModelInfo(fmi.getSavedModelProjectKey(), fmi.getSavedModelID(), fmi.getModelFolder().getAbsolutePath(), modelHandle.getModel().baseModelId);
        } else {
            sc.setHuggingFaceModelInfo(modelHandle.getModel().huggingFaceId);
        }
        return new HuggingFaceKernelClient.KernelConfig(connection.name, envName, containerConfName, clusterId, sc, extraEnv);
    }

    private KernelDesc getKernelDesc(AuthCtx authCtx, HuggingFaceLocalConnection connection, String projectKey, LLMModelHandle<HuggingFaceLocalConnection.HFLocalModel> modelHandle) throws IOException {
        HuggingFaceKernelClient.KernelConfig kc = this.getKernelConfig(connection, modelHandle);
        String llmId = modelHandle.getRef().encodeToId();
        ComputeResourceUsageContext cruContext = ComputeResourceUsageContext.forHFModel((String)projectKey, (String)llmId);
        return new KernelDesc(authCtx, projectKey, connection, modelHandle, kc, kc.toShortHash(), cruContext, JobContext.getCurrentJobContext());
    }

    public SmartLogTail getKernelLogs(String kernelId) {
        return this.manager.getKernelLogs(kernelId).orElseGet(() -> {
            logger.warn((Object)("Requested logs for kernel ID '" + kernelId + "', but it doesn't exist."));
            return new SmartLogTail();
        });
    }

    public HuggingFaceKernelPool() {
        this.manager = new KernelPool((KernelPool.KernelController)new KernelPool.KernelController<HuggingFaceLocalClient, KernelGroup, KernelDesc>(){

            @Nonnull
            public HuggingFaceLocalClient createKernel(KernelDesc kernelDesc) {
                AuthCtx authCtx = kernelDesc.authCtx;
                if (kernelDesc.forReservedCapacity) {
                    authCtx = HuggingFaceKernelPool.this.getAuthCtxForReservedCapacityKernels(kernelDesc.kernelConfig.clusterId);
                }
                return new HuggingFaceLocalClient(authCtx, kernelDesc.projectKey, kernelDesc.connection, kernelDesc.modelHandle, HuggingFaceKernelPool.this.kernelBuilder, kernelDesc.kernelConfig, kernelDesc.forReservedCapacity);
            }

            @Nonnull
            public CompletableFuture<Void> startKernel(HuggingFaceLocalClient kernel, KernelDesc kernelDesc) {
                return DKUCompletableFuture.runAsync(() -> {
                    NDC.push((String)("start-hf-kernel: " + kernel.getKernelId()));
                    try {
                        DSSKernelUtils.setKernelContext(kernelDesc.cruContext, kernelDesc.jobContext, logger);
                        kernel.startKernel();
                    }
                    finally {
                        NDC.pop();
                    }
                }, (Executor)HuggingFaceKernelPool.this.executorService);
            }

            public int getGlobalMaxKernelCount() {
                return ApplicationConfigurator.getGeneralSettingsUnsafeAutoTXN().generativeAISettings.huggingFaceLocalSettings.maxConcurrentKernels;
            }

            public int getAutoscaleTimeWindowSeconds() {
                return ApplicationConfigurator.getGeneralSettingsUnsafeAutoTXN().generativeAISettings.huggingFaceLocalSettings.kernelIdleTTLSeconds;
            }

            public int getHardMaxParallelRequests(KernelDesc kernelDesc) {
                return ApplicationConfigurator.getParams().getIntParam("dku.llm.hf.hardMaxRequestsPerKernel", Integer.valueOf(256));
            }

            public int getSoftMaxParallelRequests(KernelDesc kernelDesc) {
                return ApplicationConfigurator.getParams().getIntParam("dku.llm.hf.softMaxRequestsPerKernel", Integer.valueOf(64));
            }

            @Nonnull
            public CompletableFuture<Void> killKernel(HuggingFaceLocalClient kernel) {
                return CompletableFuture.runAsync(() -> {
                    NDC.push((String)("stop-hf-kernel: " + kernel.getKernelId()));
                    try {
                        kernel.close();
                    }
                    catch (Exception e) {
                        logger.error((Object)"Error while closing kernel", (Throwable)e);
                    }
                    finally {
                        NDC.pop();
                    }
                }, HuggingFaceKernelPool.this.executorService);
            }

            public boolean isAlive(HuggingFaceLocalClient kernel) {
                return kernel.isAlive();
            }

            public String getKernelId(HuggingFaceLocalClient kernel) {
                return kernel.getKernelId();
            }

            public SmartLogTail getKernelLog(HuggingFaceLocalClient kernel) {
                return kernel.getKernelLog();
            }

            public String getPodName(HuggingFaceLocalClient kernel) {
                return kernel.getPodName();
            }

            public String getModelId(HuggingFaceLocalClient kernel) {
                return kernel.getModelId();
            }

            @Nonnull
            public Map<String, Pair<KernelPool.KernelSpec<KernelGroup, KernelDesc>, Integer>> getKernelsWithMinCount() {
                if (ApplicationConfigurator.getProcessType() != MainLoggingConfigurator.ProcessType.BACKEND) {
                    return Collections.emptyMap();
                }
                HashMap<String, Pair<KernelPool.KernelSpec<KernelGroup, KernelDesc>, Integer>> kernelsWithMinCount = new HashMap<String, Pair<KernelPool.KernelSpec<KernelGroup, KernelDesc>, Integer>>();
                try {
                    ConnectionsDAO.get().listUnsafe().forEach((s, dssConnection) -> {
                        if (dssConnection instanceof HuggingFaceLocalConnection && ((HuggingFaceLocalConnection)dssConnection).params.enableReserveCapacity) {
                            HuggingFaceLocalConnection hfConn = (HuggingFaceLocalConnection)JSON.deepCopy((Object)dssConnection);
                            hfConn.ensureDecrypted();
                            HashSet alreadyListed = new HashSet();
                            hfConn.listRawCustomModels().forEach(customHFModel -> {
                                block7: {
                                    if (alreadyListed.contains(customHFModel.id)) {
                                        logger.infoV("%s overridden by another custom model in connection %s", new Object[]{customHFModel.id, hfConn.name});
                                    } else {
                                        alreadyListed.add(customHFModel.id);
                                        try {
                                            HuggingFaceLocalConnection.HFLocalModel loadedModel;
                                            Integer min = customHFModel.minKernelCount;
                                            if (min == null || min <= 0) break block7;
                                            try {
                                                loadedModel = hfConn.loadRawCustomModel((HuggingFaceLocalConnection.CustomHFLocalModel)customHFModel);
                                            }
                                            catch (Exception e) {
                                                logger.warn((Object)"Ignoring custom model due to underlying error during load", (Throwable)e);
                                                return;
                                            }
                                            if (loadedModel.isValid() && loadedModel.isEnabled()) {
                                                HuggingFaceLocalConnection huggingFaceLocalConnection = hfConn;
                                                Objects.requireNonNull(huggingFaceLocalConnection);
                                                AbstractLLMConnection.ConnectionModelHandle modelHandle = new AbstractLLMConnection.ConnectionModelHandle((AbstractLLMConnection)huggingFaceLocalConnection, (AbstractLLMConnection.BaseModel)loadedModel);
                                                KernelDesc desc = HuggingFaceKernelPool.this.getKernelDesc(null, hfConn, "__DKU_ANY_PROJECT__", modelHandle);
                                                desc.forReservedCapacity = true;
                                                KernelPool.KernelSpec kernelSpec = new KernelPool.KernelSpec((Object)desc, desc.descHash(), (Object)desc, desc.groupHash());
                                                kernelsWithMinCount.put(desc.descHash(), new Pair((Object)kernelSpec, (Object)min));
                                            }
                                        }
                                        catch (Exception e) {
                                            HuggingFaceKernelPool.this.dedupError(String.format("Unable to collect reserved capacity for model %s in HF connection %s", customHFModel.id, hfConn.name), e);
                                        }
                                    }
                                }
                            });
                        }
                    });
                }
                catch (Exception e) {
                    HuggingFaceKernelPool.this.dedupError("Unable to collect reserved capacities for HF models, proceeding without them.", e);
                }
                return kernelsWithMinCount;
            }

            public Integer getMaxKernelCount(KernelGroup kernelGroup) {
                if (LLMStructuredRef.LLMType.SAVED_MODEL_FINETUNED_HUGGINGFACE_TRANSFORMER.equals((Object)kernelGroup.modelRef.type)) {
                    return null;
                }
                try {
                    HuggingFaceLocalConnection hfConn = ConnectionsDAO.get().getMandatoryConnectionAs(kernelGroup.authCtx, kernelGroup.modelRef.connection, HuggingFaceLocalConnection.class);
                    for (HuggingFaceLocalConnection.CustomHFLocalModel customModel : hfConn.listRawCustomModels()) {
                        if (!Objects.equals(customModel.id, kernelGroup.modelRef.model)) continue;
                        return customModel.maxKernelCount;
                    }
                    return null;
                }
                catch (Exception e) {
                    HuggingFaceKernelPool.this.dedupError(String.format("Unable to collect max capacity for model %s in HF connection %s", kernelGroup.modelRef.model, kernelGroup.modelRef.connection), e);
                    return null;
                }
            }

            public boolean isOutdated(KernelDesc kernelDesc) {
                try {
                    KernelDesc updatedKernelDesc = HuggingFaceKernelPool.this.buildKernelDesc(kernelDesc);
                    return updatedKernelDesc == null || !updatedKernelDesc.descHash().equals(kernelDesc.descHash());
                }
                catch (Exception e) {
                    HuggingFaceKernelPool.this.dedupError(String.format("[%s] Failed to assess outdatedness of kernel", kernelDesc.descHash()), e);
                    return false;
                }
            }
        }, new KernelScalingStrategyBuilder(), new KernelPoolThreadFactory("hf"), logger);
    }

    private DSSAuthCtx getAuthCtxForReservedCapacityKernels(String clusterId) {
        if (ApplicationConfigurator.getProcessType() != MainLoggingConfigurator.ProcessType.BACKEND) {
            throw new IllegalStateException("Reserved capacity is only allowed in the backend");
        }
        PublicAPIKeysService publicAPIKeysService = (PublicAPIKeysService)SpringUtils.getBean(PublicAPIKeysService.class);
        Cluster.PermissionItem permissionItem = new Cluster.PermissionItem();
        permissionItem.use = true;
        GlobalScopePublicAPIKey apiKey = publicAPIKeysService.getInMemoryAPIKey("hf-kernel-pool", "hf-kernel-pool", "", UsersDAO.GroupPermissions.baseGroupPermissionsForUnion(), Map.of(clusterId, permissionItem));
        try {
            return DSSAuthCtx.forAPIKey(apiKey);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Nullable
    private KernelDesc buildKernelDesc(KernelGroup kernelGroup) throws Exception {
        AbstractLLMConnection.ConnectionModelHandle hfModelHandle;
        if (LLMStructuredRef.LLMType.SAVED_MODEL_FINETUNED_HUGGINGFACE_TRANSFORMER.equals((Object)kernelGroup.modelRef.type)) {
            HuggingFaceLocalConnection.HFLocalModel originalModel;
            SavedModel sm;
            try (Transaction t = this.transactionService.retrieveOrBeginRead(IsolationLevel.YOLO);){
                AnyLoc loc = AnyLoc.resolveSmart(kernelGroup.projectKey, kernelGroup.modelRef.savedModelSmartId);
                sm = (SavedModel)((SavedModelsDAO)SpringUtils.getBean(SavedModelsDAO.class)).getOrNull(loc);
            }
            if (sm == null) {
                return null;
            }
            LLMSMMgmtService.LLMSMVersionHeader vh = LLMSMMgmtService.getStatus_NT((SavedModel)sm).versions.stream().filter(v -> v.versionId.equals(kernelGroup.modelRef.savedModelVersionId)).findFirst().orElse(null);
            if (vh == null) {
                return null;
            }
            HuggingFaceLocalConnection hfConn = ConnectionsDAO.get().getMandatoryConnectionAs(kernelGroup.authCtx, kernelGroup.modelRef.connection, HuggingFaceLocalConnection.class);
            hfConn.ensureDecrypted();
            try {
                originalModel = hfConn.getLLMModelFromSMInfo(((LLMModelSnippetData)vh.snippet).llmSMInfo);
            }
            catch (IllegalArgumentException e) {
                return null;
            }
            SavedLLMModelHandle fineTunedModelHandle = new SavedLLMModelHandle(kernelGroup.modelRef, (LLMModelSnippetData)vh.snippet, originalModel);
            if (!((HuggingFaceLocalConnection.HFLocalModel)fineTunedModelHandle.getModel()).isValid() || !((HuggingFaceLocalConnection.HFLocalModel)fineTunedModelHandle.getModel()).isEnabled()) {
                return null;
            }
            return this.getKernelDesc(kernelGroup.authCtx, hfConn, kernelGroup.projectKey, fineTunedModelHandle);
        }
        HuggingFaceLocalConnection hfConn = ConnectionsDAO.get().getMandatoryConnectionAs(kernelGroup.authCtx, kernelGroup.modelRef.connection, HuggingFaceLocalConnection.class);
        hfConn.ensureDecrypted();
        try {
            hfModelHandle = hfConn.getLLMModel(kernelGroup.modelRef);
        }
        catch (IllegalArgumentException e) {
            return null;
        }
        if (!((HuggingFaceLocalConnection.HFLocalModel)hfModelHandle.getModel()).isValid() || !((HuggingFaceLocalConnection.HFLocalModel)hfModelHandle.getModel()).isEnabled()) {
            return null;
        }
        return this.getKernelDesc(kernelGroup.authCtx, hfConn, kernelGroup.projectKey, hfModelHandle);
    }

    public void killAllRequests() {
        this.manager.killAllRequests();
    }

    public void killAllKernels(KernelPool.DeathReason reason) {
        this.manager.killAllKernels(reason);
    }

    public void killKernel(String kernelId, KernelPool.DeathReason reason) {
        this.manager.killKernel(kernelId, reason);
    }

    public KernelPool.PoolDump dump(boolean full) {
        return this.manager.dump(full);
    }

    public HFStatusDump dumpConnection(String connectionName) {
        KernelPool.PoolDump dump = this.manager.dump(true);
        return HFStatusDump.dumpFrom(dump, connectionName);
    }

    private void dedupError(String message, @Nullable Throwable t) {
        ImmutableTriple key = t == null ? new ImmutableTriple((Object)message, null, null) : new ImmutableTriple((Object)message, (Object)t.getClass().getName(), (Object)t.getMessage());
        try {
            this.logCache.get((Object)key, () -> {
                logger.error((Object)message, t);
                return true;
            });
        }
        catch (ExecutionException e) {
            throw new RuntimeException(e);
        }
    }

    static class KernelDesc
    extends KernelGroup {
        final HuggingFaceLocalConnection connection;
        final LLMModelHandle<HuggingFaceLocalConnection.HFLocalModel> modelHandle;
        final HuggingFaceKernelClient.KernelConfig kernelConfig;
        final String kernelConfigHash;
        final ComputeResourceUsageContext cruContext;
        final JobContext jobContext;

        KernelDesc(AuthCtx authCtx, String projectKey, HuggingFaceLocalConnection connection, LLMModelHandle<HuggingFaceLocalConnection.HFLocalModel> modelHandle, HuggingFaceKernelClient.KernelConfig kernelConfig, String kernelConfigHash, ComputeResourceUsageContext cruContext, JobContext jobContext) {
            super(authCtx, projectKey, modelHandle.getRef());
            this.connection = connection;
            this.modelHandle = modelHandle;
            this.kernelConfig = kernelConfig;
            this.kernelConfigHash = kernelConfigHash;
            this.cruContext = cruContext;
            this.jobContext = jobContext;
        }

        private String descHash() {
            return this.groupHash() + "/" + this.kernelConfigHash;
        }
    }

    static class KernelGroup {
        @Nullable
        final AuthCtx authCtx;
        final String projectKey;
        final LLMStructuredRef modelRef;
        protected boolean forReservedCapacity = false;

        KernelGroup(AuthCtx authCtx, String projectKey, LLMStructuredRef modelRef) {
            this.authCtx = authCtx;
            this.projectKey = projectKey;
            this.modelRef = modelRef;
        }

        public String groupHash() {
            return this.modelRef.id;
        }
    }

    public static class HFStatusDump {
        public List<KernelDump> kernels = new ArrayList<KernelDump>();
        public int graveyardTimeoutInS;

        public static HFStatusDump dumpFrom(KernelPool.PoolDump dump, String connectionName) {
            HFStatusDump newDump = new HFStatusDump();
            newDump.graveyardTimeoutInS = KernelPool.getGraveyardTimeout();
            for (KernelPool.PoolDump.KernelDump kernel : dump.kernels) {
                if (!connectionName.equals(((KernelDesc)kernel.kernelDesc).kernelConfig.hfConnectionName)) continue;
                KernelDump newKernelDump = new KernelDump();
                newKernelDump.id = kernel.id;
                newKernelDump.podName = kernel.podName;
                newKernelDump.modelId = kernel.modelId;
                newKernelDump.group = kernel.group;
                newKernelDump.state = kernel.state;
                newKernelDump.startingAtTime = kernel.startingAtTime;
                newKernelDump.readyAtTime = kernel.readyAtTime;
                newKernelDump.sentencedAtTime = kernel.sentencedAtTime;
                newKernelDump.diedAtTime = kernel.diedAtTime;
                newKernelDump.nbActiveRequests = kernel.nbActiveRequests;
                newKernelDump.deathReason = kernel.deathReason;
                newKernelDump.deathError = kernel.deathError;
                newDump.kernels.add(newKernelDump);
            }
            return newDump;
        }

        public static class KernelDump {
            public String id;
            public String podName;
            public String modelId;
            public String group;
            public KernelPool.KernelState state;
            public Instant startingAtTime;
            public Instant readyAtTime;
            public Instant sentencedAtTime;
            public Instant diedAtTime;
            public long nbActiveRequests;
            public KernelPool.DeathReason deathReason;
            public String deathError;
        }
    }
}

