/*
 * Decompiled with CFR 0.152.
 */
package com.titanml.llm;

import com.dataiku.common.rpc.ExternalJSONAPIClient;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.custom.PluginSettingsResolver;
import com.dataiku.dip.llm.custom.CustomLLMClient;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.promptstudio.PromptStudio;
import com.dataiku.dip.llm.utils.OnlineLLMUtils;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dss.shadelib.org.apache.http.client.RedirectStrategy;
import com.dataiku.dss.shadelib.org.apache.http.impl.client.HttpClientBuilder;
import com.dataiku.dss.shadelib.org.apache.http.impl.client.LaxRedirectStrategy;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParseException;
import com.google.gson.JsonParser;
import com.google.gson.JsonPrimitive;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Consumer;
import java.util.stream.Collectors;

public class TitanMLLLMConnector
extends CustomLLMClient {
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.titanml");
    PluginSettingsResolver.ResolvedSettings resolvedSettings;
    private String readerID;
    private String consumerGroup;
    private ExternalJSONAPIClient client;

    public void setHeaders(String key, String value) {
        this.client.addHeader(key, value);
    }

    public void init(PluginSettingsResolver.ResolvedSettings settings) {
        logger.info((Object)"Initializing TitanMLLLMConnector-----------------------------------");
        this.resolvedSettings = settings;
        String endpointUrl = this.resolvedSettings.config.get("endpoint_url").getAsString();
        Consumer<HttpClientBuilder> customizeBuilderCallback = builder -> {
            builder.setRedirectStrategy((RedirectStrategy)new LaxRedirectStrategy());
            AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings = new AbstractLLMConnection.HTTPBasedLLMNetworkSettings();
            OnlineLLMUtils.add429RetryStrategy((HttpClientBuilder)builder, (AbstractLLMConnection.HTTPBasedLLMNetworkSettings)networkSettings);
        };
        this.client = new ExternalJSONAPIClient(endpointUrl, null, true, null, customizeBuilderCallback);
        if (settings.config.get("consumer_group") != null) {
            this.consumerGroup = settings.config.get("consumer_group").getAsString();
        }
        if (this.consumerGroup == null || this.consumerGroup.isEmpty()) {
            logger.info((Object)"No consumer group was specified, defaulting to 'primary'");
            this.consumerGroup = "primary";
        }
    }

    public int getMaxParallelism() {
        return 1;
    }

    public synchronized List<LLMClient.SimpleCompletionResponse> completeBatch(List<LLMClient.CompletionQuery> completionQueries) throws IOException {
        ArrayList<LLMClient.SimpleCompletionResponse> ret = new ArrayList<LLMClient.SimpleCompletionResponse>();
        if (this.resolvedSettings.config.get("chatTemplate").getAsBoolean()) {
            try {
                JsonObject response = (JsonObject)this.client.getToJSON("status", JsonObject.class, new Object[0]);
                logger.info((Object)("Received JSON response: " + response));
                JsonObject liveReaders = response.getAsJsonObject("live_readers");
                block4: for (String key : liveReaders.keySet()) {
                    JsonObject reader = liveReaders.getAsJsonObject(key);
                    if (reader.has("consumer_group")) {
                        JsonPrimitive fetchedConsumerGroup = reader.getAsJsonPrimitive("consumer_group");
                        if (this.consumerGroup == null || !this.consumerGroup.equals(fetchedConsumerGroup.getAsString())) continue;
                        this.readerID = key;
                        break;
                    }
                    if (reader.has("consumer_groups")) {
                        JsonArray fetchedConsumerGroups = reader.getAsJsonArray("consumer_groups");
                        for (JsonElement element : fetchedConsumerGroups) {
                            if (!element.isJsonPrimitive() || !element.getAsJsonPrimitive().isString() || !this.consumerGroup.equals(element.getAsString())) continue;
                            this.readerID = key;
                            continue block4;
                        }
                        continue;
                    }
                    logger.error((Object)"Neither consumer_group nor consumer_groups field found in reader status object.");
                }
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
            logger.info((Object)("Found readerID for template: " + this.readerID));
        }
        for (LLMClient.CompletionQuery completionQuery : completionQueries) {
            JsonObject jsonObject = this.getGenerationJsonObject(completionQuery);
            logger.info((Object)("Sending JSON object for processing: " + jsonObject.toString()));
            try {
                JsonObject response = (JsonObject)this.client.postObjectToJSON("generate", JsonObject.class, (Object)jsonObject);
                logger.info((Object)("Logging JSON response: {}" + response));
                String generations = response.get("text").getAsString();
                LLMClient.SimpleCompletionResponse queryResult = new LLMClient.SimpleCompletionResponse();
                queryResult.text = generations;
                ret.add(queryResult);
            }
            catch (Exception e) {
                throw new IOException("Error during communication with API", e);
            }
        }
        return ret;
    }

    private JsonObject getGenerationJsonObject(LLMClient.CompletionQuery completionQuery) {
        JsonElement regexEl;
        String completePrompt;
        String json_input = completionQuery.messages.stream().map(LLMClient.ChatMessage::getTextEvenIfNotTextOnly).collect(Collectors.joining("\n\n"));
        if (this.resolvedSettings.config.get("chatTemplate").getAsBoolean()) {
            JsonArray inputsArray = new JsonArray();
            try {
                JsonElement xElement = JsonParser.parseString((String)json_input);
                for (JsonElement element : xElement.getAsJsonArray()) {
                    if (!element.isJsonNull()) continue;
                    logger.error((Object)"Trailing comma detected in list of messages. This will prevent JSON from parsing.");
                    throw new JsonParseException("Trailing comma");
                }
                inputsArray.add(xElement);
            }
            catch (JsonParseException e) {
                logger.error((Object)"Invalid JSON was input for messages");
                throw new RuntimeException(e);
            }
            JsonObject templatePayload = new JsonObject();
            templatePayload.add("inputs", (JsonElement)inputsArray);
            try {
                logger.info((Object)("Template payload: " + templatePayload));
                JsonObject response = (JsonObject)this.client.postObjectToJSON("chat_template/" + this.readerID, JsonObject.class, (Object)templatePayload);
                logger.info((Object)("Logging Prompt template response: {}" + response));
                completePrompt = response.getAsJsonObject().get("messages").getAsJsonArray().get(0).getAsString();
                logger.info((Object)("TEMPLATED PROMPT:" + completePrompt));
            }
            catch (IOException e) {
                logger.error((Object)"Chat template endpoint failed");
                throw new RuntimeException(e);
            }
        }
        completePrompt = json_input;
        logger.info((Object)("Prompt constructed: " + completePrompt));
        JsonObject jsonObject = new JsonObject();
        JsonArray prompts = new JsonArray(1);
        prompts.add(completePrompt);
        jsonObject.add("text", (JsonElement)prompts);
        String consumer_group = null;
        if (this.resolvedSettings.config.get("consumer_group") != null) {
            consumer_group = this.resolvedSettings.config.get("consumer_group").getAsString();
        }
        if (consumer_group == null || consumer_group.isEmpty()) {
            logger.info((Object)"No consumer group was specified, defaulting to 'primary'");
            consumer_group = "primary";
        }
        jsonObject.add("consumer_group", (JsonElement)new JsonPrimitive(consumer_group));
        JsonElement jsonSchemaEl = this.resolvedSettings.config.get("jsonSchema");
        if (jsonSchemaEl != null && !jsonSchemaEl.isJsonNull() && !jsonSchemaEl.getAsString().isEmpty()) {
            jsonObject.add("json_schema", JsonParser.parseString((String)jsonSchemaEl.getAsString()));
        }
        if ((regexEl = this.resolvedSettings.config.get("regexScheme")) != null && !regexEl.isJsonNull() && !regexEl.getAsString().isEmpty()) {
            jsonObject.add("regex_string", regexEl);
        }
        if (completionQuery.settings.temperature != null) {
            JsonPrimitive temperature = new JsonPrimitive((Number)completionQuery.settings.temperature);
            jsonObject.add("sampling_temperature", (JsonElement)temperature);
        }
        if (completionQuery.settings.topP != null) {
            JsonPrimitive topP = new JsonPrimitive((Number)completionQuery.settings.topP);
            jsonObject.add("sampling_topp", (JsonElement)topP);
        }
        if (completionQuery.settings.topK != null) {
            JsonPrimitive topK = new JsonPrimitive((Number)completionQuery.settings.topK);
            jsonObject.add("sampling_topk", (JsonElement)topK);
        }
        if (completionQuery.settings.maxOutputTokens != null) {
            JsonPrimitive maxNewTokens = new JsonPrimitive((Number)completionQuery.settings.maxOutputTokens);
            jsonObject.add("max_new_tokens", (JsonElement)maxNewTokens);
        }
        return jsonObject;
    }

    public List<LLMClient.SimpleEmbeddingResponse> embedBatch(List<LLMClient.EmbeddingQuery> queries) throws IOException {
        ArrayList<LLMClient.SimpleEmbeddingResponse> ret = new ArrayList<LLMClient.SimpleEmbeddingResponse>();
        for (LLMClient.EmbeddingQuery embeddingQuery : queries) {
            JsonObject jsonObject = this.getEmbeddingsJsonObject(embeddingQuery);
            JsonObject response = (JsonObject)this.client.postObjectToJSON("embed", JsonObject.class, (Object)jsonObject);
            logger.info((Object)("Logging JSON response: {}" + response));
            JsonArray result = response.get("result").getAsJsonArray();
            JsonArray vector = result.get(0).getAsJsonArray();
            LLMClient.SimpleEmbeddingResponse queryResult = new LLMClient.SimpleEmbeddingResponse();
            queryResult.embedding = TitanMLLLMConnector.convertJsonArrayToDoubleArray(vector);
            ret.add(queryResult);
        }
        return ret;
    }

    private static double[] convertJsonArrayToDoubleArray(JsonArray jsonArray) {
        double[] result = new double[jsonArray.size()];
        int i = 0;
        for (JsonElement element : jsonArray) {
            result[i++] = element.getAsDouble();
        }
        return result;
    }

    private JsonObject getEmbeddingsJsonObject(LLMClient.EmbeddingQuery embeddingQuery) {
        String embeddingsText = embeddingQuery.text;
        JsonObject jsonObject = new JsonObject();
        JsonArray prompts = new JsonArray(1);
        JsonPrimitive prompt = new JsonPrimitive(embeddingsText);
        prompts.add((JsonElement)prompt);
        jsonObject.add("text", (JsonElement)prompts);
        String consumer_group = null;
        if (this.resolvedSettings.config.get("consumer_group") != null) {
            consumer_group = this.resolvedSettings.config.get("consumer_group").getAsString();
        }
        if (consumer_group == null || consumer_group.isEmpty()) {
            logger.info((Object)"No consumer group was specified, defaulting to 'primary'");
            consumer_group = "primary";
        }
        jsonObject.add("consumer_group", (JsonElement)new JsonPrimitive(consumer_group));
        return jsonObject;
    }

    public ComputeResourceUsage getTotalCRU(ComputeResourceUsage.LLMUsageType usageType, PromptStudio.LLMStructuredRef llmRef) {
        return null;
    }
}

