/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.dataflow.pipeline;

import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.MLFlowUtils;
import com.dataiku.dip.analysis.ml.MLSparkParams;
import com.dataiku.dip.analysis.ml.prediction.PredictionResultsReader;
import com.dataiku.dip.analysis.ml.prediction.flow.PredictionRecipesBasicService;
import com.dataiku.dip.analysis.ml.prediction.flow.PredictionRecipesService;
import com.dataiku.dip.analysis.ml.prediction.flow.PredictionScoringJobDefBuilder;
import com.dataiku.dip.analysis.ml.prediction.flow.TabularPredictionScoringRecipePayloadParams;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.prediction.ClassicalPredictionModelDetails;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionCoreParams;
import com.dataiku.dip.coremodel.Schema;
import com.dataiku.dip.coremodel.SerializedRecipe;
import com.dataiku.dip.dataflow.JobActivity;
import com.dataiku.dip.dataflow.pipeline.RecipePipelineHelper;
import com.dataiku.dip.dataflow.pipeline.SparkPipelineDef;
import com.dataiku.dip.dataflow.pipeline.SqlPipelineElement;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.logging.MainLoggingConfigurator;
import com.dataiku.dip.recipes.AbstractSparkRecipeParams;
import com.dataiku.dip.recipes.SqlPipelineRecipeParams;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.impersonation.FilesystemACLUtils;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.shaker.resources.ResourcesGatherer;
import com.dataiku.dip.spark.SparkOverrideConfig;
import com.dataiku.dip.sql.SQLDialect;
import com.dataiku.dip.sql.SparkSQLDialect;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.Pair;
import com.google.common.collect.ImmutableList;
import java.io.File;
import java.io.IOException;
import java.util.List;
import org.springframework.beans.factory.annotation.Autowired;

public class PredictionScoringRecipePipelineHelper
extends RecipePipelineHelper {
    @Autowired
    private PredictionRecipesBasicService predictionRecipesBasicService;
    private TabularPredictionScoringRecipePayloadParams params;
    private MLSparkParams sparkParams;
    private SqlPipelineRecipeParams sqlPipelineParams;
    private boolean isValidConfiguration;
    private MLFlowUtils.PredictionScoringRecipeScorability scorability;
    private String engineType;

    public PredictionScoringRecipePipelineHelper(AuthCtx authCtx, SerializedRecipe recipe, String payload, RecipePipelineHelper.PipelineType pipelineType, JobActivity jobActivity) {
        super(authCtx, recipe, payload, pipelineType, jobActivity);
    }

    @Override
    protected void initialize(JobActivity jobActivity) {
        SpringUtils.getInstance().autowire((Object)this);
        this.isValidConfiguration = false;
        this.params = (TabularPredictionScoringRecipePayloadParams)JSON.parse((String)this.payload, TabularPredictionScoringRecipePayloadParams.class);
        try {
            if (ApplicationConfigurator.getProcessType() == MainLoggingConfigurator.ProcessType.JEK || ApplicationConfigurator.getProcessType() == MainLoggingConfigurator.ProcessType.CDE) {
                if (this.params != null) {
                    MLFlowUtils.PredictionScoringRecipeStatusComputer computer = new MLFlowUtils.PredictionScoringRecipeStatusComputer(this.recipe, JSON.json((Object)this.params));
                    this.scorability = computer.getScorability_T(this.authCtx, jobActivity);
                    this.engineType = this.scorability.engineToUse.type;
                }
            } else {
                this.isValidConfiguration = true;
                this.engineType = this.engineType();
            }
        }
        catch (DKUSecurityException | IOException e) {
            logger.warnV(e, "Unable to compute Java scorability of %s, fallback to not pipelineable", new Object[]{this.recipe.getDisplayName()});
        }
    }

    @Override
    protected void initializeForSpark() {
        if (this.scorability != null) {
            this.isValidConfiguration = this.scorability.isSparkJava;
        }
        this.sparkParams = this.params.sparkParams;
        this.allowPipelineStart = this.sparkParams.pipelineAllowStart;
        this.allowPipelineMerge = this.sparkParams.pipelineAllowMerge;
    }

    @Override
    protected void initializeForSql() {
        if (this.scorability != null) {
            this.isValidConfiguration = this.scorability.isSqlCompatible && this.scorability.backendType != MLTask.BackendType.VERTICA && this.scorability.modelPartitionMode != MLFlowUtils.ModelPartitionMode.PARTITIONED_DISPATCH;
        }
        this.sqlPipelineParams = this.params.sqlPipelineParams;
        this.allowPipelineStart = this.sqlPipelineParams.pipelineAllowStart;
        this.allowPipelineMerge = this.sqlPipelineParams.pipelineAllowMerge;
    }

    @Override
    List<SqlPipelineElement> generateSqlQueries(JobActivity activity, SQLDialect dialect) throws Exception {
        SqlPipelineElement element = new SqlPipelineElement(activity, this.generateSqlQuery(activity, dialect));
        element.sqlQueryIsCommentFree = true;
        element.sqlQueryMayContainUnionOrSelect = false;
        element.needExecutionPlan = false;
        return ImmutableList.of((Object)element);
    }

    @Override
    protected String generateSqlQuery(JobActivity activity, SQLDialect dialect) throws Exception {
        ClassicalPredictionModelDetails details = PredictionResultsReader.makeDetails(this.scorability.fmi);
        Schema columnsToAddForPrediction = this.predictionRecipesBasicService.getColumnsToAddForPrediction(this.params, details, this.scorability.model);
        ResolvedClassicalPredictionCoreParams coreParams = (ResolvedClassicalPredictionCoreParams)JSON.parseFile((File)new File(this.scorability.fmi.getModelFolder(), "core_params.json"), ResolvedClassicalPredictionCoreParams.class);
        Pair<MLFlowUtils.ModelPartitionMode, FullModelId> modelPartitionModeAndFullModelId = MLFlowUtils.getModelPartitionModeAndFullModelId(activity, coreParams, this.scorability.fmi);
        return PredictionRecipesService.getSqlQuery(this.authCtx, activity, this.params, (FullModelId)modelPartitionModeAndFullModelId.second, columnsToAddForPrediction, dialect);
    }

    @Override
    public String setSparkPipelineability(boolean allowStart, boolean allowMerge) {
        this.sparkParams.pipelineAllowStart = allowStart;
        this.sparkParams.pipelineAllowMerge = allowMerge;
        return JSON.pretty((Object)this.params);
    }

    @Override
    String setSqlPipelineability(boolean allowStart, boolean allowMerge) {
        this.sqlPipelineParams.pipelineAllowStart = allowStart;
        this.sqlPipelineParams.pipelineAllowMerge = allowMerge;
        return JSON.pretty((Object)this.params);
    }

    @Override
    protected String getEngineType() {
        return this.engineType;
    }

    @Override
    protected boolean isValidConfiguration() {
        return this.isValidConfiguration;
    }

    @Override
    protected boolean useGlobalMetastore() {
        return this.sparkParams.sparkUseGlobalMetastore;
    }

    @Override
    protected void finalizeElt(SparkPipelineDef.SparkPipelineElt elt, JobActivity activity, SparkSQLDialect dialect, ResourcesGatherer gatherer) throws Exception {
        elt.type = "scoring";
        elt.scoring = new PredictionScoringJobDefBuilder(activity, this.params, gatherer).build();
        FilesystemACLUtils.grantFSReadACLs(this.authCtx, elt.scoring.fmi.getProjectKey(), elt.scoring.fmi.getFolderEnsuringSecurity());
    }

    @Override
    protected SparkOverrideConfig sparkConfig() {
        return this.sparkParams.sparkConf;
    }

    @Override
    protected AbstractSparkRecipeParams.SparkExecutionEngine sparkExecutionEngine() {
        return this.sparkParams.sparkExecutionEngine;
    }
}

