/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.apideployer.deployments;

import com.dataiku.dip.apideployer.datamodel.config.AbstractFullyManagedAPIDeploymentInfra;
import com.dataiku.dip.apideployer.datamodel.config.SageMakerAPIDeployment;
import com.dataiku.dip.apideployer.datamodel.config.SageMakerAPIDeploymentInfra;
import com.dataiku.dip.apideployer.deployments.FullyManagedAPIServiceDeploymentConfigManager;
import com.dataiku.dip.connections.ConnectionsDAO;
import com.dataiku.dip.connections.EC2Connection;
import com.dataiku.dip.coremodel.InfoMessage;
import com.dataiku.dip.datasets.FSProviderCodes;
import com.dataiku.dip.datasets.fs.ChrootUtils;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.exceptions.DSSIllegalArgumentException;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.PathUtils;
import com.dataiku.dip.variables.VariablesContext;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.sagemaker.model.CaptureContentTypeHeader;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.sagemaker.model.CaptureMode;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.sagemaker.model.CaptureOption;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.sagemaker.model.ContainerDefinition;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.sagemaker.model.ContainerMode;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.sagemaker.model.DataCaptureConfig;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.sagemaker.model.DeploymentConfig;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.sagemaker.model.ProductionVariant;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.sagemaker.model.TrafficRoutingConfig;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.sagemaker.model.VpcConfig;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Map;
import javax.annotation.Nullable;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang3.ObjectUtils;
import org.springframework.beans.factory.annotation.Autowired;

public class SageMakerDeploymentConfigManager
extends FullyManagedAPIServiceDeploymentConfigManager {
    private final SageMakerAPIDeployment deployment;
    private final SageMakerAPIDeploymentInfra infra;
    @Autowired
    private TransactionService transactionService;
    @Autowired
    private ConnectionsDAO connectionsDAO;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.deployer.deployment.sagemaker.config-manager");

    public SageMakerDeploymentConfigManager(SageMakerAPIDeployment deployment, SageMakerAPIDeploymentInfra infra, VariablesContext vc) {
        super(vc);
        this.deployment = deployment;
        this.infra = infra;
        SpringUtils.getInstance().autowire((Object)this);
    }

    @Override
    public SageMakerAPIDeployment getDeployment() {
        return this.deployment;
    }

    @Override
    public SageMakerAPIDeploymentInfra getInfra() {
        return this.infra;
    }

    public String generateResourceName() {
        return this.generateDeploymentConfigName(63, "dss", "-");
    }

    public ContainerDefinition getContainerDefinition(String containerImageUri) {
        Map<String, String> environment = this.getEnvironmentVariablesMap();
        return (ContainerDefinition)ContainerDefinition.builder().mode(ContainerMode.SINGLE_MODEL).image(containerImageUri).environment(environment).build();
    }

    public Boolean getEnableNetworkIsolation() {
        return this.infra.enableNetworkIsolation;
    }

    public String getExecutionRoleArn() {
        return this.infra.executionRoleArn;
    }

    public VpcConfig getVpcConfig() {
        VpcConfig vpcConfig = null;
        if (!this.infra.vpcConfig.securityGroupIds.isEmpty() || !this.infra.vpcConfig.subnets.isEmpty()) {
            vpcConfig = (VpcConfig)VpcConfig.builder().securityGroupIds(this.infra.vpcConfig.securityGroupIds).subnets(this.infra.vpcConfig.subnets).build();
        }
        return vpcConfig;
    }

    private String getFullS3Uri(EC2Connection connection) throws DKUSecurityException {
        String basePath = this.getVC().expand(this.infra.dataCaptureConfig.s3BasePath);
        String pathInBucket = PathUtils.makeLeadingNoTrailing((String)PathUtils.canonical((String)ChrootUtils.getChrootedPath(connection.params.chroot, basePath, false)));
        if (pathInBucket.contains("/../") || pathInBucket.contains("/./")) {
            throw new DKUSecurityException("`.` and `..` segments not permitted in S3 path").withCode((InfoMessage.MessageCode)FSProviderCodes.ERR_FSPROVIDER_ILLEGAL_PATH);
        }
        return String.format("s3://%s%s", connection.params.getDefaultManagedBucket(), pathInBucket);
    }

    private EC2Connection getS3Connection(AuthCtx authCtx, String s3Connection) throws IOException, DKUSecurityException {
        EC2Connection connection;
        if (StringUtils.isBlank((String)s3Connection)) {
            throw new DSSIllegalArgumentException("S3 connection not specified in the infrastructure.");
        }
        try (Transaction t = this.transactionService.retrieveOrBeginRead();){
            connection = this.connectionsDAO.getMandatoryConnectionAs(authCtx, s3Connection, EC2Connection.class);
            if (!connection.isFreelyUsableBy(authCtx)) {
                throw new DKUSecurityException("User is not allowed to access connection " + s3Connection);
            }
            if (StringUtils.isBlank((String)connection.params.getDefaultManagedBucket())) {
                throw new DSSIllegalArgumentException("S3 Connection used for data capture must have a default bucket for managed datasets and folders.");
            }
        }
        return connection;
    }

    @Nullable
    public DataCaptureConfig getDataCaptureConfig(AuthCtx authCtx) throws IOException, DKUSecurityException {
        if (!this.infra.dataCaptureConfig.isEnabled()) {
            logger.debugV("Data capture is not enabled for infrastructure %s.", new Object[]{this.infra.id});
            return null;
        }
        logger.debugV("Data capture enabled for infrastructure %s. Configuring it for deployment %s.", new Object[]{this.infra.id, this.deployment.id});
        EC2Connection connection = this.getS3Connection(authCtx, this.infra.dataCaptureConfig.s3Connection);
        ArrayList<CaptureOption> captureOptions = new ArrayList<CaptureOption>();
        if (this.infra.dataCaptureConfig.captureInput) {
            captureOptions.add((CaptureOption)CaptureOption.builder().captureMode(CaptureMode.INPUT).build());
        }
        if (this.infra.dataCaptureConfig.captureOutput) {
            captureOptions.add((CaptureOption)CaptureOption.builder().captureMode(CaptureMode.OUTPUT).build());
        }
        CaptureContentTypeHeader.Builder captureContentTypeHeader = CaptureContentTypeHeader.builder();
        if (!this.infra.dataCaptureConfig.captureContentTypeHeader.csvContentTypes.isEmpty()) {
            captureContentTypeHeader.csvContentTypes(this.infra.dataCaptureConfig.captureContentTypeHeader.csvContentTypes);
        }
        if (!this.infra.dataCaptureConfig.captureContentTypeHeader.jsonContentTypes.isEmpty()) {
            captureContentTypeHeader.jsonContentTypes(this.infra.dataCaptureConfig.captureContentTypeHeader.jsonContentTypes);
        }
        DataCaptureConfig.Builder dataCaptureConfig = DataCaptureConfig.builder().enableCapture(Boolean.valueOf(true)).initialSamplingPercentage(this.infra.dataCaptureConfig.initialSamplingPercentage).captureOptions(captureOptions).destinationS3Uri(this.getFullS3Uri(connection)).captureContentTypeHeader((CaptureContentTypeHeader)captureContentTypeHeader.build());
        String kmsKeyid = this.infra.dataCaptureConfig.kmsKeyId;
        if (StringUtils.isNotBlank((String)kmsKeyid)) {
            dataCaptureConfig.kmsKeyId(kmsKeyid);
        }
        return (DataCaptureConfig)dataCaptureConfig.build();
    }

    public ProductionVariant getProductionVariant(String modelName) {
        ProductionVariant.Builder productionVariant = ProductionVariant.builder().modelName(modelName).variantName(modelName).initialVariantWeight(Float.valueOf(1.0f));
        if (AbstractFullyManagedAPIDeploymentInfra.EndpointType.REAL_TIME.equals((Object)this.infra.endpointType)) {
            SageMakerAPIDeploymentInfra.RealTimeConfig realTimeConfig = this.deployment.getRealTimeConfig(this.infra);
            productionVariant.containerStartupHealthCheckTimeoutInSeconds(realTimeConfig.containerStartupHealthCheckTimeoutInSeconds).initialInstanceCount(realTimeConfig.initialInstanceCount).instanceType(realTimeConfig.instanceType);
        } else {
            SageMakerAPIDeploymentInfra.ServerlessConfig serverlessConfig = this.deployment.getServerlessConfig(this.infra);
            productionVariant.serverlessConfig(c2 -> c2.maxConcurrency(serverlessConfig.maxConcurrency).memorySizeInMB(serverlessConfig.memorySizeInMB));
        }
        return (ProductionVariant)productionVariant.build();
    }

    public String getKmsKeyId() {
        return (String)ObjectUtils.defaultIfNull((Object)this.infra.kmsKeyId, (Object)"");
    }

    public DeploymentConfig getDeploymentConfig() {
        SageMakerAPIDeploymentInfra.TrafficRoutingConfig trafficRoutingConfig = this.infra.blueGreenUpdatePolicy.trafficRoutingConfig;
        if (StringUtils.isBlank((String)trafficRoutingConfig.type)) {
            return null;
        }
        TrafficRoutingConfig.Builder trafficRoutingConfigBuilder = TrafficRoutingConfig.builder().type(trafficRoutingConfig.type).waitIntervalInSeconds(trafficRoutingConfig.waitIntervalInSeconds);
        switch (trafficRoutingConfig.type) {
            case "CANARY": {
                trafficRoutingConfigBuilder.canarySize(s -> s.type(trafficRoutingConfig.capacitySize.type).value(trafficRoutingConfig.capacitySize.value));
                break;
            }
            case "LINEAR": {
                trafficRoutingConfigBuilder.linearStepSize(s -> s.type(trafficRoutingConfig.capacitySize.type).value(trafficRoutingConfig.capacitySize.value));
                break;
            }
        }
        return (DeploymentConfig)DeploymentConfig.builder().blueGreenUpdatePolicy(p -> p.maximumExecutionTimeoutInSeconds(this.infra.blueGreenUpdatePolicy.maximumExecutionTimeoutInSeconds).terminationWaitInSeconds(this.infra.blueGreenUpdatePolicy.terminationWaitInSeconds).trafficRoutingConfiguration((TrafficRoutingConfig)trafficRoutingConfigBuilder.build())).build();
    }
}

