/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.connections;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.ConnectionWithBasicCredential;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.PasswordEncryptionService;
import com.dataiku.dip.security.model.ICredentialsService;
import com.dataiku.dip.server.services.ConnectionsTestService;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.NotImplementedException;
import com.dataiku.dip.variables.VariablesContext;
import com.google.common.collect.ImmutableMap;
import java.io.IOException;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import javax.annotation.Nonnull;

public class StabilityAIConnection
extends AbstractLLMConnection<StabilityAIModel, HardcodedStabilityAIModel, AbstractLLMConnection.CustomModel<StabilityAIModel>> {
    public StabilityAIConnectionParams params = new StabilityAIConnectionParams();
    public static final String connectionType = "StabilityAI";
    private static final DKULogger logger = DKULogger.getLogger((String)"dip.connections.stabilityai");

    @Override
    public AbstractLLMConnection.AbstractLLMConnectionParams getLLMConnectionParams() {
        return this.params;
    }

    @Override
    protected DKULogger getLogger() {
        return logger;
    }

    @Override
    public void expandParametersInPlaceAtDAOLevelUsingGlobalContextOnly(VariablesContext vc) {
    }

    @Override
    protected <T> T getFullyResolvedCredentials_internal(ConnectionWithBasicCredential.CredentialResolutionContext ctx, Class<T> clazz) throws DKUSecurityException, IOException, SQLException {
        assert (clazz.isAssignableFrom(ICredentialsService.BasicCredential.class));
        ICredentialsService.BasicCredential creds = new ICredentialsService.BasicCredential("", this.params.apiKey);
        return clazz.cast(creds);
    }

    @Override
    public String getType() {
        return connectionType;
    }

    @Override
    public void encryptFields(PasswordEncryptionService cryptoService, GeneralSettingsDAO.SecuritySettings securitySettings) {
        this.params.apiKey = cryptoService.encryptIfNotEncryptedOrEmpty(this.params.apiKey);
    }

    @Override
    public void decryptFields(PasswordEncryptionService cryptoService) {
        this.params.apiKey = cryptoService.decryptIfEncrypted(this.params.apiKey);
    }

    @Override
    protected boolean isHardcodedModelEnabled(HardcodedStabilityAIModel stabilityAIModel) {
        return stabilityAIModel.allowedModel.apply(this.params);
    }

    @Override
    protected List<HardcodedStabilityAIModel> listRawHardcodedModels() {
        return Arrays.asList(HardcodedStabilityAIModel.values());
    }

    @Override
    protected StabilityAIModel loadRawHardcodedModel(HardcodedStabilityAIModel hardcodedModel) {
        StabilityAIModel model = hardcodedModel.toModel();
        this.loadDefaultHardcodedModelSettings(hardcodedModel, model);
        return model;
    }

    @Override
    public ConnectionsTestService.ConnectionTestResult testConnection(AuthCtx authCtx, ConnectionsTestService connectionsTestService) throws Exception {
        throw new NotImplementedException();
    }

    public static class StabilityAIConnectionParams
    extends AbstractLLMConnection.AbstractLLMConnectionParams {
        public String apiKey;
        public int maxParallelism = 1;
        public AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings = new AbstractLLMConnection.HTTPBasedLLMNetworkSettings();
        public boolean allowStableImageCore = true;
        public boolean allowStableDiffusion30Medium = true;
        public boolean allowStableDiffusion30Large = true;
        public boolean allowStableDiffusion30LargeTurbo = true;
        public boolean allowStableDiffusion35Medium = true;
        public boolean allowStableDiffusion35Large = true;
        public boolean allowStableDiffusion35LargeTurbo = true;
        public boolean allowStableImageUltra = true;
    }

    public static enum HardcodedStabilityAIModel implements AbstractLLMConnection.IHardcodedConnectionModel<StabilityAIModel>
    {
        STABLE_IMAGE_CORE("stable-image-core", "Stable Image core", p -> p.allowStableImageCore),
        STABLE_DIFFUSION_30_MEDIUM("sd3-medium", "Stable Diffusion 3.0 Medium", p -> p.allowStableDiffusion30Medium),
        STABLE_DIFFUSION_30_LARGE("sd3-large", "Stable Diffusion 3.0 Large", p -> p.allowStableDiffusion30Large),
        STABLE_DIFFUSION_30_LARGE_TURBO("sd3-large-turbo", "Stable Diffusion 3.0 Large Turbo", p -> p.allowStableDiffusion30LargeTurbo),
        STABLE_DIFFUSION_35_MEDIUM("sd3.5-medium", "Stable Diffusion 3.5 Medium", p -> p.allowStableDiffusion35Medium),
        STABLE_DIFFUSION_35_LARGE("sd3.5-large", "Stable Diffusion 3.5 Large", p -> p.allowStableDiffusion35Large),
        STABLE_DIFFUSION_35_LARGE_TURBO("sd3.5-large-turbo", "Stable Diffusion 3.5 Large Turbo", p -> p.allowStableDiffusion35LargeTurbo),
        STABLE_IMAGE_ULTRA("stable-image-ultra", "Stable Image Ultra", p -> p.allowStableImageUltra);

        public final String id;
        public final String displayName;
        Function<StabilityAIConnectionParams, Boolean> allowedModel;

        private HardcodedStabilityAIModel(String id, String displayName, Function<StabilityAIConnectionParams, Boolean> allowedModel) {
            this.id = id;
            this.displayName = displayName;
            this.allowedModel = allowedModel;
        }

        @Override
        public StabilityAIModel toModel() {
            StabilityAIModel model = new StabilityAIModel();
            model.id = this.id;
            model.displayName = this.displayName;
            return model;
        }
    }

    public static class StabilityAIModel
    extends AbstractLLMConnection.BaseModel {
        private static final Map<String, Double> modelsCreditCosts = ImmutableMap.builder().put((Object)HardcodedStabilityAIModel.STABLE_IMAGE_CORE.id, (Object)3.0).put((Object)HardcodedStabilityAIModel.STABLE_DIFFUSION_30_MEDIUM.id, (Object)3.5).put((Object)HardcodedStabilityAIModel.STABLE_DIFFUSION_30_LARGE_TURBO.id, (Object)4.0).put((Object)HardcodedStabilityAIModel.STABLE_DIFFUSION_30_LARGE.id, (Object)6.5).put((Object)HardcodedStabilityAIModel.STABLE_DIFFUSION_35_MEDIUM.id, (Object)3.5).put((Object)HardcodedStabilityAIModel.STABLE_DIFFUSION_35_LARGE_TURBO.id, (Object)4.0).put((Object)HardcodedStabilityAIModel.STABLE_DIFFUSION_35_LARGE.id, (Object)6.5).put((Object)HardcodedStabilityAIModel.STABLE_IMAGE_ULTRA.id, (Object)8.0).build();

        @Override
        public AbstractLLMConnection.ModelCapabilities getModelCapabilities() {
            AbstractLLMConnection.ModelCapabilities capabilities = new AbstractLLMConnection.ModelCapabilities();
            return capabilities;
        }

        @Override
        LLMStructuredRef asStructuredRef(String connection) {
            return LLMStructuredRef.forStabilityAIConnection(connection, this.getId());
        }

        @Override
        public boolean canBeUsedForPurpose(@Nonnull AbstractLLMConnection.LLMUsagePurpose purpose) {
            return purpose == AbstractLLMConnection.LLMUsagePurpose.IMAGE_GENERATION;
        }

        @Override
        public double getEstimatedImageGenerationCost(LLMClient.ImageGenerationQuery query) {
            if (!modelsCreditCosts.containsKey(this.id)) {
                logger.warn((Object)("Unknown pricing for StabilityAI model: " + this.id));
                return 0.0;
            }
            double creditsPerImage = modelsCreditCosts.getOrDefault(this.id, 0.0);
            if (query.originalImage != null && query.originalImageEditionMode == LLMClient.ImageGenerationEditionMode.CONTROLNET_SKETCH) {
                creditsPerImage = 3.0;
            }
            double creditCostUSD = 0.01;
            return creditsPerImage * creditCostUSD;
        }
    }
}

