/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.futures.FutureProgress;
import com.dataiku.dip.futures.FutureProgressState;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.cache.ILLMCacheService;
import com.dataiku.dip.llm.governance.GuardrailRunner;
import com.dataiku.dip.llm.governance.GuardrailsPipelineRunner;
import com.dataiku.dip.llm.governance.GuardrailsPipelineSettings;
import com.dataiku.dip.llm.governance.GuardrailsPipelineUtils;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMClientFactory;
import com.dataiku.dip.llm.online.LLMMeshClient;
import com.dataiku.dip.llm.online.LLMTracingUtils;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JF;
import com.google.common.collect.Lists;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import java.lang.invoke.LambdaMetafactory;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.springframework.beans.factory.annotation.Autowired;

public class NonParallelLLMClient
implements LLMMeshClient {
    @Autowired
    private ILLMCacheService cacheService;
    private final LLMClient llmClient;
    private final GuardrailsPipelineSettings guardrailsPipelineSettings;
    private final EnrichedLLMStructuredRef enrichedRef;
    private final String llmId;
    private final AuthCtx authCtx;
    private final String contextProjectKey;
    public int cacheHitQueries;
    public int cacheMissQueries;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.online.nonparallel");

    public NonParallelLLMClient(AuthCtx authCtx, LLMStructuredRef llmStructuredRef, GuardrailsPipelineSettings guardrailsPipelineSettings, String contextProjectKey, AnyLoc usedDataset) throws Exception {
        SpringUtils.getInstance().autowire((Object)this);
        this.llmClient = LLMClientFactory.get(authCtx, contextProjectKey, llmStructuredRef);
        this.authCtx = authCtx;
        this.contextProjectKey = contextProjectKey;
        this.llmId = llmStructuredRef.id;
        this.enrichedRef = this.llmClient.getEnrichedRef();
        this.guardrailsPipelineSettings = guardrailsPipelineSettings;
    }

    public NonParallelLLMClient(AuthCtx authCtx, LLMClient client, GuardrailsPipelineSettings guardrailsPipelineSettings, String contextProjectKey, AnyLoc usedDataset) throws Exception {
        SpringUtils.getInstance().autowire((Object)this);
        this.llmClient = client;
        this.authCtx = authCtx;
        this.contextProjectKey = contextProjectKey;
        this.llmId = client.getEnrichedRef().id;
        this.enrichedRef = this.llmClient.getEnrichedRef();
        this.guardrailsPipelineSettings = guardrailsPipelineSettings;
    }

    @Override
    public List<LLMClient.SimpleCompletionResponseOrError> completeQueries(List<LLMClient.SingleCompletionQuery> queries, LLMClient.CompletionSettings settings) {
        logger.debug((Object)("Submitting " + queries.size() + " prompt to LLM " + String.valueOf(this.llmClient)));
        ArrayList<LLMClient.SimpleCompletionResponseOrError> ret = new ArrayList<LLMClient.SimpleCompletionResponseOrError>();
        FutureProgressState fps = FutureProgress.getState();
        boolean useCache = this.llmClient.getConnection() != null && this.llmClient.getConnection().getLLMConnectionParams() != null && this.llmClient.getConnection().getLLMConnectionParams().cachingEnabled;
        for (LLMClient.SingleCompletionQuery query : queries) {
            ret.add(this.processSingleQuery(query, settings, useCache));
            try {
                fps.increment(1.0);
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        }
        return ret;
    }

    /*
     * Exception decompiling
     */
    private LLMClient.SimpleCompletionResponseOrError processSingleQuery(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings, boolean useCache) {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Tried to end blocks [2[TRYBLOCK], 1[TRYBLOCK], 0[TRYBLOCK], 3[TRYBLOCK]], but top level block is 22[CASE]
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.processEndingBlocks(Op04StructuredStatement.java:435)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:484)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    private LLMClient.SimpleCompletionResponseOrError performActualCompletionCall(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings, LLMClient.LLMMeshTraceSpan globalSpan) {
        LLMClient.SimpleCompletionResponseOrError result = null;
        logger.info((Object)"Sending a single completion query to LLM");
        List<LLMClient.SimpleCompletionResponseOrError> resp = null;
        LLMClient.LLMMeshTraceSpan llmCallSpan = LLMClient.LLMMeshTraceSpan.start("DKU_LLM_MESH_LLM_CALL");
        LLMTracingUtils.addIdentifiersAndSetCompletionInput(llmCallSpan, this.llmId, this.llmClient, query, settings);
        try {
            List<LLMClient.SimpleCompletionResponse> rawResponses = this.llmClient.completeBatch(Lists.newArrayList((Object[])new LLMClient.SingleCompletionQuery[]{query}), settings);
            logger.info((Object)"Got successful single completion answer from LLM");
            resp = rawResponses.stream().map(LLMClient.SimpleCompletionResponseOrError::fromSuccess).collect(Collectors.toList());
            llmCallSpan.close();
        }
        catch (Exception e) {
            try {
                logger.warn((Object)"Got error from LLM while retrieving single completion answer", (Throwable)e);
                resp = Collections.nCopies(1, LLMClient.SimpleCompletionResponseOrError.fromError(e));
                llmCallSpan.close();
            }
            catch (Throwable throwable) {
                llmCallSpan.close();
                globalSpan.addObservation(llmCallSpan);
                if (resp == null) throw throwable;
                if (resp.isEmpty()) throw throwable;
                LLMClient.SimpleCompletionResponseOrError r = resp.get(0);
                LLMTracingUtils.addUsageMetadataAndSetCompletionOutput(llmCallSpan, r);
                if (r.trace == null) throw throwable;
                llmCallSpan.addObservation(r.trace);
                throw throwable;
            }
            globalSpan.addObservation(llmCallSpan);
            if (resp == null) return resp.get(0);
            if (resp.isEmpty()) return resp.get(0);
            LLMClient.SimpleCompletionResponseOrError r = resp.get(0);
            LLMTracingUtils.addUsageMetadataAndSetCompletionOutput(llmCallSpan, r);
            if (r.trace == null) return resp.get(0);
            llmCallSpan.addObservation(r.trace);
            return resp.get(0);
        }
        globalSpan.addObservation(llmCallSpan);
        if (resp == null) return resp.get(0);
        if (resp.isEmpty()) return resp.get(0);
        LLMClient.SimpleCompletionResponseOrError r = resp.get(0);
        LLMTracingUtils.addUsageMetadataAndSetCompletionOutput(llmCallSpan, r);
        if (r.trace == null) return resp.get(0);
        llmCallSpan.addObservation(r.trace);
        return resp.get(0);
    }

    @Override
    public List<LLMClient.SimpleEmbeddingResponseOrError> embedQueries(List<LLMClient.EmbeddingQuery> queries, LLMClient.EmbeddingSettings settings) {
        logger.info((Object)("Submitting " + queries.size() + " embeddings to LLM " + String.valueOf(this.llmClient)));
        FutureProgressState fps = FutureProgress.getState();
        boolean useCache = this.llmClient.getConnection() != null && this.llmClient.getConnection().getLLMConnectionParams() != null && this.llmClient.getConnection().getLLMConnectionParams().embeddingsCachingEnabled;
        ArrayList<LLMClient.SimpleEmbeddingResponseOrError> ret = new ArrayList<LLMClient.SimpleEmbeddingResponseOrError>();
        for (LLMClient.EmbeddingQuery query : queries) {
            ret.add(this.embedSingleQuery(query, settings, useCache));
            try {
                fps.increment(1.0);
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        }
        return ret;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     * Unable to fully structure code
     */
    private LLMClient.SimpleEmbeddingResponseOrError embedSingleQuery(LLMClient.EmbeddingQuery query, LLMClient.EmbeddingSettings settings, boolean useCache) {
        result = null;
        guardrailAuditData = new JsonArray();
        globalSpan = LLMClient.LLMMeshTraceSpan.start("DKU_LLM_MESH_EMBEDDING_QUERY");
        try {
            block43: {
                block42: {
                    LLMTracingUtils.addLLMIdentifiers(globalSpan, this.llmId, this.llmClient);
                    LLMTracingUtils.setEmbeddingInput(globalSpan, query, settings);
                    guardrailsPipelineRunner = new GuardrailsPipelineRunner(this.authCtx, this.contextProjectKey, this.guardrailsPipelineSettings);
                    guardrailSpan = globalSpan.withChildSpan("DKU_LLM_MESH_QUERY_ENFORCEMENT");
                    try {
                        guardrailResponse = guardrailsPipelineRunner.processEmbeddingQuery(new GuardrailRunner.GuardrailContext(), query, settings, guardrailSpan);
                        switch (1.$SwitchMap$com$dataiku$dip$llm$governance$GuardrailRunner$QueryGuardrailAction[guardrailResponse.action.ordinal()]) {
                            case 1: {
                                throw guardrailResponse.toException();
                            }
                            case 2: {
                                GuardrailsPipelineUtils.updateEmbeddingQueryFromGuardrailsResponse(query, guardrailResponse);
                                ** break;
lbl19:
                                // 1 sources

                                break;
                            }
                            case 3: {
                                GuardrailsPipelineUtils.updateEmbeddingQueryFromGuardrailsResponse(query, guardrailResponse);
                                guardrailAuditData = guardrailResponse.auditData;
                                break;
                            }
                            ** default:
lbl25:
                            // 1 sources

                            break;
                        }
                    }
                    finally {
                        if (guardrailSpan != null) {
                            guardrailSpan.close();
                        }
                    }
                    cacheResult = null;
                    if (useCache) {
                        cacheResult = this.cacheService.get(this.authCtx, this.llmId, this.contextProjectKey, query, settings);
                    }
                    cached = null;
                    if (cacheResult == null || !cacheResult.cacheHit || !((LLMClient.SimpleEmbeddingResponseOrError)cacheResult.result).ok) break block42;
                    NonParallelLLMClient.logger.debug((Object)"Got cache hit");
                    globalSpan.withChildEvent("DKU_LLM_MESH_CACHE_HIT");
                    ++this.cacheHitQueries;
                    cached = (LLMClient.SimpleEmbeddingResponseOrError)cacheResult.result;
                    cached.fromCache = true;
                    cached.estimatedCost = 0.0;
                    globalSpan.attributes.add("embeddingResponse", (JsonElement)JF.obj().with("ok", Boolean.valueOf(true)).with("promptTokens", (Number)((LLMClient.SimpleEmbeddingResponseOrError)cacheResult.result).promptTokens).with("tokenCountsAreEstimated", ((LLMClient.SimpleEmbeddingResponseOrError)cacheResult.result).tokenCountsAreEstimated).with("estimatedCost", (Number)0).get());
                    cached.trace = globalSpan;
                    var10_19 = cached;
                    guardrailsPipelineRunner.close();
                    return var10_19;
                }
                try {
                    ++this.cacheMissQueries;
                    break block43;
                    {
                        catch (Throwable cacheResult) {
                            throw cacheResult;
                        }
                    }
                    finally {
                        guardrailsPipelineRunner.close();
                    }
                }
                catch (Exception e) {
                    NonParallelLLMClient.logger.warn((Object)"Got error from LLM preprocessing while retrieving single embedding", (Throwable)e);
                    cacheResult = LLMClient.SimpleEmbeddingResponseOrError.fromError(e);
                    if (globalSpan != null) {
                        globalSpan.close();
                    }
                    return cacheResult;
                }
            }
            NonParallelLLMClient.logger.info((Object)"Sending a single embedding query to LLM");
            resp = null;
            llmMeshTraceSpan = LLMClient.LLMMeshTraceSpan.start("DKU_LLM_MESH_LLM_CALL");
            LLMTracingUtils.addLLMIdentifiers(llmMeshTraceSpan, this.llmId, this.llmClient);
            LLMTracingUtils.setEmbeddingInput(llmMeshTraceSpan, query, settings);
            try {
                rawResponses = this.llmClient.embedBatch(Lists.newArrayList((Object[])new LLMClient.EmbeddingQuery[]{query}), settings);
                NonParallelLLMClient.logger.info((Object)"Got successful single embedding answer from LLM");
                resp = rawResponses.stream().map((Function<LLMClient.SimpleEmbeddingResponse, LLMClient.SimpleEmbeddingResponseOrError>)LambdaMetafactory.metafactory(null, null, null, (Ljava/lang/Object;)Ljava/lang/Object;, fromSuccess(com.dataiku.dip.llm.online.LLMClient$SimpleEmbeddingResponse ), (Lcom/dataiku/dip/llm/online/LLMClient$SimpleEmbeddingResponse;)Lcom/dataiku/dip/llm/online/LLMClient$SimpleEmbeddingResponseOrError;)()).collect(Collectors.toList());
                llmMeshTraceSpan.close();
            }
            catch (Exception e) {
                try {
                    NonParallelLLMClient.logger.warn((Object)"Got error from LLM while retrieving single embedding answer", (Throwable)e);
                    resp = Collections.nCopies(1, LLMClient.SimpleEmbeddingResponseOrError.fromError(e));
                    llmMeshTraceSpan.close();
                }
                catch (Throwable var11_20) {
                    llmMeshTraceSpan.close();
                    globalSpan.addObservation(llmMeshTraceSpan);
                    if (resp != null && !resp.isEmpty()) {
                        r = resp.get(0);
                        LLMTracingUtils.addUsageMetadataAndSetEmbeddingOutput(llmMeshTraceSpan, r);
                        if (r.trace != null) {
                            llmMeshTraceSpan.addObservation(r.trace);
                        }
                    }
                    throw var11_20;
                }
                globalSpan.addObservation(llmMeshTraceSpan);
                if (resp != null && !resp.isEmpty()) {
                    r = resp.get(0);
                    LLMTracingUtils.addUsageMetadataAndSetEmbeddingOutput(llmMeshTraceSpan, r);
                    if (r.trace != null) {
                        llmMeshTraceSpan.addObservation(r.trace);
                    } else {
                        ** GOTO lbl103
                    }
                } else {
                    ** GOTO lbl103
                }
            }
            globalSpan.addObservation(llmMeshTraceSpan);
            if (resp != null && !resp.isEmpty()) {
                r = resp.get(0);
                LLMTracingUtils.addUsageMetadataAndSetEmbeddingOutput(llmMeshTraceSpan, r);
                if (r.trace != null) {
                    llmMeshTraceSpan.addObservation(r.trace);
                }
            }
            result = resp.get(0);
            if (result.ok && useCache) {
                this.cacheService.put(this.authCtx, this.llmId, this.contextProjectKey, query, settings, result);
            }
            result.guardrailsAuditData = guardrailAuditData;
            LLMTracingUtils.setEmbeddingOutput(llmMeshTraceSpan, result);
            result.trace = globalSpan;
        }
        finally {
            if (globalSpan != null) {
                globalSpan.close();
            }
        }
        return result;
    }

    @Override
    public ComputeResourceUsage getTotalCRU(ComputeResourceUsage.LLMUsageType usageType) {
        ComputeResourceUsage totalCRU = this.llmClient.getTotalCRU(usageType, this.enrichedRef);
        if (totalCRU != null) {
            totalCRU.llmUsage.cacheMissQueries = this.cacheMissQueries;
            totalCRU.llmUsage.cacheHitQueries = this.cacheHitQueries;
            totalCRU.llmUsage.totalQueries = this.cacheHitQueries + this.cacheMissQueries;
        }
        return totalCRU;
    }

    @Override
    public EnrichedLLMStructuredRef getEnrichedRef() {
        return this.enrichedRef;
    }

    @Override
    public AbstractLLMConnection getConnection() {
        return this.llmClient.getConnection();
    }

    @Override
    public void close() throws Exception {
        this.llmClient.close();
    }
}

