/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online;

import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.llm.online.LLMRateLimitingRunnerService;
import com.dataiku.dip.llm.online.LLMRateLimitingSettingsService;
import com.dataiku.dip.server.notifications.DSSEvent;
import com.dataiku.dip.server.notifications.backend.LLMRateLimitersSpecsChanged;
import com.dataiku.dip.server.services.IPubSubService;
import com.dataiku.dip.utils.Pair;
import com.dataiku.dss.shadelib.com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import javax.annotation.PostConstruct;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class LLMRateLimitersSpecsService {
    @Autowired
    private LLMRateLimitingSettingsService llmRateLimitingSettingsService;
    @Autowired
    private IPubSubService pubSubService;
    private HashSet<LLMRateLimiterSpec> perPurposeSpecs = new HashSet();
    private HashSet<LLMRateLimiterSpec> perModelSpecs = new HashSet();
    private final Map<LLMRateLimiterSpec, HashSet<LLMRateLimiterSpec>> derivedSpecsByPerPurposeSpec = new HashMap<LLMRateLimiterSpec, HashSet<LLMRateLimiterSpec>>();

    @VisibleForTesting
    Map<LLMRateLimiterSpec, HashSet<LLMRateLimiterSpec>> getDerivedRateLimiterSpecsByPerPurposeSpec() {
        return this.derivedSpecsByPerPurposeSpec;
    }

    @PostConstruct
    public void init() throws IOException {
        this.updateRateLimiterSpecs();
        this.pubSubService.subscribe("rate-limiting-settings-changed", evt -> this.updateRateLimiterSpecs());
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @VisibleForTesting
    void updateRateLimiterSpecs() {
        List<LLMRateLimiterSpec> deprecatedSpecs;
        Map<String, GeneralSettingsDAO.RateLimitingProviderSettings> applicableSettings = this.llmRateLimitingSettingsService.getApplicableSettings();
        HashSet<LLMRateLimiterSpec> newPerPurposeSpecs = new HashSet<LLMRateLimiterSpec>();
        HashSet<LLMRateLimiterSpec> newPerModelSpecs = new HashSet<LLMRateLimiterSpec>();
        for (String providerId : applicableSettings.keySet()) {
            GeneralSettingsDAO.RateLimitingProviderSettings applicableProviderSettings = applicableSettings.get(providerId);
            if (applicableProviderSettings.completionDefault != null) {
                newPerPurposeSpecs.add(LLMRateLimiterSpec.of(providerId, LLMRateLimitingRunnerService.RateLimitingPurpose.GENERIC_COMPLETION, applicableProviderSettings.completionDefault));
            }
            if (applicableProviderSettings.embeddingDefault != null) {
                newPerPurposeSpecs.add(LLMRateLimiterSpec.of(providerId, LLMRateLimitingRunnerService.RateLimitingPurpose.EMBEDDING_EXTRACTION, applicableProviderSettings.embeddingDefault));
            }
            if (applicableProviderSettings.imageGenerationDefault != null) {
                newPerPurposeSpecs.add(LLMRateLimiterSpec.of(providerId, LLMRateLimitingRunnerService.RateLimitingPurpose.IMAGE_GENERATION, applicableProviderSettings.imageGenerationDefault));
            }
            for (Map.Entry<String, GeneralSettingsDAO.RateLimitingConfig> modelEntry : applicableProviderSettings.perModelConfigs.entrySet()) {
                newPerModelSpecs.add(LLMRateLimiterSpec.of(providerId, modelEntry.getKey(), modelEntry.getValue()));
            }
        }
        LLMRateLimitersSpecsService lLMRateLimitersSpecsService = this;
        synchronized (lLMRateLimitersSpecsService) {
            deprecatedSpecs = this.getDeprecatedSpecs(newPerPurposeSpecs, newPerModelSpecs);
            this.cleanDeprecatedDerivedSpecsEntries(newPerPurposeSpecs);
            this.perPurposeSpecs = newPerPurposeSpecs;
            this.perModelSpecs = newPerModelSpecs;
        }
        if (!deprecatedSpecs.isEmpty()) {
            this.pubSubService.publish((DSSEvent)new LLMRateLimitersSpecsChanged(deprecatedSpecs));
        }
    }

    private synchronized List<LLMRateLimiterSpec> getDeprecatedSpecs(HashSet<LLMRateLimiterSpec> newPerPurposeSpecs, HashSet<LLMRateLimiterSpec> newPerModelSpecs) {
        HashSet<LLMRateLimiterSpec> deprecatedPerPurpose = new HashSet<LLMRateLimiterSpec>(this.perPurposeSpecs);
        HashSet<LLMRateLimiterSpec> deprecatedPerModel = new HashSet<LLMRateLimiterSpec>(this.perModelSpecs);
        deprecatedPerPurpose.removeAll(newPerPurposeSpecs);
        deprecatedPerModel.removeAll(newPerModelSpecs);
        ArrayList<LLMRateLimiterSpec> deprecatedSpecs = new ArrayList<LLMRateLimiterSpec>();
        for (LLMRateLimiterSpec perPurposeSpec : deprecatedPerPurpose) {
            deprecatedSpecs.addAll(this.derivedSpecsByPerPurposeSpec.getOrDefault(perPurposeSpec, new HashSet()));
        }
        deprecatedSpecs.addAll(deprecatedPerModel);
        deprecatedSpecs.addAll(this.getDeprecatedDerivedSpecs(newPerModelSpecs));
        return deprecatedSpecs;
    }

    private synchronized void cleanDeprecatedDerivedSpecsEntries(HashSet<LLMRateLimiterSpec> newPerPurposeSpecs) {
        HashSet<LLMRateLimiterSpec> deprecatedPerPurpose = new HashSet<LLMRateLimiterSpec>(this.perPurposeSpecs);
        deprecatedPerPurpose.removeAll(newPerPurposeSpecs);
        for (LLMRateLimiterSpec deprecatedPerPurposeSpec : deprecatedPerPurpose) {
            this.derivedSpecsByPerPurposeSpec.remove(deprecatedPerPurposeSpec);
        }
    }

    public synchronized List<LLMRateLimiterSpec> getApplicableRateLimiterSpecs(LLMRateLimitingRunnerService.LLMRateLimiterContext context) {
        List<LLMRateLimiterSpec> applicableRateLimiters = LLMRateLimitersSpecsService.getMatchingRateLimiters(context, this.perModelSpecs);
        if (applicableRateLimiters.isEmpty()) {
            applicableRateLimiters = new ArrayList<LLMRateLimiterSpec>();
            for (LLMRateLimiterSpec provider : LLMRateLimitersSpecsService.getMatchingRateLimiters(context, this.perPurposeSpecs)) {
                applicableRateLimiters.add(this.getDerivedSpec(provider, context));
            }
        }
        return applicableRateLimiters;
    }

    private synchronized LLMRateLimiterSpec getDerivedSpec(LLMRateLimiterSpec perPurposeSpec, LLMRateLimitingRunnerService.LLMRateLimiterContext context) {
        LLMRateLimiterSpec derivedConfig = LLMRateLimiterSpec.getDerivedSpec(perPurposeSpec, context.model);
        this.derivedSpecsByPerPurposeSpec.computeIfAbsent(perPurposeSpec, config -> new HashSet()).add(derivedConfig);
        return derivedConfig;
    }

    private synchronized List<LLMRateLimiterSpec> getDeprecatedDerivedSpecs(HashSet<LLMRateLimiterSpec> newModelSpecs) {
        HashSet<Pair> newModels = new HashSet<Pair>();
        for (LLMRateLimiterSpec modelSpec : newModelSpecs) {
            newModels.add(new Pair((Object)modelSpec.provider, (Object)modelSpec.model));
        }
        ArrayList<LLMRateLimiterSpec> deprecatedSpecs = new ArrayList<LLMRateLimiterSpec>();
        for (HashSet<LLMRateLimiterSpec> derivedConfigSet : this.derivedSpecsByPerPurposeSpec.values()) {
            for (LLMRateLimiterSpec derivedSpec : derivedConfigSet) {
                if (!newModels.contains(new Pair((Object)derivedSpec.provider, (Object)derivedSpec.model)) || newModelSpecs.contains(derivedSpec)) continue;
                deprecatedSpecs.add(derivedSpec);
            }
        }
        return deprecatedSpecs;
    }

    private static List<LLMRateLimiterSpec> getMatchingRateLimiters(LLMRateLimitingRunnerService.LLMRateLimiterContext context, HashSet<LLMRateLimiterSpec> specs) {
        return specs.stream().filter(config -> LLMRateLimitersSpecsService.matches(context, config)).collect(Collectors.toList());
    }

    private static boolean matches(LLMRateLimitingRunnerService.LLMRateLimiterContext context, LLMRateLimiterSpec spec) {
        if (spec.provider.equals(context.provider) && spec.model == null) {
            return spec.purpose == context.purpose;
        }
        return spec.provider.equals(context.provider) && spec.model.equals(context.model);
    }

    public static class LLMRateLimiterSpec {
        public String provider;
        public String model;
        public long maxExecutions;
        public Duration period = Duration.ofMinutes(1L);
        public LLMRateLimitingRunnerService.RateLimitingPurpose purpose;

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            LLMRateLimiterSpec that = (LLMRateLimiterSpec)o;
            return Objects.equals(this.provider, that.provider) && Objects.equals(this.model, that.model) && this.maxExecutions == that.maxExecutions && Objects.equals(this.period, that.period) && this.purpose == that.purpose;
        }

        public String toString() {
            return "LLMRateLimiterSpec{provider=" + this.provider + (String)(this.model != null ? ", model=" + this.model : "") + (String)(this.purpose != null ? ", purpose=" + String.valueOf((Object)this.purpose) : "") + ", maxExecutions=" + this.maxExecutions + ", duration=" + String.valueOf(this.period) + "}";
        }

        public int hashCode() {
            return Objects.hash(new Object[]{this.provider, this.model, this.maxExecutions, this.period, this.purpose});
        }

        public static LLMRateLimiterSpec of(String provider, LLMRateLimitingRunnerService.RateLimitingPurpose purpose, GeneralSettingsDAO.RateLimitingConfig config) {
            return LLMRateLimiterSpec.of(provider, null, purpose, config);
        }

        public static LLMRateLimiterSpec of(String provider, String model, GeneralSettingsDAO.RateLimitingConfig config) {
            return LLMRateLimiterSpec.of(provider, model, null, config);
        }

        private static LLMRateLimiterSpec of(String provider, String model, LLMRateLimitingRunnerService.RateLimitingPurpose purpose, GeneralSettingsDAO.RateLimitingConfig config) {
            LLMRateLimiterSpec llmSpec = new LLMRateLimiterSpec();
            llmSpec.provider = provider;
            llmSpec.model = model;
            llmSpec.purpose = purpose;
            llmSpec.maxExecutions = config.requestsPerMinute;
            return llmSpec;
        }

        public static LLMRateLimiterSpec getDerivedSpec(LLMRateLimiterSpec that, String model) {
            LLMRateLimiterSpec spec = new LLMRateLimiterSpec();
            spec.provider = that.provider;
            spec.maxExecutions = that.maxExecutions;
            spec.purpose = null;
            spec.model = model;
            return spec;
        }
    }
}

