/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.governance.GuardrailsPipelineSettings;
import com.dataiku.dip.llm.governance.GuardrailsPipelineUtils;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMClientFactory;
import com.dataiku.dip.llm.online.NonParallelLLMClient;
import com.dataiku.dip.llm.utils.StreamingChunkEmitter;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.json.JSONObject;

public class LLMMeshStreamClient
implements AutoCloseable {
    private final AuthCtx authCtx;
    private final String projectKey;
    private final StreamingChunkEmitter emitter;
    protected final EnrichedLLMStructuredRef enrichedLLMRef;
    private final LLMClient llmClient;
    private final GuardrailsPipelineSettings guardrailsPipelineSettings;
    private final boolean streamingDisabled;
    private final boolean includeLogs;
    private boolean emulatedStreamingInfoChunk = false;
    private AtomicInteger cacheMissQueries = new AtomicInteger(0);
    private AtomicInteger cacheHitQueries = new AtomicInteger(0);
    private ComputeResourceUsage computeResourceUsage;
    private static DKULogger logger = DKULogger.getLogger((String)"dku.llm.online.stream");

    public LLMMeshStreamClient(AuthCtx authCtx, String projectKey, EnrichedLLMStructuredRef enrichedLLMRef, GuardrailsPipelineSettings guardrailsPipelineSettings, boolean streamingDisabled, boolean useDevKernel, StreamingChunkEmitter emitter) throws Exception {
        this.projectKey = projectKey;
        this.enrichedLLMRef = enrichedLLMRef;
        this.authCtx = authCtx;
        this.guardrailsPipelineSettings = guardrailsPipelineSettings;
        this.streamingDisabled = streamingDisabled;
        this.llmClient = LLMClientFactory.get(this.authCtx, this.projectKey, this.enrichedLLMRef, useDevKernel);
        this.emitter = emitter;
        this.includeLogs = useDevKernel;
    }

    public LLMMeshStreamClient(LLMClient llmClient, AuthCtx authCtx, String projectKey, EnrichedLLMStructuredRef enrichedLLMRef, GuardrailsPipelineSettings guardrailsPipelineSettings, boolean streamingDisabled, StreamingChunkEmitter emitter) throws Exception {
        this.projectKey = projectKey;
        this.enrichedLLMRef = enrichedLLMRef;
        this.authCtx = authCtx;
        this.guardrailsPipelineSettings = guardrailsPipelineSettings;
        this.streamingDisabled = streamingDisabled;
        this.llmClient = llmClient;
        this.emitter = emitter;
        this.includeLogs = false;
    }

    public LLMClient.SimpleCompletionResponseOrError streamComplete(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings completionSettings) {
        if (this.useEmulatedStreaming()) {
            return this.emulateStreamCompletion(query, completionSettings);
        }
        return this.streamCompletion(query, completionSettings);
    }

    private boolean useEmulatedStreaming() {
        if (this.streamingDisabled) {
            logger.info((Object)"Streaming disabled");
            return true;
        }
        if (!this.llmClient.supportsStream()) {
            logger.info((Object)("Streaming requested but LLM " + this.enrichedLLMRef.id + " does not support it"));
            return true;
        }
        if (GuardrailsPipelineUtils.needsNonStreamedNonParallelProcessing(this.guardrailsPipelineSettings)) {
            logger.info((Object)"Streaming requested but Guardrails don't support it");
            return true;
        }
        return false;
    }

    /*
     * Exception decompiling
     */
    private LLMClient.SimpleCompletionResponseOrError streamCompletion(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings completionSettings) {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Tried to end blocks [2[TRYBLOCK], 1[TRYBLOCK]], but top level block is 16[CASE]
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.processEndingBlocks(Op04StructuredStatement.java:435)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:484)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private LLMClient.SimpleCompletionResponseOrError emulateStreamCompletion(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings completionSettings) {
        logger.info((Object)("Streaming is not supported. " + JSON.json((Object)this.enrichedLLMRef)));
        try (NonParallelLLMClient nplc = new NonParallelLLMClient(this.authCtx, this.llmClient, this.guardrailsPipelineSettings, this.projectKey, null);){
            ComputeResourceUsage cru;
            ArrayList<LLMClient.SingleCompletionQuery> singleCompletionQueries = new ArrayList<LLMClient.SingleCompletionQuery>();
            singleCompletionQueries.add(query);
            this.emitter.setInterruptCallback(nplc);
            if (this.emulatedStreamingInfoChunk) {
                this.sendEmulatedStreamingInfoChunk();
            }
            List<LLMClient.SimpleCompletionResponseOrError> responses = nplc.completeQueries(singleCompletionQueries, completionSettings);
            LLMClient.SimpleCompletionResponseOrError resp = responses.get(0);
            if (this.includeLogs) {
                resp.log = this.llmClient.getKernelLog();
            }
            if (resp.ok) {
                this.streamCompletionResponse(resp);
            }
            if ((cru = nplc.getTotalCRU(ComputeResourceUsage.LLMUsageType.COMPLETION)) != null) {
                this.computeResourceUsage = cru;
            }
            if (this.emitter.isInterrupted()) {
                LLMClient.SimpleCompletionResponseOrError simpleCompletionResponseOrError = LLMClient.SimpleCompletionResponseOrError.fromError(new Exception("Response generation was interrupted."));
                return simpleCompletionResponseOrError;
            }
            LLMClient.SimpleCompletionResponseOrError simpleCompletionResponseOrError = resp;
            return simpleCompletionResponseOrError;
        }
        catch (Exception e) {
            if (!this.emitter.isInterrupted()) return LLMClient.SimpleCompletionResponseOrError.fromError(e);
            return LLMClient.SimpleCompletionResponseOrError.fromError(new Exception("Response generation was interrupted."));
        }
    }

    private void streamCompletionResponse(LLMClient.SimpleCompletionResponseOrError resp) throws Exception {
        this.emitter.initSuccess();
        LLMClient.StreamedCompletionResponseChunk completeChunk = new LLMClient.StreamedCompletionResponseChunk();
        completeChunk.type = "content";
        completeChunk.text = resp.text;
        completeChunk.toolCalls = resp.toolCalls;
        completeChunk.logProbs = resp.logProbs;
        completeChunk.artifacts = resp.artifacts;
        this.emitter.emitCompletionChunk(completeChunk);
        LLMClient.StreamedCompletionResponseFooter footer = new LLMClient.StreamedCompletionResponseFooter();
        footer.finishReason = resp.finishReason;
        footer.trace = resp.trace;
        footer.additionalInformation = resp.additionalInformation;
        footer.promptTokens = resp.promptTokens;
        footer.completionTokens = resp.completionTokens;
        footer.totalTokens = resp.totalTokens;
        footer.tokenCountsAreEstimated = resp.tokenCountsAreEstimated;
        footer.estimatedCost = resp.estimatedCost;
        this.emitter.emitCompletionFooter(footer);
    }

    private void sendEmulatedStreamingInfoChunk() throws Exception {
        JSONObject noStreamJson = new JSONObject();
        if (this.streamingDisabled) {
            noStreamJson.put("text", (Object)"Streaming is disabled. Response time may be longer.");
        } else {
            noStreamJson.put("text", (Object)"Selected model does not support streaming. Response time may be longer.");
        }
        this.emitter.emitEmulateStreamingInfoChunk(noStreamJson.toString());
    }

    public ComputeResourceUsage getTotalCRU() {
        if (this.computeResourceUsage != null) {
            return this.computeResourceUsage;
        }
        ComputeResourceUsage totalCRU = this.llmClient.getTotalCRU(ComputeResourceUsage.LLMUsageType.COMPLETION, this.enrichedLLMRef);
        if (totalCRU != null) {
            totalCRU.llmUsage.cacheMissQueries = this.cacheMissQueries.get();
            totalCRU.llmUsage.cacheHitQueries = this.cacheHitQueries.get();
            totalCRU.llmUsage.totalQueries = this.cacheMissQueries.get() + this.cacheHitQueries.get();
        }
        return totalCRU;
    }

    public AbstractLLMConnection getConnection() {
        return this.llmClient.getConnection();
    }

    public void enableEmulatedStreamingInfoChunk() {
        this.emulatedStreamingInfoChunk = true;
    }

    @Override
    public void close() throws Exception {
        this.llmClient.close();
    }
}

