/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.scheduler.steps;

import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.apideployer.datamodel.actual.AbstractDeploymentBasicInfo;
import com.dataiku.dip.apideployer.datamodel.actual.AbstractDeploymentLightStatus;
import com.dataiku.dip.apideployer.datamodel.actual.PublishedApiServicePackageInfo;
import com.dataiku.dip.apideployer.deployments.APIServiceDeploymentsService;
import com.dataiku.dip.custom.PluginUsagesInspector;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.exceptions.DSSInternalErrorException;
import com.dataiku.dip.exceptions.UnauthorizedException;
import com.dataiku.dip.nodeclients.APIDeployerClientProxyUser;
import com.dataiku.dip.scheduler.reports.ReportItem;
import com.dataiku.dip.scheduler.scenarios.Scenario;
import com.dataiku.dip.scheduler.steps.NonFatalStepParams;
import com.dataiku.dip.scheduler.steps.Step;
import com.dataiku.dip.scheduler.steps.StepMeta;
import com.dataiku.dip.scheduler.steps.StepParams;
import com.dataiku.dip.scheduler.steps.StepRun;
import com.dataiku.dip.scheduler.steps.StepRunner;
import com.dataiku.dip.security.DSSAuthCtx;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.ErrorContext;
import com.dataiku.dip.variables.VariablesContext;
import com.dataiku.dip.variables.VariablesService;
import com.dataiku.lambda.model.serverconfig.BundledSMVersion;
import com.dataiku.lambda.model.serverconfig.GenerationsMapping;
import com.dataiku.lambda.model.serverconfig.LambdaEndpointConfig;
import com.dataiku.lambda.model.serverconfig.PredictionEndpointConfig;
import java.io.IOException;
import java.lang.invoke.LambdaMetafactory;
import java.util.Comparator;
import java.util.List;
import java.util.function.Supplier;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;

public class RetrieveActiveVersionOfDeployedModelRunner
implements StepRunner {
    private final Scenario scenario;
    private final Step step;
    private final RetrieveActiveVersionOfDeployedModelParams params;
    public static final StepMeta META = new StepMeta(){

        @Override
        public String getType() {
            return "retrieve_active_model_version_deployment";
        }

        @Override
        public Class<? extends StepParams> paramsClass() {
            return RetrieveActiveVersionOfDeployedModelParams.class;
        }

        @Override
        public StepRunner buildRunner(Scenario scenario, Step step) {
            return new RetrieveActiveVersionOfDeployedModelRunner(scenario, step, step.getParamsAs(RetrieveActiveVersionOfDeployedModelParams.class));
        }

        @Override
        public String buildName(Step step) {
            return "Retrieve active version of deployed model";
        }

        @Override
        public String buildId(Step step) {
            return "retrieve_active_version";
        }

        @Override
        public StepMeta.UnavailableStepInfo checkStepForDeletedPluginComponents(Scenario scenario, Step step, PluginUsagesInspector pluginUsagesInspector) {
            return null;
        }
    };
    @Autowired
    private APIServiceDeploymentsService deploymentsService;
    @Autowired
    private VariablesService variablesService;
    private static DKULogger logger = DKULogger.getLogger((String)"dip.scenario.retrieveactivemodelversion");

    public RetrieveActiveVersionOfDeployedModelRunner(Scenario scenario, Step step, RetrieveActiveVersionOfDeployedModelParams params) {
        this.scenario = scenario;
        this.step = step;
        this.params = params;
    }

    @Override
    public void run(StepRun stepRun, ReportItem.StepDone reportItem) throws Exception {
        VariablesContext variablesContext = this.variablesService.getContext(this.scenario.getProjectKey());
        String deploymentId = variablesContext.expand(this.params.deploymentId);
        if (StringUtils.isBlank((String)deploymentId)) {
            throw new IllegalArgumentException("A deployment ID must be specified");
        }
        if (!deploymentId.matches("^[\\w-]+$")) {
            throw new IllegalArgumentException("Deployment ID contains non authorized characters: " + deploymentId);
        }
        if (StringUtils.isBlank((String)this.params.endpointId)) {
            throw new IllegalArgumentException("An endpoint ID must be specified");
        }
        if (StringUtils.isBlank((String)this.params.variableName)) {
            throw new IllegalArgumentException("A target variable name must be specified");
        }
        GeneralSettingsDAO.GeneralSettings gs = ApplicationConfigurator.getGeneralSettingsUnsafeAutoTXN();
        GeneralSettingsDAO.DeployerRef ref = gs.deployerClientSettings.getRef();
        AbstractDeploymentLightStatus deploymentStatus = switch (ref.mode) {
            case GeneralSettingsDAO.DeployerMode.DISABLED -> throw ErrorContext.iae((String)"Deployer support is disabled");
            case GeneralSettingsDAO.DeployerMode.LOCAL -> this.getLocalDeploymentStatus(stepRun, deploymentId);
            case GeneralSettingsDAO.DeployerMode.REMOTE -> this.getRemoteDeploymentStatus(stepRun, deploymentId, ref);
            default -> throw new DSSInternalErrorException("Unknown Deployer client settings mode " + String.valueOf((Object)ref.mode));
        };
        if (!(deploymentStatus.deploymentBasicInfo instanceof AbstractDeploymentBasicInfo.AbstractAPIServiceDeploymentBasicInfo)) {
            throw new DSSInternalErrorException(deploymentId + " is not an API deployment");
        }
        AbstractDeploymentBasicInfo.AbstractAPIServiceDeploymentBasicInfo apiDeploymentBasicInfo = (AbstractDeploymentBasicInfo.AbstractAPIServiceDeploymentBasicInfo)deploymentStatus.deploymentBasicInfo;
        List<GenerationsMapping.MappingEntry> entries = apiDeploymentBasicInfo.generationsMapping.getEntries();
        String generation = ((GenerationsMapping.MappingEntry)entries.stream().max((Comparator)LambdaMetafactory.metafactory(null, null, null, (Ljava/lang/Object;Ljava/lang/Object;)I, lambda$run$0(com.dataiku.lambda.model.serverconfig.GenerationsMapping$MappingEntry com.dataiku.lambda.model.serverconfig.GenerationsMapping$MappingEntry ), (Lcom/dataiku/lambda/model/serverconfig/GenerationsMapping$MappingEntry;Lcom/dataiku/lambda/model/serverconfig/GenerationsMapping$MappingEntry;)I)()).orElseThrow((Supplier<IllegalArgumentException>)LambdaMetafactory.metafactory(null, null, null, ()Ljava/lang/Object;, lambda$run$1(), ()Ljava/lang/IllegalArgumentException;)())).generation;
        PublishedApiServicePackageInfo publishedPackageInfo = deploymentStatus.packages.stream().filter(p -> StringUtils.equals((String)p.id, (String)generation)).findFirst().map(packageInfo -> (PublishedApiServicePackageInfo)packageInfo).orElseThrow(() -> new IllegalArgumentException("Deployment has no package matching the active version"));
        LambdaEndpointConfig endpoint = publishedPackageInfo.endpoints.stream().filter(e -> StringUtils.equals((String)e.id, (String)this.params.endpointId)).findFirst().orElseThrow(() -> new IllegalArgumentException("Endpoint not found: " + this.params.endpointId));
        if (!(endpoint instanceof PredictionEndpointConfig)) {
            throw new IllegalArgumentException("Endpoint is not a standard prediction endpoint");
        }
        PredictionEndpointConfig predictionEndpointConfig = (PredictionEndpointConfig)endpoint;
        BundledSMVersion bundledSMVersion = publishedPackageInfo.stdModels.stream().filter(m -> StringUtils.equals((String)m.id, (String)predictionEndpointConfig.modelId)).findFirst().orElseThrow(() -> new IllegalArgumentException("No model found for enpoint"));
        String smId = String.format("S-%s-%s-%s", bundledSMVersion.originalProjectKey, bundledSMVersion.originalSavedModelId, bundledSMVersion.originalSavedModelVersion);
        logger.info((Object)("Setting active model version " + smId + " of endpoint " + this.params.endpointId + " of deployment " + deploymentId + " in scenario variable " + this.params.variableName));
        stepRun.getScenarioRun().getVariables().addProperty(this.params.variableName, smId);
    }

    private AbstractDeploymentLightStatus getLocalDeploymentStatus(StepRun stepRun, String deploymentId) throws UnauthorizedException, IOException {
        DSSAuthCtx authCtx = stepRun.getScenarioRun().getRunAsUser();
        return this.deploymentsService.getLightStatusMandatoryUnsafe_NT_Check(deploymentId, authCtx);
    }

    private AbstractDeploymentLightStatus getRemoteDeploymentStatus(StepRun stepRun, String deploymentId, GeneralSettingsDAO.DeployerRef ref) throws IOException {
        try (APIDeployerClientProxyUser client = new APIDeployerClientProxyUser(ref, stepRun.getScenarioRun().runAsUser);){
            AbstractDeploymentLightStatus.APIServiceDeploymentLightStatus aPIServiceDeploymentLightStatus = client.getDeploymentLightStatus(deploymentId);
            return aPIServiceDeploymentLightStatus;
        }
    }

    private static /* synthetic */ IllegalArgumentException lambda$run$1() {
        return new IllegalArgumentException("Entries list of deployment is empty");
    }

    private static /* synthetic */ int lambda$run$0(GenerationsMapping.MappingEntry e1, GenerationsMapping.MappingEntry e2) {
        return Double.compare(e1.proba, e2.proba);
    }

    public static class RetrieveActiveVersionOfDeployedModelParams
    extends NonFatalStepParams
    implements StepParams {
        private String deploymentId;
        private String endpointId;
        private String variableName;

        RetrieveActiveVersionOfDeployedModelParams(String deploymentId) {
            this.deploymentId = deploymentId;
        }

        private RetrieveActiveVersionOfDeployedModelParams() {
        }
    }
}

