/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.anthropic;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.AnthropicConnection;
import com.dataiku.dip.connections.ConnectionUtils;
import com.dataiku.dip.llm.LLMModelHandle;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.online.AbstractLLMClient;
import com.dataiku.dip.llm.online.LLMChatMessageUtils;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMQueryRunner;
import com.dataiku.dip.llm.online.anthropic.RawAnthropicClient;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.llm.utils.OnlineLLMUtils;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.shaker.processors.expr.TokenizedText;
import com.dataiku.dip.utils.DKULogger;
import com.google.common.base.Stopwatch;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;

public class AnthropicClient
extends AbstractLLMClient
implements LLMClient {
    private final AnthropicConnection connection;
    private final LLMQueryRunner queryRunner;
    private final RawAnthropicClient raw;
    private final AnthropicConnection.AnthropicModel model;
    private ComputeResourceUsage.InternalLLMUsageData usageData = new ComputeResourceUsage.LLMUsageData();
    private static DKULogger logger = DKULogger.getLogger((String)"dku.llm.anthropic");

    public AnthropicClient(AnthropicConnection connection, LLMModelHandle<AnthropicConnection.AnthropicModel> modelHandle, String apiKey) {
        super(modelHandle.getEnrichedRef());
        this.queryRunner = new LLMQueryRunner(this.getProviderId(), modelHandle, connection.params.networkSettings, OnlineLLMUtils::isRetryableException);
        boolean forceContentLength = ConnectionUtils.getParamsFromProperties(connection.getDkuProperties()).getBoolParam("dku.connection.llm.forceContentLength", false);
        this.raw = new RawAnthropicClient(apiKey, this.queryRunner.getHttpClientNetworkSettings(), connection.getProxySettings(), forceContentLength);
        this.connection = connection;
        this.model = modelHandle.getModel();
    }

    @Override
    public void close() {
        this.raw.close();
    }

    @Override
    public boolean supportNativeBatch() {
        return false;
    }

    @Override
    public boolean requiresCostLimiting() {
        return true;
    }

    @Override
    public String getProviderId() {
        return "Anthropic";
    }

    @Override
    public AbstractLLMConnection getConnection() {
        return this.connection;
    }

    @Override
    public int getMaxParallelism() {
        return this.connection.params.maxParallelism;
    }

    @Override
    public List<LLMClient.ChatMessage> getFormattedPrompt(List<LLMClient.ChatMessage> chatMessages) {
        return AnthropicClient.getFormattedPrompt(chatMessages, this.model.useChatApi);
    }

    public static List<LLMClient.ChatMessage> getFormattedPrompt(List<LLMClient.ChatMessage> chatMessages, boolean isChatModel) {
        if (isChatModel) {
            chatMessages = LLMChatMessageUtils.convertMessageRole(chatMessages, "tool", "user");
            return LLMChatMessageUtils.collapseAdjacentSameRoleMessages(chatMessages);
        }
        ArrayList<LLMClient.ChatMessage> formattedPromptMessages = new ArrayList<LLMClient.ChatMessage>();
        LLMClient.ChatMessage formattedPrompt = new LLMClient.ChatMessage("prompt", AnthropicClient.getFormattedPromptContent(chatMessages));
        formattedPromptMessages.add(formattedPrompt);
        return formattedPromptMessages;
    }

    public static String getFormattedPromptContent(List<LLMClient.ChatMessage> chatMessages) {
        StringBuilder sb = new StringBuilder();
        for (LLMClient.ChatMessage message : chatMessages) {
            if ("system".equals(message.role)) {
                sb.append("\n\nHuman: ");
                sb.append(message.getText());
                continue;
            }
            if ("user".equals(message.role)) {
                sb.append("\n\nHuman: ");
                sb.append(message.getText());
                continue;
            }
            if (!"assistant".equals(message.role)) continue;
            sb.append("\n\nAssistant: ");
            sb.append(message.getText());
        }
        sb.append("\n\nAssistant: ");
        return sb.toString();
    }

    @Override
    public List<LLMClient.SimpleCompletionResponse> completeBatch(List<LLMClient.SingleCompletionQuery> queries, LLMClient.CompletionSettings settings) throws Exception {
        CoreCompletionSettings ccs = this.getCoreCompletionSettings(settings);
        ArrayList<LLMClient.SimpleCompletionResponse> ret = new ArrayList<LLMClient.SimpleCompletionResponse>();
        for (LLMClient.SingleCompletionQuery query : queries) {
            Stopwatch stopwatch = Stopwatch.createStarted();
            LLMClient.SimpleCompletionResponse scr = this.queryRunner.run(() -> {
                LLMClient.SimpleCompletionResponse response;
                if (this.model.useChatApi) {
                    List<LLMClient.ChatMessage> chatMessages = this.getFormattedPrompt(query.messages);
                    response = this.raw.chatComplete(this.model.getId(), chatMessages, ccs);
                } else {
                    String prompt = AnthropicClient.getFormattedPromptContent(query.messages);
                    response = this.raw.complete(this.model.getId(), prompt, ccs);
                    response.promptTokens = (int)(2.5f * (float)new TokenizedText(prompt).size());
                    response.completionTokens = (int)(2.5f * (float)new TokenizedText(response.text).size());
                    response.tokenCountsAreEstimated = true;
                }
                response.estimatedCost = this.model.getEstimatedCompletionCost(response.promptTokens, response.completionTokens);
                return response;
            });
            scr.includeInUsageData(this.usageData, stopwatch.elapsed(TimeUnit.MILLISECONDS));
            ret.add(scr);
        }
        return ret;
    }

    @Override
    public List<LLMClient.SimpleEmbeddingResponse> embedBatch(List<LLMClient.EmbeddingQuery> queries, LLMClient.EmbeddingSettings settings) throws IOException {
        throw new IllegalArgumentException("Embeddings not supported on this LLM");
    }

    @Override
    public ComputeResourceUsage getTotalCRU(ComputeResourceUsage.LLMUsageType usageType, LLMStructuredRef llmRef) {
        ComputeResourceUsage cru = new ComputeResourceUsage();
        cru.setupLLMUsage(usageType, llmRef.connection, llmRef.type.toString(), llmRef.id);
        cru.llmUsage.setFromInternal(this.usageData);
        return cru;
    }
}

