/*
 * Decompiled with CFR 0.152.
 */
package com.customllm.llm;

import com.customllm.llm.CustomApiImplementation;
import com.customllm.llm.GptApiImplementation;
import com.dataiku.common.rpc.ExternalJSONAPIClient;
import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.custom.PluginSettingsResolver;
import com.dataiku.dip.llm.custom.CustomLLMClient;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.promptstudio.PromptStudio;
import com.dataiku.dip.llm.utils.OnlineLLMUtils;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JF;
import com.dataiku.dss.shadelib.org.apache.http.client.methods.HttpGet;
import com.dataiku.dss.shadelib.org.apache.http.client.methods.HttpPost;
import com.dataiku.dss.shadelib.org.apache.http.client.methods.HttpRequestBase;
import com.dataiku.dss.shadelib.org.apache.http.impl.client.HttpClientBuilder;
import com.google.gson.JsonElement;
import java.io.IOException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.List;

public class NvidiaNIMPlugin
extends CustomLLMClient {
    private String endpointUrl;
    private String model;
    private PluginSettingsResolver.ResolvedSettings rs;
    private String inputType;
    private ExternalJSONAPIClient client;
    private final ComputeResourceUsage.InternalLLMUsageData usageData = new ComputeResourceUsage.LLMUsageData();
    private final AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings = new AbstractLLMConnection.HTTPBasedLLMNetworkSettings();
    private int maxParallel = 1;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.customplugin.nvidia.nim");

    private String getAccessToken() {
        JsonElement apiKeyElement;
        JsonElement apiKeysElement = this.rs.config.get("apikeys");
        if (apiKeysElement != null && apiKeysElement.isJsonObject() && (apiKeyElement = apiKeysElement.getAsJsonObject().get("api_key")) != null && !apiKeyElement.isJsonNull()) {
            return "Bearer " + apiKeyElement.getAsString().trim();
        }
        return null;
    }

    public void init(PluginSettingsResolver.ResolvedSettings settings) {
        this.rs = settings;
        this.endpointUrl = settings.config.get("endpoint_url").getAsString();
        this.model = settings.config.get("model").getAsString();
        this.inputType = settings.config.get("inputType").getAsString();
        this.maxParallel = settings.config.get("maxParallelism").getAsNumber().intValue();
        this.networkSettings.queryTimeoutMS = settings.config.get("networkTimeout").getAsNumber().intValue();
        this.networkSettings.maxRetries = settings.config.get("maxRetries").getAsNumber().intValue();
        this.networkSettings.initialRetryDelayMS = settings.config.get("firstRetryDelay").getAsNumber().longValue();
        this.networkSettings.retryDelayScalingFactor = settings.config.get("retryDelayScale").getAsNumber().doubleValue();
        this.client = new ExternalJSONAPIClient(this.endpointUrl, null, true, ApplicationConfigurator.getProxySettings(), OnlineLLMUtils.getLLMResponseRetryStrategy((AbstractLLMConnection.HTTPBasedLLMNetworkSettings)this.networkSettings), builder -> OnlineLLMUtils.add429RetryStrategy((HttpClientBuilder)builder, (AbstractLLMConnection.HTTPBasedLLMNetworkSettings)this.networkSettings)){

            protected HttpGet newGet(String path) {
                HttpGet get = new HttpGet(path);
                this.setAdditionalHeadersInRequest((HttpRequestBase)get);
                get.addHeader("Content-Type", "application/json");
                get.addHeader("Authorization", NvidiaNIMPlugin.this.getAccessToken());
                return get;
            }

            protected HttpPost newPost(String path) {
                HttpPost post = new HttpPost(path);
                this.setAdditionalHeadersInRequest((HttpRequestBase)post);
                post.addHeader("Content-Type", "application/json");
                post.addHeader("Authorization", NvidiaNIMPlugin.this.getAccessToken());
                return post;
            }
        };
    }

    public int getMaxParallelism() {
        return this.maxParallel;
    }

    public synchronized List<LLMClient.SimpleCompletionResponse> completeBatch(List<LLMClient.CompletionQuery> completionQueries) throws IOException {
        ArrayList<LLMClient.SimpleCompletionResponse> ret = new ArrayList<LLMClient.SimpleCompletionResponse>();
        for (LLMClient.CompletionQuery query : completionQueries) {
            long before = System.currentTimeMillis();
            LLMClient.SimpleCompletionResponse scr = null;
            logger.info((Object)"Chat Complete.");
            scr = this.chatComplete(this.model, query.messages, query.settings.maxOutputTokens, query.settings.temperature, query.settings.topP, query.settings.topK, query.settings.stopSequences, query.settings.toolChoice, query.settings.tools);
            this.usageData.totalComputationTimeMS += System.currentTimeMillis() - before;
            this.usageData.totalPromptTokens += (long)scr.promptTokens.intValue();
            this.usageData.totalCompletionTokens += (long)scr.completionTokens.intValue();
            ret.add(scr);
        }
        return ret;
    }

    public List<LLMClient.SimpleEmbeddingResponse> embedBatch(List<LLMClient.EmbeddingQuery> queries, LLMClient.EmbeddingSettings settings) throws IOException {
        ArrayList<LLMClient.SimpleEmbeddingResponse> ret = new ArrayList<LLMClient.SimpleEmbeddingResponse>();
        for (LLMClient.EmbeddingQuery query : queries) {
            long before = System.currentTimeMillis();
            logger.info((Object)"Chat Embed.");
            LLMClient.SimpleEmbeddingResponse scr = this.embed(this.model, query.text, this.inputType);
            this.usageData.totalComputationTimeMS += System.currentTimeMillis() - before;
            this.usageData.totalPromptTokens += (long)scr.promptTokens.intValue();
            ret.add(scr);
        }
        return ret;
    }

    public ComputeResourceUsage getTotalCRU(ComputeResourceUsage.LLMUsageType usageType, PromptStudio.LLMStructuredRef llmRef) {
        ComputeResourceUsage cru = new ComputeResourceUsage();
        cru.setupLLMUsage(usageType, llmRef.connection, llmRef.type.toString(), llmRef.id);
        cru.llmUsage.setFromInternal(this.usageData);
        return cru;
    }

    private CustomApiImplementation GetCustomApiImplementation(String model) throws IOException {
        GptApiImplementation apiImplementation = new GptApiImplementation(this.endpointUrl, model);
        return apiImplementation;
    }

    public LLMClient.SimpleCompletionResponse chatComplete(String model, List<LLMClient.ChatMessage> messages, Integer maxTokens, Double temperature, Double topP, Integer topK, List<String> stopSequences, LLMClient.ToolChoice toolChoice, List<LLMClient.AbstractTool> tools) throws IOException {
        CustomApiImplementation apiImplementation = this.GetCustomApiImplementation(model);
        apiImplementation.addSettingsInObject(model, maxTokens, temperature, topP, topK, stopSequences, toolChoice, tools);
        apiImplementation.addMessagesInObject(messages);
        LLMClient.SimpleCompletionResponse ret = apiImplementation.sendPostObject(this.client, this.networkSettings);
        return ret;
    }

    public LLMClient.SimpleEmbeddingResponse embed(String model, String text, String inputType) throws IOException {
        if (inputType.isEmpty()) {
            inputType = "query";
        }
        JF.ObjectBuilder ob = JF.obj().with("input", text).with("model", model).with("input_type", inputType);
        logger.info((Object)"Raw embedding query");
        OpenAIEmbeddingResponse rcr = (OpenAIEmbeddingResponse)this.client.postObjectToJSON(this.endpointUrl, this.networkSettings.queryTimeoutMS, OpenAIEmbeddingResponse.class, (Object)ob.get());
        logger.info((Object)"Raw embedding response");
        if (rcr.data.size() != 1) {
            throw new IOException("Chat did not respond with valid embeddings");
        }
        LLMClient.SimpleEmbeddingResponse ret = new LLMClient.SimpleEmbeddingResponse();
        ret.embedding = rcr.data.get((int)0).embedding;
        ret.promptTokens = rcr.usage.total_tokens;
        return ret;
    }

    public boolean supportsStream() {
        return true;
    }

    public void streamComplete(LLMClient.CompletionQuery query, LLMClient.StreamedCompletionResponseConsumer consumer) throws Exception {
        this.streamChatComplete(consumer, this.model, query.messages, query.settings.maxOutputTokens, query.settings.temperature, query.settings.topP, query.settings.topK, query.settings.stopSequences, query.settings.toolChoice, query.settings.tools);
    }

    public void streamChatComplete(LLMClient.StreamedCompletionResponseConsumer consumer, String model, List<LLMClient.ChatMessage> messages, Integer maxTokens, Double temperature, Double topP, Integer topK, List<String> stopSequences, LLMClient.ToolChoice toolChoice, List<LLMClient.AbstractTool> tools) throws Exception {
        CustomApiImplementation apiImplementation = this.GetCustomApiImplementation(model);
        apiImplementation.addSettingsInObject(model, maxTokens, temperature, topP, topK, stopSequences, toolChoice, tools);
        apiImplementation.addMessagesInObject(messages);
        apiImplementation.streamChatComplete(this.client, consumer, this.networkSettings, toolChoice, tools);
    }

    private static String calculateMD5(String input) {
        try {
            MessageDigest md = MessageDigest.getInstance("MD5");
            md.update(input.getBytes());
            byte[] digest = md.digest();
            StringBuilder sb = new StringBuilder();
            for (byte b : digest) {
                sb.append(String.format("%02x", b));
            }
            return sb.toString();
        }
        catch (NoSuchAlgorithmException e) {
            e.printStackTrace();
            return null;
        }
    }

    public List<LLMClient.SimpleEmbeddingResponse> embedBatch(List<LLMClient.EmbeddingQuery> queries) throws IOException {
        return this.embedBatch(queries, null);
    }

    private static class OpenAIEmbeddingResponse {
        List<OpenAIEmbeddingResult> data = new ArrayList<OpenAIEmbeddingResult>();
        RawUsageResponse usage;

        private OpenAIEmbeddingResponse() {
        }
    }

    private static class OpenAIEmbeddingResult {
        double[] embedding;

        private OpenAIEmbeddingResult() {
        }
    }

    private static class RawUsageResponse {
        int total_tokens;
        int prompt_tokens;
        int completion_tokens;

        private RawUsageResponse() {
        }
    }

    private static class EmbeddingResult {
        double[] embedding;

        private EmbeddingResult() {
        }
    }

    private static class EmbeddingResponse {
        List<EmbeddingResult> data = new ArrayList<EmbeddingResult>();
        RawUsageResponse usage;

        private EmbeddingResponse() {
        }
    }

    private static class RawChatCompletionResponse {
        List<RawChatCompletionChoice> choices;
        RawUsageResponse usage;

        private RawChatCompletionResponse() {
        }
    }

    private static class RawChatCompletionChoice {
        RawChatCompletionMessage message;

        private RawChatCompletionChoice() {
        }
    }

    private static class RawChatCompletionMessage {
        String role;
        String content;

        private RawChatCompletionMessage() {
        }
    }
}

