/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.deployer.apideployer.datamodel.config;

import com.dataiku.dip.containers.exec.ContainerExecUtils;
import com.dataiku.dip.deployer.apideployer.datamodel.config.AbstractFullyManagedAPIDeploymentInfra;
import com.dataiku.dip.deployer.apideployer.deployments.APIServiceDeploymentsService;
import com.dataiku.dip.deployer.apideployer.infra.ApiNodeInfraManager;
import com.dataiku.dip.deployer.apideployer.infra.SageMakerInfraManager;
import com.dataiku.dip.deployer.common.datamodel.config.AbstractDeploymentInfra;
import com.dataiku.dip.externalinfras.sagemaker.SageMakerInputValidator;
import com.dataiku.dip.externalinfras.sagemaker.SageMakerUtils;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.ErrorContext;
import com.dataiku.dss.shadelibawssk2.software.amazon.awssdk.services.sagemaker.model.ProductionVariantInstanceType;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import javax.annotation.Nonnull;

public class SageMakerAPIDeploymentInfra
extends AbstractFullyManagedAPIDeploymentInfra {
    public String awsRegion;
    public String executionRoleArn;
    public AbstractFullyManagedAPIDeploymentInfra.EndpointType endpointType = AbstractFullyManagedAPIDeploymentInfra.EndpointType.REAL_TIME;
    public String kmsKeyId;
    @Nonnull
    public RealTimeConfig realTimeConfig = new RealTimeConfig();
    @Nonnull
    public ServerlessConfig serverlessConfig = new ServerlessConfig();
    public Boolean enableNetworkIsolation = false;
    public VPCConfig vpcConfig = new VPCConfig();
    public BlueGreenUpdatePolicy blueGreenUpdatePolicy = new BlueGreenUpdatePolicy();
    public DataCaptureConfig dataCaptureConfig = new DataCaptureConfig();
    public boolean allowOverrideEndpointName = false;

    private SageMakerAPIDeploymentInfra() {
        this.prePushMode = ContainerExecUtils.ContainerBaseConfig.DockerPrepushHookMode.ECR;
    }

    SageMakerAPIDeploymentInfra(String id, String stage, String userIdentifier, AbstractDeploymentInfra.GovernCheckPolicy governCheckPolicy) {
        super(id, stage, userIdentifier, governCheckPolicy);
        this.prePushMode = ContainerExecUtils.ContainerBaseConfig.DockerPrepushHookMode.ECR;
    }

    @Override
    public AbstractFullyManagedAPIDeploymentInfra.EndpointType getEndpointType() {
        return this.endpointType;
    }

    @Override
    public AbstractDeploymentInfra.InfraType getInfraType() {
        return AbstractDeploymentInfra.InfraType.SAGEMAKER;
    }

    @Override
    public void trySetDefaultValues_NT() {
        this.awsRegion = SageMakerUtils.getDefaultConfiguredRegion_NT();
    }

    @Override
    public ApiNodeInfraManager getInfraManager(DKULogger logger) {
        return new SageMakerInfraManager(this);
    }

    @Override
    public void verifyFields(AuthCtx authCtx, AbstractDeploymentInfra oldInfra) throws IOException {
        super.verifyFields(authCtx, oldInfra);
        APIServiceDeploymentsService apiServiceDeploymentsService = (APIServiceDeploymentsService)SpringUtils.getBean(APIServiceDeploymentsService.class);
        if (oldInfra instanceof SageMakerAPIDeploymentInfra) {
            if (apiServiceDeploymentsService.hasDeploymentForInfra(this.id)) {
                SageMakerAPIDeploymentInfra oldSageMakerInfra = (SageMakerAPIDeploymentInfra)oldInfra;
                this.verifyFieldExistingDeployments(oldSageMakerInfra.awsRegion, this.awsRegion, "AWS Region");
                this.verifyFieldExistingDeployments((Object)oldSageMakerInfra.endpointType, (Object)this.endpointType, "Mode");
            }
        } else {
            throw ErrorContext.iaef((String)"Old infra with id %s is not a SageMaker infra.", (Object)oldInfra.id, (Object[])new Object[0]);
        }
        SageMakerInputValidator.validateRegionName(this.awsRegion);
    }

    public static class RealTimeConfig {
        public String instanceType = ProductionVariantInstanceType.ML_M5_XLARGE.toString();
        public Integer initialInstanceCount = 1;
        public Integer containerStartupHealthCheckTimeoutInSeconds;

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            RealTimeConfig that = (RealTimeConfig)o;
            return Objects.equals(this.instanceType, that.instanceType) && Objects.equals(this.initialInstanceCount, that.initialInstanceCount) && Objects.equals(this.containerStartupHealthCheckTimeoutInSeconds, that.containerStartupHealthCheckTimeoutInSeconds);
        }

        public int hashCode() {
            return Objects.hash(this.instanceType, this.initialInstanceCount, this.containerStartupHealthCheckTimeoutInSeconds);
        }

        public String toString() {
            return "RealTimeConfig{instanceType='" + this.instanceType + "', initialInstanceCount=" + this.initialInstanceCount + ", containerStartupHealthCheckTimeoutInSeconds=" + this.containerStartupHealthCheckTimeoutInSeconds + "}";
        }
    }

    public static class ServerlessConfig {
        public Integer maxConcurrency = 5;
        public Integer memorySizeInMB = 6144;

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            ServerlessConfig that = (ServerlessConfig)o;
            return Objects.equals(this.maxConcurrency, that.maxConcurrency) && Objects.equals(this.memorySizeInMB, that.memorySizeInMB);
        }

        public int hashCode() {
            return Objects.hash(this.maxConcurrency, this.memorySizeInMB);
        }

        public String toString() {
            return "ServerlessConfig{maxConcurrency=" + this.maxConcurrency + ", memorySizeInMB=" + this.memorySizeInMB + "}";
        }
    }

    public static class VPCConfig {
        public List<String> securityGroupIds = new ArrayList<String>();
        public List<String> subnets = new ArrayList<String>();
    }

    public static class BlueGreenUpdatePolicy {
        public Integer maximumExecutionTimeoutInSeconds;
        public Integer terminationWaitInSeconds;
        public TrafficRoutingConfig trafficRoutingConfig = new TrafficRoutingConfig();
    }

    public static class DataCaptureConfig {
        public boolean captureInput = false;
        public boolean captureOutput = false;
        public CaptureContentTypeHeader captureContentTypeHeader = new CaptureContentTypeHeader();
        public String s3Connection;
        public String s3BasePath = "";
        public Integer initialSamplingPercentage = 100;
        public String kmsKeyId;

        public boolean isEnabled() {
            return this.captureInput || this.captureOutput;
        }
    }

    public static class CaptureContentTypeHeader {
        public List<String> csvContentTypes = new ArrayList<String>();
        public List<String> jsonContentTypes = new ArrayList<String>();
    }

    public static class TrafficRoutingConfig {
        public String type;
        public CapacitySize capacitySize = new CapacitySize();
        public Integer waitIntervalInSeconds;
    }

    public static class CapacitySize {
        public String type;
        public Integer value;
    }
}

