/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.local;

import com.dataiku.dip.DKUApp;
import com.dataiku.dip.connections.HuggingFaceLocalConnection;
import com.dataiku.dip.io.KubernetesSimplePythonKernel;
import com.dataiku.dip.io.SimplePythonKernel;
import com.dataiku.dip.llm.LLMModelHandle;
import com.dataiku.dip.llm.io.PythonRequestUtils;
import com.dataiku.dip.llm.io.commands.ProcessSingleEmbeddingCommand;
import com.dataiku.dip.llm.io.commands.ProcessSingleImageGenerationCommand;
import com.dataiku.dip.llm.io.commands.ProcessSinglePromptCommand;
import com.dataiku.dip.llm.io.commands.ProcessSingleRerankingCommand;
import com.dataiku.dip.llm.local.HuggingFaceKernelBuilder;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.reports.IReflectedEventsService;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.util.SecretKeyGenerator;
import com.dataiku.dip.utils.DKUCompletableFuture;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.DKUtils;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.SmartLogTail;
import com.dataiku.dip.utils.polyjson.Mapping;
import com.dataiku.dip.utils.polyjson.PolyJSON;
import com.dataiku.dss.shadelib.org.apache.commons.codec.digest.DigestUtils;
import com.dataiku.dss.shadelib.org.apache.commons.lang3.exception.ExceptionUtils;
import com.dataiku.j2py.annotations.PyModel;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import javax.annotation.Nullable;

public class HuggingFaceKernelClient
implements AutoCloseable {
    static final String VLLM_VERSION = "0.14.1";
    private final HuggingFaceKernelBuilder kernelBuilder = new HuggingFaceKernelBuilder();
    private final HuggingFaceLocalConnection.HFLocalModel model;
    private final AuthCtx authCtx;
    private final String projectKey;
    private final KernelConfig kernelConfig;
    private final String kernelId;
    private final ExecutorService startThread = Executors.newSingleThreadExecutor();
    private final boolean forReservedCapacity;
    private SimplePythonKernel kernel;
    private DKUtils.SmartLogTailBuilder smartLogTailBuilder = new DKUtils.SmartLogTailBuilder();
    private String podName;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.huggingface.kernel.client");

    public HuggingFaceKernelClient(AuthCtx authCtx, String projectKey, LLMModelHandle<HuggingFaceLocalConnection.HFLocalModel> modelHandle, KernelConfig kernelConfig, boolean forReservedCapacity) {
        this.kernelId = "llm-hf-" + SecretKeyGenerator.generateSmall();
        this.authCtx = authCtx;
        this.model = modelHandle.getModel();
        this.projectKey = projectKey;
        this.kernelConfig = kernelConfig;
        this.forReservedCapacity = forReservedCapacity;
    }

    public void startKernel() throws Exception {
        try {
            this.startThread.submit(() -> {
                SimplePythonKernel tempLocalKernel = this.kernelBuilder.createKernel(this.authCtx, this.projectKey, this.kernelConfig, this.kernelId);
                this.smartLogTailBuilder = tempLocalKernel.getSmartLogTailBuilder();
                this.smartLogTailBuilder.setMaxLines(500);
                if (tempLocalKernel instanceof KubernetesSimplePythonKernel) {
                    KubernetesSimplePythonKernel kubeKernel = (KubernetesSimplePythonKernel)tempLocalKernel;
                    this.podName = kubeKernel.getPodName();
                }
                this.kernelBuilder.startKernel(tempLocalKernel, this.kernelConfig);
                this.kernel = tempLocalKernel;
                JsonObject trackingData = this.collectTrackingData();
                trackingData.addProperty("exec", this.kernel.getType());
                trackingData.addProperty("reserved_capacity", Boolean.valueOf(this.forReservedCapacity));
                IReflectedEventsService.ReflectedEvent event = new IReflectedEventsService.ReflectedEvent("hf-kernel-started", trackingData);
                ((IReflectedEventsService)SpringUtils.getBean(IReflectedEventsService.class)).publish(event);
                return null;
            }).get();
        }
        catch (InterruptedException | ExecutionException e) {
            for (String s : ExceptionUtils.getRootCauseStackTrace((Throwable)e)) {
                this.smartLogTailBuilder.appendLine(s);
            }
            this.close();
            throw new IOException("Failed to start HuggingFace LLM", e instanceof ExecutionException ? e.getCause() : e);
        }
    }

    @Override
    public void close() throws Exception {
        this.startThread.shutdownNow();
        if (!this.startThread.awaitTermination(DKUApp.getParams().getIntParam("dku.hflocalclient.startThreadTimeout", Integer.valueOf(120)), TimeUnit.SECONDS)) {
            logger.warn((Object)"Timeout while waiting for kernel starting thread to terminate, closing HF local client anyway but resources may have not yet been freed");
        }
        if (this.kernel != null) {
            try {
                this.kernel.close();
            }
            catch (Throwable e) {
                logger.error((Object)"Failed to kill kernel", e);
            }
        }
    }

    public boolean isAlive() {
        return this.kernel != null && this.kernel.isAlive();
    }

    public String getKernelId() {
        return this.kernelId;
    }

    public String getModelId() {
        return this.model.getId();
    }

    public String getPodName() {
        return this.podName;
    }

    public SmartLogTail getKernelLog() {
        return this.smartLogTailBuilder.get();
    }

    public CompletableFuture<Integer> asyncStreamComplete(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings, LLMClient.StreamedCompletionResponseConsumer consumer) {
        return PythonRequestUtils.asyncStreamRequest(this.kernel.getAsyncLink(), query, settings, consumer);
    }

    public JsonObject collectTrackingData() {
        CollectTrackingData command = new CollectTrackingData();
        return ((CollectTrackingDataResponse)this.kernel.getAsyncLink().request((Object)command, CollectTrackingDataResponse.class)).trackingData;
    }

    public String getUsedEngine() {
        GetUsedEngine command = new GetUsedEngine();
        return ((GetUsedEngineResponse)this.kernel.getAsyncLink().request((Object)command, GetUsedEngineResponse.class)).usedEngine;
    }

    public CompletableFuture<LLMClient.SimpleCompletionResponse> asyncComplete(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings) {
        return this.kernel.getAsyncLink().asyncSendRequest((Object)new ProcessSinglePromptCommand(query, settings, false), ProcessSinglePromptResponse.class, ProcessSinglePromptResponse::toSimpleCompletionResponse);
    }

    public List<LLMClient.SimpleCompletionResponse> complete(List<LLMClient.SingleCompletionQuery> queries, LLMClient.CompletionSettings settings) throws Exception {
        List futures = queries.stream().map(query -> this.asyncComplete((LLMClient.SingleCompletionQuery)query, settings)).collect(Collectors.toList());
        return DKUCompletableFuture.collectResponses(futures);
    }

    public CompletableFuture<LLMClient.SimpleEmbeddingResponse> asyncEmbed(LLMClient.EmbeddingQuery query, LLMClient.EmbeddingSettings settings) {
        ProcessSingleEmbeddingCommand command = new ProcessSingleEmbeddingCommand(query, settings);
        return this.kernel.getAsyncLink().asyncStreamRequest((Object)command, LLMClient.SimpleEmbeddingResponse.class).last().toFuture();
    }

    public List<LLMClient.SimpleEmbeddingResponse> embed(List<LLMClient.EmbeddingQuery> queries, LLMClient.EmbeddingSettings settings) throws Exception {
        List futures = queries.stream().map(query -> this.asyncEmbed((LLMClient.EmbeddingQuery)query, settings)).collect(Collectors.toList());
        return DKUCompletableFuture.collectResponses(futures);
    }

    public CompletableFuture<LLMClient.SingleRerankingResponse> asyncRerank(LLMClient.RerankingQuery query, LLMClient.RerankingSettings settings) {
        ProcessSingleRerankingCommand command = ProcessSingleRerankingCommand.fromRerankingQuery(query);
        return this.kernel.getAsyncLink().asyncSendRequest((Object)command, LLMClient.SingleRerankingResponseOrError.class, r -> r);
    }

    public CompletableFuture<LLMClient.ImageGenerationResponse> asyncGenerateImages(LLMClient.ImageGenerationQuery query) {
        ProcessSingleImageGenerationCommand command = new ProcessSingleImageGenerationCommand(query);
        return this.kernel.getAsyncLink().asyncSendRequest((Object)command, ImageGenerationKernelResponse.class, ImageGenerationKernelResponse::toPublicApiResponse);
    }

    public LLMClient.ImageGenerationResponse generateImages(LLMClient.ImageGenerationQuery query) throws Exception {
        return this.asyncGenerateImages(query).get();
    }

    public static class KernelConfig {
        public final String hfConnectionName;
        public final String codeEnvName;
        public final String containerConfName;
        public final String clusterId;
        @Nullable
        public final Map<String, String> extraEnv;
        public final StartCommand startCommand;

        public KernelConfig(String hfConnectionName, String codeEnvName, String containerConfName, String clusterId, StartCommand startCommand, @Nullable Map<String, String> extraEnv) {
            this.hfConnectionName = hfConnectionName;
            this.codeEnvName = codeEnvName;
            this.containerConfName = containerConfName;
            this.clusterId = clusterId;
            this.extraEnv = extraEnv;
            this.startCommand = startCommand;
        }

        public String toShortHash() {
            return DigestUtils.sha256Hex((String)JSON.json((Object)this)).substring(0, 10) + "/" + this.startCommand.hfModelName;
        }
    }

    @PyModel
    public static class CollectTrackingData
    extends HuggingFaceKernelCommand {
    }

    @PyModel
    public static class CollectTrackingDataResponse {
        public JsonObject trackingData;
    }

    @PyModel
    public static class GetUsedEngine
    extends HuggingFaceKernelCommand {
    }

    @PyModel
    public static class GetUsedEngineResponse {
        public String usedEngine;
    }

    @PyModel
    private static class ProcessSinglePromptResponse {
        @Nullable
        String text;
        @Nullable
        ZeroShotClassificationResponse classification;
        @Nullable
        UsageData usage;
        @Nullable
        List<LLMClient.DetailedLogProb> logProbs;
        @Nullable
        LLMClient.FinishReason finishReason;
        @Nullable
        List<LLMClient.Artifact> artifacts;
        @Nullable
        public List<LLMClient.AbstractToolCall> toolCalls;

        private ProcessSinglePromptResponse() {
        }

        private LLMClient.SimpleCompletionResponse toSimpleCompletionResponse() {
            LLMClient.SimpleCompletionResponse response = new LLMClient.SimpleCompletionResponse();
            if (this.usage != null) {
                response.promptTokens = this.usage.promptTokens;
                response.completionTokens = this.usage.completionTokens;
                response.totalTokens = this.usage.promptTokens + this.usage.completionTokens;
            }
            if (this.finishReason != null) {
                response.finishReason = this.finishReason;
            }
            if (this.text != null) {
                response.text = this.text;
                response.logProbs = this.logProbs;
            }
            if (this.classification != null) {
                if (this.classification.labels != null && !this.classification.labels.isEmpty()) {
                    response.predictedClass = this.classification.labels.get(0);
                    response.predictedClassProbas = new ArrayList<LLMClient.PredictedClassProba>();
                    for (int i = 0; i < this.classification.labels.size(); ++i) {
                        LLMClient.PredictedClassProba pcp = new LLMClient.PredictedClassProba();
                        pcp.className = this.classification.labels.get(i);
                        if (this.classification.scores != null && this.classification.scores.size() > i) {
                            pcp.proba = this.classification.scores.get(i);
                        }
                        response.predictedClassProbas.add(pcp);
                    }
                } else {
                    response.text = "Missing labels";
                }
            }
            if (this.toolCalls != null) {
                response.toolCalls = this.toolCalls;
            }
            if (this.artifacts != null) {
                response.artifacts = this.artifacts;
            }
            return response;
        }
    }

    @PyModel
    private static class ImageGenerationKernelResponse {
        List<String> images = new ArrayList<String>();

        private ImageGenerationKernelResponse() {
        }

        private LLMClient.ImageGenerationResponse toPublicApiResponse() {
            LLMClient.ImageGenerationResponse publicResponse = new LLMClient.ImageGenerationResponse();
            this.images.forEach(image -> publicResponse.images.add(new LLMClient.ImageGenerationImage((String)image)));
            return publicResponse;
        }
    }

    private static class UsageData {
        int promptTokens;
        int completionTokens;

        private UsageData() {
        }
    }

    private static class ZeroShotClassificationResponse {
        public List<String> labels;
        public List<Double> scores;

        private ZeroShotClassificationResponse() {
        }
    }

    @PyModel
    public static class StartCommand
    extends HuggingFaceKernelCommand {
        public final String vllmVersion = "0.14.1";
        @Nullable
        public String hfApiKey;
        public boolean useDSSModelCache;
        public ModelOrigin modelOrigin;
        public String hfModelName;
        public String hfModelPath;
        public String savedModelVersionPath;
        public String savedModelProjectKey;
        public String savedModelId;
        public String baseModelName;
        public String baseModelPath;
        public HuggingFaceLocalConnection.HuggingFaceHandlingMode hfHandlingMode;
        public HuggingFaceLocalConnection.InferenceSettings modelSettings;
        public boolean supportsImageInputs;
        public Integer batchSize;
        boolean fakeLLMServer;

        public StartCommand(HuggingFaceLocalConnection.HuggingFaceHandlingMode hfHandlingMode, HuggingFaceLocalConnection connection, HuggingFaceLocalConnection.InferenceSettings modelSettings, Integer batchSize, boolean supportsImageInputs, boolean fakeLLMServer) {
            this.hfHandlingMode = hfHandlingMode;
            this.hfApiKey = connection.params.apiKey;
            this.useDSSModelCache = connection.params.useDSSModelCache;
            this.modelSettings = modelSettings;
            this.batchSize = batchSize;
            this.supportsImageInputs = supportsImageInputs;
            this.fakeLLMServer = fakeLLMServer;
        }

        public void setModelPath(String modelPath, String baseModelPath) {
            this.hfModelPath = modelPath;
            this.baseModelPath = baseModelPath;
        }

        public void setSavedModelInfo(String projectKey, String id, String modelVersionPath, String baseModelName) {
            this.modelOrigin = ModelOrigin.SAVED_MODEL_VERSION;
            this.savedModelProjectKey = projectKey;
            this.savedModelId = id;
            this.savedModelVersionPath = modelVersionPath;
            this.baseModelName = baseModelName;
        }

        public void setHuggingFaceModelInfo(String modelName) {
            this.modelOrigin = ModelOrigin.HUGGINGFACE_MODEL;
            this.hfModelName = modelName;
        }

        private StartCommand() {
        }
    }

    @PolyJSON(value={@Mapping(value=StartCommand.class, type="start"), @Mapping(value=CollectTrackingData.class, type="collect-env"), @Mapping(value=GetUsedEngine.class, type="get-used-engine")})
    @PyModel
    public static abstract class HuggingFaceKernelCommand {
    }

    public static enum ModelOrigin {
        SAVED_MODEL_VERSION,
        HUGGINGFACE_MODEL;

    }
}

