/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.databricks;

import com.dataiku.common.rpc.ExternalJSONAPIClient;
import com.dataiku.dip.ProxySettings;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.externalinfras.databricks.DatabricksUtils;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettingsValidator;
import com.dataiku.dip.llm.online.marshall.FinishReasonResponseAdapter;
import com.dataiku.dip.llm.utils.OnlineLLMUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.google.gson.annotations.SerializedName;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.codec.binary.Base64;

public class RawDatabricksLLMClient {
    private static final String ENDPOINT_FORMAT = "serving-endpoints/%s/invocations";
    private final AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings;
    private final ExternalJSONAPIClient client;
    private static final CoreCompletionSettingsValidator chatCompletionValidator = new CoreCompletionSettingsValidator("Databricks (chat)").allowMaxTokens().allowTemperature().allowTopK().allowTopP().allowStopSequences();
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.databricks-llm.client");

    public RawDatabricksLLMClient(String baseURI, String authToken, AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings, ProxySettings proxySettings, boolean forceContentLength) {
        this.networkSettings = networkSettings;
        String authHeaderValue = "Basic " + Base64.encodeBase64String((byte[])("token:" + authToken).getBytes(StandardCharsets.UTF_8));
        this.client = OnlineLLMUtils.getExternalJSONClientWithBuilderCallback(baseURI, null, false, proxySettings, networkSettings, customBuilder -> {
            OnlineLLMUtils.add429RetryStrategy(customBuilder, networkSettings);
            customBuilder.addInterceptorFirst(DatabricksUtils.getRateLimitingHttpInterceptor());
        });
        this.client.addHeader("Authorization", authHeaderValue);
        if (forceContentLength) {
            this.client.forceContentLength = true;
        }
    }

    public void close() {
        this.client.close();
    }

    public LLMClient.SimpleCompletionResponse chatComplete(String model, List<LLMClient.ChatMessage> messages, CoreCompletionSettings ccs) throws IOException {
        chatCompletionValidator.validate(ccs);
        RawChatQuery rawChatQuery = new RawChatQuery();
        messages.forEach(m -> {
            RawChatCompletionMessage chatMessage = new RawChatCompletionMessage();
            chatMessage.role = m.role;
            chatMessage.content = m.getText();
            rawChatQuery.messages.add(chatMessage);
        });
        if (ccs.maxTokens != null) {
            rawChatQuery.max_tokens = ccs.maxTokens;
        }
        if (ccs.temperature != null) {
            rawChatQuery.temperature = ccs.temperature;
        }
        if (ccs.topK != null) {
            rawChatQuery.top_k = ccs.topK;
        }
        if (ccs.topP != null) {
            rawChatQuery.top_p = ccs.topP;
        }
        if (ccs.stopSequences != null) {
            rawChatQuery.stop.addAll(ccs.stopSequences);
        }
        String endpoint = String.format(ENDPOINT_FORMAT, model);
        logger.trace(() -> "Raw Databricks LLM chat completion query: " + JSON.json((Object)rawChatQuery));
        RawChatCompletionResponse rawChatResponse = (RawChatCompletionResponse)this.client.postObjectToJSON(endpoint, this.networkSettings.queryTimeoutMS, RawChatCompletionResponse.class, (Object)rawChatQuery);
        logger.trace(() -> "Raw Databricks LLM chat completion response: " + JSON.json((Object)rawChatResponse));
        if (rawChatResponse.choices == null || rawChatResponse.choices.size() == 0) {
            throw new IOException("Databricks LLM did not respond with valid chat completion");
        }
        RawChatCompletionChoice firstChoice = rawChatResponse.choices.get(0);
        LLMClient.SimpleCompletionResponse ret = new LLMClient.SimpleCompletionResponse();
        ret.text = firstChoice.message.content;
        ret.finishReason = FinishReasonResponseAdapter.adapt(firstChoice.finishReason);
        ret.promptTokens = rawChatResponse.usage.prompt_tokens;
        ret.completionTokens = rawChatResponse.usage.completion_tokens;
        return ret;
    }

    public LLMClient.SimpleEmbeddingResponse embed(String model, String text) throws IOException {
        String endpoint = String.format(ENDPOINT_FORMAT, model);
        RawEmbeddingQuery rawEmbedQuery = new RawEmbeddingQuery();
        rawEmbedQuery.input = text;
        logger.trace(() -> "Raw Databricks LLM embedding query: " + JSON.json((Object)rawEmbedQuery));
        RawEmbeddingResponse rawEmbedResponse = (RawEmbeddingResponse)this.client.postObjectToJSON(endpoint, this.networkSettings.queryTimeoutMS, RawEmbeddingResponse.class, (Object)rawEmbedQuery);
        logger.trace(() -> "Raw Databricks LLM embedding response: " + JSON.json((Object)rawEmbedResponse));
        if (rawEmbedResponse.data.size() != 1) {
            throw new IOException("Databricks did not respond with valid embeddings");
        }
        LLMClient.SimpleEmbeddingResponse ret = new LLMClient.SimpleEmbeddingResponse();
        ret.embedding = rawEmbedResponse.data.get((int)0).embedding;
        ret.promptTokens = rawEmbedResponse.usage.total_tokens;
        return ret;
    }

    private static class RawChatQuery {
        List<RawChatCompletionMessage> messages = new ArrayList<RawChatCompletionMessage>();
        Integer max_tokens;
        Double temperature;
        Integer top_k;
        Double top_p;
        List<String> stop = new ArrayList<String>();

        private RawChatQuery() {
        }
    }

    private static class RawChatCompletionResponse {
        List<RawChatCompletionChoice> choices;
        RawUsageResponse usage;

        private RawChatCompletionResponse() {
        }
    }

    private static class RawChatCompletionChoice {
        RawChatCompletionMessage message;
        @SerializedName(value="finish_reason")
        String finishReason;

        private RawChatCompletionChoice() {
        }
    }

    private static class RawChatCompletionMessage {
        String role;
        String content;

        private RawChatCompletionMessage() {
        }
    }

    private static class RawUsageResponse {
        int completion_tokens;
        int prompt_tokens;
        int total_tokens;

        private RawUsageResponse() {
        }
    }

    private static class RawEmbeddingQuery {
        String input;

        private RawEmbeddingQuery() {
        }
    }

    private static class RawEmbeddingResponse {
        List<RawEmbeddingObject> data = new ArrayList<RawEmbeddingObject>();
        RawUsageResponse usage;

        private RawEmbeddingResponse() {
        }
    }

    private static class RawEmbeddingObject {
        double[] embedding;

        private RawEmbeddingObject() {
        }
    }

    private static class RawCompletionChoice {
        String text;
        @SerializedName(value="finish_reason")
        String finishReason;

        private RawCompletionChoice() {
        }
    }
}

