/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.prediction.ModelDataUtilsService;
import com.dataiku.dip.analysis.ml.prediction.overrides.FormulaOverridesOutcomeComputer;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.clustering.ClusteringMLTask;
import com.dataiku.dip.analysis.model.core.ModelUserMeta;
import com.dataiku.dip.analysis.model.prediction.ClassificationModelIntrinsicPerf;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.ResolvedPredictionPreprocessingParams;
import com.dataiku.dip.analysis.model.prediction.overrides.MLOverridesParams;
import com.dataiku.dip.analysis.model.preprocessing.ClassicalPredictionPreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.PredictionPreprocessingParams;
import com.dataiku.dip.datalayer.Column;
import com.dataiku.dip.datalayer.ColumnFactory;
import com.dataiku.dip.datalayer.Row;
import com.dataiku.dip.datalayer.memimpl.MemRow;
import com.dataiku.dip.datalayer.memimpl.MemTable;
import com.dataiku.dip.futures.FutureProgress;
import com.dataiku.dip.futures.FutureProgressState;
import com.dataiku.dip.shaker.SampleBuilder;
import com.dataiku.dip.shaker.filter.FilterRequest;
import com.dataiku.dip.shaker.model.SerializedShakerScript;
import com.dataiku.dip.shaker.server.DataService;
import com.dataiku.dip.shaker.server.MemScriptRunner;
import com.dataiku.dip.shaker.services.TypeInferrer2;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.warnings.WarningsContext;
import com.dataiku.scoring.models.overrides.MLOverridesParamsBase;
import com.dataiku.scoring.pipelines.ClassificationResult;
import com.dataiku.scoring.pipelines.OverrideInfo;
import com.dataiku.scoring.pipelines.Result;
import com.dataiku.scoring.pipelines.overrides.OverridesOutcomeComputer;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.commons.codec.digest.DigestUtils;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class PredictedDataService {
    @Autowired
    private ModelDataUtilsService modelDataUtilsService;
    TypeInferrer2 inferer = new TypeInferrer2();
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.prediction");

    public static String getSampleId(FullModelId fmi, SerializedShakerScript.RefreshableStreamableSelection sampling) throws IOException {
        logger.info((Object)("Compute sample id for " + fmi.toString() + " SAMPLING=  " + JSON.log((Object)sampling)));
        Object sampleId = fmi.toString();
        MLTask task = fmi.getHeadMLTask();
        if (task.taskType == MLTask.MLTaskType.PREDICTION) {
            PredictionMLTask pmlTask = (PredictionMLTask)task;
            if (pmlTask.predictionType == PredictionMLTask.PredictionType.BINARY_CLASSIFICATION && fmi.parseModelFile((String)"iperf.json", ClassificationModelIntrinsicPerf.class).probaAware) {
                sampleId = (String)sampleId + "-cut=" + fmi.parseModelFile((String)"user_meta.json", ModelUserMeta.class).activeClassifierThreshold;
            }
        }
        sampleId = (String)sampleId + DigestUtils.md5Hex((String)JSON.json((Object)((Object)sampling.selection))) + sampling._refreshTrigger;
        return sampleId;
    }

    public static String getSampleId(FullModelId fmi, SerializedShakerScript.ShakerExplorationSampleSettings sampling) throws IOException {
        logger.info((Object)("Compute sample id for " + fmi.toString() + " SAMPLING=  " + JSON.log((Object)sampling)));
        Object sampleId = fmi.toString();
        MLTask task = fmi.getHeadMLTask();
        if (task.taskType == MLTask.MLTaskType.PREDICTION) {
            PredictionMLTask pmlTask = (PredictionMLTask)task;
            if (pmlTask.predictionType == PredictionMLTask.PredictionType.BINARY_CLASSIFICATION && fmi.parseModelFile((String)"iperf.json", ClassificationModelIntrinsicPerf.class).probaAware) {
                sampleId = (String)sampleId + "-cut=" + fmi.parseModelFile((String)"user_meta.json", ModelUserMeta.class).activeClassifierThreshold;
            }
        }
        sampleId = (String)sampleId + DigestUtils.md5Hex((String)JSON.json((Object)((Object)sampling.selection))) + sampling._refreshTrigger;
        return sampleId;
    }

    public synchronized MemScriptRunner.TableWithReport getUncachedFiltered_NT(FullModelId fmi, SerializedShakerScript script, FilterRequest filters) throws Exception {
        return this.getUncachedFiltered_NT(fmi, script.explorationSampling, filters, PredictedDataService.getDefaultDerivedColumnsComputer(fmi), script.sorting);
    }

    public synchronized MemScriptRunner.TableWithReport getUncachedUnfiltered_NOTRANSACTION(FullModelId fmi, SerializedShakerScript.ShakerExplorationSampleSettings sampling) throws Exception {
        return this.getUncachedUnfiltered_NOTRANSACTION(fmi, sampling, PredictedDataService.getDefaultDerivedColumnsComputer(fmi));
    }

    public synchronized MemScriptRunner.TableWithReport getUncachedFiltered_NT(FullModelId fmi, SerializedShakerScript.ShakerExplorationSampleSettings explorationSampling, FilterRequest filters, DerivedColumnsComputer derivedColumnsComputer, List<SerializedShakerScript.TableSorting> sorting) throws Exception {
        MemScriptRunner.TableWithReport twr = this.getUncachedUnfiltered_NOTRANSACTION(fmi, explorationSampling, derivedColumnsComputer);
        this.modelDataUtilsService.applyFiltersAndSorts(twr, filters, sorting);
        return twr;
    }

    /*
     * Unable to fully structure code
     */
    public synchronized MemScriptRunner.TableWithReport getUncachedUnfiltered_NOTRANSACTION(FullModelId fmi, SerializedShakerScript.ShakerExplorationSampleSettings sampling, DerivedColumnsComputer derivedPredictedColumnsComputer) throws Exception {
        state = FutureProgress.pushAutoCloseableState((String)"Computing", (double)4.0, (FutureProgressState.StateUnit)FutureProgressState.StateUnit.NONE);
        try {
            sampleId = PredictedDataService.getSampleId(fmi, sampling);
            PredictedDataService.logger.info((Object)("PredictedData disk sample id is " + sampleId));
            task = fmi.getHeadMLTask();
            splitDesc = fmi.getSplitDesc();
            before = System.currentTimeMillis();
            ret = new MemScriptRunner.TableWithReport();
            warningsContext = new WarningsContext();
            sampleToRead = SampleBuilder.getPredictedDataSampleMeta(fmi, sampleId);
            FutureProgress.updateState((double)1.0);
            if (sampleToRead != null) {
                PredictedDataService.logger.info((Object)"Disk sample cache hit");
            } else {
                PredictedDataService.logger.info((Object)"Disk sample cache miss");
                SampleBuilder.clearPredictedDataSamples(fmi);
                buildingState = FutureProgress.pushAutoCloseableState((String)"Building sample");
                try {
                    switch (1.$SwitchMap$com$dataiku$dip$analysis$model$MLTask$MLTaskType[task.taskType.ordinal()]) {
                        case 1: {
                            SampleBuilder.buildPredictedSampleForClustering(fmi, sampleId, splitDesc);
                            ** break;
lbl23:
                            // 1 sources

                            break;
                        }
                        case 2: {
                            if (task instanceof PredictionMLTask.TimeseriesForecastingMLTask) {
                                SampleBuilder.buildPredictedSampleForTimeseriesForecast(fmi, sampleId);
                                ** break;
lbl28:
                                // 1 sources

                            } else {
                                SampleBuilder.buildPredictedSampleForPrediction(fmi, sampleId, splitDesc);
                            }
                            break;
                        }
                        ** default:
lbl32:
                        // 1 sources

                        break;
                    }
                }
                finally {
                    if (buildingState != null) {
                        buildingState.close();
                    }
                }
                sampleToRead = SampleBuilder.getPredictedDataSampleMeta(fmi, sampleId);
            }
            if (!PredictedDataService.$assertionsDisabled && sampleToRead == null) {
                throw new AssertionError();
            }
            FutureProgress.updateState((double)2.0);
            PredictedDataService.logger.info((Object)("Opening sample " + sampleId));
            ret.usedSample = sampleToRead;
            readingState = FutureProgress.pushAutoCloseableState((String)"Reading sample");
            try {
                ret.table = SampleBuilder.readPredictedSample(fmi, sampleToRead.id);
                derivedPredictedColumnsComputer.compute(ret.table);
                ret.initialRows = ret.table.nrows();
                ret.initialCols = ret.table.ncols();
                PredictedDataService.logger.info((Object)("Reading sample done, read " + ret.table.nrows() + " rows"));
            }
            finally {
                if (readingState != null) {
                    readingState.close();
                }
            }
            FutureProgress.updateState((double)3.0);
            ret.table.compact();
            FutureProgress.updateState((double)4.0);
            PredictedDataService.logger.info((Object)("Serialized warnings " + String.valueOf(warningsContext) + " -> " + JSON.log((Object)warningsContext.getOutput())));
            ret.warnings = warningsContext.getOutput();
            inferringState = FutureProgress.pushAutoCloseableState((String)"Detecting types");
            try {
                infererCacheKey = fmi.toString() + JSON.json((Object)sampling);
                this.inferer.processFullAuto(infererCacheKey, ret.table);
            }
            finally {
                if (inferringState != null) {
                    inferringState.close();
                }
            }
            FutureProgress.updateState((double)5.0);
            inferDone = System.currentTimeMillis();
            PredictedDataService.logger.info((Object)("PredictedDataService done time =  " + (inferDone - before)));
            var15_23 = ret;
            return var15_23;
        }
        finally {
            if (state != null) {
                state.close();
            }
        }
    }

    public DataService.ColumnDetailedAnalysis getDetailedColumnAnalysis(FullModelId fmi, SerializedShakerScript ss, String column, int alphanumMaxResults) throws Exception {
        MemTable table = this.getUncachedUnfiltered_NOTRANSACTION((FullModelId)fmi, (SerializedShakerScript.ShakerExplorationSampleSettings)ss.explorationSampling).table;
        return this.modelDataUtilsService.getDetailedColumnAnalysis(table, ss, column, alphanumMaxResults);
    }

    private static FormulaOverridesOutcomeComputer.RowFormulaOutcomeComputer initOutcomeComputer(FullModelId fmi, ColumnFactory columnFactory) {
        try {
            MLOverridesParams overridesParams = fmi.getOverridesParams();
            if (overridesParams.hasOverrides()) {
                FormulaOverridesOutcomeComputer.RowFormulaOutcomeComputer outcomeComputer = new FormulaOverridesOutcomeComputer.RowFormulaOutcomeComputer(overridesParams, columnFactory);
                outcomeComputer.init();
                return outcomeComputer;
            }
        }
        catch (IOException e) {
            logger.warn((Object)"Could not load overrides, ignoring them", (Throwable)e);
        }
        return null;
    }

    public static DerivedColumnsComputer getDefaultDerivedColumnsComputer(FullModelId fmi) throws IOException {
        MLTask mlTask = fmi.getHeadMLTask();
        switch (mlTask.taskType) {
            case PREDICTION: {
                PredictionMLTask pmlTask = (PredictionMLTask)mlTask;
                switch (pmlTask.predictionType) {
                    case BINARY_CLASSIFICATION: {
                        return new PredictionBinaryClassificationDerivedColumnsComputer(fmi);
                    }
                    case MULTICLASS: {
                        return new PredictionMulticlassDerivedColumnsComputer(fmi);
                    }
                }
                return new NoopDerivedColumnsComputer();
            }
            case CLUSTERING: {
                return new ClusteringDerivedColumnsComputer(fmi);
            }
        }
        return new NoopDerivedColumnsComputer();
    }

    public static interface DerivedColumnsComputer {
        public void compute(MemTable var1) throws IOException;

        public void compute(Row var1, ColumnFactory var2) throws IOException;

        public void initializeComputeColumns(ColumnFactory var1) throws IOException;

        public void cleanEphemeralComputeColumns(ColumnFactory var1) throws IOException;
    }

    public static class PredictionBinaryClassificationDerivedColumnsComputer
    extends BaseDerivedColumnsComputer {
        private FullModelId fmi;
        private PredictionMLTask.ClassicalPredictionMLTask mlTask;
        private boolean probaAware;
        private String positiveValue;
        private String negativeValue;
        private ModelUserMeta mum;
        private FormulaOverridesOutcomeComputer.RowFormulaOutcomeComputer overridesOutcomeComputer;
        private Column positiveProbaCol;
        private Column negativeProbaCol;
        private Column targetCol;
        private Column predictionCol;
        private Column correctCol;
        private Column cmgCol;
        private Column overrideCol;
        private Column uncertaintyCol;
        private boolean withOverrides;

        public PredictionBinaryClassificationDerivedColumnsComputer(FullModelId fmi) throws IOException {
            this.fmi = fmi;
            this.mlTask = (PredictionMLTask.ClassicalPredictionMLTask)fmi.getHeadMLTask();
            this.probaAware = fmi.parseModelFile((String)"iperf.json", ClassificationModelIntrinsicPerf.class).probaAware;
            if (this.probaAware) {
                this.mum = fmi.parseModelFile("user_meta.json", ModelUserMeta.class);
                if (FullModelId.Type.ANALYSIS.equals((Object)fmi.getType())) {
                    ClassicalPredictionPreprocessingParams predictionPreprocessingParams = this.mlTask.getPreprocessingParams();
                    this.positiveValue = predictionPreprocessingParams.getSourceValueForMapped(1);
                    this.negativeValue = predictionPreprocessingParams.getSourceValueForMapped(0);
                } else if (FullModelId.Type.SAVED.equals((Object)fmi.getType())) {
                    ResolvedPredictionPreprocessingParams resolvedPredictionPreprocessingParams = (ResolvedPredictionPreprocessingParams)fmi.getResolvedPreprocessingParams();
                    this.positiveValue = resolvedPredictionPreprocessingParams.getSourceValueForMapped(1);
                    this.negativeValue = resolvedPredictionPreprocessingParams.getSourceValueForMapped(0);
                } else {
                    throw new IllegalArgumentException("Invalid FMI type: " + fmi.getType().name());
                }
            }
        }

        @Override
        public void initializeComputeColumns(ColumnFactory columnFactory) throws IOException {
            this.overridesOutcomeComputer = PredictedDataService.initOutcomeComputer(this.fmi, columnFactory);
            String positiveProbaColName = "proba_" + this.positiveValue;
            String negativeProbaColName = "proba_" + this.negativeValue;
            this.positiveProbaCol = columnFactory.column(positiveProbaColName);
            this.negativeProbaCol = columnFactory.column(negativeProbaColName);
            this.targetCol = columnFactory.column(this.mlTask.targetVariable);
            this.predictionCol = columnFactory.column("prediction");
            this.correctCol = columnFactory.column("prediction_correct");
            this.cmgCol = columnFactory.column("costmatrix_gain");
            this.overrideCol = null;
            this.uncertaintyCol = null;
            boolean bl = this.withOverrides = this.overridesOutcomeComputer != null;
            if (this.withOverrides) {
                this.overrideCol = columnFactory.column("override");
                this.uncertaintyCol = columnFactory.column("prediction_uncertainty");
            }
        }

        @Override
        public void cleanEphemeralComputeColumns(ColumnFactory columnFactory) throws IOException {
            if (this.withOverrides) {
                columnFactory.deleteColumn(this.uncertaintyCol.getName());
            }
        }

        @Override
        public void compute(Row row, ColumnFactory columnFactory) {
            if (this.probaAware) {
                double rawProba = row.getAsDoubleOrNaN(this.positiveProbaCol);
                if (Double.isNaN(rawProba)) {
                    return;
                }
                boolean predictedPositive = rawProba > this.mum.activeClassifierThreshold;
                row.put(this.predictionCol, predictedPositive ? this.positiveValue : this.negativeValue);
                if (this.withOverrides) {
                    row.put(this.uncertaintyCol, 1.0 - Math.max(rawProba, 1.0 - rawProba));
                    OverridesOutcomeComputer.OutcomeCandidate newCandidate = this.overridesOutcomeComputer.getOutcomeCandidate(row);
                    row.delete(this.uncertaintyCol);
                    if (newCandidate != null) {
                        double[] probas = new double[]{row.getAsDoubleOrNaN(this.negativeProbaCol), row.getAsDoubleOrNaN(this.positiveProbaCol)};
                        ClassificationResult rawResult = new ClassificationResult(row.get(this.predictionCol), probas, null);
                        MLOverridesParamsBase.MLOverride.Outcome newOutcome = newCandidate.outcome;
                        String[] classes = new String[]{this.negativeValue, this.positiveValue};
                        ClassificationResult.RawResult r = new ClassificationResult.RawResult(rawResult, classes);
                        if (newOutcome.type == MLOverridesParamsBase.MLOverride.Outcome.Type.DECLINED) {
                            OverrideInfo overrideInfo = OverrideInfo.declined((String)newCandidate.overrideName, (Result.RawResult)r);
                            row.put(this.predictionCol, null);
                            row.put(this.positiveProbaCol, null);
                            row.put(this.negativeProbaCol, null);
                            row.put(this.overrideCol, overrideInfo.toJson());
                            return;
                        }
                        OverrideInfo overrideInfo = new OverrideInfo(newCandidate.overrideName, Boolean.valueOf(!Objects.equals(newOutcome.category, rawResult.getPrediction())), (Result.RawResult)r);
                        row.put(this.overrideCol, overrideInfo.toJson());
                        row.put(this.predictionCol, newOutcome.category);
                        if (this.positiveValue.equals(newOutcome.category)) {
                            row.put(this.positiveProbaCol, 1.0);
                            row.put(this.negativeProbaCol, 0.0);
                            predictedPositive = true;
                        } else {
                            row.put(this.positiveProbaCol, 0.0);
                            row.put(this.negativeProbaCol, 1.0);
                            predictedPositive = false;
                        }
                    } else {
                        row.put(this.overrideCol, OverrideInfo.noMatch().toJson());
                    }
                }
                boolean correct = false;
                String targetVal = row.get(this.targetCol);
                if (!StringUtils.isBlank((String)targetVal)) {
                    correct = targetVal.equals(predictedPositive ? this.positiveValue : this.negativeValue);
                }
                row.put(this.correctCol, correct);
                double cmg = 0.0;
                if (predictedPositive && correct) {
                    cmg = this.mlTask.modeling.metrics.costMatrixWeights.tpGain;
                }
                if (predictedPositive && !correct) {
                    cmg = this.mlTask.modeling.metrics.costMatrixWeights.fpGain;
                }
                if (!predictedPositive && correct) {
                    cmg = this.mlTask.modeling.metrics.costMatrixWeights.tnGain;
                }
                if (!predictedPositive && !correct) {
                    cmg = this.mlTask.modeling.metrics.costMatrixWeights.fnGain;
                }
                row.put(this.cmgCol, cmg);
            }
        }
    }

    public static class PredictionMulticlassDerivedColumnsComputer
    extends BaseDerivedColumnsComputer {
        private final PredictionMLTask predictionMLTask;
        private Map<String, Integer> forwardMap;
        private Column targetCol;
        private Column predictionCol;
        private Column actualIdCol;
        private Column predictedIdCol;
        private Column correctCol;

        public PredictionMulticlassDerivedColumnsComputer(FullModelId fmi) throws IOException {
            this.predictionMLTask = (PredictionMLTask)fmi.getHeadMLTask();
            if (FullModelId.Type.ANALYSIS.equals((Object)fmi.getType())) {
                PredictionPreprocessingParams predictionPreprocessingParams = this.predictionMLTask.getPreprocessingParams();
                this.forwardMap = predictionPreprocessingParams.getTargetForwardMap();
            } else if (FullModelId.Type.SAVED.equals((Object)fmi.getType())) {
                ResolvedPredictionPreprocessingParams resolvedPredictionPreprocessingParams = (ResolvedPredictionPreprocessingParams)fmi.getResolvedPreprocessingParams();
                this.forwardMap = resolvedPredictionPreprocessingParams.getTargetForwardMap();
            } else {
                throw new IllegalArgumentException("Invalid FMI type: " + fmi.getType().name());
            }
        }

        @Override
        public void initializeComputeColumns(ColumnFactory columnFactory) throws IOException {
            this.targetCol = columnFactory.column(this.predictionMLTask.targetVariable);
            this.predictionCol = columnFactory.column("prediction");
            this.actualIdCol = columnFactory.column("actual_class_id");
            this.predictedIdCol = columnFactory.column("predicted_class_id");
            this.correctCol = columnFactory.column("prediction_correct");
        }

        @Override
        public void compute(Row row, ColumnFactory columnFactory) {
            Integer predictedIdVal;
            String actualVal = row.get(this.targetCol);
            String predictedVal = row.get(this.predictionCol);
            Integer actualIdVal = this.forwardMap.get(actualVal);
            if (actualIdVal != null) {
                row.put(this.actualIdCol, actualIdVal.intValue());
            }
            if ((predictedIdVal = this.forwardMap.get(predictedVal)) != null) {
                row.put(this.predictedIdCol, predictedIdVal.intValue());
            }
            if (actualIdVal == null) {
                row.put(this.correctCol, predictedIdVal == null);
            } else {
                row.put(this.correctCol, actualIdVal.equals(predictedIdVal));
            }
        }

        @Override
        public void compute(MemTable table) throws IOException {
            super.compute(table);
            if (table.hasNonDeletedColumn("fold_id")) {
                table.moveAtEnd("fold_id");
            }
        }
    }

    public static class NoopDerivedColumnsComputer
    extends BaseDerivedColumnsComputer {
        @Override
        public void compute(Row row, ColumnFactory columnFactory) throws IOException {
        }
    }

    public static class ClusteringDerivedColumnsComputer
    extends BaseDerivedColumnsComputer {
        private final ClusteringMLTask mlTask;
        private ModelUserMeta mum;
        private Column cluster_labels;
        private Column cluster_id;

        public ClusteringDerivedColumnsComputer(FullModelId fmi) throws IOException {
            this.mlTask = (ClusteringMLTask)fmi.getHeadMLTask();
            this.mum = fmi.parseModelFile("user_meta.json", ModelUserMeta.class);
        }

        @Override
        public void initializeComputeColumns(ColumnFactory columnFactory) throws IOException {
            this.cluster_labels = columnFactory.column("cluster_labels");
            this.cluster_id = columnFactory.column("cluster_id");
        }

        @Override
        public void compute(Row row, ColumnFactory columnFactory) throws IOException {
            if (this.mlTask.modeling.isolation_forest.enabled) {
                return;
            }
            String cl = row.get(this.cluster_labels);
            if (!StringUtils.isBlank((String)cl)) {
                ModelUserMeta.ClusterMeta meta = this.mum.clusterMetas.get(cl);
                if (meta == null) {
                    return;
                }
                row.put(this.cluster_labels, meta.name);
                if (cl.equals("cluster_outliers")) {
                    row.put(this.cluster_id, "-1");
                } else if (cl.startsWith("cluster_")) {
                    row.put(this.cluster_id, cl.replace("cluster_", ""));
                }
            }
        }
    }

    public static abstract class BaseDerivedColumnsComputer
    implements DerivedColumnsComputer {
        @Override
        public void compute(MemTable table) throws IOException {
            this.initializeComputeColumns(table);
            for (MemRow row : table.rows) {
                this.compute(row, table);
            }
            this.cleanEphemeralComputeColumns(table);
        }

        @Override
        public void initializeComputeColumns(ColumnFactory columnFactory) throws IOException {
        }

        @Override
        public void cleanEphemeralComputeColumns(ColumnFactory columnFactory) throws IOException {
        }
    }
}

