/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.connections;

import com.dataiku.dip.ProxySettings;
import com.dataiku.dip.activity.UsageSummaryModel;
import com.dataiku.dip.analysis.model.llm.LLMModelSnippetData;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.AbstractSQLConnection;
import com.dataiku.dip.connections.ConnectionWithAzureAuthCredentials;
import com.dataiku.dip.connections.ConnectionWithBasicCredential;
import com.dataiku.dip.connections.OpenAIConnection;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.llm.LLMModelHandle;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.online.openai.OpenAIChatAPI;
import com.dataiku.dip.llm.online.openai.OpenAIImageHandling;
import com.dataiku.dip.llm.online.openai.OpenAIPricing;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.PasswordEncryptionService;
import com.dataiku.dip.server.services.ConnectionsTestService;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.Params;
import com.dataiku.dip.variables.VariablesContext;
import java.io.IOException;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.lang.StringUtils;

public class AzureOpenAIConnection
extends AbstractLLMConnection<AzureOpenAIModel, HardcodedAzureOpenAIModel, AbstractLLMConnection.CustomModel>
implements ConnectionWithAzureAuthCredentials {
    private static final String DEFAULT_OAUTH_SCOPE = "https://cognitiveservices.azure.com/.default";
    private static final String DEFAULT_OAUTH_USER_IMPERSONATION_SCOPE = "https://cognitiveservices.azure.com/user_impersonation offline_access";
    public AzureOpenAIConnectionParams params = new AzureOpenAIConnectionParams();
    public static final String connectionType = "AzureOpenAI";
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.connections.azureopenai");

    @Override
    public AbstractLLMConnection.AbstractLLMConnectionParams getLLMConnectionParams() {
        return this.params;
    }

    @Override
    protected DKULogger getLogger() {
        return logger;
    }

    @Override
    public ConnectionWithAzureAuthCredentials.IAzureAuthParams getAzureAuth2NonResolvedParams() {
        return this.params;
    }

    @Override
    protected List<AbstractLLMConnection.CustomModel> listRawCustomModels() {
        ArrayList<AbstractLLMConnection.CustomModel> customModels = new ArrayList<AbstractLLMConnection.CustomModel>();
        customModels.addAll(this.params.availableDeployments);
        customModels.addAll(this.params.customModels);
        return customModels;
    }

    @Override
    protected AzureOpenAIModel loadRawCustomModel(AbstractLLMConnection.CustomModel rawCustomModel) {
        AzureOpenAIModel model;
        if (rawCustomModel instanceof AzureOpenAIDeployment) {
            model = ((AzureOpenAIDeployment)rawCustomModel).toModel();
        } else if (rawCustomModel instanceof CustomAzureOpenAIModel) {
            model = ((CustomAzureOpenAIModel)rawCustomModel).toModel();
        } else {
            throw new IllegalStateException("Unexpected custom model: " + String.valueOf(rawCustomModel));
        }
        this.loadDefaultCustomModelSettings(rawCustomModel, model);
        return model;
    }

    @Override
    public ProxySettings getProxySettingsFromConnection() {
        return this.getProxySettings();
    }

    @Override
    public String getDefaultAuthScope() {
        return DEFAULT_OAUTH_SCOPE;
    }

    @Override
    public String getDefaultAuthUserImpersonationScope() {
        return DEFAULT_OAUTH_USER_IMPERSONATION_SCOPE;
    }

    @Override
    public Params getDkuPropertiesAsParams() {
        return AbstractSQLConnection.CustomDatabaseProperty.toParams(this.getDkuProperties());
    }

    @Override
    public boolean mustResolveOnBackend() {
        return this.hasRefreshTokenRotation() || super.mustResolveOnBackend();
    }

    public String getOauthScope() {
        return this.getDkuPropertiesAsParams().getParam("dku.endpoint.oauth_scope", DEFAULT_OAUTH_SCOPE);
    }

    @Override
    public void expandParametersInPlaceAtDAOLevelUsingGlobalContextOnly(VariablesContext vc) {
    }

    @Override
    protected <T> T getFullyResolvedCredentials_internal(ConnectionWithBasicCredential.CredentialResolutionContext ctx, Class<T> clazz) throws DKUSecurityException, IOException, SQLException {
        assert (clazz.isAssignableFrom(ConnectionWithAzureAuthCredentials.SerializableAzureAuthCredentials.class));
        ConnectionWithAzureAuthCredentials.SerializableAzureAuthCredentials creds = this.getFullyResolvedAzureAuthCredentials(ctx);
        return clazz.cast(creds);
    }

    @Override
    protected boolean isHardcodedModelEnabled(HardcodedAzureOpenAIModel hardcodedModel) {
        return hardcodedModel.allowedModel.apply(this.params);
    }

    @Override
    protected List<HardcodedAzureOpenAIModel> listRawHardcodedModels() {
        return Arrays.asList(HardcodedAzureOpenAIModel.values());
    }

    @Override
    protected AzureOpenAIModel loadRawHardcodedModel(HardcodedAzureOpenAIModel hardcodedModel) {
        AzureOpenAIModel model = hardcodedModel.toModel();
        this.loadDefaultHardcodedModelSettings(hardcodedModel, model);
        return model;
    }

    @Override
    public boolean ignoreConnectionTest(LLMStructuredRef llmRef) {
        return llmRef.deployment == null;
    }

    @Override
    protected String generateUniqueModelIdentifier(AzureOpenAIModel model) {
        return (model.isDeployment() ? "deployment:" : "model:") + model.getId();
    }

    @Override
    public String getType() {
        return connectionType;
    }

    @Override
    public void encryptFields(PasswordEncryptionService cryptoService, GeneralSettingsDAO.SecuritySettings securitySettings) {
        if (securitySettings.secureSecretKeys) {
            this.params.apiKey = cryptoService.encryptIfNotEncryptedOrEmpty(this.params.apiKey);
            this.params.appSecret = cryptoService.encryptIfNotEncryptedOrEmpty(this.params.appSecret);
            this.params.encryptProperties(this.params.customHeaders, cryptoService);
        }
    }

    @Override
    public void decryptFields(PasswordEncryptionService cryptoService) {
        this.params.apiKey = cryptoService.decryptIfEncrypted(this.params.apiKey);
        this.params.appSecret = cryptoService.decryptIfEncrypted(this.params.appSecret);
        this.params.decryptProperties(this.params.customHeaders, cryptoService);
    }

    @Override
    public void fillModelsForGlobalSummaryReport(UsageSummaryModel.LLMConnectionSummary lcs) {
        for (AzureOpenAIDeployment depl : this.params.availableDeployments) {
            if (StringUtils.isBlank((String)depl.underlyingModelName)) {
                if (depl.deploymentType == OpenAIConnection.OpenAIModelType.TEXT_EMBEDDING_EXTRACTION) {
                    ++lcs.enabledTextEmbeddingCustomModels;
                    continue;
                }
                if (depl.deploymentType == OpenAIConnection.OpenAIModelType.IMAGE_GENERATION) continue;
                ++lcs.enabledCompletionCustomModels;
                continue;
            }
            if (depl.deploymentType == OpenAIConnection.OpenAIModelType.TEXT_EMBEDDING_EXTRACTION) {
                lcs.enabledTextEmbeddingStandardModels.add(depl.underlyingModelName);
                continue;
            }
            if (depl.deploymentType == OpenAIConnection.OpenAIModelType.IMAGE_GENERATION) continue;
            lcs.enabledCompletionStandardModels.add(depl.underlyingModelName);
        }
    }

    @Override
    public Map<String, Object> getConsistencyCheckables() {
        Map<String, Object> consistencyCheckables = super.getConsistencyCheckables();
        consistencyCheckables.put("Resource name / URL", this.params.resourceName);
        return consistencyCheckables;
    }

    @Override
    public ConnectionsTestService.ConnectionTestResult testConnection(AuthCtx authCtx, ConnectionsTestService connectionsTestService) throws Exception {
        return connectionsTestService.testAzureOpenAI(this, authCtx);
    }

    public static class AzureOpenAIConnectionParams
    extends AbstractLLMConnection.AbstractLLMConnectionParams
    implements ConnectionWithAzureAuthCredentials.IAzureAuthParams {
        public AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings = new AbstractLLMConnection.HTTPBasedLLMNetworkSettings();
        public String resourceName;
        public List<AzureOpenAIDeployment> availableDeployments = new ArrayList<AzureOpenAIDeployment>();
        public List<CustomAzureOpenAIModel> customModels = new ArrayList<CustomAzureOpenAIModel>();
        public String apiKey;
        public AuthType authType = AuthType.API_KEY;
        public String tenantId;
        public String appId;
        public String appSecret;
        public String tokenEndpoint;
        public String authorizationEndpoint;
        public boolean refreshTokenRotation;
        public int maxParallelism = 8;
        public List<AbstractSQLConnection.CustomDatabaseProperty> customHeaders = new ArrayList<AbstractSQLConnection.CustomDatabaseProperty>();
        public boolean allowDavinciFinetuning;
        public boolean allowBabbageFinetuning;
        public boolean allowGPT35Turbo0125Finetuning;
        public boolean allowGPT35Turbo0613Finetuning;
        public boolean allowGPT35Turbo1106Finetuning;
        public String subscriptionId;
        public String resourceGroup;
        public String azureMLConnection;

        @Override
        public ConnectionWithAzureAuthCredentials.AuthType getAuthType() {
            switch (this.authType) {
                case API_KEY: {
                    return ConnectionWithAzureAuthCredentials.AuthType.KEY;
                }
                case OAUTH2_APP: {
                    return ConnectionWithAzureAuthCredentials.AuthType.OAUTH2_APP;
                }
            }
            throw new IllegalArgumentException("Unhandled auth mode " + String.valueOf((Object)this.authType));
        }

        @Override
        public String getKey() {
            return this.apiKey;
        }

        @Override
        public String getOauth2TenantId() {
            return this.tenantId;
        }

        @Override
        public String getOauth2AppId() {
            return this.appId;
        }

        @Override
        public String getOauth2AppSecret() {
            return this.appSecret;
        }

        @Override
        public String getOauth2AuthorizationEndpoint() {
            return this.authorizationEndpoint;
        }

        @Override
        public String getOauth2TokenEndpoint() {
            return this.tokenEndpoint;
        }

        @Override
        public boolean getRefreshTokenRotation() {
            return this.refreshTokenRotation;
        }
    }

    public static class AzureOpenAIDeployment
    extends AbstractLLMConnection.CustomModel<AzureOpenAIModel> {
        public String name;
        public OpenAIConnection.OpenAIModelType deploymentType;
        public AzureOpenAIMaxTokensAPIMode maxTokensAPIMode = AzureOpenAIMaxTokensAPIMode.LEGACY;
        public String underlyingModelName;
        @Nullable
        public OpenAIImageHandling imageHandlingMode = OpenAIImageHandling.DALL_E_3;
        @Nullable
        public OpenAIChatAPI api = OpenAIChatAPI.CHAT_COMPLETIONS;

        @Override
        public AzureOpenAIModel toModel() {
            if (this.deploymentType == null) {
                throw new IllegalStateException("Undefined deployment type for Azure Open AI deployment " + this.name);
            }
            AzureOpenAIModel model = new AzureOpenAIModel();
            model.loadFromCustomModel(this);
            model.id = this.name;
            model.deploymentId = this.name;
            model.displayName = this.name;
            model.deploymentType = this.deploymentType;
            model.maxTokensAPIMode = this.maxTokensAPIMode;
            model.useChatApi = this.deploymentType.useChatAPI;
            model.canBeFineTuned = false;
            model.underlyingModelName = this.underlyingModelName;
            if (this.deploymentType == OpenAIConnection.OpenAIModelType.TEXT_EMBEDDING_EXTRACTION) {
                model.embeddingCost = this.embeddingCost != null ? this.embeddingCost : OpenAIPricing.getAzureOpenAIEmbeddingCostPer1KTokens(this.underlyingModelName);
            } else {
                model.promptCost = this.promptCost != null ? this.promptCost : OpenAIPricing.getAzureOpenAIPromptCostPer1KTokens(this.underlyingModelName);
                model.completionCost = this.completionCost != null ? this.completionCost : OpenAIPricing.getAzureOpenAICompletionCostPer1KTokens(this.underlyingModelName);
            }
            if (this.deploymentType == OpenAIConnection.OpenAIModelType.IMAGE_GENERATION && this.imageHandlingMode != null) {
                model.imageHandlingMode = this.imageHandlingMode;
            }
            if (this.deploymentType == OpenAIConnection.OpenAIModelType.COMPLETION_CHAT || this.deploymentType == OpenAIConnection.OpenAIModelType.COMPLETION_CHAT_MULTIMODAL || this.deploymentType == OpenAIConnection.OpenAIModelType.COMPLETION_CHAT_NO_SYSTEM_PROMPT && this.api != null) {
                model.api = this.api;
            }
            return model;
        }
    }

    public static class AzureOpenAIModel
    extends AbstractLLMConnection.BaseModel
    implements LLMModelHandle.FineTuneableModel<AzureOpenAIModel> {
        @Nullable
        public String baseModelId;
        @Nullable
        public OpenAIConnection.OpenAIModelType deploymentType;
        @Nullable
        public AzureOpenAIMaxTokensAPIMode maxTokensAPIMode;
        @Nullable
        public String deploymentId;
        public String underlyingModelName;
        public boolean useChatApi;
        public boolean isDKUFineTuned = false;
        @Nullable
        public OpenAIImageHandling imageHandlingMode;
        @Nullable
        public OpenAIChatAPI api;

        @Override
        public LLMStructuredRef asStructuredRef(String connection) {
            if (this.isDeployment()) {
                return LLMStructuredRef.forAzureOpenAIConnection(connection, this.getId(), LLMStructuredRef.LLMType.AZURE_OPENAI_DEPLOYMENT);
            }
            return LLMStructuredRef.forAzureOpenAIConnection(connection, this.getId(), LLMStructuredRef.LLMType.AZURE_OPENAI_MODEL);
        }

        public boolean isDeployment() {
            return this.deploymentType != null;
        }

        @Override
        public AbstractLLMConnection.ModelCapabilities getModelCapabilities() {
            AbstractLLMConnection.ModelCapabilities capabilities = new AbstractLLMConnection.ModelCapabilities();
            capabilities.canGenerateCrossLanguageOutput = true;
            capabilities.handlesSystemMessage = this.useChatApi;
            capabilities.supportsImageInputs = this.canBeUsedForPurpose(AbstractLLMConnection.LLMUsagePurpose.IMAGE_INPUT);
            capabilities.temperatureRange = OpenAIConnection.TEMPERATURE_RANGE;
            capabilities.topKRange = OpenAIConnection.TOP_K_RANGE;
            return capabilities;
        }

        @Override
        public Optional<String> getInvalidityReason() {
            if (StringUtils.isBlank((String)this.id)) {
                return Optional.of("Empty model/deployment name");
            }
            return Optional.empty();
        }

        @Override
        public boolean canBeUsedForPurpose(@Nonnull AbstractLLMConnection.LLMUsagePurpose purpose) {
            if (this.isDeployment()) {
                return this.deploymentType.matchesPurpose(purpose);
            }
            return this.canBeFineTuned() && purpose.equals((Object)AbstractLLMConnection.LLMUsagePurpose.FINE_TUNING);
        }

        @Override
        public AzureOpenAIModel toFineTunedModel(LLMModelSnippetData snippetData) {
            AzureOpenAIModel fineTunedModel = new AzureOpenAIModel();
            fineTunedModel.id = snippetData.llmSMInfo.remoteModelId;
            if (snippetData.deployment != null) {
                fineTunedModel.deploymentId = snippetData.deployment.deploymentId;
            }
            fineTunedModel.useChatApi = this.useChatApi;
            fineTunedModel.baseModelId = this.getId();
            fineTunedModel.isDKUFineTuned = true;
            fineTunedModel.canBeFineTuned = true;
            fineTunedModel.displayName = this.displayName;
            fineTunedModel.embeddingSize = this.embeddingSize;
            fineTunedModel.maxTokensLimit = this.maxTokensLimit;
            fineTunedModel.promptCost = this.promptCost;
            fineTunedModel.completionCost = this.completionCost;
            fineTunedModel.embeddingCost = this.embeddingCost;
            return fineTunedModel;
        }

        @Override
        public String getBaseModelId() {
            return this.isDKUFineTuned ? this.baseModelId : null;
        }
    }

    public static class CustomAzureOpenAIModel
    extends AbstractLLMConnection.CustomModel<AzureOpenAIModel> {
        public String id;
        public String displayName;
        public boolean useChatAPI;

        @Override
        public AzureOpenAIModel toModel() {
            AzureOpenAIModel model = new AzureOpenAIModel();
            model.loadFromCustomModel(this);
            model.id = this.id;
            model.underlyingModelName = this.id;
            model.displayName = this.displayName;
            model.canBeFineTuned = this.canBeFineTuned;
            model.useChatApi = this.useChatAPI;
            model.deploymentType = null;
            return model;
        }
    }

    public static enum HardcodedAzureOpenAIModel implements AbstractLLMConnection.IHardcodedConnectionModel<AzureOpenAIModel>
    {
        GPT3_DAVINCI("davinci-002", "GPT 3 - Davinci", false, true, p -> p.allowDavinciFinetuning),
        GPT3_BABBAGE("babbage-002", "GPT 3 - Babbage", false, true, p -> p.allowBabbageFinetuning),
        GPT35_TURBO_0125("gpt-35-turbo-0125", "GPT 3.5 Turbo - 0125", true, true, p -> p.allowGPT35Turbo0125Finetuning),
        GPT35_TURBO_0613("gpt-35-turbo-0613", "GPT 3.5 Turbo - 0613", true, true, p -> p.allowGPT35Turbo0613Finetuning),
        GPT35_TURBO_1106("gpt-35-turbo-1106", "GPT 3.5 Turbo - 1106", true, true, p -> p.allowGPT35Turbo1106Finetuning);

        public final String id;
        public final String displayName;
        public final boolean useChatAPI;
        public final boolean canBeFineTuned;
        public final Function<AzureOpenAIConnectionParams, Boolean> allowedModel;

        private HardcodedAzureOpenAIModel(String id, String displayName, boolean useChatAPI, boolean canBeFineTuned, Function<AzureOpenAIConnectionParams, Boolean> allowedModel) {
            this.id = id;
            this.displayName = displayName;
            this.useChatAPI = useChatAPI;
            this.canBeFineTuned = canBeFineTuned;
            this.allowedModel = allowedModel;
        }

        @Override
        public AzureOpenAIModel toModel() {
            AzureOpenAIModel model = new AzureOpenAIModel();
            model.id = this.id;
            model.displayName = this.displayName;
            model.underlyingModelName = this.id;
            model.useChatApi = this.useChatAPI;
            model.canBeFineTuned = this.canBeFineTuned;
            model.deploymentType = null;
            return model;
        }
    }

    public static enum AuthType {
        OAUTH2_APP,
        API_KEY;

    }

    public static enum AzureOpenAIMaxTokensAPIMode {
        MODERN,
        LEGACY;

    }
}

