/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online;

import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.futures.FutureProgress;
import com.dataiku.dip.futures.FutureProgressState;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMModelHandle;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.cache.ILLMCacheService;
import com.dataiku.dip.llm.governance.GuardrailRunner;
import com.dataiku.dip.llm.governance.GuardrailsPipelineRunner;
import com.dataiku.dip.llm.governance.GuardrailsPipelineSettings;
import com.dataiku.dip.llm.governance.GuardrailsPipelineUtils;
import com.dataiku.dip.llm.online.CompletionRecipeLLMMeshClient;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMClientFactory;
import com.dataiku.dip.llm.online.LLMMeshClient;
import com.dataiku.dip.llm.online.LLMTracingUtils;
import com.dataiku.dip.llm.online.utils.BatchProcessor;
import com.dataiku.dip.llm.online.utils.QueryProcessor;
import com.dataiku.dip.llm.utils.AgentTrajectoryService;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.util.JsonUtils;
import com.dataiku.dip.utils.DKUCompletableFuture;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.google.gson.JsonArray;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.springframework.beans.factory.annotation.Autowired;

public class ParallelLLMClient
implements LLMMeshClient,
CompletionRecipeLLMMeshClient {
    @Autowired
    private ILLMCacheService cacheService;
    @Autowired
    private AgentTrajectoryService agentTrajectoryService;
    private final LLMClient llmClient;
    private final EnrichedLLMStructuredRef llmRef;
    private final GuardrailsPipelineSettings guardrailsPipelineSettings;
    private final AuthCtx authCtx;
    private final String contextProjectKey;
    private final GuardrailsPipelineRunner guardrailsPipelineRunner;
    public AtomicInteger cacheHitQueries = new AtomicInteger();
    public AtomicInteger cacheMissQueries = new AtomicInteger();
    private final ExecutorService llmPool;
    private final ExecutorService guardrailsPool;
    private final int parallelism;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.online.parallel");

    public ParallelLLMClient(AuthCtx authCtx, LLMStructuredRef llmStructuredRef, GuardrailsPipelineSettings guardrailsPipelineSettings, String contextProjectKey, AnyLoc usedDataset, int maxParallelism) throws Exception {
        SpringUtils.getInstance().autowire((Object)this);
        this.llmClient = LLMClientFactory.get(authCtx, contextProjectKey, llmStructuredRef);
        this.llmRef = this.llmClient.getEnrichedRef();
        this.guardrailsPipelineSettings = guardrailsPipelineSettings;
        this.authCtx = authCtx;
        this.contextProjectKey = contextProjectKey;
        this.parallelism = Math.max(1, Math.min(this.llmClient.getMaxParallelism(), maxParallelism));
        this.llmPool = Executors.newFixedThreadPool(this.parallelism);
        this.guardrailsPipelineRunner = new GuardrailsPipelineRunner(authCtx, contextProjectKey, guardrailsPipelineSettings);
        int guardrailsParallelism = ApplicationConfigurator.getParams().getIntParam("dku.llm.guardrails.parallelism", Integer.valueOf(5));
        logger.debug((Object)("Using LLM parallelism factor: " + this.parallelism + ", guardrails parallelism factor: " + guardrailsParallelism));
        this.guardrailsPool = Executors.newFixedThreadPool(guardrailsParallelism);
    }

    @Override
    public List<LLMClient.SimpleCompletionResponseOrError> completeQueries(List<LLMClient.SingleCompletionQuery> queries, LLMClient.CompletionSettings settings) throws Exception {
        logger.info((Object)("Submitting " + queries.size() + " prompt to LLM " + this.llmRef.id));
        List queriesWithTrace = queries.stream().map(q -> new SingleCompletionQueryWithTrace((LLMClient.SingleCompletionQuery)q)).collect(Collectors.toList());
        List<CompletableFuture<LLMClient.SimpleCompletionResponseOrError>> futures = this.completeQueriesAsyncBatcher(settings).submitAllAndFlush(queriesWithTrace);
        logger.info((Object)"All queries submitted");
        List responses = DKUCompletableFuture.collectResponses(futures);
        for (int i = 0; i < queriesWithTrace.size(); ++i) {
            ((LLMClient.SimpleCompletionResponseOrError)responses.get((int)i)).trace = ((SingleCompletionQueryWithTrace)queriesWithTrace.get((int)i)).globalSpan;
        }
        logger.info((Object)("Done receiving the completion responses. (" + responses.size() + " texts)"));
        return responses;
    }

    @Override
    public CompletionRecipeLLMMeshClient.CompletionsStreamer completeQueriesAsyncStream(LLMClient.CompletionSettings settings) {
        final BatchProcessor<SingleCompletionQueryWithTrace, LLMClient.SimpleCompletionResponseOrError> batchExecutor = this.completeQueriesAsyncBatcher(settings);
        int queueSize = batchExecutor.batchSize * (this.parallelism + 1);
        return new CompletionRecipeLLMMeshClient.CompletionsStreamer(batchExecutor, queueSize, queueSize){

            @Override
            public CompletableFuture<Void> done() {
                return batchExecutor.flush().thenCompose(nb -> super.done());
            }
        };
    }

    public BatchProcessor<SingleCompletionQueryWithTrace, LLMClient.SimpleCompletionResponseOrError> completeQueriesAsyncBatcher(final LLMClient.CompletionSettings settings) {
        final FutureProgressState fps = FutureProgress.getState();
        final boolean useCache = this.llmClient.getConnection() != null && this.llmClient.getConnection().getLLMConnectionParams() != null && this.llmClient.getConnection().getLLMConnectionParams().cachingEnabled;
        int batchSize = this.llmClient.getBatchSize(AbstractLLMConnection.QueryType.completion, this.llmRef);
        logger.info((Object)("Using batchSize=" + batchSize));
        return new BatchProcessor<SingleCompletionQueryWithTrace, LLMClient.SimpleCompletionResponseOrError>(this.llmPool, batchSize, (Executor)this.guardrailsPool, (Executor)this.guardrailsPool){

            /*
             * Unable to fully structure code
             */
            @Override
            protected Optional<LLMClient.SimpleCompletionResponseOrError> preProcessQueryOrBypassProcessing(SingleCompletionQueryWithTrace query) {
                query.globalSpan.name = "DKU_LLM_MESH_COMPLETION_QUERY";
                query.globalSpan.start();
                LLMTracingUtils.addIdentifiersAndSetCompletionInput(query.globalSpan, ParallelLLMClient.this.llmRef.id, ParallelLLMClient.this.llmClient, query.query, settings);
                try {
                    guardrailSpan = query.globalSpan.withChildSpan("DKU_LLM_MESH_QUERY_ENFORCEMENT");
                    try {
                        guardrailResponse = ParallelLLMClient.this.guardrailsPipelineRunner.processCompletionQuery(query.guardrailContext, query.query, settings, guardrailSpan);
                        switch (5.$SwitchMap$com$dataiku$dip$llm$governance$GuardrailRunner$QueryGuardrailAction[guardrailResponse.action.ordinal()]) {
                            case 1: {
                                throw guardrailResponse.toException();
                            }
                            case 2: {
                                GuardrailsPipelineUtils.updateCompletionQueryFromGuardrailsResponse(query.query, guardrailResponse);
                                ** break;
lbl14:
                                // 1 sources

                                break;
                            }
                            case 3: {
                                GuardrailsPipelineUtils.updateCompletionQueryFromGuardrailsResponse(query.query, guardrailResponse);
                                query.guardrailAuditData = guardrailResponse.auditData;
                                ** break;
lbl19:
                                // 1 sources

                                break;
                            }
                            case 4: {
                                throw new Exception("Unreachable");
                            }
                            ** default:
lbl23:
                            // 1 sources

                            break;
                        }
                    }
                    finally {
                        if (guardrailSpan != null) {
                            guardrailSpan.close();
                        }
                    }
                    cacheResult = null;
                    if (useCache) {
                        cacheResult = ParallelLLMClient.this.cacheService.get(ParallelLLMClient.this.authCtx, ParallelLLMClient.this.llmRef.id, ParallelLLMClient.this.contextProjectKey, query.query, settings);
                    }
                    cached = null;
                    if (cacheResult != null && cacheResult.cacheHit && ((LLMClient.SimpleCompletionResponseOrError)cacheResult.result).ok) {
                        ParallelLLMClient.logger.debug((Object)"Got cache hit");
                        query.globalSpan.withChildEvent("DKU_LLM_MESH_CACHE_HIT");
                        ParallelLLMClient.this.cacheHitQueries.incrementAndGet();
                        cached = (LLMClient.SimpleCompletionResponseOrError)cacheResult.result;
                        cached.fromCache = true;
                        cached.estimatedCost = 0.0;
                    } else {
                        ParallelLLMClient.this.cacheMissQueries.incrementAndGet();
                    }
                    return Optional.ofNullable(cached);
                }
                catch (Exception e) {
                    ParallelLLMClient.logger.warn((Object)"Got error from LLM preprocessing", (Throwable)e);
                    return Optional.of(LLMClient.SimpleCompletionResponseOrError.fromError(e));
                }
            }

            /*
             * WARNING - Removed try catching itself - possible behaviour change.
             */
            @Override
            protected List<LLMClient.SimpleCompletionResponseOrError> processBatch(List<SingleCompletionQueryWithTrace> queries) {
                logger.info((Object)("Sending a batch of " + queries.size() + " completion queries to LLM " + ParallelLLMClient.this.llmRef.id));
                List<LLMClient.SimpleCompletionResponseOrError> resp = null;
                try {
                    List<LLMClient.SingleCompletionQuery> baseQueries = queries.stream().map(q -> q.query).collect(Collectors.toList());
                    LLMClient.LLMMeshTraceSpan span = LLMClient.LLMMeshTraceSpan.start("DKU_LLM_MESH_LLM_CALL");
                    LLMTracingUtils.addLLMIdentifiers(span, ParallelLLMClient.this.llmRef.id, ParallelLLMClient.this.llmClient);
                    try {
                        List<LLMClient.SimpleCompletionResponse> rawResponses = ParallelLLMClient.this.llmClient.completeBatch(baseQueries, settings);
                        logger.info((Object)("Got successful completion batch answer from LLM " + ParallelLLMClient.this.llmRef.id));
                        resp = rawResponses.stream().map(LLMClient.SimpleCompletionResponseOrError::fromSuccess).collect(Collectors.toList());
                        span.close();
                    }
                    catch (Throwable throwable) {
                        span.close();
                        for (int i = 0; i < queries.size(); ++i) {
                            SingleCompletionQueryWithTrace qwt = queries.get(i);
                            LLMClient.LLMMeshTraceSpan qwtSpan = (LLMClient.LLMMeshTraceSpan)JSON.deepCopy((Object)span);
                            qwt.globalSpan.addObservation(qwtSpan);
                            if (resp == null || resp.size() <= i) continue;
                            LLMClient.SimpleCompletionResponseOrError r = resp.get(i);
                            if (queries.size() == 1) {
                                assert (i == 0);
                                LLMTracingUtils.setCompletionInput(qwtSpan, qwt.query, settings);
                                LLMTracingUtils.addUsageMetadataAndSetCompletionOutput(qwtSpan, r);
                            }
                            if (r.trace == null) continue;
                            qwtSpan.addObservation(r.trace);
                        }
                        throw throwable;
                    }
                    for (int i = 0; i < queries.size(); ++i) {
                        SingleCompletionQueryWithTrace qwt = queries.get(i);
                        LLMClient.LLMMeshTraceSpan qwtSpan = (LLMClient.LLMMeshTraceSpan)JSON.deepCopy((Object)span);
                        qwt.globalSpan.addObservation(qwtSpan);
                        if (resp == null || resp.size() <= i) continue;
                        LLMClient.SimpleCompletionResponseOrError r = resp.get(i);
                        if (queries.size() == 1) {
                            assert (i == 0);
                            LLMTracingUtils.setCompletionInput(qwtSpan, qwt.query, settings);
                            LLMTracingUtils.addUsageMetadataAndSetCompletionOutput(qwtSpan, r);
                        }
                        if (r.trace == null) continue;
                        qwtSpan.addObservation(r.trace);
                    }
                }
                catch (Exception e) {
                    logger.warn((Object)"Got error from LLM while retrieving batch completion", (Throwable)e);
                    resp = Collections.nCopies(queries.size(), LLMClient.SimpleCompletionResponseOrError.fromError(e));
                }
                return resp;
            }

            /*
             * WARNING - Removed try catching itself - possible behaviour change.
             * Unable to fully structure code
             */
            @Override
            protected LLMClient.SimpleCompletionResponseOrError postProcess(SingleCompletionQueryWithTrace query, LLMClient.SimpleCompletionResponseOrError result, boolean processingWasBypassed) {
                try {
                    if (result.ok) {
                        if (useCache && !processingWasBypassed) {
                            ParallelLLMClient.this.cacheService.put(ParallelLLMClient.this.authCtx, ParallelLLMClient.this.llmRef.id, ParallelLLMClient.this.contextProjectKey, query.query, settings, result);
                        }
                        guardrailSpan = query.globalSpan.withChildSpan("DKU_LLM_MESH_RESPONSE_ENFORCEMENT");
                        try {
                            guardrailResp = ParallelLLMClient.this.guardrailsPipelineRunner.processCompletionResponse(query.guardrailContext, query.query, result, settings, guardrailSpan);
                            switch (5.$SwitchMap$com$dataiku$dip$llm$governance$GuardrailRunner$ResponseGuardrailAction[guardrailResp.action.ordinal()]) {
                                case 1: {
                                    throw guardrailResp.toException();
                                }
                                case 2: {
                                    GuardrailsPipelineUtils.updateCompletionResponseFromGuardrailsResponse(result, guardrailResp);
                                    ** break;
lbl14:
                                    // 1 sources

                                    break;
                                }
                                case 3: {
                                    GuardrailsPipelineUtils.updateCompletionResponseFromGuardrailsResponse(result, guardrailResp);
                                    query.guardrailAuditData = JsonUtils.appendIntoJsonArray(query.guardrailAuditData, guardrailResp.auditData);
                                    ** break;
lbl19:
                                    // 1 sources

                                    break;
                                }
                                case 4: {
                                    GuardrailsPipelineUtils.updateCompletionResponseFromGuardrailsResponse(result, guardrailResp);
                                    query.guardrailAuditData = JsonUtils.appendIntoJsonArray(query.guardrailAuditData, guardrailResp.auditData);
                                    guardrailSpan.attributes.addProperty("responseFromGuardrail", "true");
                                    ** break;
lbl25:
                                    // 1 sources

                                    break;
                                }
                                case 5: {
                                    throw new Error("unreachable");
                                }
                                ** default:
lbl29:
                                // 1 sources

                                break;
                            }
                        }
                        finally {
                            if (guardrailSpan != null) {
                                guardrailSpan.close();
                            }
                        }
                    }
                    result.guardrailsAuditData = query.guardrailAuditData;
                    LLMTracingUtils.setCompletionOutput(query.globalSpan, result);
                    query.globalSpan.close();
                }
                catch (Exception e) {
                    ParallelLLMClient.logger.warn((Object)"Got error from LLM postprocessing", (Throwable)e);
                    result = LLMClient.SimpleCompletionResponseOrError.fromError(e);
                }
                finally {
                    try {
                        fps.increment(1.0);
                    }
                    catch (InterruptedException e) {
                        Thread.currentThread().interrupt();
                    }
                }
                return ParallelLLMClient.this.agentTrajectoryService.enrichResponseWithTrajectoryIfNeeded(ParallelLLMClient.this.authCtx, ParallelLLMClient.this.contextProjectKey, result, query.globalSpan, ParallelLLMClient.this.llmRef);
            }
        };
    }

    @Override
    public List<LLMClient.SimpleEmbeddingResponseOrError> embedQueries(List<LLMClient.EmbeddingQuery> queries, LLMClient.EmbeddingSettings settings) throws Exception {
        logger.info((Object)("Submitting " + queries.size() + " embeddings to LLM " + this.llmRef.id));
        List queriesWithTrace = queries.stream().map(SingleEmbeddingQueryWithTrace::new).collect(Collectors.toList());
        List<CompletableFuture<LLMClient.SimpleEmbeddingResponseOrError>> futures = this.embedQueriesAsyncBatcher(settings).submitAllAndFlush(queriesWithTrace);
        logger.info((Object)"All embedding queries submitted, waiting for responses");
        List responses = DKUCompletableFuture.collectResponses(futures);
        for (int i = 0; i < queriesWithTrace.size(); ++i) {
            ((LLMClient.SimpleEmbeddingResponseOrError)responses.get((int)i)).trace = ((SingleEmbeddingQueryWithTrace)queriesWithTrace.get((int)i)).globalSpan;
        }
        logger.info((Object)("Done receiving the full embedding response. (" + responses.size() + " embeddings received)"));
        return responses;
    }

    @Override
    public ComputeResourceUsage getTotalCRU(ComputeResourceUsage.LLMUsageType usageType) {
        ComputeResourceUsage totalCRU = this.llmClient.getTotalCRU(usageType, this.llmRef);
        if (totalCRU != null) {
            totalCRU.llmUsage.cacheHitQueries = this.cacheHitQueries.get();
            totalCRU.llmUsage.cacheMissQueries = this.cacheMissQueries.get();
            totalCRU.llmUsage.totalQueries = totalCRU.llmUsage.cacheHitQueries + totalCRU.llmUsage.cacheMissQueries;
        }
        return totalCRU;
    }

    private BatchProcessor<SingleEmbeddingQueryWithTrace, LLMClient.SimpleEmbeddingResponseOrError> embedQueriesAsyncBatcher(final LLMClient.EmbeddingSettings settings) {
        final FutureProgressState fps = FutureProgress.getState();
        final boolean useCache = this.llmClient.getConnection() != null && this.llmClient.getConnection().getLLMConnectionParams() != null && this.llmClient.getConnection().getLLMConnectionParams().embeddingsCachingEnabled;
        int batchSize = this.llmClient.getBatchSize(this.getEmbeddingsQueryType(), this.llmRef);
        logger.debug((Object)("Using batchSize=" + batchSize + " to process embedding queries"));
        return new BatchProcessor<SingleEmbeddingQueryWithTrace, LLMClient.SimpleEmbeddingResponseOrError>(this.llmPool, batchSize, (Executor)this.guardrailsPool, Runnable::run){

            /*
             * Unable to fully structure code
             */
            @Override
            protected Optional<LLMClient.SimpleEmbeddingResponseOrError> preProcessQueryOrBypassProcessing(SingleEmbeddingQueryWithTrace query) {
                query.globalSpan.name = "DKU_LLM_MESH_EMBEDDING_QUERY";
                query.globalSpan.start();
                LLMTracingUtils.addLLMIdentifiers(query.globalSpan, ParallelLLMClient.this.llmRef.id, ParallelLLMClient.this.llmClient);
                LLMTracingUtils.setEmbeddingInput(query.globalSpan, query.query, settings);
                try {
                    guardrailSpan = query.globalSpan.withChildSpan("DKU_LLM_MESH_QUERY_ENFORCEMENT");
                    try {
                        guardrailResponse = ParallelLLMClient.this.guardrailsPipelineRunner.processEmbeddingQuery(query.guardrailContext, query.query, settings, guardrailSpan);
                        switch (5.$SwitchMap$com$dataiku$dip$llm$governance$GuardrailRunner$QueryGuardrailAction[guardrailResponse.action.ordinal()]) {
                            case 1: {
                                throw guardrailResponse.toException();
                            }
                            case 2: {
                                GuardrailsPipelineUtils.updateEmbeddingQueryFromGuardrailsResponse(query.query, guardrailResponse);
                                ** break;
lbl15:
                                // 1 sources

                                break;
                            }
                            case 3: {
                                GuardrailsPipelineUtils.updateEmbeddingQueryFromGuardrailsResponse(query.query, guardrailResponse);
                                query.guardrailAuditData = guardrailResponse.auditData;
                                break;
                            }
                            ** default:
lbl21:
                            // 1 sources

                            break;
                        }
                    }
                    finally {
                        if (guardrailSpan != null) {
                            guardrailSpan.close();
                        }
                    }
                    cacheResult = null;
                    if (useCache) {
                        cacheResult = ParallelLLMClient.this.cacheService.get(ParallelLLMClient.this.authCtx, ParallelLLMClient.this.llmRef.id, ParallelLLMClient.this.contextProjectKey, query.query, settings);
                    }
                    cached = null;
                    if (cacheResult != null && cacheResult.cacheHit && ((LLMClient.SimpleEmbeddingResponseOrError)cacheResult.result).ok) {
                        ParallelLLMClient.logger.debug((Object)"Got cache hit");
                        query.globalSpan.withChildEvent("DKU_LLM_MESH_CACHE_HIT");
                        ParallelLLMClient.this.cacheHitQueries.incrementAndGet();
                        cached = (LLMClient.SimpleEmbeddingResponseOrError)cacheResult.result;
                        cached.fromCache = true;
                        cached.estimatedCost = 0.0;
                    } else {
                        ParallelLLMClient.this.cacheMissQueries.incrementAndGet();
                    }
                    return Optional.ofNullable(cached);
                }
                catch (Exception e) {
                    ParallelLLMClient.logger.warn((Object)"Got error from LLM preprocessing while retrieving batch embedding", (Throwable)e);
                    return Optional.of(LLMClient.SimpleEmbeddingResponseOrError.fromError(e));
                }
            }

            /*
             * WARNING - Removed try catching itself - possible behaviour change.
             */
            @Override
            protected List<LLMClient.SimpleEmbeddingResponseOrError> processBatch(List<SingleEmbeddingQueryWithTrace> queries) {
                logger.info((Object)("Sending a batch of " + queries.size() + " embedding queries to LLM " + ParallelLLMClient.this.llmRef.id));
                List<LLMClient.SimpleEmbeddingResponseOrError> resp = null;
                try {
                    List<LLMClient.EmbeddingQuery> baseQueries = queries.stream().map(q -> q.query).collect(Collectors.toList());
                    LLMClient.LLMMeshTraceSpan span = LLMClient.LLMMeshTraceSpan.start("DKU_LLM_MESH_LLM_CALL");
                    LLMTracingUtils.addLLMIdentifiers(span, ParallelLLMClient.this.llmRef.id, ParallelLLMClient.this.llmClient);
                    try {
                        List<LLMClient.SimpleEmbeddingResponse> rawResponses = ParallelLLMClient.this.llmClient.embedBatch(baseQueries, settings);
                        logger.info((Object)("Got successful embeddings batch answer from LLM (size=" + rawResponses.size() + ")"));
                        resp = rawResponses.stream().map(LLMClient.SimpleEmbeddingResponseOrError::fromSuccess).collect(Collectors.toList());
                        span.close();
                    }
                    catch (Throwable throwable) {
                        span.close();
                        for (int i = 0; i < queries.size(); ++i) {
                            SingleEmbeddingQueryWithTrace qwt = queries.get(i);
                            LLMClient.LLMMeshTraceSpan qwtSpan = (LLMClient.LLMMeshTraceSpan)JSON.deepCopy((Object)span);
                            qwt.globalSpan.addObservation(qwtSpan);
                            if (resp == null || resp.size() <= i) continue;
                            LLMClient.SimpleEmbeddingResponseOrError r = resp.get(i);
                            if (queries.size() == 1) {
                                assert (i == 0);
                                LLMTracingUtils.setEmbeddingInput(qwtSpan, qwt.query, settings);
                                LLMTracingUtils.addUsageMetadataAndSetEmbeddingOutput(qwtSpan, r);
                            }
                            if (r.trace == null) continue;
                            qwtSpan.addObservation(r.trace);
                        }
                        throw throwable;
                    }
                    for (int i = 0; i < queries.size(); ++i) {
                        SingleEmbeddingQueryWithTrace qwt = queries.get(i);
                        LLMClient.LLMMeshTraceSpan qwtSpan = (LLMClient.LLMMeshTraceSpan)JSON.deepCopy((Object)span);
                        qwt.globalSpan.addObservation(qwtSpan);
                        if (resp == null || resp.size() <= i) continue;
                        LLMClient.SimpleEmbeddingResponseOrError r = resp.get(i);
                        if (queries.size() == 1) {
                            assert (i == 0);
                            LLMTracingUtils.setEmbeddingInput(qwtSpan, qwt.query, settings);
                            LLMTracingUtils.addUsageMetadataAndSetEmbeddingOutput(qwtSpan, r);
                        }
                        if (r.trace == null) continue;
                        qwtSpan.addObservation(r.trace);
                    }
                }
                catch (Exception e) {
                    logger.warn((Object)"Got error from LLM while retrieving batch embedding", (Throwable)e);
                    resp = Collections.nCopies(queries.size(), LLMClient.SimpleEmbeddingResponseOrError.fromError(e));
                }
                return resp;
            }

            @Override
            protected LLMClient.SimpleEmbeddingResponseOrError postProcess(SingleEmbeddingQueryWithTrace query, LLMClient.SimpleEmbeddingResponseOrError result, boolean processingWasBypassed) {
                try {
                    if (result.ok && useCache && !processingWasBypassed) {
                        ParallelLLMClient.this.cacheService.put(ParallelLLMClient.this.authCtx, ParallelLLMClient.this.llmRef.id, ParallelLLMClient.this.contextProjectKey, query.query, settings, result);
                    }
                    result.guardrailsAuditData = query.guardrailAuditData;
                    LLMTracingUtils.setEmbeddingOutput(query.globalSpan, result);
                    query.globalSpan.close();
                    fps.increment(1.0);
                }
                catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
                return result;
            }
        };
    }

    private AbstractLLMConnection.QueryType getEmbeddingsQueryType() {
        AbstractLLMConnection connection = this.llmClient.getConnection();
        LLMModelHandle.Model model = null;
        try {
            if (connection != null) {
                model = connection.getLLMModel(this.llmRef).getModel();
            }
        }
        catch (Exception e) {
            throw new RuntimeException("Unexpected exception while getting model for llmId: " + String.valueOf(this.llmRef), e);
        }
        if (model == null || model.canBeUsedForPurpose(AbstractLLMConnection.LLMUsagePurpose.TEXT_EMBEDDING_EXTRACTION)) {
            return AbstractLLMConnection.QueryType.textEmbedding;
        }
        if (model.canBeUsedForPurpose(AbstractLLMConnection.LLMUsagePurpose.IMAGE_EMBEDDING_EXTRACTION)) {
            return AbstractLLMConnection.QueryType.imageEmbedding;
        }
        throw new IllegalArgumentException("This model doesn't support embeddings: " + model.getId());
    }

    @Override
    public List<LLMClient.SingleRerankingResponseOrError> rerankQueries(List<LLMClient.RerankingQuery> queries, LLMClient.RerankingSettings settings) throws Exception {
        logger.info((Object)("Submitting " + queries.size() + " rerankings to LLM " + String.valueOf(this.llmClient)));
        List singleQueries = queries.stream().map(SingleRerankingQueryWithTrace::new).collect(Collectors.toList());
        List<CompletableFuture<LLMClient.SingleRerankingResponseOrError>> futures = this.rerankQueriesAsyncBatcher(settings).submitAllAndFlush(singleQueries);
        logger.info((Object)"All queries submitted");
        List responses = DKUCompletableFuture.collectResponses(futures);
        logger.info((Object)("Done receiving the reranking responses. (" + responses.size() + " texts)"));
        return responses;
    }

    private BatchProcessor<SingleRerankingQueryWithTrace, LLMClient.SingleRerankingResponseOrError> rerankQueriesAsyncBatcher(final LLMClient.RerankingSettings settings) {
        final FutureProgressState fps = FutureProgress.getState();
        final boolean useCache = this.llmClient.getConnection() != null && this.llmClient.getConnection().getLLMConnectionParams() != null && this.llmClient.getConnection().getLLMConnectionParams().rerankingsCachingEnabled;
        int batchSize = this.llmClient.getBatchSize(AbstractLLMConnection.QueryType.reranking, this.llmRef);
        logger.debug((Object)("Using batchSize=" + batchSize + " to process reranking queries"));
        return new BatchProcessor<SingleRerankingQueryWithTrace, LLMClient.SingleRerankingResponseOrError>(this.llmPool, batchSize, null, null){

            @Override
            protected Optional<LLMClient.SingleRerankingResponseOrError> preProcessQueryOrBypassProcessing(SingleRerankingQueryWithTrace query) {
                query.globalSpan.name = "DKU_LLM_MESH_RERANKING_QUERY";
                query.globalSpan.start();
                LLMTracingUtils.addLLMIdentifiers(query.globalSpan, ParallelLLMClient.this.llmRef.id, ParallelLLMClient.this.llmClient);
                LLMTracingUtils.setRerankingInput(query.globalSpan, query.query, settings);
                try {
                    ILLMCacheService.QueryCacheResult<LLMClient.SingleRerankingResponseOrError> cacheResult = null;
                    if (useCache) {
                        cacheResult = ParallelLLMClient.this.cacheService.get(ParallelLLMClient.this.authCtx, ParallelLLMClient.this.llmRef.id, ParallelLLMClient.this.contextProjectKey, query.query, settings);
                    }
                    LLMClient.SingleRerankingResponseOrError cached = null;
                    if (cacheResult != null && cacheResult.cacheHit && ((LLMClient.SingleRerankingResponseOrError)cacheResult.result).ok) {
                        logger.debug((Object)"Got cache hit");
                        query.globalSpan.withChildEvent("DKU_LLM_MESH_CACHE_HIT");
                        ParallelLLMClient.this.cacheHitQueries.incrementAndGet();
                        cached = (LLMClient.SingleRerankingResponseOrError)cacheResult.result;
                        cached.fromCache = true;
                        cached.estimatedCost = 0.0;
                    } else {
                        ParallelLLMClient.this.cacheMissQueries.incrementAndGet();
                    }
                    return Optional.ofNullable(cached);
                }
                catch (Exception e) {
                    logger.warn((Object)"Got error from LLM preprocessing while retrieving batch reranking", (Throwable)e);
                    return Optional.of(LLMClient.SingleRerankingResponseOrError.fromError(e));
                }
            }

            /*
             * WARNING - Removed try catching itself - possible behaviour change.
             */
            @Override
            protected List<LLMClient.SingleRerankingResponseOrError> processBatch(List<SingleRerankingQueryWithTrace> queries) {
                logger.info((Object)("Sending a batch of " + queries.size() + " reranking queries to LLM"));
                List<LLMClient.SingleRerankingResponseOrError> resp = null;
                try {
                    List<LLMClient.RerankingQuery> baseQueries = queries.stream().map(q -> q.query).toList();
                    List<LLMClient.SingleRerankingResponse> rawResponses = ParallelLLMClient.this.llmClient.rerankBatch(baseQueries, settings);
                    logger.info((Object)"Got successful completion batch answer from LLM");
                    resp = rawResponses.stream().map(LLMClient.SingleRerankingResponseOrError::fromSuccess).toList();
                }
                catch (Exception e) {
                    logger.warn((Object)"Got error from LLM while retrieving batch reranking", (Throwable)e);
                    resp = Collections.nCopies(queries.size(), LLMClient.SingleRerankingResponseOrError.fromError(e));
                }
                finally {
                    try {
                        fps.increment(1.0);
                    }
                    catch (InterruptedException e) {
                        Thread.currentThread().interrupt();
                    }
                }
                return resp;
            }

            @Override
            protected LLMClient.SingleRerankingResponseOrError postProcess(SingleRerankingQueryWithTrace query, LLMClient.SingleRerankingResponseOrError result, boolean processingWasBypassed) {
                try {
                    if (result.ok && useCache && !processingWasBypassed) {
                        ParallelLLMClient.this.cacheService.put(ParallelLLMClient.this.authCtx, ParallelLLMClient.this.llmRef.id, ParallelLLMClient.this.contextProjectKey, query.query, settings, result);
                    }
                    LLMTracingUtils.setRerankingOutput(query.globalSpan, result);
                    query.globalSpan.close();
                    fps.increment(1.0);
                }
                catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                }
                return result;
            }
        };
    }

    @Override
    public EnrichedLLMStructuredRef getEnrichedRef() {
        return this.llmRef;
    }

    @Override
    public AbstractLLMConnection getConnection() {
        return this.llmClient.getConnection();
    }

    @Override
    public void close() throws Exception {
        try {
            this.llmPool.shutdown();
            this.guardrailsPool.shutdown();
        }
        finally {
            try {
                this.guardrailsPipelineRunner.close();
            }
            finally {
                this.llmClient.close();
            }
        }
    }

    public static class SingleCompletionQueryWithTrace {
        LLMClient.SingleCompletionQuery query;
        LLMClient.LLMMeshTraceSpan globalSpan = LLMClient.LLMMeshTraceSpan.newNotStarted();
        GuardrailRunner.GuardrailContext guardrailContext = new GuardrailRunner.GuardrailContext();
        JsonArray guardrailAuditData;

        public SingleCompletionQueryWithTrace(LLMClient.SingleCompletionQuery query) {
            this.query = query;
        }
    }

    public static class SingleEmbeddingQueryWithTrace {
        LLMClient.EmbeddingQuery query;
        GuardrailRunner.GuardrailContext guardrailContext = new GuardrailRunner.GuardrailContext();
        LLMClient.LLMMeshTraceSpan globalSpan = LLMClient.LLMMeshTraceSpan.newNotStarted();
        JsonArray guardrailAuditData;

        public SingleEmbeddingQueryWithTrace(LLMClient.EmbeddingQuery query) {
            this.query = query;
        }
    }

    public static class BatchProcessorCompletionsStreamer
    extends CompletionRecipeLLMMeshClient.CompletionsStreamer {
        private final BatchProcessor<SingleCompletionQueryWithTrace, LLMClient.SimpleCompletionResponseOrError> batchProcessor;

        public BatchProcessorCompletionsStreamer(BatchProcessor<SingleCompletionQueryWithTrace, LLMClient.SimpleCompletionResponseOrError> batchProcessor, int incomingQueueSize, int outgoingQueueSize) {
            super((QueryProcessor<SingleCompletionQueryWithTrace, LLMClient.SimpleCompletionResponseOrError>)batchProcessor, incomingQueueSize, outgoingQueueSize);
            logger.infoV("Using incomingQueueSize=%d, outgoingQueueSize=%d", new Object[]{incomingQueueSize, outgoingQueueSize});
            this.batchProcessor = batchProcessor;
        }

        @Override
        public CompletableFuture<Void> done() {
            return this.batchProcessor.flush().thenCompose(nb -> super.done());
        }
    }

    public static class SingleRerankingQuery {
        LLMClient.RerankingQuery query;

        public SingleRerankingQuery(LLMClient.RerankingQuery query) {
            this.query = query;
        }
    }

    public static class SingleRerankingQueryWithTrace {
        LLMClient.RerankingQuery query;
        LLMClient.LLMMeshTraceSpan globalSpan = LLMClient.LLMMeshTraceSpan.newNotStarted();

        public SingleRerankingQueryWithTrace(LLMClient.RerankingQuery query) {
            this.query = query;
        }
    }
}

