/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.utils;

import com.dataiku.dip.llm.online.utils.QueryProcessor;
import com.dataiku.dip.utils.DKULogger;
import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;

public class BatchBufferProcessor<Q, R>
implements QueryProcessor<Q, R> {
    private final BlockingQueue<BufferedTask> queriesBuffer = new LinkedBlockingQueue<BufferedTask>();
    private final ExecutorService executorService;
    private final int bufferSize;
    private final Function<List<Q>, List<R>> processBatchFunction;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.online.utils.bbp");

    public BatchBufferProcessor(int bufferSize, @Nonnull ExecutorService executorService, @Nonnull Function<List<Q>, List<R>> processBatchFunction) {
        this.bufferSize = bufferSize;
        this.executorService = (ExecutorService)Preconditions.checkNotNull((Object)executorService, (Object)"Mandatory executor service");
        this.processBatchFunction = (Function)Preconditions.checkNotNull(processBatchFunction, processBatchFunction);
    }

    @Override
    public synchronized CompletableFuture<R> submit(Q query) {
        CompletableFuture completableFuture = new CompletableFuture();
        this.queriesBuffer.add(new BufferedTask(query, completableFuture));
        this.submitBatchIfComplete();
        return completableFuture;
    }

    private int submitRemainingQueriesBatch() {
        assert (this.queriesBuffer.size() <= this.bufferSize);
        ArrayList queries = new ArrayList(this.bufferSize);
        if (this.queriesBuffer.drainTo(queries, this.bufferSize) > 0) {
            this.submitBatch(new ArrayList<BufferedTask>(queries));
        }
        return queries.size();
    }

    public synchronized int flush() {
        return this.submitRemainingQueriesBatch();
    }

    private void submitBatchIfComplete() {
        if (this.queriesBuffer.size() < this.bufferSize) {
            return;
        }
        ArrayList<BufferedTask> queries = new ArrayList<BufferedTask>(this.bufferSize);
        this.queriesBuffer.drainTo(queries, this.bufferSize);
        this.submitBatch(queries);
    }

    private void submitBatch(List<BufferedTask> bufferedTasks) {
        if (bufferedTasks.isEmpty()) {
            return;
        }
        this.executorService.submit(() -> {
            try {
                List queries = bufferedTasks.stream().map(bq -> bq.query).collect(Collectors.toList());
                List<R> results = this.processBatchFunction.apply(queries);
                this.completeFuturesWithResults(bufferedTasks, results);
            }
            catch (Throwable e) {
                this.completeFuturesWithError(bufferedTasks, e);
            }
        });
    }

    private void completeFuturesWithResults(List<BufferedTask> bufferedTasks, List<R> results) {
        logger.info((Object)("Batch of " + bufferedTasks.size() + " completed"));
        for (int i = 0; i < bufferedTasks.size(); ++i) {
            BufferedTask query = bufferedTasks.get(i);
            R result = results.get(i);
            query.future.complete(result);
        }
    }

    private void completeFuturesWithError(List<BufferedTask> bufferedTasks, Throwable e) {
        logger.warn((Object)("Batch of " + bufferedTasks.size() + " completed in error"), e);
        for (BufferedTask task : bufferedTasks) {
            task.future.completeExceptionally(e);
        }
    }

    class BufferedTask {
        private final Q query;
        private final CompletableFuture<R> future;

        BufferedTask(Q query, CompletableFuture<R> future) {
            this.query = query;
            this.future = future;
        }
    }
}

