/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.stabilityai;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.ConnectionUtils;
import com.dataiku.dip.connections.StabilityAIConnection;
import com.dataiku.dip.llm.LLMModelHandle;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.online.AbstractLLMClient;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMQueryRunner;
import com.dataiku.dip.llm.online.stabilityai.RawStabilityAIClient;
import com.dataiku.dip.llm.utils.OnlineLLMUtils;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.utils.DKULogger;
import java.io.IOException;
import java.util.List;

public class StabilityAIClient
extends AbstractLLMClient
implements LLMClient {
    private final StabilityAIConnection connection;
    private final LLMQueryRunner queryRunner;
    private final RawStabilityAIClient raw;
    private final StabilityAIConnection.StabilityAIModel model;
    private ComputeResourceUsage.InternalLLMUsageData usageData = new ComputeResourceUsage.LLMUsageData();
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.stabilityai");

    public StabilityAIClient(AuthCtx authCtx, StabilityAIConnection connection, LLMModelHandle<StabilityAIConnection.StabilityAIModel> modelHandle) {
        super(modelHandle.getEnrichedRef());
        this.connection = connection;
        this.queryRunner = new LLMQueryRunner(this.getProviderId(), modelHandle, connection.params.networkSettings, OnlineLLMUtils::isRetryableException);
        boolean forceContentLength = ConnectionUtils.getParamsFromProperties(connection.getDkuProperties()).getBoolParam("dku.connection.llm.forceContentLength", false);
        this.raw = new RawStabilityAIClient(connection.params.apiKey, this.queryRunner.getHttpClientNetworkSettings(), connection.getProxySettings(), forceContentLength);
        this.model = modelHandle.getModel();
    }

    @Override
    public void close() {
        this.raw.close();
    }

    @Override
    public boolean supportNativeBatch() {
        return false;
    }

    @Override
    public boolean requiresCostLimiting() {
        return true;
    }

    @Override
    public String getProviderId() {
        return "StabilityAI";
    }

    @Override
    public AbstractLLMConnection getConnection() {
        return this.connection;
    }

    @Override
    public int getMaxParallelism() {
        return this.connection.params.maxParallelism;
    }

    @Override
    public List<LLMClient.ChatMessage> getFormattedPrompt(List<LLMClient.ChatMessage> chatMessages) {
        throw new IllegalArgumentException("not on stabilityai");
    }

    @Override
    public List<LLMClient.SimpleCompletionResponse> completeBatch(List<LLMClient.SingleCompletionQuery> queries, LLMClient.CompletionSettings settings) throws IOException {
        throw new IllegalArgumentException("not on stabilityai");
    }

    @Override
    public List<LLMClient.SimpleEmbeddingResponse> embedBatch(List<LLMClient.EmbeddingQuery> queries, LLMClient.EmbeddingSettings settings) throws IOException {
        throw new IllegalArgumentException("not on stabilityai");
    }

    @Override
    public List<LLMClient.SingleRerankingResponse> rerankBatch(List<LLMClient.RerankingQuery> queries, LLMClient.RerankingSettings settings) throws Exception {
        throw new IllegalArgumentException("Rerankings not supported on this LLM");
    }

    @Override
    public LLMClient.ImageGenerationResponse generateImages(LLMClient.ImageGenerationQuery query) throws Exception {
        long before = System.currentTimeMillis();
        LLMClient.ImageGenerationResponse resp = this.queryRunner.run(() -> this.raw.generateImages(this.model.getId(), query));
        long computeTimeMS = System.currentTimeMillis() - before;
        resp.estimatedCost = this.model.getEstimatedImageGenerationCost(query);
        this.usageData.incrementTotalComputationTimeMS(Long.valueOf(computeTimeMS));
        this.usageData.incrementEstimatedCostUSD(Double.valueOf(resp.estimatedCost));
        return resp;
    }

    @Override
    public ComputeResourceUsage getTotalCRU(ComputeResourceUsage.LLMUsageType usageType, LLMStructuredRef llmRef) {
        ComputeResourceUsage cru = new ComputeResourceUsage();
        cru.setupLLMUsage(usageType, llmRef.connection, llmRef.type.toString(), llmRef.id);
        cru.llmUsage.setFromInternal(this.usageData);
        return cru;
    }
}

