/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction;

import com.dataiku.dip.analysis.ml.prediction.guess.CausalPredictionGuesser;
import com.dataiku.dip.analysis.ml.prediction.guess.ClassicalPredictionGuesser;
import com.dataiku.dip.analysis.ml.prediction.guess.DeepHubPredictionGuesser;
import com.dataiku.dip.analysis.ml.prediction.guess.PredictionGuessPolicy;
import com.dataiku.dip.analysis.ml.prediction.guess.PredictionGuesser;
import com.dataiku.dip.analysis.ml.prediction.guess.TimeseriesForecastingGuesser;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.SplitParams;
import com.dataiku.dip.analysis.model.core.TrainExecutionParams;
import com.dataiku.dip.analysis.model.prediction.CalibrationParams;
import com.dataiku.dip.analysis.model.prediction.DeepHubPreTrainModelingParams;
import com.dataiku.dip.analysis.model.prediction.PartitionedModelParams;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.ResolvedCausalPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.ResolvedDeepHubPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.ResolvedPredictionCoreParams;
import com.dataiku.dip.analysis.model.prediction.ResolvedTimeseriesForecastingCoreParams;
import com.dataiku.dip.analysis.model.prediction.TimeOrderingParams;
import com.dataiku.dip.analysis.model.prediction.TimestepParams;
import com.dataiku.dip.analysis.model.prediction.UncertaintyParams;
import com.dataiku.dip.analysis.model.prediction.WeightParams;
import com.dataiku.dip.analysis.model.prediction.assertions.MLAssertionsParams;
import com.dataiku.dip.analysis.model.prediction.overrides.MLOverridesParams;
import com.dataiku.dip.analysis.model.preprocessing.CausalPredictionPreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.ClassicalPredictionPreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.FeaturePreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.PredictionPreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.TabularPredictionPreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.TimeseriesForecastingPreprocessingParams;
import com.dataiku.dip.datalayer.memimpl.MemTable;
import com.dataiku.j2ts.annotations.UIModel;
import java.io.IOException;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.List;
import java.util.Optional;
import javax.annotation.Nullable;

public abstract class PredictionMLTask
extends MLTask {
    public PredictionType predictionType;
    public String targetVariable;
    public SplitParams splitParams;

    public PredictionMLTask() {
        this.taskType = MLTask.MLTaskType.PREDICTION;
    }

    @Override
    public abstract PredictionPreprocessingParams getPreprocessingParams();

    protected abstract EnumSet<PredictionType> getSupportedPredictionTypes();

    public void setPredictionType(PredictionType predictionType) {
        if (!this.getSupportedPredictionTypes().contains((Object)predictionType)) {
            throw new IllegalArgumentException(String.format("Unsupported prediction type: %s. Valid types are: %s", new Object[]{predictionType, this.getSupportedPredictionTypes()}));
        }
        this.predictionType = predictionType;
    }

    public abstract boolean isPartitioned();

    public abstract PredictionGuesser<? extends PredictionMLTask> getGuesser(MemTable var1);

    public abstract Optional<MLAssertionsParams> getAssertionsParams();

    public Optional<MLOverridesParams> getOverridesParams() {
        return Optional.empty();
    }

    public abstract Optional<PartitionedModelParams> getPartitionedModel();

    @Override
    public abstract ResolvedPredictionCoreParams buildResolvedCoreParams(String var1) throws IOException;

    @UIModel
    public static enum PredictionType {
        BINARY_CLASSIFICATION(PredictionTypeCategory.CLASSIFICATION),
        REGRESSION(PredictionTypeCategory.REGRESSION),
        MULTICLASS(PredictionTypeCategory.CLASSIFICATION),
        DEEP_HUB_IMAGE_OBJECT_DETECTION(PredictionTypeCategory.DEEP_HUB_IMAGE),
        DEEP_HUB_IMAGE_CLASSIFICATION(PredictionTypeCategory.DEEP_HUB_IMAGE),
        TIMESERIES_FORECAST(PredictionTypeCategory.TIMESERIES_FORECAST),
        CAUSAL_BINARY_CLASSIFICATION(PredictionTypeCategory.CAUSAL),
        CAUSAL_REGRESSION(PredictionTypeCategory.CAUSAL);

        public final PredictionTypeCategory category;

        private PredictionType(PredictionTypeCategory category) {
            this.category = category;
        }

        public boolean supportsPartitioning() {
            return this.category.supportsPartitioning;
        }

        public boolean supportsEnsemble() {
            return this.category.supportsEnsemble;
        }

        public boolean supportsMLAssertions() {
            return this.category.supportsMLAssertions;
        }

        public boolean supportsMLOverrides() {
            return this.category.supportsMLOverrides;
        }

        public boolean supportsGuessTrainDeploy() {
            return this.category.supportsGuessTrainDeploy;
        }

        public boolean supportsMLflowAndExternal() {
            return this.category == PredictionTypeCategory.CLASSIFICATION || this.category == PredictionTypeCategory.REGRESSION;
        }

        public boolean canCopySettings(@Nullable PredictionType otherType) {
            return this.category.canCopySettingsWithinCategory ? otherType != null && this.category == otherType.category : this == otherType;
        }
    }

    public static class CausalPredictionMLTask
    extends TabularPredictionMLTask {
        public String positiveClass;
        public String treatmentVariable;
        public String controlValue;
        public List<String> treatmentValues = new ArrayList<String>();
        public boolean enableMultiTreatment;
        public CausalPredictionPreprocessingParams preprocessing;

        public CausalPredictionMLTask() {
            this.guessPolicy = PredictionGuessPolicy.CAUSAL_PREDICTION;
        }

        @Override
        public CausalPredictionPreprocessingParams getPreprocessingParams() {
            return this.preprocessing;
        }

        @Override
        protected EnumSet<PredictionType> getSupportedPredictionTypes() {
            return EnumSet.of(PredictionType.CAUSAL_BINARY_CLASSIFICATION, PredictionType.CAUSAL_REGRESSION);
        }

        @Override
        protected EnumSet<PredictionGuessPolicy> getSupportedGuessPolicies() {
            return EnumSet.of(PredictionGuessPolicy.CAUSAL_PREDICTION);
        }

        @Override
        public ResolvedCausalPredictionCoreParams buildResolvedCoreParams(String projectKey) throws IOException {
            ResolvedCausalPredictionCoreParams rpcp = new ResolvedCausalPredictionCoreParams();
            rpcp.prediction_type = this.predictionType;
            rpcp.target_variable = this.targetVariable;
            rpcp.positive_class = this.positiveClass;
            rpcp.treatment_variable = this.treatmentVariable;
            rpcp.control_value = this.controlValue;
            rpcp.treatment_values = this.treatmentValues;
            rpcp.enable_multi_treatment = this.enableMultiTreatment;
            rpcp.backendType = MLTask.BackendType.PY_MEMORY;
            rpcp.executionParams = TrainExecutionParams.fromMLTask(this, projectKey);
            rpcp.diagnosticsSettings = this.diagnosticsSettings;
            return rpcp;
        }

        public CausalPredictionGuesser getGuesser(MemTable table) {
            return new CausalPredictionGuesser(this, table);
        }

        @Override
        public Optional<MLAssertionsParams> getAssertionsParams() {
            return Optional.empty();
        }
    }

    public static class TimeseriesForecastingMLTask
    extends TabularPredictionMLTask {
        public TimeseriesForecastingPreprocessingParams preprocessing;
        public String timeVariable;
        public TimestepParams timestepParams = new TimestepParams();
        public long predictionLength;
        public List<String> timeseriesIdentifiers = new ArrayList<String>();
        public List<Double> quantilesToForecast = Arrays.asList(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9);
        public boolean customTrainTestSplit = false;
        public List<FoldInterval> customTrainTestIntervals = Arrays.asList(new FoldInterval[0]);
        public TimeseriesEvaluationParams evaluationParams = new TimeseriesEvaluationParams();
        public boolean skipTooShortTimeseriesForTraining = false;

        public TimeseriesForecastingMLTask() {
            this.guessPolicy = PredictionGuessPolicy.TIMESERIES_DEFAULT;
        }

        @Override
        public TimeseriesForecastingPreprocessingParams getPreprocessingParams() {
            return this.preprocessing;
        }

        @Override
        protected EnumSet<PredictionType> getSupportedPredictionTypes() {
            return EnumSet.of(PredictionType.TIMESERIES_FORECAST);
        }

        @Override
        protected EnumSet<PredictionGuessPolicy> getSupportedGuessPolicies() {
            return EnumSet.of(PredictionGuessPolicy.TIMESERIES_DEFAULT, PredictionGuessPolicy.TIMESERIES_STATISTICAL, PredictionGuessPolicy.TIMESERIES_DEEP_LEARNING);
        }

        public PredictionGuesser<TimeseriesForecastingMLTask> getGuesser(MemTable table) {
            return (TimeseriesForecastingGuesser)this.guessPolicy.meta.getGuesser(this, table);
        }

        @Override
        public Optional<MLAssertionsParams> getAssertionsParams() {
            return Optional.empty();
        }

        @Override
        public ResolvedTimeseriesForecastingCoreParams buildResolvedCoreParams(String projectKey) throws IOException {
            ResolvedTimeseriesForecastingCoreParams rtfcp = new ResolvedTimeseriesForecastingCoreParams();
            rtfcp.prediction_type = this.predictionType;
            rtfcp.target_variable = this.targetVariable;
            rtfcp.timeVariable = this.timeVariable;
            rtfcp.timeseriesIdentifiers = this.timeseriesIdentifiers;
            rtfcp.timestepParams = this.timestepParams;
            rtfcp.predictionLength = this.predictionLength;
            rtfcp.quantilesToForecast = this.quantilesToForecast;
            rtfcp.customTrainTestSplit = this.customTrainTestSplit;
            rtfcp.customTrainTestIntervals = this.customTrainTestIntervals;
            rtfcp.evaluationParams = this.evaluationParams;
            rtfcp.skipTooShortTimeseriesForTraining = this.skipTooShortTimeseriesForTraining;
            rtfcp.partitionedModel = this.partitionedModel;
            rtfcp.backendType = this.backendType;
            rtfcp.executionParams = TrainExecutionParams.fromMLTask(this, projectKey);
            rtfcp.diagnosticsSettings = this.diagnosticsSettings;
            return rtfcp;
        }

        public static DateTimeFormatter getDateTimeFormatter() {
            return DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss.SSS");
        }

        public static class FoldInterval {
            public List<String> train;
            public List<String> test;
        }

        public static class TimeseriesEvaluationParams {
            public long gapSize = 0L;
            public long testSize = 10L;
        }
    }

    public static class ClassicalPredictionMLTask
    extends TabularPredictionMLTask {
        public WeightParams weight = new WeightParams();
        public TimeOrderingParams time = new TimeOrderingParams();
        public CalibrationParams calibration = new CalibrationParams();
        public ClassicalPredictionPreprocessingParams preprocessing;
        public MLAssertionsParams assertionsParams = new MLAssertionsParams();
        public MLOverridesParams overridesParams = new MLOverridesParams();
        public UncertaintyParams uncertainty = new UncertaintyParams();
        @Nullable
        public String managedFolderSmartId;

        public ClassicalPredictionMLTask() {
            this.guessPolicy = PredictionGuessPolicy.DEFAULT;
        }

        @Override
        public ClassicalPredictionPreprocessingParams getPreprocessingParams() {
            return this.preprocessing;
        }

        @Override
        protected EnumSet<PredictionGuessPolicy> getSupportedGuessPolicies() {
            return EnumSet.of(PredictionGuessPolicy.DEFAULT, new PredictionGuessPolicy[]{PredictionGuessPolicy.INTERPRETABLE, PredictionGuessPolicy.PERFORMANCE, PredictionGuessPolicy.ALGORITHMS, PredictionGuessPolicy.DEEP, PredictionGuessPolicy.CUSTOM, PredictionGuessPolicy.SIMPLE_FORMULA, PredictionGuessPolicy.DECISION_TREE, PredictionGuessPolicy.EXPLANATORY});
        }

        @Override
        protected EnumSet<PredictionType> getSupportedPredictionTypes() {
            return EnumSet.of(PredictionType.BINARY_CLASSIFICATION, PredictionType.MULTICLASS, PredictionType.REGRESSION);
        }

        public PredictionGuesser<ClassicalPredictionMLTask> getGuesser(MemTable table) {
            return (ClassicalPredictionGuesser)this.guessPolicy.meta.getGuesser(this, table);
        }

        @Override
        public Optional<MLAssertionsParams> getAssertionsParams() {
            return Optional.ofNullable(this.assertionsParams);
        }

        @Override
        public Optional<MLOverridesParams> getOverridesParams() {
            return Optional.ofNullable(this.overridesParams);
        }

        @Override
        public ResolvedClassicalPredictionCoreParams buildResolvedCoreParams(String projectKey) throws IOException {
            ResolvedClassicalPredictionCoreParams rpcp = new ResolvedClassicalPredictionCoreParams();
            rpcp.prediction_type = this.predictionType;
            rpcp.target_variable = this.targetVariable;
            rpcp.weight = this.weight;
            rpcp.time = this.time;
            rpcp.calibration = this.calibration;
            rpcp.partitionedModel = this.partitionedModel;
            rpcp.backendType = this.backendType;
            rpcp.executionParams = TrainExecutionParams.fromMLTask(this, projectKey);
            rpcp.diagnosticsSettings = this.diagnosticsSettings;
            rpcp.managedFolderSmartId = this.isManagedFolderUsed() ? this.managedFolderSmartId : null;
            rpcp.uncertainty = this.uncertainty;
            return rpcp;
        }

        private boolean isManagedFolderUsed() {
            for (FeaturePreprocessingParams featureParams : this.preprocessing.per_feature.values()) {
                if (!featureParams.role.equals((Object)FeaturePreprocessingParams.Role.INPUT) || !featureParams.type.equals((Object)FeaturePreprocessingParams.FeatureType.IMAGE)) continue;
                return true;
            }
            return false;
        }
    }

    public static abstract class TabularPredictionMLTask
    extends PredictionMLTask {
        public PredictionGuessPolicy guessPolicy;
        public PartitionedModelParams partitionedModel = new PartitionedModelParams();
        public PredictionModelingParams modeling;

        @Override
        public abstract TabularPredictionPreprocessingParams getPreprocessingParams();

        @Override
        public boolean isPartitioned() {
            return (this.predictionType == null || this.predictionType.supportsPartitioning()) && this.partitionedModel != null && this.partitionedModel.isEnabled();
        }

        public void setGuessPolicy(PredictionGuessPolicy guessPolicy) {
            if (!this.getSupportedGuessPolicies().contains((Object)guessPolicy)) {
                throw new IllegalArgumentException(String.format("Unsupported guess policy: %s. Valid policies are: %s", new Object[]{guessPolicy, this.getSupportedGuessPolicies()}));
            }
            this.guessPolicy = guessPolicy;
        }

        protected abstract EnumSet<PredictionGuessPolicy> getSupportedGuessPolicies();

        @Override
        public Optional<PartitionedModelParams> getPartitionedModel() {
            return Optional.ofNullable(this.partitionedModel);
        }
    }

    @UIModel
    public static class DeepHubPredictionMLTask
    extends PredictionMLTask {
        public PredictionPreprocessingParams preprocessing;
        public DeepHubPreTrainModelingParams modeling;
        public String pathColumn;
        public String managedFolderSmartId;

        @Override
        public PredictionPreprocessingParams getPreprocessingParams() {
            return this.preprocessing;
        }

        @Override
        public ResolvedDeepHubPredictionCoreParams buildResolvedCoreParams(String projectKey) throws IOException {
            ResolvedDeepHubPredictionCoreParams pcp = new ResolvedDeepHubPredictionCoreParams();
            pcp.prediction_type = this.predictionType;
            pcp.target_variable = this.targetVariable;
            pcp.backendType = this.backendType;
            pcp.executionParams = TrainExecutionParams.fromMLTask(this, projectKey);
            pcp.diagnosticsSettings = this.diagnosticsSettings;
            pcp.managedFolderSmartId = this.managedFolderSmartId;
            pcp.pathColumn = this.pathColumn;
            return pcp;
        }

        @Override
        public boolean isPartitioned() {
            return false;
        }

        @Override
        protected EnumSet<PredictionType> getSupportedPredictionTypes() {
            return EnumSet.of(PredictionType.DEEP_HUB_IMAGE_OBJECT_DETECTION, PredictionType.DEEP_HUB_IMAGE_CLASSIFICATION);
        }

        public DeepHubPredictionGuesser getGuesser(MemTable table) {
            return new DeepHubPredictionGuesser(this, table);
        }

        @Override
        public Optional<MLAssertionsParams> getAssertionsParams() {
            return Optional.empty();
        }

        @Override
        public Optional<PartitionedModelParams> getPartitionedModel() {
            return Optional.empty();
        }
    }

    public static enum PredictionTypeCategory {
        CLASSIFICATION(true),
        REGRESSION(true),
        DEEP_HUB_IMAGE(false),
        TIMESERIES_FORECAST(true, false, false, false, false, false, true),
        CAUSAL(false);

        public final boolean supportsPartitioning;
        public final boolean supportsEnsemble;
        public final boolean supportsDocumentExport;
        public final boolean supportsMLAssertions;
        public final boolean supportsMLOverrides;
        public final boolean supportsGuessTrainDeploy;
        public final boolean canCopySettingsWithinCategory;

        private PredictionTypeCategory(boolean simple) {
            this.supportsPartitioning = simple;
            this.supportsEnsemble = simple;
            this.supportsDocumentExport = simple;
            this.supportsMLAssertions = simple;
            this.supportsMLOverrides = simple;
            this.supportsGuessTrainDeploy = simple;
            this.canCopySettingsWithinCategory = simple;
        }

        private PredictionTypeCategory(boolean supportsPartitioning, boolean supportsEnsemble, boolean supportsDocumentExport, boolean supportsMLAssertions, boolean supportsMLOverrides, boolean supportsGuessTrainDeploy, boolean canCopySettingsWithinCategory) {
            this.supportsPartitioning = supportsPartitioning;
            this.supportsEnsemble = supportsEnsemble;
            this.supportsDocumentExport = supportsDocumentExport;
            this.supportsMLAssertions = supportsMLAssertions;
            this.supportsMLOverrides = supportsMLOverrides;
            this.supportsGuessTrainDeploy = supportsGuessTrainDeploy;
            this.canCopySettingsWithinCategory = canCopySettingsWithinCategory;
        }
    }
}

