/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online;

import com.dataiku.dip.DKUApp;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.AbstractSQLConnection;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.externalinfras.azureml.AzureMLUtils;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMModelHandle;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.online.ISavedModelDeployer;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMCostLimitingService;
import com.dataiku.dip.llm.online.RemoteFineTuningClient;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.utils.Params;
import com.dataiku.dip.utils.SmartLogTail;
import com.dataiku.dss.shadelib.com.google.common.annotations.VisibleForTesting;
import com.dataiku.dss.shadelib.com.google.common.base.MoreObjects;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.Callable;
import javax.annotation.Nullable;
import org.springframework.beans.factory.annotation.Autowired;

public class LLMClientCostLimitingWrapper
implements LLMClient {
    @Autowired
    private LLMCostLimitingService costLimitingService;
    private final LLMClient llmClient;
    private final LLMCostLimitingService.LLMCostLimitingContext context;

    @VisibleForTesting
    protected LLMClientCostLimitingWrapper(LLMClient wrappedClient, LLMModelHandle<?> llmModelHandle) {
        assert (!(wrappedClient instanceof LLMClientCostLimitingWrapper));
        assert (llmModelHandle != null);
        SpringUtils.getInstance().autowire((Object)this);
        this.llmClient = wrappedClient;
        this.context = new LLMCostLimitingService.LLMCostLimitingContext();
        this.context.llmId = llmModelHandle.getEnrichedRef().id;
        this.context.connectionName = llmModelHandle.getEnrichedRef().connection;
        this.context.provider = wrappedClient.getProviderId();
    }

    public LLMClientCostLimitingWrapper(LLMClient wrappedClient, AuthCtx authCtx, String projectKey, LLMModelHandle<?> llmModelHandle) {
        this(wrappedClient, llmModelHandle);
        this.context.projectKey = projectKey;
        this.context.user = authCtx != null ? authCtx.getAssociatedDSSUser() : null;
        try {
            if (authCtx != null) {
                this.context.groups = authCtx.getGroups();
            }
        }
        catch (DKUSecurityException e) {
            logger.warn((Object)"Unexpected use of LLM Client without valid authentication context: could not access groups.");
        }
        AbstractLLMConnection llmConnection = this.llmClient.getConnection();
        if (llmConnection != null) {
            Object model = llmModelHandle.getModel();
            Params connectionProperties = AbstractSQLConnection.CustomDatabaseProperty.toParams(this.llmClient.getConnection().getDkuProperties());
            boolean limitFreeQueries = DKUApp.getProperty((String)"dku.llm.costLimiting.limitFreeQueries", (boolean)false);
            boolean nullCostIsFree = connectionProperties.getBoolParam("dku.llm.costLimiting.nullCostIsFree", model.nullCostIsFree());
            boolean hasImageGenerationCost = connectionProperties.getBoolParam("dku.llm.costLimiting.hasImageGenerationCost", model.hasImageGenerationCost());
            this.context.canBlockOnCompletion = limitFreeQueries || !LLMClientCostLimitingWrapper.costIsFree(model.getPromptCost(), nullCostIsFree) || !LLMClientCostLimitingWrapper.costIsFree(model.getCompletionCost(), nullCostIsFree);
            this.context.canBlockOnEmbeddings = limitFreeQueries || !LLMClientCostLimitingWrapper.costIsFree(model.getTextEmbeddingCost(), nullCostIsFree) || !LLMClientCostLimitingWrapper.costIsFree(model.getImageEmbeddingCost(), nullCostIsFree);
            this.context.canBlockOnImageGeneration = limitFreeQueries || hasImageGenerationCost;
        } else {
            logger.warn((Object)"Unexpected use of a LLM Client without LLM Connection in the LLMClientCostLimitingWrapper. This client will not be blocked by cost limiting.");
            this.context.canBlockOnCompletion = false;
            this.context.canBlockOnEmbeddings = false;
            this.context.canBlockOnImageGeneration = false;
        }
    }

    @Override
    public boolean supportNativeBatch() {
        return this.llmClient.supportNativeBatch();
    }

    @Override
    public boolean requiresCostLimiting() {
        return false;
    }

    @Override
    public String getProviderId() {
        return null;
    }

    @Override
    public AbstractLLMConnection getConnection() {
        return this.llmClient.getConnection();
    }

    @Override
    public int getMaxParallelism() {
        return this.llmClient.getMaxParallelism();
    }

    @Override
    public int getBatchSize(AbstractLLMConnection.QueryType queryType, LLMStructuredRef llmRef) {
        return this.llmClient.getBatchSize(queryType, llmRef);
    }

    @Override
    public List<LLMClient.SimpleCompletionResponse> completeBatch(List<LLMClient.SingleCompletionQuery> queries, LLMClient.CompletionSettings settings) throws Exception {
        if (this.context.canBlockOnCompletion) {
            this.costLimitingService.checkQuery(this.context);
        }
        return this.handleQueryProcessingError(queries.size(), () -> {
            List<LLMClient.SimpleCompletionResponse> responses = this.llmClient.completeBatch(queries, settings);
            double estimatedCost = responses.stream().mapToDouble(r -> LLMClientCostLimitingWrapper.safeCost(r.estimatedCost)).sum();
            this.costLimitingService.reportCost(this.context, estimatedCost, responses.size());
            return responses;
        });
    }

    @Override
    public List<LLMClient.SimpleEmbeddingResponse> embedBatch(List<LLMClient.EmbeddingQuery> queries, LLMClient.EmbeddingSettings settings) throws Exception {
        if (this.context.canBlockOnEmbeddings) {
            this.costLimitingService.checkQuery(this.context);
        }
        return this.handleQueryProcessingError(queries.size(), () -> {
            List<LLMClient.SimpleEmbeddingResponse> responses = this.llmClient.embedBatch(queries, settings);
            double cost = responses.stream().mapToDouble(r -> LLMClientCostLimitingWrapper.safeCost(r.estimatedCost)).sum();
            this.costLimitingService.reportCost(this.context, cost, responses.size());
            return responses;
        });
    }

    @Override
    public List<LLMClient.SimpleRerankingResponse> rerankBatch(List<LLMClient.RerankingQuery> queries) throws Exception {
        return this.handleQueryProcessingError(queries.size(), () -> {
            List<LLMClient.SimpleRerankingResponse> responses = this.llmClient.rerankBatch(queries);
            return responses;
        });
    }

    @Override
    public boolean supportsStream() {
        return this.llmClient.supportsStream();
    }

    @Override
    public void streamComplete(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings, final LLMClient.StreamedCompletionResponseConsumer consumer) throws Exception {
        if (this.context.canBlockOnCompletion) {
            this.costLimitingService.checkQuery(this.context);
        }
        this.handleQueryProcessingError(1, () -> {
            this.llmClient.streamComplete(query, settings, new LLMClient.StreamedCompletionResponseConsumer(){

                @Override
                public void onStreamStarted() throws Exception {
                    consumer.onStreamStarted();
                }

                @Override
                public void onStreamChunk(LLMClient.StreamedCompletionResponseChunk chunk) throws Exception {
                    consumer.onStreamChunk(chunk);
                }

                @Override
                public void onStreamComplete(LLMClient.StreamedCompletionResponseFooter footer) throws Exception {
                    LLMClientCostLimitingWrapper.this.costLimitingService.reportCost(LLMClientCostLimitingWrapper.this.context, LLMClientCostLimitingWrapper.safeCost(footer.estimatedCost), 1);
                    consumer.onStreamComplete(footer);
                }
            });
            return null;
        });
    }

    @Override
    public LLMClient.ImageGenerationResponse generateImages(LLMClient.ImageGenerationQuery query) throws Exception {
        if (this.context.canBlockOnImageGeneration) {
            this.costLimitingService.checkQuery(this.context);
        }
        return this.handleQueryProcessingError(1, () -> {
            LLMClient.ImageGenerationResponse response = this.llmClient.generateImages(query);
            this.costLimitingService.reportCost(this.context, LLMClientCostLimitingWrapper.safeCost(response.estimatedCost), 1);
            return response;
        });
    }

    @Override
    public ComputeResourceUsage getTotalCRU(ComputeResourceUsage.LLMUsageType usageType, LLMStructuredRef llmRef) {
        return this.llmClient.getTotalCRU(usageType, llmRef);
    }

    @Override
    public EnrichedLLMStructuredRef getEnrichedRef() throws Exception {
        return this.llmClient.getEnrichedRef();
    }

    @Override
    public RemoteFineTuningClient newFineTuningClient() throws UnsupportedOperationException {
        return this.llmClient.newFineTuningClient();
    }

    @Override
    public ISavedModelDeployer newSavedModelDeployer(AuthCtx authCtx) throws UnsupportedOperationException, AzureMLUtils.AzureAuthenticationException, IOException, DKUSecurityException {
        return this.llmClient.newSavedModelDeployer(authCtx);
    }

    @Override
    public List<LLMClient.ChatMessage> getFormattedPrompt(List<LLMClient.ChatMessage> chatMessages) {
        return this.llmClient.getFormattedPrompt(chatMessages);
    }

    @Override
    public SmartLogTail getKernelLog() throws Exception {
        return this.llmClient.getKernelLog();
    }

    @Override
    public void close() throws Exception {
        this.llmClient.close();
    }

    public String toString() {
        return "LLMClientCostLimitingWrapper{llmClient=" + String.valueOf(this.llmClient) + "}";
    }

    private static double safeCost(@Nullable Double cost) {
        return (Double)MoreObjects.firstNonNull((Object)cost, (Object)0.0);
    }

    @VisibleForTesting
    static boolean costIsFree(@Nullable Double cost, boolean nullCostIsFree) {
        if (cost == null) {
            return nullCostIsFree;
        }
        return cost == 0.0;
    }

    private <R> R handleQueryProcessingError(int nbQueries, Callable<R> processingCall) throws Exception {
        try {
            return processingCall.call();
        }
        catch (LLMCostLimitingService.LLMCostLimitingReportingException e) {
            logger.error((Object)"Cost limiting reporting error", (Throwable)e);
            throw e;
        }
        catch (Exception exception) {
            double errorCost = 0.0;
            if (exception instanceof LLMClient.LLMException) {
                LLMClient.LLMException llmException = (LLMClient.LLMException)exception;
                if (llmException.estimatedCost != null && llmException.estimatedCost > 0.0) {
                    errorCost = llmException.estimatedCost;
                }
            }
            this.costLimitingService.reportCost(this.context, errorCost, nbQueries);
            throw exception;
        }
    }
}

