/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.connections;

import com.dataiku.dip.analysis.ml.llm.LLMSavedModelInfo;
import com.dataiku.dip.analysis.model.llm.LLMModelSnippetData;
import com.dataiku.dip.code.DSSInternalCodeEnvsService;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.ConnectionWithBasicCredential;
import com.dataiku.dip.containers.exec.ContainerExecSelection;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMModelHandle;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.PasswordEncryptionService;
import com.dataiku.dip.security.model.ICredentialsService;
import com.dataiku.dip.server.services.ConnectionsTestService;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.NotImplementedException;
import com.dataiku.dip.variables.VariablesContext;
import com.dataiku.j2py.annotations.PyModel;
import java.io.IOException;
import java.io.InputStream;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.lang.StringUtils;

public class HuggingFaceLocalConnection
extends AbstractLLMConnection<HFLocalModel, AbstractLLMConnection.IHardcodedConnectionModel, CustomHFLocalModel> {
    public static final EnrichedLLMStructuredRef.FieldRange TEMPERATURE_RANGE = new EnrichedLLMStructuredRef.FieldRange(0.0, 100.0, 0.1);
    public static final EnrichedLLMStructuredRef.FieldRange TOP_K_RANGE = new EnrichedLLMStructuredRef.FieldRange(0.0, 1.0E8, 0.1);
    public static final String connectionType = "HuggingFaceLocal";
    public HuggingFaceLocalConnectionParams params = new HuggingFaceLocalConnectionParams();
    private static final DKULogger logger = DKULogger.getLogger((String)"dip.connections.hflocal");

    @Override
    public AbstractLLMConnection.AbstractLLMConnectionParams getLLMConnectionParams() {
        return this.params;
    }

    public HFLocalModel getLLMModelFromSMInfo(LLMSavedModelInfo llmSMInfo) {
        LLMStructuredRef originalModelIdentifier = LLMStructuredRef.decodeId(llmSMInfo.originalLLMId);
        if (llmSMInfo.originalLLMIsUnreferenced.booleanValue()) {
            return this.getUnregisteredLLMModel(originalModelIdentifier, llmSMInfo);
        }
        return (HFLocalModel)this.getLLMModel(originalModelIdentifier).getModel();
    }

    @Override
    public AbstractLLMConnection.ConnectionModelHandle getLLMModel(LLMStructuredRef llmRef) {
        for (CustomHFLocalModel customModel : this.params.models) {
            if (!Objects.equals(customModel.id, llmRef.model)) continue;
            try {
                return new AbstractLLMConnection.ConnectionModelHandle((AbstractLLMConnection)this, (AbstractLLMConnection.BaseModel)this.loadRawCustomModel(customModel));
            }
            catch (Exception e) {
                logger.warn((Object)"Ignoring custom model due to underlying error during load", (Throwable)e);
            }
        }
        throw new IllegalArgumentException(String.format("Could not find the model %s in connection %s", llmRef.getModelNameForAudit(), this.name));
    }

    private HFLocalModel getUnregisteredLLMModel(LLMStructuredRef llmRef, LLMSavedModelInfo llmSMInfo) {
        CustomHFLocalModel customModel = new CustomHFLocalModel();
        customModel.id = llmRef.model;
        customModel.huggingFaceId = llmRef.model;
        customModel.handlingMode = llmSMInfo.huggingFaceHandlingMode;
        customModel.quantizationMode = llmSMInfo.quantizationMode;
        customModel.displayName = llmSMInfo.inputLLMName != null ? llmSMInfo.inputLLMName : llmRef.model;
        customModel.canBeFineTuned = true;
        HFLocalModel model = customModel.toModel();
        CustomHFLocalModel customModelForInferenceSettings = this.params.models.stream().filter(cm -> Objects.equals(cm.id, customModel.id) && cm.handlingMode == customModel.handlingMode).findFirst().orElse(customModel);
        model.inferenceSettings = this.resolveInferenceSettings(customModelForInferenceSettings);
        if (llmSMInfo.quantizationMode != null) {
            model.inferenceSettings.quantizationMode = llmSMInfo.quantizationMode;
        }
        return model;
    }

    @Override
    public void expandParametersInPlaceAtDAOLevelUsingGlobalContextOnly(VariablesContext vc) {
    }

    @Override
    protected <T> T getFullyResolvedCredentials_internal(ConnectionWithBasicCredential.CredentialResolutionContext ctx, Class<T> clazz) throws DKUSecurityException, IOException, SQLException {
        assert (clazz.isAssignableFrom(ICredentialsService.BasicCredential.class));
        ICredentialsService.BasicCredential creds = new ICredentialsService.BasicCredential("", this.params.apiKey);
        return clazz.cast(creds);
    }

    @Override
    public String getType() {
        return connectionType;
    }

    @Override
    public void encryptFields(PasswordEncryptionService cryptoService, GeneralSettingsDAO.SecuritySettings securitySettings) {
        this.params.apiKey = cryptoService.encryptIfNotEncryptedOrEmpty(this.params.apiKey);
    }

    @Override
    public void decryptFields(PasswordEncryptionService cryptoService) {
        this.params.apiKey = cryptoService.decryptIfEncrypted(this.params.apiKey);
    }

    @Override
    protected DKULogger getLogger() {
        return logger;
    }

    public Integer getParallelism(HFLocalModel model) {
        String keyPrefix = "parallelism";
        AbstractLLMConnection.QueryType queryType = AbstractLLMConnection.QueryType.fromModel(model);
        return this.getCustomPropertyInteger(keyPrefix, queryType, model.getId());
    }

    private Integer getCustomPropertyInteger(String keyPrefix, @Nullable AbstractLLMConnection.QueryType queryType, @Nullable String modelId) {
        return this.getCustomProperty(keyPrefix, queryType, modelId).flatMap(property -> {
            try {
                return Optional.of(Integer.parseInt(property.value));
            }
            catch (NumberFormatException ex) {
                logger.warn((Object)("Invalid " + keyPrefix + " value: " + property.value));
                return Optional.empty();
            }
        }).orElse(null);
    }

    public Integer getPropertyInteger(String keyPrefix, @Nullable AbstractLLMConnection.QueryType queryType, @Nullable String modelId, @Nullable Integer modelProperty) {
        if (modelProperty != null) {
            return modelProperty;
        }
        return this.getCustomPropertyInteger(keyPrefix, queryType, modelId);
    }

    public Optional<Boolean> getPropertyFlag(String keyPrefix, @Nullable AbstractLLMConnection.QueryType queryType, @Nullable String modelId, @Nullable Boolean modelProperty) {
        if (modelProperty != null) {
            return Optional.of(modelProperty);
        }
        return this.getCustomPropertyFlag(keyPrefix, queryType, modelId);
    }

    private String getCustomPropertyString(String keyPrefix, @Nullable AbstractLLMConnection.QueryType queryType, @Nullable String modelId) {
        return this.getCustomProperty(keyPrefix, queryType, modelId).flatMap(property -> StringUtils.isBlank((String)property.value) ? Optional.empty() : Optional.of(property.value)).orElse(null);
    }

    public Double getGpuMemoryUtilization(CustomHFLocalModel model) {
        if (model.gpuMemoryUtilization != null) {
            return model.gpuMemoryUtilization;
        }
        return this.getCustomProperty("gpuMemoryUtilization", AbstractLLMConnection.QueryType.completion, model.id).flatMap(property -> {
            try {
                return Optional.of(Double.parseDouble(property.value));
            }
            catch (NumberFormatException ex) {
                logger.warn((Object)("Invalid gpuMemoryUtilization value: " + property.value));
                return Optional.empty();
            }
        }).orElse(null);
    }

    public HFLocalModel getEnabledDetectionModel(String modelId) {
        return this.listAvailableModels(null).stream().filter(m -> Objects.equals(m.id, modelId)).findFirst().orElseThrow(() -> new IllegalArgumentException("Unavailable Hugging Face local model: " + modelId));
    }

    public boolean isModelEnabled(String modelId) {
        return this.listAvailableModels(null).stream().anyMatch(model -> model.id.equals(modelId));
    }

    @Override
    public List<CustomHFLocalModel> listRawCustomModels() {
        return this.params.models;
    }

    @Override
    public HFLocalModel loadRawCustomModel(CustomHFLocalModel rawCustomModel) {
        HFLocalModel model = rawCustomModel.toModel();
        this.loadDefaultCustomModelSettings(rawCustomModel, model);
        model.inferenceSettings = this.resolveInferenceSettings(rawCustomModel);
        model.parallelism = this.getParallelism(model);
        return model;
    }

    private InferenceSettings resolveInferenceSettings(CustomHFLocalModel customModel) {
        String modelId = customModel.id;
        InferenceSettings settings = new InferenceSettings();
        settings.engine = this.getCustomPropertyEnum("engine", AbstractLLMConnection.QueryType.completion, modelId, InferenceEngine.AUTO);
        if (customModel.quantizationMode != null) {
            settings.quantizationMode = customModel.quantizationMode;
        }
        settings.refinerId = customModel.refinerId;
        settings.defaultHeight = customModel.defaultHeight;
        settings.defaultWidth = customModel.defaultWidth;
        settings.defaultNumInferenceSteps = customModel.defaultNumInferenceSteps;
        settings.defaultGuidanceScale = customModel.defaultGuidanceScale;
        settings.defaultStrength = customModel.defaultStrength;
        settings.maxSequenceLength = customModel.maxSequenceLength;
        settings.enableVaeSlicing = this.getCustomPropertyFlag("vae.slicing", AbstractLLMConnection.QueryType.imageGeneration, modelId).orElse(Boolean.TRUE);
        settings.enableVaeTiling = this.getCustomPropertyFlag("vae.tiling", AbstractLLMConnection.QueryType.imageGeneration, modelId).orElse(Boolean.TRUE);
        settings.deviceStrategy = this.getCustomPropertyEnum("deviceStrategy", AbstractLLMConnection.QueryType.imageGeneration, modelId, DeviceStrategy.NONE);
        settings.vllmEngine = this.getCustomPropertyString("vllmEngine", AbstractLLMConnection.QueryType.completion, modelId);
        settings.maxModelLen = this.getPropertyInteger("maxModelLen", AbstractLLMConnection.QueryType.completion, modelId, customModel.maxTokensLimit);
        settings.kvCacheDType = this.getCustomPropertyString("kvCacheDType", AbstractLLMConnection.QueryType.completion, modelId);
        settings.enforceEager = this.getPropertyFlag("enforceEager", AbstractLLMConnection.QueryType.completion, modelId, customModel.enforceEager).orElse(null);
        settings.maxNumSeqs = this.getPropertyInteger("maxNumSeqs", AbstractLLMConnection.QueryType.completion, modelId, customModel.maxNumSeqs);
        settings.trustRemoteCode = customModel.trustRemoteCode;
        settings.dtype = StringUtils.isBlank((String)customModel.dtype) ? null : customModel.dtype;
        settings.tensorParallelSize = customModel.tensorParallelSize;
        settings.pipelineParallelSize = customModel.pipelineParallelSize;
        settings.enableExpertParallelism = customModel.enableExpertParallelism;
        settings.enablePrefixCaching = this.getCustomPropertyFlag("enablePrefixCaching", AbstractLLMConnection.QueryType.completion, modelId).orElse(null);
        settings.gpuMemoryUtilization = this.getGpuMemoryUtilization(customModel);
        settings.enableChunkedPrefill = this.getCustomPropertyFlag("enableChunkedPrefill", AbstractLLMConnection.QueryType.completion, modelId).orElse(null);
        settings.limitImagesPerPrompt = this.getPropertyInteger("limitImagesPerPrompt", AbstractLLMConnection.QueryType.completion, modelId, customModel.limitImagesPerPrompt);
        settings.guidedDecodingBackend = this.getCustomPropertyString("guidedDecodingBackend", AbstractLLMConnection.QueryType.completion, modelId);
        settings.enableJsonConstraintsInPrompt = this.getCustomPropertyFlag("enableJsonConstraintsInPrompt", AbstractLLMConnection.QueryType.completion, modelId).orElse(null);
        settings.toolSettings = customModel.toolSettings;
        settings.chatTemplateSettings = customModel.chatTemplateSettings;
        return settings;
    }

    @Override
    public ConnectionsTestService.ConnectionTestResult testConnection(AuthCtx authCtx, ConnectionsTestService connectionsTestService) throws Exception {
        throw new NotImplementedException();
    }

    public static class HuggingFaceLocalConnectionParams
    extends AbstractLLMConnection.AbstractLLMConnectionParams {
        public String apiKey;
        public List<CustomHFLocalModel> models = new ArrayList<CustomHFLocalModel>();
        public String clusterId = null;
        public ContainerExecSelection containerSelection = new ContainerExecSelection(ContainerExecSelection.ContainerExecMode.NONE);
        public boolean useDSSModelCache = true;
        private String codeEnvName;
        public boolean enableReserveCapacity = true;

        public String getCodeEnvName() {
            if (StringUtils.isBlank((String)this.codeEnvName)) {
                return DSSInternalCodeEnvsService.getCodeEnvName(DSSInternalCodeEnvsService.DSSInternalCodeEnvType.HUGGINGFACE_LOCAL_CODE_ENV);
            }
            return this.codeEnvName;
        }

        public static enum QuantizationMode {
            Q_4BIT,
            Q_8BIT,
            NONE;

        }
    }

    public static class HFLocalModel
    extends AbstractLLMConnection.BaseModel
    implements LLMModelHandle.FineTuneableModel<HFLocalModel> {
        public String huggingFaceId;
        public HuggingFaceHandlingMode handlingMode;
        public InferenceSettings inferenceSettings;
        public boolean isDKUFineTuned;
        @Nullable
        public String baseModelId;
        public Integer parallelism;
        public boolean supportsImageInputs;
        public ToolSettings toolSettings;
        public ChatTemplateSettings chatTemplateSettings;
        public Integer minKernelCount;
        public Integer maxKernelCount;
        public ContainerExecSelection containerSelection;
        @Nullable
        public String cudaVisibleDevices;

        public OptionalInt getParallelism() {
            return this.parallelism == null || this.parallelism < 0 ? OptionalInt.empty() : OptionalInt.of(this.parallelism);
        }

        @Override
        protected void loadFromCustomModel(AbstractLLMConnection.CustomModel customModel) {
            super.loadFromCustomModel(customModel);
            this.embeddingSize = null;
        }

        @Override
        public AbstractLLMConnection.ModelCapabilities getModelCapabilities() {
            AbstractLLMConnection.ModelCapabilities capabilities = new AbstractLLMConnection.ModelCapabilities();
            capabilities.promptDriven = this.canBeUsedForPurpose(AbstractLLMConnection.LLMUsagePurpose.GENERIC_COMPLETION);
            capabilities.handlesSystemMessage = true;
            capabilities.supportsImageInputs = this.canBeUsedForPurpose(AbstractLLMConnection.LLMUsagePurpose.IMAGE_INPUT);
            capabilities.customClassificationRequiresHypothesisTemplate = this.handlingMode == HuggingFaceHandlingMode.ZSC_GENERIC;
            capabilities.canDoNativeSentimentAnalysis = this.handlingMode == HuggingFaceHandlingMode.TEXT_CLASSIFICATION_SENTIMENT;
            capabilities.canDoNativeEmotionAnalysis = this.handlingMode == HuggingFaceHandlingMode.TEXT_CLASSIFICATION_EMOTIONS;
            capabilities.temperatureRange = TEMPERATURE_RANGE;
            capabilities.topKRange = TOP_K_RANGE;
            return capabilities;
        }

        @Override
        public boolean canBeUsedForPurpose(@Nonnull AbstractLLMConnection.LLMUsagePurpose purpose) {
            if (this.handlingMode == null) {
                return false;
            }
            if (purpose.equals((Object)AbstractLLMConnection.LLMUsagePurpose.FINE_TUNING)) {
                return this.canBeFineTuned();
            }
            if (purpose.equals((Object)AbstractLLMConnection.LLMUsagePurpose.IMAGE_INPUT)) {
                return this.supportsImageInputs;
            }
            return this.handlingMode.purposes.contains((Object)purpose);
        }

        @Override
        public Optional<String> getInvalidityReason() {
            if (StringUtils.isBlank((String)this.getId())) {
                return Optional.of("Empty model id");
            }
            if (!this.isDKUFineTuned && StringUtils.isBlank((String)this.huggingFaceId)) {
                return Optional.of("Missing Hugging Face model id");
            }
            if (this.handlingMode == null) {
                return Optional.of("Missing handling mode");
            }
            return Optional.empty();
        }

        @Override
        public LLMStructuredRef asStructuredRef(String connection) {
            return LLMStructuredRef.forHuggingFaceLocalModel(connection, this.id);
        }

        @Override
        public HFLocalModel toFineTunedModel(LLMModelSnippetData snippetData) {
            HFLocalModel fineTunedModel = new HFLocalModel();
            fineTunedModel.id = snippetData.fullModelId;
            fineTunedModel.baseModelId = this.huggingFaceId;
            fineTunedModel.isDKUFineTuned = true;
            fineTunedModel.canBeFineTuned = false;
            fineTunedModel.handlingMode = this.handlingMode;
            fineTunedModel.inferenceSettings = new InferenceSettings();
            fineTunedModel.inferenceSettings.engine = this.inferenceSettings.engine;
            fineTunedModel.inferenceSettings.quantizationMode = this.inferenceSettings.quantizationMode;
            fineTunedModel.inferenceSettings.vllmEngine = this.inferenceSettings.vllmEngine;
            fineTunedModel.inferenceSettings.maxModelLen = this.inferenceSettings.maxModelLen;
            fineTunedModel.inferenceSettings.kvCacheDType = this.inferenceSettings.kvCacheDType;
            fineTunedModel.inferenceSettings.enforceEager = this.inferenceSettings.enforceEager;
            fineTunedModel.inferenceSettings.maxNumSeqs = this.inferenceSettings.maxNumSeqs;
            fineTunedModel.inferenceSettings.trustRemoteCode = this.inferenceSettings.trustRemoteCode;
            fineTunedModel.inferenceSettings.dtype = this.inferenceSettings.dtype;
            fineTunedModel.inferenceSettings.tensorParallelSize = this.inferenceSettings.tensorParallelSize;
            fineTunedModel.inferenceSettings.pipelineParallelSize = this.inferenceSettings.pipelineParallelSize;
            fineTunedModel.inferenceSettings.enableExpertParallelism = this.inferenceSettings.enableExpertParallelism;
            fineTunedModel.inferenceSettings.enablePrefixCaching = this.inferenceSettings.enablePrefixCaching;
            fineTunedModel.inferenceSettings.gpuMemoryUtilization = this.inferenceSettings.gpuMemoryUtilization;
            fineTunedModel.inferenceSettings.enableChunkedPrefill = this.inferenceSettings.enableChunkedPrefill;
            fineTunedModel.inferenceSettings.limitImagesPerPrompt = this.inferenceSettings.limitImagesPerPrompt;
            fineTunedModel.inferenceSettings.guidedDecodingBackend = this.inferenceSettings.guidedDecodingBackend;
            fineTunedModel.inferenceSettings.enableJsonConstraintsInPrompt = this.inferenceSettings.enableJsonConstraintsInPrompt;
            fineTunedModel.inferenceSettings.toolSettings = this.inferenceSettings.toolSettings;
            fineTunedModel.inferenceSettings.chatTemplateSettings = this.inferenceSettings.chatTemplateSettings;
            fineTunedModel.displayName = this.displayName;
            fineTunedModel.embeddingSize = this.embeddingSize;
            fineTunedModel.maxTokensLimit = this.maxTokensLimit;
            fineTunedModel.promptCost = this.promptCost;
            fineTunedModel.completionCost = this.completionCost;
            fineTunedModel.embeddingCost = this.embeddingCost;
            fineTunedModel.supportsImageInputs = this.supportsImageInputs;
            fineTunedModel.containerSelection = this.containerSelection;
            return fineTunedModel;
        }

        @Override
        public String getBaseModelId() {
            return this.isDKUFineTuned ? this.baseModelId : null;
        }
    }

    public static class CustomHFLocalModel
    extends AbstractLLMConnection.CustomModel<HFLocalModel> {
        public String id;
        public String huggingFaceId;
        public String displayName;
        public String presetId;
        public boolean enabled = true;
        public HuggingFaceHandlingMode handlingMode;
        public HuggingFaceLocalConnectionParams.QuantizationMode quantizationMode = HuggingFaceLocalConnectionParams.QuantizationMode.NONE;
        public boolean supportsImageInputs;
        @Nullable
        public Boolean enforceEager;
        @Nullable
        public Integer maxNumSeqs;
        @Nullable
        public Boolean trustRemoteCode;
        @Nullable
        public String dtype;
        @Nullable
        public Integer tensorParallelSize;
        @Nullable
        public Integer pipelineParallelSize;
        @Nullable
        public Boolean enableExpertParallelism;
        @Nullable
        public Integer limitImagesPerPrompt;
        @Nullable
        public Double gpuMemoryUtilization;
        public ToolSettings toolSettings = new ToolSettings();
        public ChatTemplateSettings chatTemplateSettings = new ChatTemplateSettings();
        public Integer minKernelCount;
        public Integer maxKernelCount;
        public ContainerExecSelection containerSelection = new ContainerExecSelection(ContainerExecSelection.ContainerExecMode.INHERIT);
        @Nullable
        public String cudaVisibleDevices;

        @Override
        public HFLocalModel toModel() {
            HFLocalModel model = new HFLocalModel();
            model.loadFromCustomModel(this);
            model.id = this.id;
            model.huggingFaceId = this.huggingFaceId;
            model.displayName = this.displayName;
            model.handlingMode = this.handlingMode;
            model.enabled = this.enabled;
            model.canBeFineTuned = this.canBeFineTuned;
            model.refinerId = this.refinerId;
            model.supportsImageInputs = this.supportsImageInputs;
            model.toolSettings = this.toolSettings;
            model.chatTemplateSettings = this.chatTemplateSettings;
            model.minKernelCount = this.minKernelCount;
            model.maxKernelCount = this.maxKernelCount;
            model.containerSelection = this.containerSelection;
            model.cudaVisibleDevices = this.cudaVisibleDevices;
            return model;
        }
    }

    public static enum HuggingFaceHandlingMode {
        TEXT_GENERATION_GENERIC(AbstractLLMConnection.PROMPT_DRIVEN_PURPOSE_SET_WITH_FT),
        TEXT_GENERATION_AUTO(AbstractLLMConnection.PROMPT_DRIVEN_PURPOSE_SET_WITH_FT),
        TEXT_GENERATION_DEEPSEEK(AbstractLLMConnection.PROMPT_DRIVEN_PURPOSE_SET_WITH_FT),
        TEXT_GENERATION_FALCON(AbstractLLMConnection.PROMPT_DRIVEN_PURPOSE_SET_WITH_FT),
        TEXT_GENERATION_MISTRAL(AbstractLLMConnection.PROMPT_DRIVEN_PURPOSE_SET_WITH_FT),
        TEXT_GENERATION_GEMMA(AbstractLLMConnection.PROMPT_DRIVEN_PURPOSE_SET_WITH_FT),
        TEXT_GENERATION_QWEN(AbstractLLMConnection.PROMPT_DRIVEN_PURPOSE_SET_WITH_FT),
        TEXT_GENERATION_GPT(AbstractLLMConnection.PROMPT_DRIVEN_PURPOSE_SET_WITH_FT),
        TEXT_GENERATION_ZEPHYR(AbstractLLMConnection.PROMPT_DRIVEN_PURPOSE_SET_WITH_FT),
        TEXT_GENERATION_LLAMA_2(AbstractLLMConnection.PROMPT_DRIVEN_PURPOSE_SET_WITH_FT),
        TEXT_GENERATION_LLAMA_3(AbstractLLMConnection.PROMPT_DRIVEN_PURPOSE_SET_WITH_FT),
        TEXT_GENERATION_LLAMA_GUARD(AbstractLLMConnection.LLMUsagePurpose.TOXICITY_DETECTION, AbstractLLMConnection.LLMUsagePurpose.GENERIC_COMPLETION),
        TEXT_GENERATION_DOLLY(AbstractLLMConnection.PROMPT_DRIVEN_PURPOSE_SET_WITH_FT),
        TEXT_GENERATION_MPT(AbstractLLMConnection.PROMPT_DRIVEN_PURPOSE_SET_WITH_FT),
        TEXT_GENERATION_PHI_3(AbstractLLMConnection.PROMPT_DRIVEN_PURPOSE_SET_WITH_FT),
        TEXT_CLASSIFICATION_SENTIMENT(AbstractLLMConnection.LLMUsagePurpose.SENTIMENT_ANALYSIS),
        TEXT_CLASSIFICATION_EMOTIONS(AbstractLLMConnection.LLMUsagePurpose.EMOTION_ANALYSIS),
        TEXT_CLASSIFICATION_TOXICITY(AbstractLLMConnection.LLMUsagePurpose.TOXICITY_DETECTION, AbstractLLMConnection.LLMUsagePurpose.CLASSIFICATION_WITH_OTHER_MODEL_PROVIDED_CLASSES),
        TEXT_CLASSIFICATION_PROMPT_INJECTION(AbstractLLMConnection.LLMUsagePurpose.PROMPT_INJECTION_DETECTION, AbstractLLMConnection.LLMUsagePurpose.CLASSIFICATION_WITH_OTHER_MODEL_PROVIDED_CLASSES),
        TEXT_CLASSIFICATION_OTHER(AbstractLLMConnection.LLMUsagePurpose.CLASSIFICATION_WITH_OTHER_MODEL_PROVIDED_CLASSES),
        SUMMARIZATION_GENERIC(AbstractLLMConnection.LLMUsagePurpose.SUMMARIZATION),
        SUMMARIZATION_ROBERTA(AbstractLLMConnection.LLMUsagePurpose.SUMMARIZATION),
        ZSC_GENERIC(AbstractLLMConnection.LLMUsagePurpose.CLASSIFICATION_WITH_USER_PROVIDED_CLASSES),
        T5(AbstractLLMConnection.PROMPT_DRIVEN_PURPOSE_SET),
        TEXT_EMBEDDING(AbstractLLMConnection.LLMUsagePurpose.TEXT_EMBEDDING_EXTRACTION),
        IMAGE_EMBEDDING(AbstractLLMConnection.LLMUsagePurpose.IMAGE_EMBEDDING_EXTRACTION),
        IMAGE_GENERATION_DIFFUSION(AbstractLLMConnection.LLMUsagePurpose.IMAGE_GENERATION);

        public final LinkedHashSet<AbstractLLMConnection.LLMUsagePurpose> purposes;

        private HuggingFaceHandlingMode(LinkedHashSet<AbstractLLMConnection.LLMUsagePurpose> purposes) {
            this.purposes = purposes;
        }

        private HuggingFaceHandlingMode(AbstractLLMConnection.LLMUsagePurpose ... purposes) {
            this(new LinkedHashSet<AbstractLLMConnection.LLMUsagePurpose>(Arrays.asList(purposes)));
        }
    }

    @PyModel
    public static class InferenceSettings {
        public InferenceEngine engine;
        public HuggingFaceLocalConnectionParams.QuantizationMode quantizationMode = HuggingFaceLocalConnectionParams.QuantizationMode.NONE;
        @Nullable
        public String vllmEngine;
        @Nullable
        public Integer maxModelLen;
        @Nullable
        public String kvCacheDType;
        @Nullable
        public Boolean enforceEager;
        @Nullable
        public Integer maxNumSeqs;
        @Nullable
        public Boolean trustRemoteCode;
        @Nullable
        public String dtype;
        @Nullable
        public Integer tensorParallelSize;
        @Nullable
        public Integer pipelineParallelSize;
        @Nullable
        public Boolean enableExpertParallelism;
        @Nullable
        public Boolean enablePrefixCaching;
        @Nullable
        public Double gpuMemoryUtilization;
        @Nullable
        public Boolean enableChunkedPrefill;
        @Nullable
        public Integer limitImagesPerPrompt;
        @Nullable
        public String guidedDecodingBackend;
        @Nullable
        public Boolean enableJsonConstraintsInPrompt;
        @Nullable
        public ToolSettings toolSettings;
        @Nullable
        public ChatTemplateSettings chatTemplateSettings;
        @Nullable
        public String refinerId;
        @Nullable
        public String hfRefinerPath;
        @Nullable
        public Integer defaultHeight;
        @Nullable
        public Integer defaultWidth;
        @Nullable
        public Integer defaultNumInferenceSteps;
        @Nullable
        public Float defaultGuidanceScale;
        @Nullable
        public Float defaultStrength;
        @Nullable
        public Integer maxSequenceLength;
        @Nullable
        public Boolean enableVaeSlicing;
        @Nullable
        public Boolean enableVaeTiling;
        @Nullable
        public DeviceStrategy deviceStrategy;
    }

    public static enum InferenceEngine {
        AUTO,
        VLLM,
        TRANSFORMERS;

    }

    @PyModel
    public static enum DeviceStrategy {
        NONE,
        MODEL_CPU_OFFLOAD,
        SEQUENTIAL_CPU_OFFLOAD;

    }

    @PyModel
    public static class ToolSettings {
        boolean enableTools = false;
        @Nullable
        String toolParser;
    }

    @PyModel
    public static class ChatTemplateSettings {
        boolean overrideChatTemplate = false;
        @Nullable
        String chatTemplate;
    }

    public static class HFLocalModelPresetFacetValue {
        public String id;
        public String name;
        public String description;
    }

    public static class HFLocalModelPresetFacet {
        public String name;
        public String description;
        public List<HFLocalModelPresetFacetValue> values = new ArrayList<HFLocalModelPresetFacetValue>();
    }

    public static class HFLocalModelPreset {
        public String id;
        public CustomHFLocalModel model;
        public String description;
        public boolean includeInNewConnections = false;
        public Map<String, List<String>> facets = new HashMap<String, List<String>>();
    }

    public static class HFLocalPresetsConfig {
        public List<HFLocalModelPreset> presets;
        public Map<String, HFLocalModelPresetFacet> facets;

        public static HFLocalPresetsConfig load() throws IOException {
            InputStream modelsInput = HFLocalPresetsConfig.class.getResourceAsStream("hf_presets.json");
            HFLocalPresetsConfig config = (HFLocalPresetsConfig)JSON.parse((InputStream)modelsInput, HFLocalPresetsConfig.class);
            HFLocalModelPresetFacet usagePurposesFacet = new HFLocalModelPresetFacet();
            usagePurposesFacet.name = "Usage purposes";
            usagePurposesFacet.values = Arrays.stream(AbstractLLMConnection.LLMUsagePurpose.values()).map(p -> {
                HFLocalModelPresetFacetValue facetValue = new HFLocalModelPresetFacetValue();
                facetValue.id = p.name();
                facetValue.name = p.displayName;
                return facetValue;
            }).collect(Collectors.toList());
            config.facets.put("usagePurposes", usagePurposesFacet);
            HFLocalModelPresetFacet mainUsagePurposeFacet = new HFLocalModelPresetFacet();
            mainUsagePurposeFacet.name = "Main usage purpose";
            mainUsagePurposeFacet.values = Arrays.stream(AbstractLLMConnection.LLMUsagePurpose.values()).map(p -> {
                HFLocalModelPresetFacetValue facetValue = new HFLocalModelPresetFacetValue();
                facetValue.id = p.name();
                facetValue.name = p.displayName;
                return facetValue;
            }).collect(Collectors.toList());
            config.facets.put("mainUsagePurpose", mainUsagePurposeFacet);
            HashSet distinctPurposes = new HashSet();
            HashSet distinctMainPurposes = new HashSet();
            config.presets.forEach(preset -> {
                preset.model.presetId = preset.id;
                List modelPurposes = preset.model.handlingMode.purposes.stream().map(p -> p.name()).collect(Collectors.toList());
                preset.facets.put("usagePurposes", modelPurposes);
                distinctPurposes.addAll(modelPurposes);
                List mainUsagePurpose = preset.model.handlingMode.purposes.stream().map(p -> p.name()).limit(1L).collect(Collectors.toList());
                preset.facets.put("mainUsagePurpose", mainUsagePurpose);
                distinctMainPurposes.addAll(mainUsagePurpose);
            });
            usagePurposesFacet.values.removeIf(v -> !distinctPurposes.contains(v.id));
            mainUsagePurposeFacet.values.removeIf(v -> !distinctMainPurposes.contains(v.id));
            return config;
        }
    }
}

