/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.retrieval;

import com.dataiku.dip.DKUApp;
import com.dataiku.dip.code.CodeEnvSelection;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMRefEnricherService;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.retrieval.LangchainBasedRAGServer;
import com.dataiku.dip.llm.retrieval.RAGLLMSettings;
import com.dataiku.dip.llm.retrieval.RetrievableKnowledge;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.utils.DKUCompletableFuture;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.SmartLogTail;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;

public class LangchainBasedRAGClient
implements LLMClient {
    private final RAGLLMSettings settings;
    private final LLMStructuredRef llmRef;
    private final LangchainBasedRAGServer server;
    private static DKULogger logger = DKULogger.getLogger((String)"dku.llm.langchain");

    public LangchainBasedRAGClient(AuthCtx authCtx, LLMStructuredRef llmRef, String projectKey, RetrievableKnowledge rk, RAGLLMSettings settings, LLMClient.CompletionSettings completionSettings, CodeEnvSelection selection, String containerConfName, boolean devKernel) throws Exception {
        this.llmRef = llmRef;
        this.settings = settings;
        AnyLoc loc = AnyLoc.resolveSmart(projectKey, llmRef.savedModelSmartId);
        String version = llmRef.savedModelVersionId != null ? llmRef.savedModelVersionId : "v1";
        File logBaseDir = DKUApp.getFile((String[])new String[]{"saved_models", projectKey, loc.getId(), "versions", version, "logs"});
        this.server = new LangchainBasedRAGServer(authCtx, projectKey, llmRef, rk, settings, completionSettings, selection, containerConfName, logBaseDir, devKernel);
    }

    public void start() throws Exception {
        if (!this.isAlive()) {
            this.server.init();
        }
    }

    @Override
    public void close() {
        this.server.close();
    }

    public boolean isAlive() {
        return this.server.isAlive();
    }

    @Override
    public SmartLogTail getKernelLog() {
        return this.server.getKernelLog();
    }

    @Override
    public boolean supportNativeBatch() {
        return false;
    }

    @Override
    public boolean requiresCostLimiting() {
        return false;
    }

    @Override
    public String getProviderId() {
        return null;
    }

    @Override
    public AbstractLLMConnection getConnection() {
        return null;
    }

    @Override
    public int getMaxParallelism() {
        return 1;
    }

    @Override
    public List<LLMClient.SimpleCompletionResponse> completeBatch(List<LLMClient.SingleCompletionQuery> queries, LLMClient.CompletionSettings settings) throws IOException {
        List futures = queries.stream().map(q -> this.server.processAsync((LLMClient.SingleCompletionQuery)q, settings)).collect(Collectors.toList());
        try {
            return DKUCompletableFuture.collectResponses(futures);
        }
        catch (Exception e) {
            if (e instanceof InterruptedException) {
                Thread.currentThread().interrupt();
            }
            throw new IOException("Interrupted while waiting for completions", e);
        }
    }

    public CompletableFuture<LLMClient.SimpleCompletionResponse> asyncComplete(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings) {
        return this.server.processAsync(query, settings);
    }

    @Override
    public boolean supportsStream() {
        throw new Error("unreachable");
    }

    @Override
    public void streamComplete(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings, LLMClient.StreamedCompletionResponseConsumer consumer) throws Exception {
        throw new Error("unreachable");
    }

    public CompletableFuture<Integer> asyncStreamComplete(LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings, LLMClient.StreamedCompletionResponseConsumer consumer) {
        return this.server.streamProcessAsync(query, settings, consumer);
    }

    @Override
    public ComputeResourceUsage getTotalCRU(ComputeResourceUsage.LLMUsageType usageType, LLMStructuredRef llmRef) {
        return null;
    }

    @Override
    public List<LLMClient.SimpleEmbeddingResponse> embedBatch(List<LLMClient.EmbeddingQuery> queries, LLMClient.EmbeddingSettings settings) throws IOException {
        throw new IllegalArgumentException("Embeddings not supported on this LLM");
    }

    @Override
    public List<LLMClient.SimpleRerankingResponse> rerankBatch(List<LLMClient.RerankingQuery> queries) throws Exception {
        throw new IllegalArgumentException("Reranking not supported on this LLM");
    }

    @Override
    public EnrichedLLMStructuredRef getEnrichedRef() throws Exception {
        return ((LLMRefEnricherService)SpringUtils.getBean(LLMRefEnricherService.class)).getEnrichedLLMRefFromRetrievalAugmentedLLM(this.server.authCtx, this.server.projectKey, this.llmRef);
    }

    @Override
    public List<LLMClient.ChatMessage> getFormattedPrompt(List<LLMClient.ChatMessage> messages) {
        return LangchainBasedRAGClient.getFormattedPrompt(this.settings, messages);
    }

    public static List<LLMClient.ChatMessage> getFormattedPrompt(RAGLLMSettings settings, List<LLMClient.ChatMessage> messages) {
        ArrayList<LLMClient.ChatMessage> newMessages = new ArrayList<LLMClient.ChatMessage>(messages);
        newMessages.add(newMessages.size() - 1, new LLMClient.ChatMessage("system", settings.contextMessage));
        newMessages.add(newMessages.size() - 1, new LLMClient.ChatMessage("user", "{{rag sources}}"));
        return newMessages;
    }

    public String getKernelId() {
        return this.server.getKernelId();
    }
}

