/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.scheduler.steps;

import com.dataiku.dip.analysis.ml.clustering.flow.ClusteringSMMgmtService;
import com.dataiku.dip.analysis.ml.llm.LLMSMMgmtService;
import com.dataiku.dip.analysis.ml.prediction.flow.PredictionSMMgmtService;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.clustering.ClusteringModelSnippetData;
import com.dataiku.dip.analysis.model.llm.LLMModelSnippetData;
import com.dataiku.dip.analysis.model.prediction.PredictionModelSnippetData;
import com.dataiku.dip.custom.PluginUsagesInspector;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.dao.SavedModelsDAO;
import com.dataiku.dip.dataflow.graph.FlowComputable;
import com.dataiku.dip.scheduler.reports.ReportItem;
import com.dataiku.dip.scheduler.scenarios.Scenario;
import com.dataiku.dip.scheduler.steps.FlowComputableSpecification;
import com.dataiku.dip.scheduler.steps.Step;
import com.dataiku.dip.scheduler.steps.StepMeta;
import com.dataiku.dip.scheduler.steps.StepParams;
import com.dataiku.dip.scheduler.steps.StepParamsWithComputables;
import com.dataiku.dip.scheduler.steps.StepRun;
import com.dataiku.dip.scheduler.steps.StepRunner;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.TransactionContext;
import com.dataiku.dip.transactions.ifaces.RWTransaction;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.ErrorContext;
import com.google.common.collect.Lists;
import java.io.IOException;
import java.util.Collection;
import java.util.Comparator;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;

public class SetActiveModelVersionStepRunner
implements StepRunner {
    public static final StepMeta META = new StepMeta(){

        @Override
        public Class<? extends StepParams> paramsClass() {
            return SetActiveModelVersionStepParams.class;
        }

        @Override
        public String getType() {
            return "activate_version";
        }

        @Override
        public StepRunner buildRunner(Scenario scenario, Step step) {
            return new SetActiveModelVersionStepRunner(scenario, step, step.getParamsAs(SetActiveModelVersionStepParams.class));
        }

        @Override
        public String buildName(Step step) {
            SetActiveModelVersionStepParams params = step.getParamsAs(SetActiveModelVersionStepParams.class);
            return "set active version of " + params.modelId;
        }

        @Override
        public String buildId(Step step) {
            SetActiveModelVersionStepParams params = step.getParamsAs(SetActiveModelVersionStepParams.class);
            StringBuilder sb = new StringBuilder();
            sb.append("activate_model");
            if (params != null) {
                sb.append("_");
                if (StringUtils.isBlank((String)params.getProjectKey())) {
                    sb.append(params.getProjectKey());
                    sb.append(".");
                }
                sb.append(params.getModelId());
            }
            return sb.toString();
        }

        @Override
        public StepMeta.UnavailableStepInfo checkStepForDeletedPluginComponents(Scenario sc, Step step, PluginUsagesInspector pluginUsagesInspector) {
            return null;
        }
    };
    private final SetActiveModelVersionStepParams params;
    private final Scenario scenario;
    private final Step step;
    @Autowired
    private TransactionService transactionService;
    @Autowired
    private SavedModelsDAO savedModelsDAO;
    @Autowired
    private PredictionSMMgmtService predictionSMMgmtService;
    @Autowired
    private ClusteringSMMgmtService clusteringSMMgmtService;
    @Autowired
    private LLMSMMgmtService llmSMMgmtService;
    private static DKULogger logger = DKULogger.getLogger((String)"dip.scenario.step.setmodelversion");

    SetActiveModelVersionStepRunner(Scenario scenario, Step step, SetActiveModelVersionStepParams params) {
        this.scenario = scenario;
        this.step = step;
        this.params = params;
    }

    private String getMostRecentVersion_NT(String projectKey, String modelId) throws IOException {
        SavedModel sm;
        TransactionContext.assertNoAttachedTransaction();
        try (Transaction t = this.transactionService.beginRead();){
            sm = (SavedModel)this.savedModelsDAO.getMandatory(projectKey, modelId);
        }
        if ((switch (sm.getType()) {
            case MLTask.MLTaskType.PREDICTION -> this.predictionSMMgmtService.getStatus_NT((SavedModel)sm).versions.stream().max(Comparator.comparing(v -> ((PredictionModelSnippetData)v.snippet).trainDate));
            case MLTask.MLTaskType.CLUSTERING -> this.clusteringSMMgmtService.getStatus_NT((SavedModel)sm).versions.stream().max(Comparator.comparing(v -> ((ClusteringModelSnippetData)v.snippet).sessionDate));
            case MLTask.MLTaskType.LLM_GENERIC_RAW, MLTask.MLTaskType.LLM_GENERIC_PROMPTABLE_COMPLETION, MLTask.MLTaskType.LLM_CLASSIFICATION -> LLMSMMgmtService.getStatus_NT((SavedModel)sm).versions.stream().max(Comparator.comparing(v -> ((LLMModelSnippetData)v.snippet).versionTag.getLastModifiedOn()));
            default -> throw new IllegalArgumentException(String.format("Unknown type %s", new Object[]{sm.getType()}));
        }).isPresent()) {
            return mostRecentVersion.get().versionId;
        }
        throw ErrorContext.iae((String)String.format("Model %s in project %s has no version", projectKey, modelId));
    }

    @Override
    public void run(StepRun stepRun, ReportItem.StepDone stepReportItem) throws Exception {
        String projectKey;
        logger.info((Object)("Start step " + this.step.getName()));
        String versionId = this.params.getVersionId();
        String modelId = this.params.getModelId();
        String string = projectKey = this.params.getProjectKey() != null ? this.params.getProjectKey() : this.scenario.getProjectKey();
        if (StringUtils.isBlank((String)versionId)) {
            versionId = this.getMostRecentVersion_NT(projectKey, modelId);
            logger.info((Object)String.format("No version specified, activating the most recent: %s", versionId));
        }
        try (RWTransaction t = this.transactionService.beginWriteAsLoggedInUser((AuthCtx)stepRun.getScenarioRun().getRunAsUser());){
            boolean schemaChanged;
            SavedModel sm = (SavedModel)this.savedModelsDAO.getMandatory(projectKey, this.params.getModelId());
            if (sm.getType() == MLTask.MLTaskType.PREDICTION) {
                schemaChanged = this.predictionSMMgmtService.setActive(sm, versionId);
            } else if (sm.getType() == MLTask.MLTaskType.CLUSTERING) {
                schemaChanged = this.clusteringSMMgmtService.setActive(sm, versionId);
            } else {
                throw new IllegalArgumentException("Unsupported model type:" + String.valueOf((Object)sm.getType()));
            }
            if (schemaChanged) {
                logger.info((Object)"The active version does not have the same preparation script schema as the previous one");
            }
            t.commit("Set active version of model " + this.params.projectKey + "." + this.params.modelId + " to " + versionId);
        }
        logger.info((Object)("Done step " + this.step.getName()));
        stepReportItem.withOutcome(ReportItem.Outcome.SUCCESS);
    }

    public static class SetActiveModelVersionStepParams
    implements StepParams,
    StepParamsWithComputables {
        public String projectKey;
        public String modelId;
        public String versionId;

        public String getProjectKey() {
            return this.projectKey;
        }

        public String getModelId() {
            return this.modelId;
        }

        public String getVersionId() {
            return this.versionId;
        }

        public SetActiveModelVersionStepParams withVersionId(String versionId) {
            this.versionId = versionId;
            return this;
        }

        public SetActiveModelVersionStepParams withModelId(String modelId) {
            this.modelId = modelId;
            return this;
        }

        @Override
        public Collection<FlowComputableSpecification> getComputablesSpec() {
            FlowComputableSpecification item = new FlowComputableSpecification();
            item.projectKey = this.projectKey;
            item.itemId = this.modelId;
            item.type = FlowComputable.FCType.SAVED_MODEL;
            return Lists.newArrayList((Object[])new FlowComputableSpecification[]{item});
        }
    }
}

