/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online;

import com.dataiku.dip.DKUApp;
import com.dataiku.dip.cluster.ClusterSelector;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.llm.LLMModelHandle;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.online.LLMRateLimitersSpecsService;
import com.dataiku.dip.llm.online.LLMRateLimitingSettingsService;
import com.dataiku.dip.logging.MainLoggingConfigurator;
import com.dataiku.dip.server.services.IPubSubService;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dss.shadelib.com.google.api.client.util.Sleeper;
import com.dataiku.dss.shadelib.com.google.common.annotations.VisibleForTesting;
import com.dataiku.dss.shadelib.com.google.common.cache.CacheBuilder;
import com.dataiku.dss.shadelib.com.google.common.cache.CacheLoader;
import com.dataiku.dss.shadelib.com.google.common.cache.LoadingCache;
import dev.failsafe.RateLimitExceededException;
import dev.failsafe.RateLimiter;
import java.time.Duration;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import javax.annotation.PostConstruct;
import org.apache.commons.lang.time.DurationFormatUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class LLMRateLimitingRunnerService {
    @Autowired
    private LLMRateLimitingSettingsService llmRateLimitingSettingsService;
    @Autowired
    private LLMRateLimitersSpecsService llmRateLimitersSpecsService;
    @Autowired
    private IPubSubService pubSubService;
    @VisibleForTesting
    protected Sleeper sleeper = Sleeper.DEFAULT;
    private boolean shouldLogBackendCentralisationWarning;
    private final LoadingCache<LLMRateLimiterContext, List<RateLimiter<Object>>> applicableRateLimiterCache = CacheBuilder.newBuilder().expireAfterAccess(1L, TimeUnit.HOURS).build((CacheLoader)new CacheLoader<LLMRateLimiterContext, List<RateLimiter<Object>>>(){

        public List<RateLimiter<Object>> load(LLMRateLimiterContext context) {
            return LLMRateLimitingRunnerService.this.resolveContext(context);
        }
    });
    private final ConcurrentHashMap<LLMRateLimitersSpecsService.LLMRateLimiterSpec, RateLimiter<Object>> rateLimiters = new ConcurrentHashMap();
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.ratelimiting");

    @PostConstruct
    public void init() {
        this.pubSubService.subscribe("rate-limiters-specs-changed", evt -> {
            if (!evt.deprecatedRateLimiters.isEmpty()) {
                for (LLMRateLimitersSpecsService.LLMRateLimiterSpec rateLimiterSpec : evt.deprecatedRateLimiters) {
                    logger.info((Object)("Removing deprecated rate limiter: " + String.valueOf(rateLimiterSpec)));
                    this.rateLimiters.remove(rateLimiterSpec);
                }
                this.applicableRateLimiterCache.invalidateAll();
            }
        });
        this.shouldLogBackendCentralisationWarning = ClusterSelector.getContext() != MainLoggingConfigurator.ProcessType.BACKEND;
    }

    private RateLimiter<Object> getRateLimiterFromSpec(LLMRateLimitersSpecsService.LLMRateLimiterSpec spec) {
        return this.rateLimiters.computeIfAbsent(spec, cfg -> {
            logger.info((Object)("Initializing new rate limiter for spec: " + String.valueOf(spec)));
            if (this.llmRateLimitingSettingsService.isSmoothProvider(cfg.provider)) {
                return RateLimiter.smoothBuilder((long)cfg.maxExecutions, (Duration)cfg.period).build();
            }
            return RateLimiter.burstyBuilder((long)cfg.maxExecutions, (Duration)cfg.period).build();
        });
    }

    public static boolean isFeatureEnabled() {
        return DKUApp.getParams().getBoolParam("dku.llm.rateLimiting.enabled", true);
    }

    private List<RateLimiter<Object>> resolveContext(LLMRateLimiterContext context) {
        return this.llmRateLimitersSpecsService.getApplicableRateLimiterSpecs(context).stream().map(this::getRateLimiterFromSpec).collect(Collectors.toList());
    }

    public <R> R run(LLMRateLimiterContext queryContext, Duration maxDelay, Callable<R> callable) throws Exception {
        return this.run(queryContext, 1, maxDelay, callable);
    }

    public <R> R run(LLMRateLimiterContext queryContext, int nbPermits, Duration maxDelay, Callable<R> callable) throws Exception {
        List rateLimiters;
        if (this.shouldLogBackendCentralisationWarning) {
            this.logBackendCentralisationWarningIdNeeded();
        }
        if ((rateLimiters = (List)this.applicableRateLimiterCache.getUnchecked((Object)queryContext)).isEmpty()) {
            return callable.call();
        }
        assert (rateLimiters.size() == 1);
        Duration delay = ((RateLimiter)rateLimiters.get(0)).reservePermits(nbPermits);
        if (delay.compareTo(maxDelay) > 0) {
            logger.warn((Object)("rate limit exceeded (requested_delay: " + LLMRateLimitingRunnerService.formatDelay(delay) + " > max_delay: " + LLMRateLimitingRunnerService.formatDelay(maxDelay) + ")"));
            throw new RateLimitExceededException((RateLimiter)rateLimiters.get(0));
        }
        if (delay.toMillis() > 0L) {
            logger.info((Object)("waiting for " + LLMRateLimitingRunnerService.formatDelay(delay) + " before running query"));
            this.sleeper.sleep(delay.toMillis());
        }
        return callable.call();
    }

    @VisibleForTesting
    static String formatDelay(Duration delay) {
        Object readableDelay = DurationFormatUtils.formatDurationWords((long)delay.toMillis(), (boolean)true, (boolean)true);
        if (((String)readableDelay).equals("0 seconds")) {
            readableDelay = "0." + String.format("%03d", delay.toMillis()) + " seconds";
        }
        return readableDelay;
    }

    private synchronized void logBackendCentralisationWarningIdNeeded() {
        if (this.shouldLogBackendCentralisationWarning) {
            logger.warn((Object)"The LLM backend centralisation seems disabled. The rate limiting will do best effort, only taking in account your current job queries.");
            this.shouldLogBackendCentralisationWarning = false;
        }
    }

    public static class LLMRateLimiterContext {
        public final String provider;
        public final String model;
        public final RateLimitingPurpose purpose;

        public LLMRateLimiterContext(String providerId, LLMStructuredRef modelRef, LLMModelHandle.Model model) {
            this(providerId, modelRef.getModelNameForAudit(), LLMRateLimiterContext.getModelPurpose(model));
        }

        public LLMRateLimiterContext(String providerId, String modelName, RateLimitingPurpose purpose) {
            this.provider = providerId;
            this.model = modelName;
            this.purpose = purpose;
        }

        private static RateLimitingPurpose getModelPurpose(LLMModelHandle.Model model) {
            if (model.canBeUsedForPurpose(AbstractLLMConnection.LLMUsagePurpose.IMAGE_GENERATION)) {
                return RateLimitingPurpose.IMAGE_GENERATION;
            }
            if (model.canBeUsedForPurpose(AbstractLLMConnection.LLMUsagePurpose.TEXT_EMBEDDING_EXTRACTION) || model.canBeUsedForPurpose(AbstractLLMConnection.LLMUsagePurpose.IMAGE_EMBEDDING_EXTRACTION)) {
                return RateLimitingPurpose.EMBEDDING_EXTRACTION;
            }
            return RateLimitingPurpose.GENERIC_COMPLETION;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            LLMRateLimiterContext that = (LLMRateLimiterContext)o;
            return Objects.equals(this.provider, that.provider) && Objects.equals(this.model, that.model) && Objects.equals((Object)this.purpose, (Object)that.purpose);
        }

        public String toString() {
            return "LLMRateLimiterContext{provider=" + this.provider + ", model=" + this.model + ", purpose=" + String.valueOf((Object)this.purpose) + "}";
        }

        public int hashCode() {
            return Objects.hash(new Object[]{this.provider, this.model, this.purpose});
        }
    }

    public static enum RateLimitingPurpose {
        GENERIC_COMPLETION,
        EMBEDDING_EXTRACTION,
        IMAGE_GENERATION;

    }
}

