/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction.flow;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.ScoringRecipeUtils;
import com.dataiku.dip.analysis.ml.prediction.flow.PredictionRecipesMeta;
import com.dataiku.dip.analysis.ml.prediction.flow.PredictionScoringRecipeSchemaComputer;
import com.dataiku.dip.analysis.ml.prediction.flow.TabularPredictionScoringRecipePayloadParams;
import com.dataiku.dip.analysis.model.prediction.ResolvedPredictionPreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.FeaturePreprocessingParams;
import com.dataiku.dip.connections.AbstractSQLConnection;
import com.dataiku.dip.connections.SnowflakeConnection;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.InfoMessage;
import com.dataiku.dip.coremodel.Schema;
import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.dao.DatasetsDAO;
import com.dataiku.dip.dataflow.JobActivity;
import com.dataiku.dip.dataflow.JobAuthCtxService;
import com.dataiku.dip.dataflow.exec.sql.NonLoopingSQLEngineVisualRecipeRunner;
import com.dataiku.dip.dataflow.graph.FlowDataset;
import com.dataiku.dip.dataflow.utils.FlowJobUtils;
import com.dataiku.dip.datasets.DatasetInspector;
import com.dataiku.dip.datasets.DatasetUtils;
import com.dataiku.dip.exceptions.CodedException;
import com.dataiku.dip.partitioning.Partition;
import com.dataiku.dip.queries.QueryBunch;
import com.dataiku.dip.recipes.consistency.RecipeCodes;
import com.dataiku.dip.scoring.exports.snowflake.JarsBuilder;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.services.SingleWriteTransactionTransactionService;
import com.dataiku.dip.shaker.sql.FinalSchemaCaster;
import com.dataiku.dip.shaker.sql.SQLQueryWithSchema;
import com.dataiku.dip.sql.SQLDialect;
import com.dataiku.dip.sql.SQLUtils;
import com.dataiku.dip.sql.queries.ExpressionBuilder;
import com.dataiku.dip.sql.queries.ExpressionUtils;
import com.dataiku.dip.sql.queries.SelectQueryBuilder;
import com.dataiku.dip.util.AutoDelete;
import com.dataiku.dip.util.SecretKeyGenerator;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.PathUtils;
import com.dataiku.scoring.builders.Build;
import java.io.File;
import java.io.IOException;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.lang.StringUtils;

public class SnowflakeJavaUDFPredictionRecipeSubrunner
extends NonLoopingSQLEngineVisualRecipeRunner {
    private final FullModelId fmi;
    final TabularPredictionScoringRecipePayloadParams desc;
    private final Schema predictedColumns;
    private final String modelSuffix;
    private static ExpressionBuilder.ExpressionBuilderFactory EBF = new ExpressionBuilder.ExpressionBuilderFactory();

    public static boolean isModelSupported(FullModelId fmi) {
        return !fmi.isPartitionedBaseModel();
    }

    SnowflakeJavaUDFPredictionRecipeSubrunner(JobActivity activity, FullModelId fmi, Schema predictedColumns, TabularPredictionScoringRecipePayloadParams desc) {
        super(activity);
        this.fmi = fmi;
        this.predictedColumns = predictedColumns;
        this.desc = desc;
        this.modelSuffix = SecretKeyGenerator.generate((int)12);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    protected QueryBunch getQueryBunch(AbstractSQLConnection conn) throws Exception {
        AuthCtx authCtx = ((JobAuthCtxService)SpringUtils.getBean(JobAuthCtxService.class)).getAuthCtx();
        this.sqlQueryIsCommentFree = true;
        this.sqlQueryMayContainUnionOrSelect = false;
        this.needExecutionPlan = false;
        SingleWriteTransactionTransactionService.DetransactionalizedCallable<QueryBunch> callableSqlQuery = this.getCallableSqlQuery(authCtx, this.activity, this.desc, this.fmi, this.predictedColumns, conn);
        ((SingleWriteTransactionTransactionService)SpringUtils.getBean(SingleWriteTransactionTransactionService.class)).stashTheSingleTransaction();
        try {
            QueryBunch queryBunch = callableSqlQuery.call_NT();
            return queryBunch;
        }
        finally {
            ((SingleWriteTransactionTransactionService)SpringUtils.getBean(SingleWriteTransactionTransactionService.class)).unstashTheSingleTransaction();
        }
    }

    private SingleWriteTransactionTransactionService.DetransactionalizedCallable<QueryBunch> getCallableSqlQuery(AuthCtx auth, JobActivity activity, final TabularPredictionScoringRecipePayloadParams desc, final FullModelId fmi, Schema predictedColumns, final AbstractSQLConnection conn) throws Exception {
        ArrayList<String> keptColumns;
        Dataset outputDataset;
        logger.info((Object)"Generating SQL query for Snowflake prediction scoring");
        FlowDataset inputFDS = activity.getSubgraph().getSourceDatasets().get(0);
        final Dataset inputDataset = inputFDS.getMandatory((DatasetsDAO)SpringUtils.getBean(DatasetsDAO.class));
        if (!DatasetInspector.isSQLOrHive(inputDataset)) {
            throw new CodedException((InfoMessage.MessageCode)RecipeCodes.ERR_RECIPE_CANNOT_USE_ENGINE, "Cannot use the SQL engine on a non-SQL dataset.");
        }
        if (activity.getSubgraph().getTargetsDatasets().isEmpty()) {
            outputDataset = null;
        } else {
            FlowDataset outputFDS = activity.getSubgraph().getTargetsDatasets().get(0);
            outputDataset = outputFDS.getMandatory((DatasetsDAO)SpringUtils.getBean(DatasetsDAO.class));
        }
        final List<Partition> inputPartitions = activity.getSubgraph().getSourcePartitions(inputFDS);
        if (desc.filterInputColumns) {
            keptColumns = desc.keptInputColumns;
        } else {
            keptColumns = new ArrayList<String>();
            for (SchemaColumn c2 : inputDataset.getSchema().columns) {
                keptColumns.add(c2.getName());
            }
        }
        for (SchemaColumn predictedColumn : predictedColumns.getColumns()) {
            if (!keptColumns.contains(predictedColumn.getName())) continue;
            logger.warn((Object)("Column " + predictedColumn.getName() + " will be overwritten by prediction, not keeping it"));
            keptColumns.remove(predictedColumn.getName());
        }
        final PredictionScoringRecipeSchemaComputer schemaComputer = (PredictionScoringRecipeSchemaComputer)PredictionRecipesMeta.SCORING_META.buildSchemaComputer(auth, activity);
        return new SingleWriteTransactionTransactionService.DetransactionalizedCallable<QueryBunch>(){

            @Override
            public QueryBunch call_NT() throws Exception {
                schemaComputer.setPayload(JSON.json((Object)desc));
                Schema outputSchema = schemaComputer.getSchema_NT();
                SQLUtils.SQLTable table = DatasetUtils.getResolvedTableWithSparkSQLFallback(inputDataset, conn.getDialect(), null);
                Dataset workingOutputDataset = outputDataset == null ? new Dataset() : outputDataset;
                workingOutputDataset.setSchema(outputSchema);
                QueryBunch queryBunch = SnowflakeJavaUDFPredictionRecipeSubrunner.this.buildQuery(fmi, inputDataset, inputPartitions, workingOutputDataset, keptColumns, table, conn);
                logger.info((Object)("Generated query length=" + queryBunch.query.length()));
                logger.info((Object)("Generated query: " + queryBunch.query));
                return queryBunch;
            }
        };
    }

    private QueryBunch buildQuery(FullModelId fmi, Dataset inputDataset, List<Partition> inputPartitions, Dataset outputDataset, List<String> columnsToKeep, SQLUtils.SQLTable table, AbstractSQLConnection connRaw) throws Exception {
        File resources = fmi.getModelFolder().getAbsoluteFile();
        Build.DssPipelineMeta meta = Build.pipelineMeta((URL)resources.toURI().toURL());
        if (meta.type == null) {
            throw new IOException("Failed to parse a valid type from the dss_pipeline_meta.json");
        }
        AutoDelete compilationFolder = FlowJobUtils.getTmpFolder("jar-compilation", "tmp");
        JarsBuilder jarsBuilder = new JarsBuilder((File)compilationFolder, fmi, this.modelSuffix);
        jarsBuilder.compileAndPackageJar(meta, this.getForcedClassifierThreshold());
        String functionName = "dssScoringRecipeRun" + this.modelSuffix;
        SnowflakeConnection conn = (SnowflakeConnection)connRaw;
        String stage = conn.params.javaUDFStage;
        if (StringUtils.isBlank((String)stage)) {
            throw new IllegalArgumentException("Cannot use Java functions in Snowflake: you must specify a writable stage");
        }
        String runPathInStage = PathUtils.makeLeadingNoTrailing((String)(StringUtils.defaultIfBlank((String)conn.params.javaUDFPathInStage, (String)"") + "/scoring-run-" + this.modelSuffix));
        ResolvedPredictionPreprocessingParams rppp = (ResolvedPredictionPreprocessingParams)JSON.parseFile((File)new File(resources, "rpreprocessing_params.json"), ResolvedPredictionPreprocessingParams.class);
        ArrayList<String> featureColumns = new ArrayList<String>();
        for (Map.Entry e : rppp.per_feature.entrySet()) {
            if (((FeaturePreprocessingParams)e.getValue()).role != FeaturePreprocessingParams.Role.INPUT) continue;
            featureColumns.add((String)e.getKey());
        }
        if (fmi.isPartitionedBaseModel()) {
            throw new Error("unsupported partitioned models");
        }
        SelectQueryBuilder nestedQuery = new SelectQueryBuilder();
        nestedQuery.select("*");
        nestedQuery.from(table, "data");
        if (inputPartitions != null && !inputPartitions.isEmpty() && inputDataset.getPartitioningSchema() != null && inputDataset.getPartitioningSchema().isPartitioned()) {
            nestedQuery.where(ExpressionUtils.getPartitionFilterClause(inputDataset.getPartitioningSchema(), inputDataset, inputPartitions, (SQLDialect)conn.getDialect()));
        }
        SQLQueryWithSchema outerQuery = new SQLQueryWithSchema();
        outerQuery.setDialect(conn.getDialect());
        List objectColumns = featureColumns.stream().map(featureColumn -> "'" + featureColumn + "', \"" + featureColumn + "\"").collect(Collectors.toList());
        nestedQuery.select(EBF.expr(functionName + "( OBJECT_CONSTRUCT(" + StringUtils.join(objectColumns, (String)",") + ") )"), "RESULT");
        switch (meta.type) {
            case REGRESSION: {
                outerQuery.select("*");
                outerQuery.select(EBF.expr("RESULT:value:prediction"), "prediction");
                break;
            }
            case CLASSIFICATION_ONLY: {
                throw new Error("unsupported classif only");
            }
            case BINARY_PROBABILISTIC: 
            case MULTICLASS_PROBABILISTIC: {
                outerQuery.select("*");
                outerQuery.select(EBF.expr("RESULT:value:prediction"), "prediction");
                for (int i = 0; i < meta.classes.length; ++i) {
                    outerQuery.select(EBF.expr("RESULT:value:probabilities[" + i + "]"), "proba_" + meta.classes[i]);
                }
                break;
            }
        }
        if (ScoringRecipeUtils.ModelMetadataUtils.schemaIncludesModelMetadata(outputDataset.getSchema()) != null) {
            ScoringRecipeUtils.ModelMetadataUtils.sqlAddModelMetadata(outerQuery, fmi);
        }
        outerQuery.from(nestedQuery, "__object");
        outerQuery.initWithSchema(outputDataset.getSchema().getCopy());
        String mainSQL = new FinalSchemaCaster().getCasted(outerQuery, outputDataset.getSchema()).applyInsertIntoCasts(outputDataset).toSQL(conn.getDialect());
        List preQueries = jarsBuilder.getFilesToUploadToSnowflake().stream().map(f -> String.format("PUT file://%s @%s%s/ AUTO_COMPRESS=false; \n", f.getAbsolutePath(), stage, runPathInStage)).collect(Collectors.toList());
        String imports = jarsBuilder.getFilesToUploadToSnowflake().stream().map(f -> String.format("'@%s%s/%s'", stage, runPathInStage, f.getName())).collect(Collectors.joining(","));
        preQueries.add(String.format("CREATE FUNCTION %s (input OBJECT) returns OBJECT LANGUAGE JAVA RUNTIME_VERSION = '17' IMPORTS = (%s) HANDLER='%s.snowflakePredictObject';\n", functionName, imports, jarsBuilder.getClassFullyQualifiedName()));
        QueryBunch queryBunch = new QueryBunch();
        queryBunch.preQueries = preQueries;
        queryBunch.postQueries.add(String.format("DROP FUNCTION %s (OBJECT);\n", functionName));
        queryBunch.postQueries.add(String.format("REMOVE @%s%s;\n", stage, runPathInStage));
        queryBunch.query = mainSQL;
        return queryBunch;
    }

    private Double getForcedClassifierThreshold() {
        return this.desc.overrideModelSpecifiedThreshold ? Double.valueOf(this.desc.forcedClassifierThreshold) : null;
    }
}

