/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.connections;

import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.ProxySettings;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.ConnectionWithBasicCredential;
import com.dataiku.dip.connections.ConnectionsDAO;
import com.dataiku.dip.connections.SnowflakeConnection;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.DSSAuthCtx;
import com.dataiku.dip.security.PasswordEncryptionService;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.services.ConnectionsTestService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.ifaces.IsolationLevel;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.variables.VariablesContext;
import java.io.IOException;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import javax.annotation.Nonnull;
import org.apache.commons.lang.StringUtils;

public class SnowflakeCortexLLMConnection
extends AbstractLLMConnection<SnowflakeCortexLLMModel, HardcodedSnowflakeCortexLLMModel, CustomSnowflakeCortexLLMModel> {
    private static final EnrichedLLMStructuredRef.FieldRange TEMPERATURE_RANGE = new EnrichedLLMStructuredRef.FieldRange(0.0, 2.0, 0.01);
    private static final EnrichedLLMStructuredRef.FieldRange TOP_K_RANGE = null;
    public SnowflakeCortexLLMConnectionParams params = new SnowflakeCortexLLMConnectionParams();
    public static final String SNOWFLAKE_CORTEX_CONNECTION_TYPE = "SnowflakeCortex";
    private static final DKULogger logger = DKULogger.getLogger((String)"dip.connections.snowflake-cortex-llm");

    @Override
    public String getType() {
        return SNOWFLAKE_CORTEX_CONNECTION_TYPE;
    }

    @Override
    public AbstractLLMConnection.AbstractLLMConnectionParams getLLMConnectionParams() {
        return this.params;
    }

    @Override
    protected List<HardcodedSnowflakeCortexLLMModel> listRawHardcodedModels() {
        return Arrays.asList(HardcodedSnowflakeCortexLLMModel.values());
    }

    @Override
    protected List<CustomSnowflakeCortexLLMModel> listRawCustomModels() {
        return this.params.customModels;
    }

    @Override
    protected SnowflakeCortexLLMModel loadRawHardcodedModel(HardcodedSnowflakeCortexLLMModel hardcodedModel) {
        SnowflakeCortexLLMModel model = hardcodedModel.toModel();
        this.loadDefaultHardcodedModelSettings(hardcodedModel, model);
        return model;
    }

    @Override
    protected SnowflakeCortexLLMModel loadRawCustomModel(CustomSnowflakeCortexLLMModel customModel) {
        SnowflakeCortexLLMModel model = customModel.toModel();
        this.loadDefaultCustomModelSettings(customModel, model);
        return model;
    }

    @Override
    protected boolean isHardcodedModelEnabled(HardcodedSnowflakeCortexLLMModel hardcodedModel) {
        return hardcodedModel.allowedModel.apply(this.params);
    }

    public static Optional<HardcodedSnowflakeCortexLLMModel> getHardcodedModel(LLMStructuredRef llmRef) {
        return Arrays.stream(HardcodedSnowflakeCortexLLMModel.values()).filter(m -> m.id.equals(llmRef.model)).findFirst();
    }

    @Override
    public void encryptFields(PasswordEncryptionService cryptoService, GeneralSettingsDAO.SecuritySettings securitySettings) {
    }

    @Override
    public void decryptFields(PasswordEncryptionService cryptoService) {
    }

    @Override
    protected <T> T getFullyResolvedCredentials_internal(ConnectionWithBasicCredential.CredentialResolutionContext ctx, Class<T> clazz) throws DKUSecurityException, IOException, SQLException {
        SnowflakeConnection sfConn;
        assert (clazz.isAssignableFrom(SnowflakeConnection.SerializableSnowflakeCredentials.class));
        try (Transaction t = ((TransactionService)SpringUtils.getBean(TransactionService.class)).retrieveOrBeginRead(IsolationLevel.YOLO);){
            sfConn = ConnectionsDAO.get().getMandatoryConnectionAs(ctx.authCtx, this.params.snowflakeConnection, SnowflakeConnection.class);
        }
        return clazz.cast(sfConn.getFullyResolvedCredentials(new ConnectionWithBasicCredential.CredentialResolutionContext(ctx.authCtx, null), SnowflakeConnection.SerializableSnowflakeCredentials.class));
    }

    @Override
    public void expandParametersInPlaceAtDAOLevelUsingGlobalContextOnly(VariablesContext vc) {
    }

    @Override
    protected DKULogger getLogger() {
        return logger;
    }

    @Override
    @Nonnull
    public ProxySettings getProxySettings() {
        if (StringUtils.isNotEmpty((String)this.params.snowflakeConnection)) {
            ProxySettings proxySettings;
            block9: {
                Transaction t = ((TransactionService)SpringUtils.getBean(TransactionService.class)).retrieveOrBeginRead(IsolationLevel.YOLO);
                try {
                    SnowflakeConnection sfConn = ConnectionsDAO.get().getMandatoryConnectionAs(DSSAuthCtx.internalAdminAuth(), this.params.snowflakeConnection, SnowflakeConnection.class);
                    proxySettings = sfConn.getProxySettings();
                    if (t == null) break block9;
                }
                catch (Throwable throwable) {
                    try {
                        if (t != null) {
                            try {
                                t.close();
                            }
                            catch (Throwable throwable2) {
                                throwable.addSuppressed(throwable2);
                            }
                        }
                        throw throwable;
                    }
                    catch (Exception e) {
                        logger.errorV((Throwable)e, "Error resolving Snowflake connection '%s'. Will use global proxy settings.", new Object[]{this.params.snowflakeConnection});
                    }
                }
                t.close();
            }
            return proxySettings;
        }
        return ApplicationConfigurator.getProxySettings();
    }

    @Override
    public ConnectionsTestService.ConnectionTestResult testConnection(AuthCtx authCtx, ConnectionsTestService connectionsTestService) throws Exception {
        return connectionsTestService.testSnowflakeCortexLLM(this, authCtx);
    }

    public static class SnowflakeCortexLLMConnectionParams
    extends AbstractLLMConnection.AbstractLLMConnectionParams {
        public AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings = new AbstractLLMConnection.HTTPBasedLLMNetworkSettings();
        public String snowflakeConnection;
        public boolean allowLlama33_70BChat = false;
        public boolean allowLlama31_70BChat = false;
        public boolean allowLlama32_3BChat = false;
        public boolean allowLlama2_70BChat = false;
        public boolean allowMixtral8x7B = false;
        public boolean allowMistral_7B = false;
        public boolean allowMistral_Large2 = false;
        public boolean allowMistral_Large = false;
        public boolean allowGemma_7B = true;
        public boolean allowSnowflakeArctic = true;
        public boolean allowDeepSeekR1 = false;
        public boolean allowClaude35Sonnet = false;
        public boolean allowLlama4_Maverick = false;
        public boolean allowSnowflakeArcticEmbedM = true;
        public boolean allowE5BaseV2 = true;
        public boolean allowNVEmbedQA4 = true;
        public int maxParallelism = 8;
        public int queryTimeoutMs = 60000;
        public List<CustomSnowflakeCortexLLMModel> customModels = new ArrayList<CustomSnowflakeCortexLLMModel>();
    }

    public static enum HardcodedSnowflakeCortexLLMModel implements AbstractLLMConnection.IHardcodedConnectionModel<SnowflakeCortexLLMModel>
    {
        LLAMA33_70B_CHAT("snowflake-llama-3.3-70b", "Llama 3.3 70B", SnowflakeCortexLLMModelType.CHAT_COMPLETION, p -> p.allowLlama33_70BChat, true, true),
        LLAMA31_70B_CHAT("llama3.1-70b", "Llama 3.1 70B", SnowflakeCortexLLMModelType.CHAT_COMPLETION, p -> p.allowLlama31_70BChat, true, true),
        LLAMA32_3B_CHAT("llama3.2-3b", "Llama 3.2 3B", SnowflakeCortexLLMModelType.CHAT_COMPLETION, p -> p.allowLlama32_3BChat, true, true),
        LLAMA2_70B_CHAT("llama2-70b-chat", "Llama 2 70B Chat (legacy)", SnowflakeCortexLLMModelType.CHAT_COMPLETION, p -> p.allowLlama2_70BChat, true, false),
        MIXTRAL_8x7B("mixtral-8x7b", "Mixtral-8x7B", SnowflakeCortexLLMModelType.CHAT_COMPLETION, p -> p.allowMixtral8x7B, true, false),
        MISTRAL_7B("mistral-7b", "Mistral", SnowflakeCortexLLMModelType.CHAT_COMPLETION, p -> p.allowMistral_7B, true, true),
        MISTRAL_LARGE2("mistral-large2", "Mistral Large 2", SnowflakeCortexLLMModelType.CHAT_COMPLETION, p -> p.allowMistral_Large2, true, true),
        MISTRAL_LARGE("mistral-large", "Mistral Large (legacy)", SnowflakeCortexLLMModelType.CHAT_COMPLETION, p -> p.allowMistral_Large, true, true),
        GEMMA_7B("gemma-7b", "Gemma 7B", SnowflakeCortexLLMModelType.CHAT_COMPLETION, p -> p.allowGemma_7B, true, false),
        SNOWFLAKE_ARCTIC("snowflake-arctic", "Snowflake Arctic", SnowflakeCortexLLMModelType.CHAT_COMPLETION, p -> p.allowSnowflakeArctic, true, false),
        DEEPSEEK_R1("deepseek-r1", "DeepSeek R1", SnowflakeCortexLLMModelType.CHAT_COMPLETION, p -> p.allowDeepSeekR1, true, true),
        CLAUDE_35_SONNET("claude-3-5-sonnet", "Claude 3.5 Sonnet", SnowflakeCortexLLMModelType.CHAT_COMPLETION, p -> p.allowClaude35Sonnet, true, true),
        LLAMA4_MAVERICK("llama4-maverick", "Llama4 Maverick", SnowflakeCortexLLMModelType.CHAT_COMPLETION, p -> p.allowLlama4_Maverick, true, true),
        SNOWFLAKE_ARCTIC_EMBED_M("snowflake-arctic-embed-m", "Snowflake Arctic Embed M", SnowflakeCortexLLMModelType.TEXT_EMBEDDING, 768, 512, p -> p.allowSnowflakeArcticEmbedM, true, true),
        E5_BASE_V2("e5-base-v2", "E5 Base v2", SnowflakeCortexLLMModelType.TEXT_EMBEDDING, 768, 512, p -> p.allowE5BaseV2, true, true),
        NV_EMBED_QA_4("nv-embed-qa-4", "NV Embed QA 4", SnowflakeCortexLLMModelType.TEXT_EMBEDDING, 1024, 512, p -> p.allowNVEmbedQA4, true, true);

        public final String id;
        public final String displayName;
        public final Integer embeddingSize;
        public final Integer maxTokensLimit;
        public final SnowflakeCortexLLMModelType modelType;
        public final Function<SnowflakeCortexLLMConnectionParams, Boolean> allowedModel;
        public final boolean sqlCompatible;
        public final boolean restCompatible;

        private HardcodedSnowflakeCortexLLMModel(String id, String displayName, SnowflakeCortexLLMModelType modelType, Function<SnowflakeCortexLLMConnectionParams, Boolean> allowedModel, boolean sqlCompatible, boolean restCompatible) {
            this(id, displayName, modelType, null, null, allowedModel, sqlCompatible, restCompatible);
        }

        private HardcodedSnowflakeCortexLLMModel(String id, String displayName, SnowflakeCortexLLMModelType modelType, Integer embeddingSize, Integer maxTokensLimit, Function<SnowflakeCortexLLMConnectionParams, Boolean> allowedModel, boolean sqlCompatible, boolean restCompatible) {
            this.id = id;
            this.displayName = displayName;
            this.modelType = modelType;
            this.embeddingSize = embeddingSize;
            this.maxTokensLimit = maxTokensLimit;
            this.allowedModel = allowedModel;
            this.sqlCompatible = sqlCompatible;
            this.restCompatible = restCompatible;
        }

        @Override
        public SnowflakeCortexLLMModel toModel() {
            SnowflakeCortexLLMModel model = new SnowflakeCortexLLMModel();
            model.id = this.id;
            model.modelType = this.modelType;
            model.displayName = this.displayName;
            model.embeddingSize = this.embeddingSize;
            model.maxTokensLimit = this.maxTokensLimit;
            model.sqlCompatible = this.sqlCompatible;
            model.restCompatible = this.restCompatible;
            return model;
        }
    }

    public static class SnowflakeCortexLLMModel
    extends AbstractLLMConnection.BaseModel {
        public SnowflakeCortexLLMModelType modelType;
        public boolean sqlCompatible;
        public boolean restCompatible;

        @Override
        public LLMStructuredRef asStructuredRef(String connection) {
            return LLMStructuredRef.forSnowflakeCortexLLMConnection(connection, this.getId());
        }

        @Override
        public boolean canBeUsedForPurpose(@Nonnull AbstractLLMConnection.LLMUsagePurpose purpose) {
            return this.modelType.matchesPurpose(purpose);
        }

        @Override
        public AbstractLLMConnection.ModelCapabilities getModelCapabilities() {
            AbstractLLMConnection.ModelCapabilities capabilities = new AbstractLLMConnection.ModelCapabilities();
            capabilities.canGenerateCrossLanguageOutput = true;
            capabilities.handlesSystemMessage = true;
            capabilities.temperatureRange = TEMPERATURE_RANGE;
            capabilities.topKRange = TOP_K_RANGE;
            return capabilities;
        }

        @Override
        public Optional<String> getInvalidityReason() {
            if (StringUtils.isBlank((String)this.getId())) {
                return Optional.of("Empty model id");
            }
            if (this.modelType == null) {
                return Optional.of("Missing model type");
            }
            if (this.modelType == SnowflakeCortexLLMModelType.TEXT_EMBEDDING && (this.embeddingSize == null || this.embeddingSize != 768 && this.embeddingSize != 1024)) {
                return Optional.of("Embedding size for model " + this.getId() + " must be either 768 or 1024");
            }
            return Optional.empty();
        }
    }

    public static class CustomSnowflakeCortexLLMModel
    extends AbstractLLMConnection.CustomModel<SnowflakeCortexLLMModel> {
        public String id;
        public String displayName;
        public SnowflakeCortexLLMModelType modelType;
        public boolean sqlCompatible = true;
        public boolean restCompatible = false;

        @Override
        public SnowflakeCortexLLMModel toModel() {
            SnowflakeCortexLLMModel model = new SnowflakeCortexLLMModel();
            model.loadFromCustomModel(this);
            model.id = this.id;
            model.displayName = this.displayName;
            model.modelType = this.modelType;
            model.sqlCompatible = this.sqlCompatible;
            model.restCompatible = this.restCompatible;
            return model;
        }
    }

    public static enum SnowflakeCortexLLMModelType {
        CHAT_COMPLETION(AbstractLLMConnection.PROMPT_DRIVEN_PURPOSE_SET),
        TEXT_EMBEDDING(Set.of(AbstractLLMConnection.LLMUsagePurpose.TEXT_EMBEDDING_EXTRACTION));

        public final Set<AbstractLLMConnection.LLMUsagePurpose> matchingPurposes;

        private SnowflakeCortexLLMModelType(Set<AbstractLLMConnection.LLMUsagePurpose> purposes) {
            this.matchingPurposes = purposes;
        }

        public boolean matchesPurpose(AbstractLLMConnection.LLMUsagePurpose purpose) {
            return this.matchingPurposes.contains((Object)purpose);
        }
    }
}

