/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.clustering.flow;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.MLFlowUtils;
import com.dataiku.dip.analysis.ml.MLPaths;
import com.dataiku.dip.analysis.ml.ScoringRecipeUtils;
import com.dataiku.dip.analysis.ml.clustering.ClusteringResultsReader;
import com.dataiku.dip.analysis.ml.clustering.flow.ClusteringClusterRecipePayloadParams;
import com.dataiku.dip.analysis.ml.clustering.flow.ClusteringScoringRecipePayloadParams;
import com.dataiku.dip.analysis.model.clustering.ClusteringModelDetails;
import com.dataiku.dip.analysis.model.clustering.PreTrainClusteringModelingParams;
import com.dataiku.dip.analysis.model.clustering.ResolvedClusteringPreprocessingParams;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.Schema;
import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.dao.DatasetsDAO;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.dataflow.JobActivity;
import com.dataiku.dip.dataflow.RecipeRunnableSubgraph;
import com.dataiku.dip.dataflow.graph.FlowDataset;
import com.dataiku.dip.dataflow.graph.FlowSavedModel;
import com.dataiku.dip.datasets.Type;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.shaker.model.SerializedShakerScript;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.utils.ErrorContext;
import com.dataiku.dip.utils.JSON;
import java.io.File;
import java.util.List;
import org.springframework.beans.factory.annotation.Autowired;

public class ClusteringRecipesBasicService {
    @Autowired
    protected DatasetsDAO datasetsDAO;
    @Autowired
    private TransactionService transactionService;

    protected Schema getClusterRecipePreparedSchema_NT(AuthCtx authCtx, JobActivity activity, ClusteringClusterRecipePayloadParams desc) throws Exception {
        return this.getPreparedSchema_NT(authCtx, activity, this.getClusterRecipeScript_NT(activity, desc), desc.preprocessing, desc.expectedPreparationOutputSchema);
    }

    public Schema getClusterRecipeOutputSchema_NT(AuthCtx authCtx, JobActivity activity, ClusteringClusterRecipePayloadParams desc, PreTrainClusteringModelingParams.Algorithm algo) throws Exception {
        Schema preparationSchema = this.getClusterRecipePreparedSchema_NT(authCtx, activity, desc);
        return this.getOutputSchema(preparationSchema, desc.filterInputColumns ? desc.keptInputColumns : null, algo.meta.additionalScoringColumns());
    }

    private SerializedShakerScript getClusterRecipeScript_NT(JobActivity activity, ClusteringClusterRecipePayloadParams desc) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            FlowDataset outputFDS = activity.getSubgraph().getSingleTargetDataset();
            Dataset outputDataset = outputFDS.getMandatory(this.datasetsDAO);
            SerializedShakerScript script = desc.script;
            script.contextProjectKey = outputDataset.getProjectKey();
            SerializedShakerScript serializedShakerScript = script;
            return serializedShakerScript;
        }
    }

    protected Schema getScoringRecipePreparedSchema_NT(AuthCtx authCtx, JobActivity activity, ClusteringScoringRecipePayloadParams desc) throws Exception {
        FlowSavedModel fsm = MLFlowUtils.getSMInput(activity);
        SavedModel sm = fsm.getSavedModel();
        ClusteringModelDetails details = ClusteringResultsReader.makeDetails(new FullModelId(sm.projectKey, sm.getId(), sm.activeVersion));
        return this.getPreparedSchema_NT(authCtx, activity, this.getScoringRecipeScript(activity), details.getPreprocessing(), details.splitDesc.schema);
    }

    public Schema getScoringRecipeOutputSchema_NT(AuthCtx authCtx, JobActivity activity, ClusteringScoringRecipePayloadParams desc, PreTrainClusteringModelingParams.Algorithm algo) throws Exception {
        Schema preparationSchema = this.getScoringRecipePreparedSchema_NT(authCtx, activity, desc);
        List<SchemaColumn> columnsToAdd = algo.meta.additionalScoringColumns();
        if (desc.outputModelMetadata) {
            columnsToAdd.addAll(ScoringRecipeUtils.ModelMetadataUtils.getModelMetadataSchemaColumns());
        }
        return this.getOutputSchema(preparationSchema, desc.filterInputColumns ? desc.keptInputColumns : null, columnsToAdd);
    }

    private SerializedShakerScript getScoringRecipeScript(JobActivity activity) throws Exception {
        FlowSavedModel fsm = MLFlowUtils.getSMInput(activity);
        SavedModel sm = fsm.getSavedModel();
        MLFlowUtils.checkActiveVersion(sm);
        File activeModelFolder = MLPaths.savedModelVersionFolder(sm, sm.activeVersion);
        SerializedShakerScript script = (SerializedShakerScript)JSON.parseFile((File)new File(activeModelFolder, "script.json"), SerializedShakerScript.class);
        script.contextProjectKey = sm.projectKey;
        return script;
    }

    public Schema getOutputSchema(Schema preparationSchema, List<String> keptInputColumns, List<SchemaColumn> additionalScoringColumns) throws Exception {
        Schema filteredPreparationSchema;
        if (keptInputColumns != null) {
            filteredPreparationSchema = new Schema();
            for (String colName : keptInputColumns) {
                SchemaColumn col = preparationSchema.getColumn(colName);
                if (col == null) {
                    throw ErrorContext.iaef((String)"Column '%s' is not in input features", (Object)colName, (Object[])new Object[0]);
                }
                filteredPreparationSchema.addColumn(col);
            }
        } else {
            filteredPreparationSchema = preparationSchema;
        }
        Schema outSchema = new Schema(filteredPreparationSchema);
        outSchema.addColumn("cluster_labels", Type.STRING);
        for (SchemaColumn c2 : additionalScoringColumns) {
            outSchema.addColumn(c2);
        }
        return outSchema;
    }

    private Schema getPreparedSchema_NT(AuthCtx authCtx, JobActivity activity, SerializedShakerScript script, ResolvedClusteringPreprocessingParams preprocessingParams, Schema schemaUsedForTrain) throws Exception {
        Dataset inputDataset = null;
        Dataset outputDataset = null;
        try (Transaction t = this.transactionService.beginRead();){
            FlowDataset inputFDS = ((RecipeRunnableSubgraph)activity.getSubgraph()).getSingleSourceDatasetForRole("main");
            inputDataset = inputFDS.getMandatory(this.datasetsDAO);
            FlowDataset outputFDS = ((RecipeRunnableSubgraph)activity.getSubgraph()).getSingleTargetDatasetForRole("main");
            outputDataset = outputFDS.getMandatory(this.datasetsDAO);
        }
        String projectKey = ((RecipeRunnableSubgraph)activity.getSubgraph()).getRecipe().getProjectKey();
        Schema desiredPreparedSchema = MLFlowUtils.getInferredPreparationOutputSchema_NT(projectKey, inputDataset, script, outputDataset.getType(), authCtx);
        return MLFlowUtils.getSchemaToUseForPreparedScoringInput(preprocessingParams, schemaUsedForTrain, desiredPreparedSchema);
    }
}

