/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online;

import com.dataiku.dip.analysis.ml.prediction.overrides.ReadOnlyColumnFactory;
import com.dataiku.dip.analysis.ml.prediction.overrides.ReadOnlyRowObservation;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.dataflow.exec.filter.FilterDescUtils;
import com.dataiku.dip.datalayer.ColumnFactory;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.expressions.Expression;
import com.dataiku.dip.llm.online.LLMCostLimitingService;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.notifications.backend.GeneralSettingsChangedEvent;
import com.dataiku.dip.server.services.IPubSubService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.server.services.licensing.LicenseEnforcementService;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelib.com.google.common.annotations.VisibleForTesting;
import com.dataiku.dss.shadelib.com.google.common.base.MoreObjects;
import com.dataiku.scoring.util.RawObservation;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import javax.annotation.PostConstruct;
import org.jetbrains.annotations.NotNull;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class LLMCostLimitingQuotasRepository {
    @Autowired
    private TransactionService transactionService;
    @Autowired
    private GeneralSettingsDAO generalSettingsDAO;
    @Autowired
    private IPubSubService pubSubService;
    private LicenseEnforcementService licenseEnforcementService;
    private List<GeneralSettingsDAO.CustomLLMCostLimitingQuota> customQuotas;
    private GeneralSettingsDAO.FallbackLLMCostLimitingQuota fallbackQuota;

    public LLMCostLimitingQuotasRepository(@Autowired LicenseEnforcementService licenseEnforcementService) {
        this.licenseEnforcementService = licenseEnforcementService;
    }

    @PostConstruct
    public void init() throws IOException {
        try (Transaction t = this.transactionService.beginRead();){
            this.updateSettings(this.generalSettingsDAO.read().generativeAISettings.costLimitingSettings);
        }
        this.pubSubService.subscribe("general-settings-changed", evt -> {
            boolean costLimitingSettingsChanged;
            GeneralSettingsChangedEvent settingsChangedEvent = (GeneralSettingsChangedEvent)evt;
            boolean bl = costLimitingSettingsChanged = !JSON.jsonEquals((Object)settingsChangedEvent.previousSettings.generativeAISettings.costLimitingSettings, (Object)settingsChangedEvent.newSettings.generativeAISettings.costLimitingSettings);
            if (costLimitingSettingsChanged) {
                this.updateSettings(settingsChangedEvent.newSettings.generativeAISettings.costLimitingSettings);
            }
        });
    }

    @VisibleForTesting
    synchronized void updateSettings(GeneralSettingsDAO.LLMCostLimitingSettings newSettings) {
        this.customQuotas = newSettings.quotas;
        this.fallbackQuota = newSettings.fallbackQuota;
    }

    @Nullable
    public synchronized GeneralSettingsDAO.LLMCostLimitingQuota getQuota(String quotaId) {
        if ("DKU-FALLBACK-QUOTA".equals(quotaId)) {
            return this.fallbackQuota;
        }
        return this.customQuotas.stream().filter(q -> Objects.equals(quotaId, q.getId())).findFirst().orElse(null);
    }

    public synchronized List<GeneralSettingsDAO.CustomLLMCostLimitingQuota> getCustomQuotas() {
        return this.customQuotas;
    }

    public synchronized List<GeneralSettingsDAO.LLMCostLimitingQuota> getAllQuotas() {
        ArrayList<GeneralSettingsDAO.LLMCostLimitingQuota> allQuotas = new ArrayList<GeneralSettingsDAO.LLMCostLimitingQuota>(this.customQuotas);
        allQuotas.add((GeneralSettingsDAO.CustomLLMCostLimitingQuota)((Object)this.fallbackQuota));
        return allQuotas;
    }

    public synchronized List<GeneralSettingsDAO.LLMCostLimitingQuota> getAuthorizedQuotas(AuthCtx authCtx) throws DKUSecurityException {
        GeneralSettingsDAO.LLMCostLimitingSettings authorizedSettings = this.getAuthorizedSettings(authCtx);
        ArrayList<GeneralSettingsDAO.LLMCostLimitingQuota> authorizedQuotas = new ArrayList<GeneralSettingsDAO.LLMCostLimitingQuota>(authorizedSettings.quotas);
        if (authorizedSettings.fallbackQuota != null) {
            authorizedQuotas.add(authorizedSettings.fallbackQuota);
        }
        return authorizedQuotas;
    }

    public synchronized Map<String, Double> getCostLimitByQuotaId() {
        return this.getAllQuotas().stream().collect(Collectors.toMap(quota -> quota.getId(), quota -> quota.costLimit));
    }

    private List<GeneralSettingsDAO.CustomLLMCostLimitingQuota> getMatchingCustomQuotas(LLMCostLimitingService.LLMCostLimitingContext context) {
        return this.getCustomQuotas().stream().filter(quota -> LLMCostLimitingQuotasRepository.matches(context, quota)).toList();
    }

    private boolean areCustomQuotaEnabled() {
        return this.licenseEnforcementService.getFeaturesStatus().advancedLLMMeshAllowed;
    }

    public List<GeneralSettingsDAO.LLMCostLimitingQuota> getAvailableQuotas() {
        ArrayList<GeneralSettingsDAO.LLMCostLimitingQuota> allQuotas = new ArrayList<GeneralSettingsDAO.LLMCostLimitingQuota>();
        if (this.areCustomQuotaEnabled()) {
            allQuotas.addAll(this.customQuotas);
        }
        allQuotas.add(this.fallbackQuota);
        return allQuotas;
    }

    public GeneralSettingsDAO.LLMCostLimitingSettings getAuthorizedSettings(AuthCtx authCtx) throws DKUSecurityException {
        GeneralSettingsDAO.LLMCostLimitingSettings authorizedSettings = new GeneralSettingsDAO.LLMCostLimitingSettings();
        Set groups = authCtx.getGroups();
        if (this.areCustomQuotaEnabled()) {
            authorizedSettings.quotas = this.customQuotas.stream().filter(quota -> LLMCostLimitingQuotasRepository.matchesQuotaPermission(authCtx.isAdmin(), authCtx.getAssociatedDSSUser(), groups, quota)).map(quota -> {
                if (LLMCostLimitingQuotasRepository.canReadFullData(authCtx.isAdmin(), authCtx.getAssociatedDSSUser(), groups, quota)) {
                    return quota;
                }
                return GeneralSettingsDAO.CustomLLMCostLimitingQuota.asBasicAccessQuota(quota);
            }).toList();
        }
        authorizedSettings.fallbackQuota = LLMCostLimitingQuotasRepository.matchesQuotaPermission(authCtx.isAdmin(), authCtx.getAssociatedDSSUser(), groups, this.fallbackQuota) ? (LLMCostLimitingQuotasRepository.canReadFullData(authCtx.isAdmin(), authCtx.getAssociatedDSSUser(), groups, this.fallbackQuota) ? this.fallbackQuota : GeneralSettingsDAO.FallbackLLMCostLimitingQuota.asBasicAccessQuota(this.fallbackQuota)) : null;
        return authorizedSettings;
    }

    public List<GeneralSettingsDAO.LLMCostLimitingQuota> getApplicableQuotas(LLMCostLimitingService.LLMCostLimitingContext context) {
        List<GeneralSettingsDAO.CustomLLMCostLimitingQuota> matchingQuotas = null;
        if (this.areCustomQuotaEnabled()) {
            matchingQuotas = this.getMatchingCustomQuotas(context);
        }
        if (matchingQuotas != null && !matchingQuotas.isEmpty()) {
            return new ArrayList<GeneralSettingsDAO.LLMCostLimitingQuota>(matchingQuotas);
        }
        return Arrays.asList(this.fallbackQuota);
    }

    private static boolean matches(LLMCostLimitingService.LLMCostLimitingContext context, GeneralSettingsDAO.CustomLLMCostLimitingQuota quota) {
        if (quota.filter != null && quota.filter.enabled) {
            Expression expression;
            RawObservation observation = new RawObservation(Map.of("projectKey", MoreObjects.firstNonNull((Object)context.projectKey, (Object)""), "project", MoreObjects.firstNonNull((Object)context.projectKey, (Object)""), "user", MoreObjects.firstNonNull((Object)context.user, (Object)""), "groups", List.copyOf(context.groups), "userLogin", MoreObjects.firstNonNull((Object)context.user, (Object)""), "provider", MoreObjects.firstNonNull((Object)context.provider, (Object)""), "connection", MoreObjects.firstNonNull((Object)context.connectionName, (Object)""), "connectionName", MoreObjects.firstNonNull((Object)context.connectionName, (Object)""), "llmId", MoreObjects.firstNonNull((Object)context.llmId, (Object)"")));
            try {
                expression = new Expression(FilterDescUtils.getGrelExpression(quota.filter));
                expression.setColumnFactory((ColumnFactory)new ReadOnlyColumnFactory(observation.keys().toArray(new String[0])));
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
            return expression.isTrueish(new ReadOnlyRowObservation(observation));
        }
        return true;
    }

    @NotNull
    private static Predicate<GeneralSettingsDAO.LLMCostLimitingPermission> getCostLimitingPermissionPredicate(String user, Set<String> groups) {
        return permission -> user.equals(permission.user) || permission.group != null && groups.contains(permission.group);
    }

    @VisibleForTesting
    static boolean matchesQuotaPermission(boolean isAdmin, String user, Set<String> groups, GeneralSettingsDAO.LLMCostLimitingQuota quota) {
        if (isAdmin) {
            return true;
        }
        if (quota.permissions == null) {
            return false;
        }
        return quota.permissions.stream().anyMatch(LLMCostLimitingQuotasRepository.getCostLimitingPermissionPredicate(user, groups));
    }

    @VisibleForTesting
    static boolean canReadFullData(boolean isAdmin, String user, Set<String> groups, GeneralSettingsDAO.LLMCostLimitingQuota quota) {
        if (isAdmin) {
            return true;
        }
        if (quota.permissions == null) {
            return false;
        }
        return quota.permissions.stream().filter(LLMCostLimitingQuotasRepository.getCostLimitingPermissionPredicate(user, groups)).anyMatch(permission -> permission.readFullData);
    }
}

