/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.snowflakecortex;

import com.dataiku.common.rpc.ExternalJSONAPIClient;
import com.dataiku.dip.DKUApp;
import com.dataiku.dip.connections.ConnectionUtils;
import com.dataiku.dip.connections.SnowflakeConnection;
import com.dataiku.dip.connections.SnowflakeCortexLLMConnection;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.exceptions.UnauthorizedException;
import com.dataiku.dip.llm.LLMModelHandle;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettingsValidator;
import com.dataiku.dip.llm.online.snowflakecortex.AbstractRawSnowflakeCortexLLMClient;
import com.dataiku.dip.llm.online.snowflakecortex.SnowflakeCompletionRESTQuery;
import com.dataiku.dip.llm.online.snowflakecortex.SnowflakeCompletionRESTQueryAdapter;
import com.dataiku.dip.llm.online.snowflakecortex.SnowflakeCompletionRESTResponseAdapter;
import com.dataiku.dip.llm.online.snowflakecortex.SnowflakeCortexRESTChunkResponseAdapter;
import com.dataiku.dip.llm.online.snowflakecortex.SnowflakeCortexRESTCompletionNonStreamedResponse;
import com.dataiku.dip.llm.online.snowflakecortex.SnowflakeCortexRESTStreamedCompletionChunkResponse;
import com.dataiku.dip.llm.utils.OnlineLLMUtils;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.streaming.endpoints.httpsse.SSEDecoder;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.google.gson.annotations.SerializedName;
import java.io.IOException;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang.StringUtils;

public class RawSnowflakeCortexLLMRESTClient
extends AbstractRawSnowflakeCortexLLMClient {
    final ExternalJSONAPIClient client;
    final int queryTimeoutMs;
    static final CoreCompletionSettingsValidator completionValidator = new CoreCompletionSettingsValidator("Snowflake Cortex (REST API)").allowMaxTokens().allowTemperature().allowTopP().allowTools();
    public static final String BASE_PATH = "/api/v2/cortex";
    public static final String COMPLETE_PATH = "/inference:complete";
    public static final String EMBED_PATH = "/inference:embed";
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.snowflakecortex-llm.client");

    public RawSnowflakeCortexLLMRESTClient(AuthCtx authCtx, SnowflakeCortexLLMConnection connection, SnowflakeConnection sfConnection) throws UnauthorizedException {
        boolean trustAllSSLCertificates = ConnectionUtils.getParamsFromProperties(sfConnection.getDkuProperties()).getBoolParam("dku.connection.llm.trustAllSSLCertificates", false);
        String endpointBase = sfConnection.params.host + BASE_PATH;
        if (!endpointBase.startsWith("http://") && !endpointBase.startsWith("https://")) {
            endpointBase = "https://" + endpointBase;
        }
        this.client = new ExternalJSONAPIClient(endpointBase, null, trustAllSSLCertificates, connection.getProxySettings(), OnlineLLMUtils.getLLMResponseRetryStrategy(connection.params.networkSettings), builder -> OnlineLLMUtils.add429RetryStrategy(builder, connection.params.networkSettings));
        String token = sfConnection.getAuthToken(authCtx);
        this.client.addHeader("User-Agent", DKUApp.getDataikuApplicationString());
        this.client.addHeader("Authorization", "Bearer " + token);
        this.client.addHeader("Content-Type", "application/json");
        this.client.addHeader("Accept", "application/json, text/event-stream");
        this.client.addHeader("X-Snowflake-Authorization-Token-Type", switch (sfConnection.params.authType) {
            case SnowflakeConnection.AuthType.OAUTH2_APP -> "OAUTH";
            case SnowflakeConnection.AuthType.KEY_PAIR -> "KEYPAIR_JWT";
            default -> throw new RuntimeException("Unreachable");
        });
        this.client.forceContentLength = true;
        this.queryTimeoutMs = connection.params.queryTimeoutMs;
    }

    @Override
    public LLMClient.SimpleCompletionResponse chatComplete(LLMModelHandle.Model model, List<LLMClient.ChatMessage> messages, CoreCompletionSettings ccs) throws IOException, SQLException, DKUSecurityException, InterruptedException {
        completionValidator.validate(ccs);
        SnowflakeCompletionRESTQuery query = SnowflakeCompletionRESTQueryAdapter.adapt(model.getId(), messages, ccs);
        logger.trace(() -> String.format("Snowflake Cortex REST raw completion non streamed query: %s", JSON.prettyLog((Object)query)));
        try {
            SnowflakeCortexRESTCompletionNonStreamedResponse response = (SnowflakeCortexRESTCompletionNonStreamedResponse)this.client.postObjectToJSON(COMPLETE_PATH, this.queryTimeoutMs, SnowflakeCortexRESTCompletionNonStreamedResponse.class, (Object)query);
            logger.trace(() -> String.format("Snowflake Cortex REST raw completion non streamed response: %s", JSON.prettyLog((Object)response)));
            return SnowflakeCompletionRESTResponseAdapter.adapt(response);
        }
        catch (Exception e) {
            logger.errorV((Throwable)e, "Error while executing Snowflake Cortex REST raw completion non streamed query: %s", new Object[]{JSON.prettyLog((Object)query)});
            throw e;
        }
    }

    @Override
    public void streamChatComplete(LLMClient.StreamedCompletionResponseConsumer consumer, LLMModelHandle.Model model, List<LLMClient.ChatMessage> messages, CoreCompletionSettings ccs) throws Exception {
        completionValidator.validate(ccs);
        SnowflakeCompletionRESTQuery query = SnowflakeCompletionRESTQueryAdapter.adaptForStreaming(model.getId(), messages, ccs);
        logger.trace(() -> String.format("Snowflake Cortex REST raw completion streaming query: %s", JSON.prettyLog((Object)query)));
        LLMClient.FinishReason finishReason = null;
        try {
            ExternalJSONAPIClient.EntityAndRequest ear = this.client.postJSONToStreamAndRequest(COMPLETE_PATH, this.queryTimeoutMs, (Object)query);
            SSEDecoder decoder = new SSEDecoder(ear.entity.getContent());
            consumer.onStreamStarted();
            SnowflakeCortexRESTStreamedCompletionChunkResponse.Usage usage = null;
            LLMClient.StreamedCompletionResponseChunk chunk = new LLMClient.StreamedCompletionResponseChunk();
            while (true) {
                SSEDecoder.HTTPSSEEvent event = decoder.next();
                logger.trace(() -> "Received raw event from Snowflake Cortex: " + JSON.log((Object)event));
                if (event == null || event.data == null) {
                    logger.info((Object)"End of Snowflake Cortex stream");
                    break;
                }
                SnowflakeCortexRESTStreamedCompletionChunkResponse response = (SnowflakeCortexRESTStreamedCompletionChunkResponse)JSON.parse((String)event.data, SnowflakeCortexRESTStreamedCompletionChunkResponse.class);
                logger.trace(() -> String.format("Snowflake Cortex streamed chat completion response chunk: %s", JSON.prettyLog((Object)response)));
                usage = response.usage;
                if (response.choices.isEmpty()) continue;
                if (response.choices.size() > 1) {
                    logger.warn((Object)"Received unexpected response with multiple choices. Only the first one will be processed.");
                }
                if (response.choices.size() == 1 && response.choices.get((int)0).delta.contentList.isEmpty() && StringUtils.isEmpty((String)response.choices.get((int)0).delta.content) && !StringUtils.equals((String)response.choices.get((int)0).delta.type, (String)"tool_use")) {
                    if (!chunk.isEmpty()) {
                        consumer.onStreamChunk(chunk);
                        chunk = null;
                    }
                    logger.info((Object)"End of Snowflake Cortex stream");
                    break;
                }
                SnowflakeCortexRESTChunkResponseAdapter.ReadyAndCurrentChunks state = SnowflakeCortexRESTChunkResponseAdapter.adapt(response, chunk);
                chunk = state.currentChunk;
                if (finishReason == null && state.finishReason != null) {
                    finishReason = state.finishReason;
                }
                if (state.readyChunk == null) continue;
                consumer.onStreamChunk(state.readyChunk);
            }
            if (chunk != null && !chunk.isEmpty()) {
                consumer.onStreamChunk(chunk);
            }
            LLMClient.StreamedCompletionResponseFooter footer = new LLMClient.StreamedCompletionResponseFooter();
            if (usage != null) {
                footer.completionTokens = usage.completionTokens;
                footer.promptTokens = usage.promptTokens;
                footer.totalTokens = usage.totalTokens;
                footer.estimatedCost = model.getEstimatedCompletionCost(usage.promptTokens, usage.completionTokens);
                footer.finishReason = finishReason;
            }
            consumer.onStreamComplete(footer);
        }
        catch (Exception e) {
            logger.errorV((Throwable)e, "Error while executing Snowflake Cortex REST raw completion streamed query: %s", new Object[]{JSON.prettyLog((Object)query)});
            throw e;
        }
    }

    @Override
    public List<LLMClient.SimpleEmbeddingResponse> embed(String model, Integer embeddingSize, List<String> batchTexts) throws IOException, SQLException, DKUSecurityException, InterruptedException {
        SnowflakeCortexEmbeddingsRESTQuery query = new SnowflakeCortexEmbeddingsRESTQuery();
        query.model = model;
        query.text = batchTexts;
        logger.trace(() -> String.format("Snowflake Cortex REST embedding query: %s", JSON.prettyLog((Object)query)));
        SnowflakeCortexRESTEmbeddingsResponse response = (SnowflakeCortexRESTEmbeddingsResponse)this.client.postObjectToJSON(EMBED_PATH, this.queryTimeoutMs, SnowflakeCortexRESTEmbeddingsResponse.class, (Object)query);
        if (CollectionUtils.isEmpty(response.embeddings)) {
            throw new IOException("Snowflake Cortex did not respond with valid embeddings");
        }
        logger.trace(() -> String.format("Snowflake Cortex REST raw embedding response: %s", JSON.prettyLog((Object)response)));
        ArrayList<LLMClient.SimpleEmbeddingResponse> batchResponses = new ArrayList<LLMClient.SimpleEmbeddingResponse>();
        for (SnowflakeCortexRESTEmbeddingsResponse.Embedding embedding : response.embeddings) {
            LLMClient.SimpleEmbeddingResponse singleResponse = new LLMClient.SimpleEmbeddingResponse();
            singleResponse.embedding = embedding.embedding[0];
            singleResponse.promptTokens = response.usage.totalTokens / batchTexts.size();
            batchResponses.add(singleResponse);
        }
        return batchResponses;
    }

    @Override
    public void close() {
        this.client.close();
    }

    public static class SnowflakeCortexEmbeddingsRESTQuery {
        public List<String> text;
        public String model;
    }

    public static class SnowflakeCortexRESTEmbeddingsResponse {
        public String model;
        @SerializedName(value="data")
        public List<Embedding> embeddings;
        public Usage usage;

        public static class Usage {
            @SerializedName(value="prompt_tokens")
            public int promptTokens;
            @SerializedName(value="completion_tokens")
            public int completionTokens;
            @SerializedName(value="total_tokens")
            public int totalTokens;
        }

        public static class Embedding {
            public double[][] embedding;
        }
    }
}

