/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction;

import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.PreTrainModelingParams;
import com.dataiku.dip.analysis.model.prediction.EnsembleParams;
import com.dataiku.dip.analysis.model.prediction.MetricParams;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.algorithms.EnsembleMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.PredictionAlgorithmMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.VirtualAlgorithmMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.ArimaMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.AutoArimaMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.CausalForestMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.CrostonMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.CustomScikitMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.DeepNeuralNetworkMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.ETSMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.GluonTSMXNetDeepARMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.GluonTSMXNetMQCNNMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.GluonTSMXNetSimpleFeedForwardMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.GluonTSMXNetTransformerMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.GluonTSNPTSForecasterMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.GluonTSTorchDeepARMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.GluonTSTorchSimpleFeedForwardMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.KNNMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.KerasMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.LarsMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.LassoMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.LightGBMMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.NeuralNetworkMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.OLSMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.ProphetMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyCustomPluginMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyDecisionTreeMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyExtraTreesMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyGradientBoostingMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyLogisticMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyRandomForestMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.RidgeMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.SGDMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.SVMMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.SeasonalLoessMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.SeasonalNaiveMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.TrivialIdentityTimeseriesMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.XGBoostMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.spark.MLLibCustomMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.spark.MLLibDecisionTreeMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.spark.MLLibGBTMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.spark.MLLibLinearRegressionMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.spark.MLLibLogisticRegressionMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.spark.MLLibNBMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.spark.MLLibRandomForestMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.spark.SparklingDeepLearningMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.spark.SparklingGBMMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.spark.SparklingGLMMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.spark.SparklingNBMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.spark.SparklingRandomForestMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.vertica.VerticaLinearRegressionMeta;
import com.dataiku.dip.analysis.model.prediction.algorithms.vertica.VerticaLogitMeta;
import com.dataiku.dip.utils.JSON;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import javax.annotation.Nullable;

public class PreTrainPredictionModelingParams
implements PreTrainModelingParams {
    public Algorithm algorithm;
    public PredictionModelingParams.RandomForestHyperparametersSpace rf_regressor_grid;
    public PredictionModelingParams.RandomForestHyperparametersSpace rf_classifier_grid;
    public PredictionModelingParams.RandomForestHyperparametersSpace extra_trees_grid;
    public PredictionModelingParams.GBTClassificationHyperparametersSpace gbt_classifier_grid;
    public PredictionModelingParams.GBTRegressionHyperparametersSpace gbt_regressor_grid;
    public PredictionModelingParams.DecisionTreeHyperparametersSpace dtc_classifier_grid;
    public PredictionModelingParams.LogisticRegressionHyperparametersSpace logit_grid;
    public PredictionModelingParams.NeuralNetworkHyperparametersSpace neural_network_grid;
    public PredictionModelingParams.SVMHyperparametersSpace svc_grid;
    public PredictionModelingParams.SVMHyperparametersSpace svr_grid;
    public PredictionModelingParams.LeastSquareHyperparametersSpace least_squares_grid;
    public PredictionModelingParams.SGDClassificationHyperparametersSpace sgd_grid;
    public PredictionModelingParams.SGDRegressionHyperparametersSpace sgd_reg_grid;
    public PredictionModelingParams.RidgeRegressionHyperparametersSpace ridge_grid;
    public PredictionModelingParams.LassoHyperparametersSpace lasso_grid;
    public PredictionModelingParams.LarsHyperparametersSpace lars_grid;
    public PredictionModelingParams.KNNHyperparametersSpace knn_grid;
    public PredictionModelingParams.LightGBMHyperParametersSpace lightgbm_regression_grid;
    public PredictionModelingParams.LightGBMHyperParametersSpace lightgbm_classification_grid;
    public PredictionModelingParams.XGBoostHyperparametersSpace xgboost_grid;
    public PredictionModelingParams.DeepNeuralNetworkHyperParameterSpace deep_neural_network_regression_grid;
    public PredictionModelingParams.DeepNeuralNetworkHyperParameterSpace deep_neural_network_classification_grid;
    public PredictionModelingParams.CustomPythonParams custom_python;
    public PredictionModelingParams.CustomPythonPluginParams plugin_python_grid;
    public PredictionModelingParams.KerasCodeParams keras;
    public PredictionModelingParams.TrivialIdentityTimeseriesSpace trivial_identity_timeseries_grid;
    public PredictionModelingParams.SeasonalNaiveSpace seasonal_naive_timeseries_grid;
    public PredictionModelingParams.AutoArimaSpace autoarima_timeseries_grid;
    public PredictionModelingParams.ArimaSpace arima_grid;
    public PredictionModelingParams.CrostonSpace croston_timeseries_grid;
    public PredictionModelingParams.ETSSpace ets_timeseries_grid;
    public PredictionModelingParams.SeasonalLoessSpace seasonal_loess_timeseries_grid;
    public PredictionModelingParams.ProphetSpace prophet_timeseries_grid;
    public PredictionModelingParams.GluonTSNPTSForecasterSpace gluonts_npts_timeseries_grid;
    public PredictionModelingParams.GluonTSTorchSimpleFeedForwardSpace gluonts_torch_simple_feed_forward_timeseries_grid;
    public PredictionModelingParams.GluonTSTorchDeepARSpace gluonts_torch_deepar_timeseries_grid;
    public PredictionModelingParams.GluonTSMXNetSimpleFeedForwardSpace gluonts_simple_feed_forward_timeseries_grid;
    public PredictionModelingParams.GluonTSMXNetDeepARSpace gluonts_deepar_timeseries_grid;
    public PredictionModelingParams.GluonTSMXNetTransformerSpace gluonts_transformer_timeseries_grid;
    public PredictionModelingParams.GluonTSMXNetMQCNNSpace gluonts_mqcnn_timeseries_grid;
    public PredictionModelingParams.MLLibLogisticRegressionGridParams mllib_logit_grid;
    public PredictionModelingParams.MLLibDecisionTreeGridParams mllib_dt_grid;
    public PredictionModelingParams.MLLibNaiveBayesGridParams mllib_naive_bayes_grid;
    public PredictionModelingParams.MLLibLinearRegressionGridParams mllib_linreg_grid;
    public PredictionModelingParams.MLLibTreesEnsembleGridParams mllib_rf_grid;
    public PredictionModelingParams.MLLibTreesEnsembleGridParams mllib_gbt_grid;
    public PredictionModelingParams.MLLibCustomGridParams custom_mllib_grid;
    public PredictionModelingParams.H2ODeepLearningGridParams deep_learning_sparkling_grid;
    public PredictionModelingParams.H2OGBMGridParams gbm_sparkling_grid;
    public PredictionModelingParams.H2OGLMGridParams glm_sparkling_grid;
    public PredictionModelingParams.H2ORandomForestGridParams rf_sparkling_grid;
    public PredictionModelingParams.H2ONaiveBayesGridParams nb_sparkling_grid;
    public PredictionModelingParams.VerticaLinearRegParams vertica_linreg_grid;
    public PredictionModelingParams.VerticaLogisticRegParams vertica_logit_grid;
    public EnsembleParams ensemble_params;
    public int max_ensemble_nodes_serialized;
    public PredictionModelingParams.CausalForestHyperparameterSpace causal_forest_grid;
    public PredictionModelingParams.CausalLearningMethod causal_method;
    public PredictionModelingParams.CausalMetaLearner meta_learner;
    public static final List<Algorithm> CAUSAL_BASE_LEARNERS = Arrays.asList(Algorithm.RANDOM_FOREST_CLASSIFICATION, Algorithm.RANDOM_FOREST_REGRESSION, Algorithm.EXTRA_TREES, Algorithm.CUSTOM_PLUGIN, Algorithm.SCIKIT_MODEL, Algorithm.RIDGE_REGRESSION, Algorithm.LASSO_REGRESSION, Algorithm.LEASTSQUARE_REGRESSION, Algorithm.LOGISTIC_REGRESSION, Algorithm.SVC_CLASSIFICATION, Algorithm.SVM_REGRESSION, Algorithm.SGD_CLASSIFICATION, Algorithm.SGD_REGRESSION, Algorithm.GBT_CLASSIFICATION, Algorithm.GBT_REGRESSION, Algorithm.DECISION_TREE_CLASSIFICATION, Algorithm.DECISION_TREE_REGRESSION, Algorithm.KNN, Algorithm.NEURAL_NETWORK, Algorithm.LARS, Algorithm.LIGHTGBM_CLASSIFICATION, Algorithm.LIGHTGBM_REGRESSION, Algorithm.XGBOOST_CLASSIFICATION, Algorithm.XGBOOST_REGRESSION, Algorithm.DEEP_NEURAL_NETWORK_CLASSIFICATION, Algorithm.DEEP_NEURAL_NETWORK_REGRESSION);
    public MetricParams metrics;
    public boolean autoOptimizeThreshold;
    public double forcedClassifierThreshold;
    public boolean isShiftWindowsCompatible = false;
    public int gridLength = 1;
    public PredictionModelingParams.GridSearchParams grid_search_params = new PredictionModelingParams.GridSearchParams();
    public boolean pluginAlgoCustomGridSearch = false;
    public boolean computeLearningCurves;
    public boolean skipExpensiveReports;
    public PredictionModelingParams.PropensityModeling propensityModeling;

    public boolean hasProbabilities() {
        return this.algorithm.meta.hasProbabilities(this);
    }

    public PreTrainPredictionModelingParams() {
    }

    public PreTrainPredictionModelingParams(Algorithm algorithm, PredictionModelingParams pmp) {
        this.algorithm = algorithm;
        this.skipExpensiveReports = pmp.skipExpensiveReports;
        this.grid_search_params = (PredictionModelingParams.GridSearchParams)JSON.deepCopy((Object)pmp.gridSearchParams);
        this.propensityModeling = pmp.propensityModeling;
        this.isShiftWindowsCompatible = algorithm.meta.isShiftWindowsCompatible();
    }

    public void setCausalMethodFields(PredictionModelingParams.CausalLearningMethod method, @Nullable PredictionModelingParams.CausalMetaLearner metaLearner) {
        this.causal_method = method;
        this.meta_learner = metaLearner;
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription() {
        ModelTrainInfo.PreSearchDescription desc = this.algorithm.meta.generatePreTrainDescription(this);
        if (this.causal_method != null) {
            desc.withSVParam("Causal learning method", this.causal_method.name());
            if (PredictionModelingParams.CausalLearningMethod.META_LEARNER.equals((Object)this.causal_method)) {
                desc.withSVParam("Meta-learner", this.meta_learner.displayName);
            }
        }
        return desc;
    }

    @Override
    public String generateName() {
        if (PredictionModelingParams.CausalLearningMethod.META_LEARNER.equals((Object)this.causal_method)) {
            return this.meta_learner.displayName + " | " + this.algorithm.meta.generateName(this);
        }
        return this.algorithm.meta.generateName(this);
    }

    @Override
    public String getEvaluationMetricName() {
        if (this.metrics.evaluationMetric == null) {
            return null;
        }
        if (this.metrics.evaluationMetric == MetricParams.EvaluationMetric.CUSTOM) {
            return this.metrics.customEvaluationMetricName;
        }
        return this.metrics.evaluationMetric.toString();
    }

    public Optional<PredictionModelingParams.GridSearchParams.Strategy> checkMaximumSearchStrategy(PredictionModelingParams.GridSearchParams.Strategy currentSearchStrategy) {
        PredictionModelingParams.GridSearchParams.Strategy maxSupportedStrategy = this.algorithm.meta.getMaximumSupportedSearchStrategy();
        if (currentSearchStrategy.ordinal() > maxSupportedStrategy.ordinal()) {
            return Optional.of(maxSupportedStrategy);
        }
        return Optional.empty();
    }

    public static enum Algorithm {
        PYTHON_ENSEMBLE(new EnsembleMeta(true)),
        SPARK_ENSEMBLE(new EnsembleMeta(false)),
        RANDOM_FOREST_CLASSIFICATION(new PyRandomForestMeta(true)),
        EXTRA_TREES(new PyExtraTreesMeta()),
        CUSTOM_PLUGIN(new PyCustomPluginMeta()),
        SCIKIT_MODEL(new CustomScikitMeta()),
        RANDOM_FOREST_REGRESSION(new PyRandomForestMeta(false)),
        RIDGE_REGRESSION(new RidgeMeta()),
        LASSO_REGRESSION(new LassoMeta()),
        LEASTSQUARE_REGRESSION(new OLSMeta()),
        LOGISTIC_REGRESSION(new PyLogisticMeta()),
        SVC_CLASSIFICATION(new SVMMeta(true)),
        SVM_REGRESSION(new SVMMeta(false)),
        SGD_CLASSIFICATION(new SGDMeta(true)),
        SGD_REGRESSION(new SGDMeta(false)),
        GBT_CLASSIFICATION(new PyGradientBoostingMeta(true)),
        GBT_REGRESSION(new PyGradientBoostingMeta(false)),
        DECISION_TREE_CLASSIFICATION(new PyDecisionTreeMeta(true)),
        DECISION_TREE_REGRESSION(new PyDecisionTreeMeta(false)),
        KNN(new KNNMeta()),
        NEURAL_NETWORK(new NeuralNetworkMeta()),
        LARS(new LarsMeta()),
        LIGHTGBM_CLASSIFICATION(new LightGBMMeta(true)),
        LIGHTGBM_REGRESSION(new LightGBMMeta(false)),
        XGBOOST_CLASSIFICATION(new XGBoostMeta(true)),
        XGBOOST_REGRESSION(new XGBoostMeta(false)),
        DEEP_NEURAL_NETWORK_CLASSIFICATION(new DeepNeuralNetworkMeta(true)),
        DEEP_NEURAL_NETWORK_REGRESSION(new DeepNeuralNetworkMeta(false)),
        KERAS_CODE(new KerasMeta()),
        SPARKLING_DEEP_LEARNING(new SparklingDeepLearningMeta()),
        SPARKLING_GBM(new SparklingGBMMeta()),
        SPARKLING_RF(new SparklingRandomForestMeta()),
        SPARKLING_GLM(new SparklingGLMMeta()),
        SPARKLING_NB(new SparklingNBMeta()),
        MLLIB_LOGISTIC_REGRESSION(new MLLibLogisticRegressionMeta()),
        MLLIB_DECISION_TREE(new MLLibDecisionTreeMeta()),
        MLLIB_LINEAR_REGRESSION(new MLLibLinearRegressionMeta()),
        MLLIB_RANDOM_FOREST(new MLLibRandomForestMeta()),
        MLLIB_NAIVE_BAYES(new MLLibNBMeta()),
        MLLIB_GBT(new MLLibGBTMeta()),
        MLLIB_CUSTOM(new MLLibCustomMeta()),
        VERTICA_LINEAR_REGRESSION(new VerticaLinearRegressionMeta()),
        VERTICA_LOGISTIC_REGRESSION(new VerticaLogitMeta()),
        VIRTUAL_MLFLOW_PYFUNC(new VirtualAlgorithmMeta()),
        VIRTUAL_PROXY_MODEL(new VirtualAlgorithmMeta()),
        TRIVIAL_IDENTITY_TIMESERIES(new TrivialIdentityTimeseriesMeta()),
        SEASONAL_NAIVE(new SeasonalNaiveMeta()),
        AUTO_ARIMA(new AutoArimaMeta()),
        ARIMA(new ArimaMeta()),
        CROSTON(new CrostonMeta()),
        ETS(new ETSMeta()),
        SEASONAL_LOESS(new SeasonalLoessMeta()),
        PROPHET(new ProphetMeta()),
        GLUONTS_NPTS_FORECASTER(new GluonTSNPTSForecasterMeta()),
        GLUONTS_TORCH_SIMPLE_FEEDFORWARD(new GluonTSTorchSimpleFeedForwardMeta()),
        GLUONTS_TORCH_DEEPAR(new GluonTSTorchDeepARMeta()),
        GLUONTS_SIMPLE_FEEDFORWARD(new GluonTSMXNetSimpleFeedForwardMeta()),
        GLUONTS_DEEPAR(new GluonTSMXNetDeepARMeta()),
        GLUONTS_TRANSFORMER(new GluonTSMXNetTransformerMeta()),
        GLUONTS_MQCNN(new GluonTSMXNetMQCNNMeta()),
        CAUSAL_FOREST(new CausalForestMeta());

        public final MLTask.BackendType backendType;
        public final PredictionAlgorithmMeta meta;

        private Algorithm(PredictionAlgorithmMeta meta) {
            this.backendType = meta.backendType();
            this.meta = meta;
        }
    }
}

