/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.projects.importexport;

import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.retrieval.RAGLLMSettings;
import java.util.Map;

public class RetrievalAugmentedLLMConnectionsUtils {
    private RetrievalAugmentedLLMConnectionsUtils() {
    }

    public static void remapConnections(RAGLLMSettings ragLLMSettings, Map<String, String> replacements) {
        RetrievalAugmentedLLMConnectionsUtils.remapExposedLlm(ragLLMSettings, replacements);
        if (ragLLMSettings.hasGuardrailsEnabled()) {
            RetrievalAugmentedLLMConnectionsUtils.remapGuardrailsConnections(ragLLMSettings, replacements);
        }
    }

    private static void remapExposedLlm(RAGLLMSettings settings, Map<String, String> replacements) {
        if (settings.llmId == null) {
            return;
        }
        LLMStructuredRef ref = LLMStructuredRef.decodeId(settings.llmId);
        if (ref.connection == null) {
            return;
        }
        String newConnection = replacements.get(ref.connection);
        if (newConnection != null) {
            ref.setConnection(newConnection);
            settings.llmId = ref.encodeToId();
        }
    }

    private static void remapGuardrailsConnections(RAGLLMSettings settings, Map<String, String> replacements) {
        String newConnection;
        if (settings.ragSpecificGuardrails.embeddingModelId != null) {
            LLMStructuredRef guardrailEmbedRef = LLMStructuredRef.decodeId(settings.ragSpecificGuardrails.embeddingModelId);
            if (guardrailEmbedRef.connection != null && (newConnection = replacements.get(guardrailEmbedRef.connection)) != null) {
                guardrailEmbedRef.setConnection(newConnection);
                settings.ragSpecificGuardrails.embeddingModelId = guardrailEmbedRef.encodeToId();
            }
        }
        if (settings.ragSpecificGuardrails.llmId != null) {
            LLMStructuredRef guardrailCompletionRef = LLMStructuredRef.decodeId(settings.ragSpecificGuardrails.llmId);
            if (guardrailCompletionRef.connection != null && (newConnection = replacements.get(guardrailCompletionRef.connection)) != null) {
                guardrailCompletionRef.setConnection(newConnection);
                settings.ragSpecificGuardrails.llmId = guardrailCompletionRef.encodeToId();
            }
        }
    }
}

