/*
 * Decompiled with CFR 0.152.
 */
package com.customllm.llm;

import com.customllm.llm.CustomApiImplementation;
import com.customllm.llm.GPTAIChatChunkResponseAdapter;
import com.dataiku.common.rpc.ExternalJSONAPIClient;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.llm.online.openai.OpenAIMode;
import com.dataiku.dip.llm.online.openai.api.OpenAIChatQuery;
import com.dataiku.dip.llm.online.openai.api.OpenAIChatQueryAdapter;
import com.dataiku.dip.llm.online.openai.api.OpenAIChatResponse;
import com.dataiku.dip.streaming.endpoints.httpsse.SSEDecoder;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JF;
import com.dataiku.dip.utils.JSON;
import com.google.gson.annotations.SerializedName;
import java.io.IOException;
import java.util.List;
import javax.annotation.Nullable;

public class GptApiImplementation
implements CustomApiImplementation {
    JF.ObjectBuilder ob;
    CoreCompletionSettings ccs;
    OpenAIChatQuery query;
    String endpoint;
    String model;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.customplugin.nvidia.nim");

    public GptApiImplementation(String endpointUrl, String model) {
        this.endpoint = endpointUrl;
        this.model = model;
        this.ob = JF.obj();
    }

    @Override
    public void addSettingsInObject(String model, Integer maxTokens, Double temperature, Double topP, Integer topK, List<String> stopSequences, LLMClient.ToolChoice toolChoice, List<LLMClient.AbstractTool> tools) {
        this.ccs = new CoreCompletionSettings();
        this.ccs.maxTokens = maxTokens;
        this.ccs.temperature = temperature;
        this.ccs.topP = topP;
        this.ccs.topK = topK;
        this.ccs.stopSequences = stopSequences;
        this.ccs.toolChoice = toolChoice;
        this.ccs.tools = tools;
    }

    @Override
    public void addMessagesInObject(List<LLMClient.ChatMessage> messages) {
        this.query = OpenAIChatQueryAdapter.adaptForStreaming((OpenAIMode)OpenAIMode.OPENAI, (String)this.model, messages, (CoreCompletionSettings)this.ccs, (boolean)false);
    }

    @Override
    public LLMClient.SimpleCompletionResponse sendPostObject(ExternalJSONAPIClient client, AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings) throws IOException {
        this.query.stream = false;
        this.query.streamOptions = null;
        logger.debug((Object)("Batch:" + JSON.pretty((Object)this.query)));
        logger.debug((Object)("Endpoint:" + this.endpoint));
        RawChatCompletionResponse rcr = (RawChatCompletionResponse)client.postObjectToJSON(this.endpoint, networkSettings.queryTimeoutMS, RawChatCompletionResponse.class, (Object)this.query);
        if (rcr.choices == null || rcr.choices.size() == 0) {
            throw new IOException("Chat did not respond with valid completion");
        }
        LLMClient.SimpleCompletionResponse ret = new LLMClient.SimpleCompletionResponse();
        ret.text = rcr.choices.get((int)0).message.content;
        ret.promptTokens = rcr.usage.prompt_tokens;
        ret.completionTokens = rcr.usage.completion_tokens;
        return ret;
    }

    @Override
    public void streamChatComplete(ExternalJSONAPIClient client, LLMClient.StreamedCompletionResponseConsumer consumer, AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings, LLMClient.ToolChoice toolChoice, List<LLMClient.AbstractTool> tools) throws Exception {
        logger.debug((Object)("Final Request AdamG:" + JSON.pretty((Object)this.query)));
        ExternalJSONAPIClient.EntityAndRequest ear = client.postJSONToStreamAndRequest(this.endpoint, networkSettings.queryTimeoutMS, (Object)this.query);
        SSEDecoder decoder = new SSEDecoder(ear.entity.getContent());
        consumer.onStreamStarted();
        OpenAIChatResponse.Usage usage = null;
        LLMClient.FinishReason finishReason = null;
        StringBuilder refusalBuilder = null;
        while (true) {
            LLMClient.StreamedCompletionResponseChunk chunk;
            String refusalChunk;
            SSEDecoder.HTTPSSEEvent event = decoder.next();
            if (logger.isTraceEnabled()) {
                logger.trace((Object)("Received raw event from OpenAI: " + JSON.json((Object)event)));
            }
            if (event == null || event.data == null) {
                logger.info((Object)"End of OpenAI stream");
                break;
            }
            if (event.data.equals("[DONE]")) {
                logger.info((Object)"Received explicit end marker from OpenAI stream");
                break;
            }
            OpenAIChatChunkResponse response = (OpenAIChatChunkResponse)JSON.parse((String)event.data, OpenAIChatChunkResponse.class);
            if (response.usage != null) {
                usage = response.usage;
            }
            if (response.choices.isEmpty()) continue;
            LLMClient.FinishReason reason = GPTAIChatChunkResponseAdapter.extractFinishReason(response);
            if (reason != null) {
                finishReason = reason;
            }
            if ((refusalChunk = GPTAIChatChunkResponseAdapter.getRefusal(response)) != null) {
                if (refusalBuilder == null) {
                    refusalBuilder = new StringBuilder();
                }
                refusalBuilder.append(refusalChunk);
            }
            if (refusalBuilder != null || (chunk = GPTAIChatChunkResponseAdapter.adapt(response)).isEmpty()) continue;
            consumer.onStreamChunk(chunk);
        }
        if (refusalBuilder != null) {
            LLMClient.RefusalException refusalException = new LLMClient.RefusalException(refusalBuilder.toString());
            if (usage != null) {
                refusalException.completionTokens = usage.completionTokens;
                refusalException.promptTokens = usage.promptTokens;
                refusalException.totalTokens = usage.totalTokens;
            }
            throw refusalException;
        }
        LLMClient.StreamedCompletionResponseFooter footer = new LLMClient.StreamedCompletionResponseFooter();
        if (usage != null) {
            footer.completionTokens = usage.completionTokens;
            footer.promptTokens = usage.promptTokens;
            footer.totalTokens = usage.totalTokens;
        }
        if (finishReason != null) {
            footer.finishReason = finishReason;
        }
        consumer.onStreamComplete(footer);
    }

    private static class RawChatCompletionResponse {
        List<RawChatCompletionChoice> choices;
        RawUsageResponse usage;

        private RawChatCompletionResponse() {
        }
    }

    public static class RawChatCompletionChoice {
        @SerializedName(value="finish_reason")
        public String finishReason;
        public ChoiceMessage message;
        @Nullable
        @SerializedName(value="logprobs")
        public LogProbs logProbs;
    }

    public static class ChoiceMessage {
        public String role;
        @Nullable
        public String content;
        @Nullable
        @SerializedName(value="tool_calls")
        public List<ToolCall> toolCalls;
        @Nullable
        public String refusal;
    }

    private static class RawUsageResponse {
        int total_tokens;
        int prompt_tokens;
        int completion_tokens;

        private RawUsageResponse() {
        }
    }

    public class OpenAIChatChunkResponse {
        public List<Choice> choices;
        @Nullable
        public OpenAIChatResponse.Usage usage;
    }

    private static class Usage {
        @SerializedName(value="completion_tokens")
        public int completionTokens;
        @SerializedName(value="prompt_tokens")
        public int promptTokens;
        @SerializedName(value="total_tokens")
        public int totalTokens;

        private Usage() {
        }
    }

    public static class ToolCallFunction {
        public String name;
        public String arguments;
    }

    public static class ToolCall {
        public String id;
        public ToolCallFunction function;
    }

    public static class TopLogProbContent {
        public String token;
        @SerializedName(value="logprob")
        public double logProb;
    }

    public static class LogProbContent {
        public String token;
        @SerializedName(value="logprob")
        public double logProb;
        @SerializedName(value="top_logprobs")
        public List<TopLogProbContent> topLogProbs;
    }

    public static class LogProbs {
        @Nullable
        public List<LogProbContent> content;
    }

    public static class Delta {
        @Nullable
        public String content;
        @Nullable
        @SerializedName(value="tool_calls")
        public List<PartToolCall> toolCalls;
        @Nullable
        public String refusal;
    }

    public static class PartialToolCallFunction {
        @Nullable
        public String name;
        @Nullable
        public String arguments;
    }

    public static class PartToolCall {
        @Nullable
        public Integer index;
        @Nullable
        public String id;
        public PartialToolCallFunction function;
    }

    public static class Choice {
        public Delta delta;
        @Nullable
        @SerializedName(value="finish_reason")
        public String finishReason;
        @Nullable
        @SerializedName(value="logprobs")
        public LogProbs logProbs;
    }
}

