/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.mistralai;

import com.dataiku.common.rpc.ExternalJSONAPIClient;
import com.dataiku.dip.ProxySettings;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.MistralAIConnection;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettingsValidator;
import com.dataiku.dip.llm.online.mistralai.api.MistralAIChatChunkResponse;
import com.dataiku.dip.llm.online.mistralai.api.MistralAIChatChunkResponseAdapter;
import com.dataiku.dip.llm.online.mistralai.api.MistralAIChatQuery;
import com.dataiku.dip.llm.online.mistralai.api.MistralAIChatQueryAdapter;
import com.dataiku.dip.llm.online.mistralai.api.MistralAIChatResponse;
import com.dataiku.dip.llm.online.mistralai.api.MistralAIChatResponseAdapter;
import com.dataiku.dip.llm.utils.OnlineLLMUtils;
import com.dataiku.dip.streaming.endpoints.httpsse.SSEDecoder;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class RawMistralAIClient {
    private static final String DEFAULT_ENDPOINT_BASE = "https://api.mistral.ai/v1";
    private final AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings;
    private final ExternalJSONAPIClient client;
    private static final CoreCompletionSettingsValidator chatCompletionValidator = new CoreCompletionSettingsValidator("MistralAI (chat)").allowMaxTokens().allowTemperature().allowTopP().allowJsonMode().allowTools();
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.mistralai.client");

    public RawMistralAIClient(String apiKey, AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings, ProxySettings proxySettings, boolean forceContentLength) {
        this.networkSettings = networkSettings;
        this.client = OnlineLLMUtils.getExternalJSONClientWithRetryStrategy(DEFAULT_ENDPOINT_BASE, null, false, proxySettings, networkSettings);
        this.client.addHeader("Authorization", "Bearer " + apiKey);
        if (forceContentLength) {
            this.client.forceContentLength = true;
        }
    }

    public void close() {
        this.client.close();
    }

    public void streamChatComplete(LLMClient.StreamedCompletionResponseConsumer consumer, MistralAIConnection.MistralAIModel model, List<LLMClient.ChatMessage> messages, CoreCompletionSettings ccs) throws Exception {
        chatCompletionValidator.validate(ccs);
        MistralAIChatQuery query = MistralAIChatQueryAdapter.adaptForStream(model.getId(), messages, ccs);
        logger.trace(() -> String.format("MistralAI raw chat completion streaming query: %s", JSON.pretty((Object)query)));
        ExternalJSONAPIClient.EntityAndRequest ear = this.client.postJSONToStreamAndRequest("/chat/completions", this.networkSettings.queryTimeoutMS, (Object)query);
        SSEDecoder decoder = new SSEDecoder(ear.entity.getContent());
        consumer.onStreamStarted();
        MistralAIChatResponse.Usage usage = null;
        LLMClient.FinishReason finishReason = null;
        while (true) {
            LLMClient.StreamedCompletionResponseChunk chunk;
            SSEDecoder.HTTPSSEEvent event = decoder.next();
            if (logger.isTraceEnabled()) {
                logger.trace((Object)("Received raw event from MistralAI: " + JSON.json((Object)event)));
            }
            if (event == null || event.data == null) {
                logger.info((Object)"End of MistralAI stream");
                break;
            }
            if (event.data.equals("[DONE]")) {
                logger.info((Object)"Received explicit end marker from MistralAI stream");
                break;
            }
            MistralAIChatChunkResponse response = (MistralAIChatChunkResponse)JSON.parse((String)event.data, MistralAIChatChunkResponse.class);
            logger.trace(() -> String.format("MistralAI raw streamed chat completion response chunk: %s", JSON.pretty((Object)response)));
            if (response.usage != null) {
                usage = response.usage;
            }
            if (response.choices != null && response.choices.isEmpty()) continue;
            LLMClient.FinishReason reason = MistralAIChatChunkResponseAdapter.extractFinishReason(response);
            if (reason != null) {
                finishReason = reason;
            }
            if ((chunk = MistralAIChatChunkResponseAdapter.adapt(response)).isEmpty()) continue;
            consumer.onStreamChunk(chunk);
        }
        LLMClient.StreamedCompletionResponseFooter footer = new LLMClient.StreamedCompletionResponseFooter();
        if (usage != null) {
            footer.completionTokens = usage.completionTokens;
            footer.promptTokens = usage.promptTokens;
            footer.totalTokens = usage.totalTokens;
            footer.estimatedCost = model.getEstimatedCompletionCost(footer.promptTokens, footer.completionTokens);
        }
        if (finishReason != null) {
            footer.finishReason = finishReason;
        }
        consumer.onStreamComplete(footer);
    }

    public LLMClient.SimpleCompletionResponse chatComplete(String model, List<LLMClient.ChatMessage> messages, CoreCompletionSettings ccs) throws IOException {
        chatCompletionValidator.validate(ccs);
        MistralAIChatQuery query = MistralAIChatQueryAdapter.adapt(model, messages, ccs);
        logger.trace(() -> String.format("MistralAI raw chat completion query: %s", JSON.pretty((Object)query)));
        MistralAIChatResponse response = (MistralAIChatResponse)this.client.postObjectToJSON("/chat/completions", this.networkSettings.queryTimeoutMS, MistralAIChatResponse.class, (Object)query);
        logger.trace(() -> String.format("MistralAI raw chat completion response: %s", JSON.pretty((Object)response)));
        return MistralAIChatResponseAdapter.adapt(response);
    }

    public List<LLMClient.SimpleEmbeddingResponse> embed(String model, List<String> input) throws IOException {
        RawEmbeddingQuery rawEmbedQuery = new RawEmbeddingQuery();
        rawEmbedQuery.model = model;
        rawEmbedQuery.input = input;
        if (logger.isTraceEnabled()) {
            logger.trace((Object)("Raw Mistral AI embedding query: " + JSON.json((Object)rawEmbedQuery)));
        }
        RawEmbeddingResponse rawEmbedResponse = (RawEmbeddingResponse)this.client.postObjectToJSON("/embeddings", this.networkSettings.queryTimeoutMS, RawEmbeddingResponse.class, (Object)rawEmbedQuery);
        if (logger.isTraceEnabled()) {
            logger.trace((Object)("Raw Mistral AI embedding response: " + JSON.json((Object)rawEmbedResponse)));
        }
        if (rawEmbedResponse.data.size() != input.size()) {
            throw new IOException("Mistral AI did not respond with valid embeddings");
        }
        ArrayList<LLMClient.SimpleEmbeddingResponse> batchResponses = new ArrayList<LLMClient.SimpleEmbeddingResponse>();
        for (RawEmbeddingObject singleResult : rawEmbedResponse.data) {
            LLMClient.SimpleEmbeddingResponse singleResponse = new LLMClient.SimpleEmbeddingResponse();
            singleResponse.embedding = singleResult.embedding;
            singleResponse.promptTokens = rawEmbedResponse.usage.total_tokens / input.size();
            batchResponses.add(singleResponse);
        }
        return batchResponses;
    }

    private static class RawEmbeddingQuery {
        String model;
        List<String> input;

        private RawEmbeddingQuery() {
        }
    }

    private static class RawEmbeddingResponse {
        List<RawEmbeddingObject> data = new ArrayList<RawEmbeddingObject>();
        RawUsageResponse usage;

        private RawEmbeddingResponse() {
        }
    }

    private static class RawEmbeddingObject {
        double[] embedding;

        private RawEmbeddingObject() {
        }
    }

    private static class RawUsageResponse {
        int prompt_tokens;
        int total_tokens;

        private RawUsageResponse() {
        }
    }
}

