/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online;

import com.dataiku.dip.DKUApp;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.llm.LLMModelHandle;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.online.LLMRateLimitingRunnerService;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.utils.DKULogger;
import dev.failsafe.Failsafe;
import dev.failsafe.FailsafeException;
import dev.failsafe.Policy;
import dev.failsafe.RetryPolicy;
import dev.failsafe.RetryPolicyBuilder;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.concurrent.Callable;
import java.util.function.Predicate;
import org.springframework.beans.factory.annotation.Autowired;

public class LLMQueryRunner {
    @Autowired
    private LLMRateLimitingRunnerService rateLimitingService;
    private final LLMRateLimitingRunnerService.LLMRateLimiterContext context;
    private final AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings;
    private final boolean rateLimitingEnabled;
    private final Predicate<Throwable> retryCondition;
    private static DKULogger logger = DKULogger.getLogger((String)"dku.llm.online.utils");

    public LLMQueryRunner(String providerId, LLMModelHandle modelHandle, AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings, Predicate<Throwable> retryCondition) {
        this(providerId, modelHandle.getRef(), (LLMModelHandle.Model)modelHandle.getModel(), networkSettings, retryCondition);
    }

    public LLMQueryRunner(String providerId, LLMStructuredRef modelRef, LLMModelHandle.Model model, AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings, Predicate<Throwable> retryCondition) {
        SpringUtils.getInstance().autowire((Object)this);
        this.context = new LLMRateLimitingRunnerService.LLMRateLimiterContext(providerId, modelRef, model);
        this.networkSettings = networkSettings;
        this.rateLimitingEnabled = LLMRateLimitingRunnerService.isFeatureEnabled();
        this.retryCondition = retryCondition;
    }

    public <R> R run(Callable<R> callable) throws Exception {
        if (!this.rateLimitingEnabled) {
            return callable.call();
        }
        RetryPolicy retryPolicy = LLMQueryRunner.buildRetryPolicy(this.networkSettings, this.retryCondition);
        try {
            Instant submitTime = Instant.now();
            return (R)Failsafe.with(retryPolicy, (Policy[])new RetryPolicy[0]).get(ctx -> this.rateLimitingService.run(this.context, LLMQueryRunner.computeRemainingDelay(submitTime, LLMQueryRunner.getDefaultMaxDelay()), callable));
        }
        catch (FailsafeException e) {
            if (e.getCause() != null) {
                if (e.getCause() instanceof Exception) {
                    throw (Exception)e.getCause();
                }
                throw new RuntimeException(e.getCause());
            }
            throw e;
        }
    }

    private static Duration computeRemainingDelay(Instant startTime, Duration maxDelay) {
        Duration elapsedDuration = Duration.between(startTime, Instant.now());
        return maxDelay.minus(elapsedDuration);
    }

    private static Duration getDefaultMaxDelay() {
        int maxDelay = DKUApp.getParams().getIntParam("dku.llm.rateLimiting.maxDelayInSeconds", Integer.valueOf(600));
        return Duration.ofSeconds(maxDelay);
    }

    public AbstractLLMConnection.HTTPBasedLLMNetworkSettings getHttpClientNetworkSettings() {
        if (this.rateLimitingEnabled) {
            return this.networkSettings.copyWithoutRetry();
        }
        return this.networkSettings;
    }

    private static <T> RetryPolicy<T> buildRetryPolicy(AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings, Predicate<Throwable> retryCondition) {
        RetryPolicyBuilder retryPolicyBuilder = ((RetryPolicyBuilder)RetryPolicy.builder().handleIf(e -> retryCondition.test((Throwable)e))).withMaxRetries(networkSettings.maxRetries).withBackoff(networkSettings.initialRetryDelayMS, Long.MAX_VALUE, ChronoUnit.MILLIS, networkSettings.retryDelayScalingFactor).withJitter(DKUApp.getParams().getDoubleParam("dku.llm.jitterFactor", 0.25)).onRetry(ctx -> {
            Throwable ex = ctx.getLastException();
            logger.infoV("Request to the LLM failed with %s. Retry count=%d/%d", new Object[]{ex, ctx.getAttemptCount(), networkSettings.maxRetries});
        });
        return retryPolicyBuilder.build();
    }
}

