/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction.flow;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.MLFlowUtils;
import com.dataiku.dip.analysis.ml.MLPaths;
import com.dataiku.dip.analysis.ml.ScoringRecipeUtils;
import com.dataiku.dip.analysis.ml.prediction.PredictionResultsReader;
import com.dataiku.dip.analysis.ml.prediction.flow.AbstractPredictionScoringRecipePayloadParams;
import com.dataiku.dip.analysis.ml.prediction.flow.EvaluationDatasetHelper;
import com.dataiku.dip.analysis.ml.prediction.flow.EvaluationRecipePayloadParams;
import com.dataiku.dip.analysis.ml.prediction.flow.PredictionRecipesMeta;
import com.dataiku.dip.analysis.ml.prediction.flow.TabularPredictionScoringRecipePayloadParams;
import com.dataiku.dip.analysis.model.ModelDetailsBase;
import com.dataiku.dip.analysis.model.core.EvaluationCustomEvaluationMetric;
import com.dataiku.dip.analysis.model.core.ResolvedCoreParams;
import com.dataiku.dip.analysis.model.prediction.CausalPredictionModelDetails;
import com.dataiku.dip.analysis.model.prediction.ClassicalPredictionModelDetails;
import com.dataiku.dip.analysis.model.prediction.ClassificationModelIntrinsicPerf;
import com.dataiku.dip.analysis.model.prediction.DeepHubPredictionModelDetails;
import com.dataiku.dip.analysis.model.prediction.PredictionModelDetails;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.ResolvedPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.ResolvedTimeseriesForecastingCoreParams;
import com.dataiku.dip.analysis.model.prediction.TimeseriesForecastingModelDetails;
import com.dataiku.dip.analysis.model.preprocessing.PredictionPreprocessingParams;
import com.dataiku.dip.coremodel.ConditionalOutput;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.Schema;
import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.coremodel.SerializedRecipe;
import com.dataiku.dip.dao.DatasetsDAO;
import com.dataiku.dip.dao.RecipesDAO;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.dataflow.JobActivity;
import com.dataiku.dip.dataflow.RecipeRunnableSubgraph;
import com.dataiku.dip.dataflow.graph.FlowDataset;
import com.dataiku.dip.dataflow.graph.FlowSavedModel;
import com.dataiku.dip.datasets.Type;
import com.dataiku.dip.scoring.exports.SQLPrediction;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.shaker.model.SerializedShakerScript;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.ErrorContext;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.variables.VariablesContext;
import com.dataiku.dip.variables.VariablesService;
import java.io.File;
import java.text.DecimalFormat;
import java.text.DecimalFormatSymbols;
import java.util.List;
import java.util.Locale;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;

public class PredictionRecipesBasicService {
    @Autowired
    protected DatasetsDAO datasetsDAO;
    @Autowired
    private VariablesService variablesService;
    @Autowired
    protected TransactionService transactionService;
    public static final String OVERRIDE_INFO_COL = "override";
    public static final String UNCERTAINTY_COL = "prediction_uncertainty";
    public static final String PREDICTION_INTERVAL_COL_LOWER = "prediction_interval_lower";
    public static final String PREDICTION_INTERVAL_COL_UPPER = "prediction_interval_upper";
    private static final String PREDICTION_COL = "prediction";
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.analysis.flow");

    public Schema getEvaluationRecipeOutputSchema_NT(JobActivity activity, EvaluationRecipePayloadParams desc, AuthCtx user) throws Exception {
        SavedModel sm = PredictionRecipesBasicService.getSavedModel(activity);
        String versionId = this.getModelVersionToUseForEvaluationRecipe(activity, sm, desc);
        PredictionModelDetails details = PredictionResultsReader.makeModelDetails(new FullModelId(sm.projectKey, sm.id, versionId));
        return this.getScoringOutputSchema_NT(activity, desc, details, sm, true, user);
    }

    public Schema getScoringRecipeOutputSchema_NT(JobActivity activity, TabularPredictionScoringRecipePayloadParams desc, AuthCtx user) throws Exception {
        SavedModel sm = PredictionRecipesBasicService.getSavedModel(activity);
        PredictionModelDetails details = PredictionResultsReader.makeModelDetails(new FullModelId(sm.projectKey, sm.id, sm.activeVersion));
        return this.getScoringOutputSchema_NT(activity, desc, details, sm, false, user);
    }

    public Schema getScoringOutputSchema_NT(JobActivity activity, TabularPredictionScoringRecipePayloadParams desc, PredictionModelDetails details, SavedModel sm, boolean sampleWeightNeeded, AuthCtx user) throws Exception {
        Schema actualPreparedSchema = this.getPreparedSchema_NT(activity, sm, details, user);
        if (sampleWeightNeeded && details instanceof ClassicalPredictionModelDetails) {
            ResolvedClassicalPredictionCoreParams coreParams = ((ClassicalPredictionModelDetails)details).coreParams;
            if (desc instanceof EvaluationRecipePayloadParams) {
                coreParams.weight.checkHasSampleWeightColumn(actualPreparedSchema, ((EvaluationRecipePayloadParams)desc).evaluationDatasetType);
            } else {
                coreParams.weight.checkHasSampleWeightColumn(actualPreparedSchema);
            }
        }
        return this.getBaseScoringOutputSchema(actualPreparedSchema, desc, details, sm);
    }

    public Schema getPreparedSchema_NT(JobActivity activity, AuthCtx user) throws Exception {
        SavedModel sm = PredictionRecipesBasicService.getSavedModel(activity);
        PredictionModelDetails details = PredictionResultsReader.makeModelDetails(new FullModelId(sm.projectKey, sm.id, sm.activeVersion));
        return this.getPreparedSchema_NT(activity, sm, details, user);
    }

    private Schema getPreparedSchema_NT(JobActivity activity, SavedModel sm, ModelDetailsBase details, AuthCtx user) throws Exception {
        Dataset inputDataset;
        String outputDatasetType = null;
        boolean skipInputChecks = false;
        try (Transaction t = this.transactionService.beginRead();){
            RecipeRunnableSubgraph subgraph = (RecipeRunnableSubgraph)activity.getSubgraph();
            if (subgraph.getRecipe().getId() != null) {
                RecipesDAO recipesDAO = (RecipesDAO)SpringUtils.getBean(RecipesDAO.class);
                SerializedRecipe sr = (SerializedRecipe)recipesDAO.getOrNull(subgraph.getRecipe().getProjectKey(), subgraph.getRecipe().getId());
                if (sr.type.equals(PredictionRecipesMeta.EVALUATION_META.getType())) {
                    String payload = recipesDAO.getPayloadOrNull(sr.projectKey, sr.name);
                    EvaluationRecipePayloadParams desc = (EvaluationRecipePayloadParams)JSON.parse((String)payload, EvaluationRecipePayloadParams.class);
                    boolean bl = skipInputChecks = desc.evaluationDatasetTypeDetected != EvaluationDatasetHelper.EvaluationDatasetType.CLASSIC;
                    if (skipInputChecks) {
                        logger.debug((Object)"Evaluation Recipe configured with an API log dataset as input detected. Will skip input checks.");
                    }
                }
            } else {
                skipInputChecks = true;
            }
            FlowDataset inputFDS = subgraph.getSingleSourceDatasetForRole("main");
            inputDataset = inputFDS.getMandatory(this.datasetsDAO);
            FlowDataset outputFDS = null;
            if (CollectionUtils.isNotEmpty(subgraph.getTargetDatasetsForRole("main"))) {
                outputFDS = subgraph.getSingleTargetDatasetForRole("main");
                Dataset outputDataset = outputFDS.getMandatory(this.datasetsDAO);
                outputDatasetType = outputDataset.getType();
            }
        }
        MLFlowUtils.checkActiveVersion(sm);
        FullModelId fmi = new FullModelId(sm.getProjectKey(), sm.id, sm.activeVersion);
        if (fmi.isExternalMLflowModelVersion()) {
            Schema schema = fmi.getNonScoredSchema();
            return MLFlowUtils.getSchemaToUseForPreparedScoringInput(details.getPreprocessing(), schema, schema, skipInputChecks);
        }
        File activeModelFolder = MLPaths.savedModelVersionFolder(sm, sm.activeVersion);
        SerializedShakerScript script = (SerializedShakerScript)JSON.parseFile((File)new File(activeModelFolder, "script.json"), SerializedShakerScript.class);
        script.contextProjectKey = sm.projectKey;
        Schema desiredPreparedSchema = MLFlowUtils.getInferredPreparationOutputSchema_NT(sm.projectKey, inputDataset, script, outputDatasetType, user);
        return MLFlowUtils.getSchemaToUseForPreparedScoringInput(details.getPreprocessing(), details.splitDesc.schema, desiredPreparedSchema, skipInputChecks);
    }

    static SavedModel getSavedModel(JobActivity activity) {
        FlowSavedModel fsm = MLFlowUtils.getSMInput(activity);
        SavedModel sm = fsm.getSavedModel();
        MLFlowUtils.checkActiveVersion(sm);
        return sm;
    }

    public String resolveVersionIdOverrideEvaluationRecipe(String projectKey, EvaluationRecipePayloadParams desc) {
        VariablesContext vc = this.variablesService.getForProject(projectKey);
        return vc.expand(StringUtils.defaultIfBlank((String)desc.modelVersionId, (String)""));
    }

    private String getModelVersionToUseForEvaluationRecipe(JobActivity activity, SavedModel sm, EvaluationRecipePayloadParams desc) {
        String versionIdOverride = this.resolveVersionIdOverrideEvaluationRecipe(((RecipeRunnableSubgraph)activity.getSubgraph()).getRecipe().getProjectKey(), desc);
        return MLFlowUtils.getModelVersionToUse(sm, versionIdOverride);
    }

    public Schema getBaseScoringOutputSchema(Schema preparedSchema, AbstractPredictionScoringRecipePayloadParams scoring, ModelDetailsBase details, SavedModel sm) {
        Schema outSchema;
        if (scoring.filterInputColumns) {
            outSchema = new Schema();
            for (String colName : scoring.keptInputColumns) {
                SchemaColumn col = preparedSchema.getColumn(colName);
                if (col == null) {
                    throw ErrorContext.iaef((String)"Column '%s' is not in input features", (Object)colName, (Object[])new Object[0]);
                }
                outSchema.addColumn(col);
            }
        } else {
            outSchema = new Schema(preparedSchema);
        }
        Schema columnsToAdd = this.getColumnsToAddForPrediction(scoring, details, sm);
        for (SchemaColumn columnToAdd : columnsToAdd.getColumns()) {
            SchemaColumn oldColumn = outSchema.getColumn(columnToAdd.getName());
            if (oldColumn != null) {
                logger.warn((Object)("Replacing old column " + JSON.json((Object)oldColumn) + " with scoring output " + JSON.json((Object)columnToAdd)));
                outSchema.removeColumn(oldColumn.getName());
            }
            outSchema.addColumn(columnToAdd);
        }
        return outSchema;
    }

    private boolean hasProbas(FullModelId fmi, ClassicalPredictionModelDetails details) {
        boolean hasProbas;
        boolean bl = hasProbas = details.iperf != null ? ((ClassificationModelIntrinsicPerf)details.iperf).probaAware : false;
        if (fmi != null && !fmi.isExternalMLflowModelVersion()) {
            hasProbas = hasProbas || details.modeling.hasProbabilities();
        }
        return hasProbas;
    }

    public Schema getColumnsToAddForPrediction(AbstractPredictionScoringRecipePayloadParams scoring_, ModelDetailsBase details_, SavedModel sm) {
        FullModelId fmi = new FullModelId(sm.getProjectKey(), sm.id, sm.activeVersion);
        Schema columnsToAdd = new Schema();
        ResolvedPredictionCoreParams coreParams = (ResolvedPredictionCoreParams)details_.getCoreParams();
        switch (coreParams.prediction_type) {
            case BINARY_CLASSIFICATION: {
                boolean isSQLEngine;
                PredictionModelDetails details = (ClassicalPredictionModelDetails)details_;
                TabularPredictionScoringRecipePayloadParams scoring = (TabularPredictionScoringRecipePayloadParams)scoring_;
                boolean hasProbas = this.hasProbas(fmi, (ClassicalPredictionModelDetails)details);
                if (scoring.outputProbabilities && hasProbas) {
                    this.addProbabilityColumns(columnsToAdd, details.preprocessing.target_remapping);
                }
                if (!(isSQLEngine = "SQL".equals(scoring.engineType)) && scoring.outputProbaPercentiles && hasProbas) {
                    columnsToAdd.addColumn("proba_percentile", Type.TINYINT);
                }
                columnsToAdd.addColumn(PREDICTION_COL, Type.STRING);
                if (!isSQLEngine && hasProbas) {
                    for (ConditionalOutput co2 : sm.conditionalOutputs) {
                        columnsToAdd.addColumn(co2.name, Type.STRING);
                    }
                }
                this.addOverrideColumn(columnsToAdd, (ClassicalPredictionModelDetails)details);
                this.addExplanations(columnsToAdd, scoring);
                break;
            }
            case MULTICLASS: {
                boolean notBadSQL;
                PredictionModelDetails details = (ClassicalPredictionModelDetails)details_;
                TabularPredictionScoringRecipePayloadParams scoring = (TabularPredictionScoringRecipePayloadParams)scoring_;
                boolean hasProbas = this.hasProbas(fmi, (ClassicalPredictionModelDetails)details);
                boolean bl = notBadSQL = !"SQL".equals(scoring.engineType) || SQLPrediction.canOutputMulticlassProbas(details.modeling.algorithm);
                if (notBadSQL && scoring.outputProbabilities && hasProbas) {
                    this.addProbabilityColumns(columnsToAdd, details.preprocessing.target_remapping);
                }
                columnsToAdd.addColumn(PREDICTION_COL, Type.STRING);
                this.addOverrideColumn(columnsToAdd, (ClassicalPredictionModelDetails)details);
                this.addExplanations(columnsToAdd, scoring);
                break;
            }
            case REGRESSION: {
                PredictionModelDetails details = (ClassicalPredictionModelDetails)details_;
                columnsToAdd.addColumn(PREDICTION_COL, Type.FLOAT);
                this.addExplanations(columnsToAdd, (TabularPredictionScoringRecipePayloadParams)scoring_);
                this.addOverrideColumn(columnsToAdd, (ClassicalPredictionModelDetails)details);
                this.addConfidenceIntervalsColumns(columnsToAdd, (ClassicalPredictionModelDetails)details);
                break;
            }
            case DEEP_HUB_IMAGE_CLASSIFICATION: {
                PredictionModelDetails details = (DeepHubPredictionModelDetails)details_;
                this.addProbabilityColumns(columnsToAdd, ((DeepHubPredictionModelDetails)details).preprocessing.target_remapping);
                columnsToAdd.addColumn(PREDICTION_COL, Type.STRING);
                break;
            }
            case DEEP_HUB_IMAGE_OBJECT_DETECTION: {
                columnsToAdd.addColumn(PREDICTION_COL, Type.STRING);
                break;
            }
            case TIMESERIES_FORECAST: {
                columnsToAdd.addColumn("forecast", Type.DOUBLE);
                PredictionModelDetails details = (TimeseriesForecastingModelDetails)details_;
                TabularPredictionScoringRecipePayloadParams scoring = (TabularPredictionScoringRecipePayloadParams)scoring_;
                if (!scoring.outputProbabilities || !((TimeseriesForecastingModelDetails)details).modeling.hasProbabilities()) break;
                DecimalFormat df = new DecimalFormat("0.####", DecimalFormatSymbols.getInstance(Locale.ENGLISH));
                for (Double quantile : ((TimeseriesForecastingModelDetails)details).coreParams.quantilesToForecast) {
                    columnsToAdd.addColumn("quantile_" + df.format(quantile).replace(".", ""), Type.DOUBLE);
                }
                break;
            }
            case CAUSAL_BINARY_CLASSIFICATION: 
            case CAUSAL_REGRESSION: {
                boolean isMultiValueTreatment;
                PredictionModelDetails details = (CausalPredictionModelDetails)details_;
                TabularPredictionScoringRecipePayloadParams scoring = (TabularPredictionScoringRecipePayloadParams)scoring_;
                boolean bl = isMultiValueTreatment = ((CausalPredictionModelDetails)details).coreParams.enable_multi_treatment && ((CausalPredictionModelDetails)details).coreParams.treatment_values.size() > 2;
                if (isMultiValueTreatment) {
                    for (String treatment : ((CausalPredictionModelDetails)details).coreParams.treatment_values) {
                        if (treatment.equals(((CausalPredictionModelDetails)details).coreParams.control_value) || treatment.equals("") && ((CausalPredictionModelDetails)details).preprocessing.drop_missing_treatment_values) continue;
                        columnsToAdd.addColumn("predicted_effect_" + treatment, Type.FLOAT);
                    }
                    columnsToAdd.addColumn("predicted_best_treatment", Type.STRING);
                } else {
                    columnsToAdd.addColumn("predicted_effect", Type.FLOAT);
                }
                if (scoring.computePropensity && ((CausalPredictionModelDetails)details).modeling.propensityModeling.enabled) {
                    if (((CausalPredictionModelDetails)details).coreParams.enable_multi_treatment && ((CausalPredictionModelDetails)details).coreParams.treatment_values.size() > 2) {
                        for (String treatment : ((CausalPredictionModelDetails)details).coreParams.treatment_values) {
                            if (treatment.equals("") && ((CausalPredictionModelDetails)details).preprocessing.drop_missing_treatment_values) continue;
                            columnsToAdd.addColumn("propensity_" + treatment, Type.FLOAT);
                        }
                    } else {
                        columnsToAdd.addColumn("propensity", Type.FLOAT);
                    }
                }
                if (scoring.assignTreatment && !isMultiValueTreatment) {
                    columnsToAdd.addColumn("treatment_recommended", Type.BOOLEAN);
                }
                this.addExplanations(columnsToAdd, scoring);
                break;
            }
            default: {
                throw new IllegalArgumentException("Unsupported prediction type: " + String.valueOf((Object)coreParams.prediction_type));
            }
        }
        if (scoring_.outputModelMetadata) {
            for (SchemaColumn col : ScoringRecipeUtils.ModelMetadataUtils.getModelMetadataSchemaColumns()) {
                columnsToAdd.addColumn(col);
            }
        }
        return columnsToAdd;
    }

    private void addOverrideColumn(Schema columnsToAdd, ClassicalPredictionModelDetails details) {
        if (details.overridesParams != null && details.overridesParams.hasOverrides()) {
            columnsToAdd.addColumn(OVERRIDE_INFO_COL, Type.STRING);
        }
    }

    private void addConfidenceIntervalsColumns(Schema columnsToAdd, ClassicalPredictionModelDetails details) {
        if (details.coreParams.isPredictionIntervalEnabled()) {
            columnsToAdd.addColumn(PREDICTION_INTERVAL_COL_LOWER, Type.STRING);
            columnsToAdd.addColumn(PREDICTION_INTERVAL_COL_UPPER, Type.STRING);
        }
    }

    public Schema getEvaluationRecipeMetricsSchema_NT(JobActivity activity, EvaluationRecipePayloadParams params, AuthCtx authCtx) throws Exception {
        ResolvedCoreParams coreParams;
        Schema schema = new Schema();
        SavedModel sm = PredictionRecipesBasicService.getSavedModel(activity);
        String versionId = this.getModelVersionToUseForEvaluationRecipe(activity, sm, params);
        FullModelId fmi = new FullModelId(sm.projectKey, sm.id, versionId);
        PredictionModelDetails details = PredictionResultsReader.makeModelDetails(fmi);
        if (SavedModel.SavedModelType.DSS_MANAGED.equals((Object)sm.savedModelType) && (coreParams = fmi.getResolvedCoreParams()) instanceof ResolvedTimeseriesForecastingCoreParams && params.computePerTimeseriesMetrics) {
            ResolvedTimeseriesForecastingCoreParams rtfcp = (ResolvedTimeseriesForecastingCoreParams)coreParams;
            Schema preparedSchema = this.getPreparedSchema_NT(activity, authCtx);
            for (String timeseriesIdentifier : rtfcp.timeseriesIdentifiers) {
                SchemaColumn col = preparedSchema.getColumn(timeseriesIdentifier);
                if (col == null) {
                    throw ErrorContext.iaef((String)"Column '%s' is not in input features, but required to compute per-timeseries metrics", (Object)timeseriesIdentifier, (Object[])new Object[0]);
                }
                schema.addColumn(col);
            }
        }
        schema.addColumn(new SchemaColumn("date", Type.DATE));
        for (String s : params.metrics) {
            schema.addColumn(new SchemaColumn(s, Type.DOUBLE));
        }
        for (String customMetricName : params.getCustomMetrics()) {
            schema.addColumn(new SchemaColumn("custom_" + customMetricName, Type.DOUBLE));
        }
        for (EvaluationCustomEvaluationMetric m : params.customEvaluationMetrics) {
            schema.addColumn(new SchemaColumn(m.name, Type.DOUBLE));
        }
        if (params.computeAssertions) {
            schema.addColumn("passingAssertionsRatio", Type.DOUBLE);
            schema.addColumn("assertionsMetrics", Type.STRING);
        }
        this.addOverridesMetricsColumnIfRelevant(schema, details);
        logger.infoV("Metrics schema: %s", new Object[]{schema});
        return schema;
    }

    private void addOverridesMetricsColumnIfRelevant(Schema schema, PredictionModelDetails details) {
        if (details instanceof ClassicalPredictionModelDetails) {
            ClassicalPredictionModelDetails classicalDetails = (ClassicalPredictionModelDetails)details;
            if (classicalDetails.overridesParams != null && classicalDetails.overridesParams.hasOverrides()) {
                schema.addColumn("overridesMetrics", Type.STRING);
            }
        }
    }

    private void addProbabilityColumns(Schema columnsToAdd, List<PredictionPreprocessingParams.MappingValue> targetRemapping) {
        for (PredictionPreprocessingParams.MappingValue targetClass : targetRemapping) {
            columnsToAdd.addColumn("proba_" + targetClass.sourceValue, Type.DOUBLE);
        }
    }

    private void addExplanations(Schema columnsToAdd, TabularPredictionScoringRecipePayloadParams scoring) {
        if (scoring.outputExplanations && scoring.backendType.supportsExplanations()) {
            columnsToAdd.addColumn("explanations", Type.STRING);
        }
    }
}

