/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.coreservices;

import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.analysis.model.ClusteringModelingParams;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.clustering.ClusteringMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.preprocessing.PredictionPreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.PreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.TabularPredictionPreprocessingParams;
import com.dataiku.dip.code.PythonCodeEnvPackagesUtils;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.commons.lang.StringUtils;

public class MLTaskCodeEnvCompatibilityComputer {
    private final MLTask task;
    private final boolean isBayesianWithSkopt;
    private final boolean isTimeseries;
    private final boolean isMxnetTimeseries;
    private final boolean isTorchTimeseries;
    private final boolean isProphetEnabled;
    private final boolean isGluontsNeeded;
    private final boolean isStatsmodelsNeeded;
    private final boolean isStatsforecastNeeded;
    private final boolean isPmdarimaNeeded;
    private final boolean isClassicalMLTs;
    private final boolean isCausalPrediction;
    private final boolean hasTextFeatureWithSentenceEmbedding;
    private final boolean isMLPEnabled;
    private final boolean isHDBScanEnabled;
    private final boolean isMonotonicConstraintsEnabled;
    private final boolean isKeepMissingNumericalAsNaNEnabled;

    public MLTaskCodeEnvCompatibilityComputer(MLTask task) {
        if (ApplicationConfigurator.getNodeType() != ApplicationConfigurator.DSSNodeType.DESIGN) {
            throw new IllegalStateException("Can only compute env compatibility on design node");
        }
        this.task = task;
        this.isTimeseries = task instanceof PredictionMLTask.TimeseriesForecastingMLTask;
        this.isCausalPrediction = task instanceof PredictionMLTask.CausalPredictionMLTask;
        if (task instanceof PredictionMLTask.TabularPredictionMLTask) {
            PredictionMLTask.TabularPredictionMLTask tabularPredictionMlTask = (PredictionMLTask.TabularPredictionMLTask)task;
            PredictionModelingParams modelingParams = tabularPredictionMlTask.modeling;
            PredictionPreprocessingParams preprocessingParams = ((PredictionMLTask)task).getPreprocessingParams();
            PredictionModelingParams.GridSearchParams gridSearchParams = modelingParams != null ? modelingParams.gridSearchParams : null;
            this.isBayesianWithSkopt = gridSearchParams != null && PredictionModelingParams.GridSearchParams.Strategy.BAYESIAN.equals((Object)gridSearchParams.strategy) && PredictionModelingParams.GridSearchParams.BayesianOptimizer.SCIKIT_OPTIMIZE.equals((Object)gridSearchParams.bayesianOptimizer);
            this.isProphetEnabled = modelingParams != null && modelingParams.prophet_timeseries.enabled;
            this.isMLPEnabled = modelingParams != null && (modelingParams.deep_neural_network_regression.enabled || modelingParams.deep_neural_network_classification.enabled);
            this.isMonotonicConstraintsEnabled = preprocessingParams != null && ((TabularPredictionPreprocessingParams)preprocessingParams).isMonotonicConstrainedEnabled();
            this.isKeepMissingNumericalAsNaNEnabled = preprocessingParams != null && ((TabularPredictionPreprocessingParams)preprocessingParams).isKeepMissingAsNaNEnabled();
            this.isMxnetTimeseries = modelingParams != null && (modelingParams.gluonts_deepar_timeseries.enabled || modelingParams.gluonts_mqcnn_timeseries.enabled || modelingParams.gluonts_transformer_timeseries.enabled || modelingParams.gluonts_simple_feed_forward_timeseries.enabled);
            this.isTorchTimeseries = modelingParams != null && (modelingParams.gluonts_torch_deepar_timeseries.enabled || modelingParams.gluonts_torch_simple_feed_forward_timeseries.enabled);
            this.isGluontsNeeded = this.isMxnetTimeseries || this.isTorchTimeseries || modelingParams != null && (modelingParams.gluonts_npts_timeseries.enabled || modelingParams.seasonal_naive_timeseries.enabled || modelingParams.trivial_identity_timeseries.enabled);
            this.isStatsmodelsNeeded = modelingParams != null && (modelingParams.arima_timeseries.enabled || modelingParams.ets_timeseries.enabled || modelingParams.seasonal_loess_timeseries.enabled);
            this.isStatsforecastNeeded = modelingParams != null && modelingParams.croston_timeseries.enabled;
            this.isPmdarimaNeeded = modelingParams != null && modelingParams.autoarima_timeseries.enabled;
            this.isClassicalMLTs = tabularPredictionMlTask != null && tabularPredictionMlTask.predictionType == PredictionMLTask.PredictionType.TIMESERIES_FORECAST && modelingParams != null && (modelingParams.random_forest_regression.enabled || modelingParams.ridge_regression.enabled || modelingParams.xgboost.enabled);
        } else {
            this.isBayesianWithSkopt = false;
            this.isProphetEnabled = false;
            this.isMLPEnabled = false;
            this.isMonotonicConstraintsEnabled = false;
            this.isKeepMissingNumericalAsNaNEnabled = false;
            this.isMxnetTimeseries = false;
            this.isTorchTimeseries = false;
            this.isGluontsNeeded = false;
            this.isStatsmodelsNeeded = false;
            this.isStatsforecastNeeded = false;
            this.isPmdarimaNeeded = false;
            this.isClassicalMLTs = false;
        }
        if (task instanceof ClusteringMLTask) {
            ClusteringMLTask clusteringMLTask = (ClusteringMLTask)task;
            ClusteringModelingParams clusteringModelingParams = clusteringMLTask.modeling;
            this.isHDBScanEnabled = clusteringModelingParams != null && clusteringModelingParams.hdb_scan_clustering.enabled;
        } else {
            this.isHDBScanEnabled = false;
        }
        PreprocessingParams preprocessingParams = task.getPreprocessingParams();
        this.hasTextFeatureWithSentenceEmbedding = preprocessingParams != null && !preprocessingParams.codeEnvSentenceEmbeddedFeaturesAndModels().isEmpty();
    }

    private void addIncompatibilityPackageInformations(Set<String> container, boolean condition, PythonCodeEnvPackagesUtils.CompatibilityInfo compatibilityInfo, @Nullable String packageDescription) {
        if (condition && !compatibilityInfo.compatible) {
            if (packageDescription != null) {
                container.add(packageDescription);
            } else {
                container.addAll(compatibilityInfo.reasons);
            }
        }
    }

    public String getMLTaskDescriptionForCompatibility(PythonCodeEnvPackagesUtils.CodeEnvVisualMLCompat envCompat) {
        LinkedHashSet<String> additionalDescriptionParts = new LinkedHashSet<String>();
        Object description = switch (this.task.backendType) {
            case MLTask.BackendType.KERAS -> "deep learning models";
            case MLTask.BackendType.DEEP_HUB -> "visual computer vision";
            default -> "visual ML";
        };
        if (this.isCausalPrediction) {
            description = "causal visual ML";
        }
        if (this.isTimeseries) {
            description = "time series forecasting models";
        }
        this.addIncompatibilityPackageInformations(additionalDescriptionParts, this.isTorchTimeseries, envCompat.torchTimeseries, "Torch");
        this.addIncompatibilityPackageInformations(additionalDescriptionParts, this.isMxnetTimeseries, envCompat.mxnetTimeseries, "MXNet");
        this.addIncompatibilityPackageInformations(additionalDescriptionParts, this.isProphetEnabled, envCompat.prophet, "prophet");
        this.addIncompatibilityPackageInformations(additionalDescriptionParts, this.isGluontsNeeded, envCompat.gluonts, "gluonts");
        this.addIncompatibilityPackageInformations(additionalDescriptionParts, this.isStatsmodelsNeeded, envCompat.statsmodel, "statsmodels");
        this.addIncompatibilityPackageInformations(additionalDescriptionParts, this.isStatsforecastNeeded, envCompat.statsforecast, "statsforecast");
        this.addIncompatibilityPackageInformations(additionalDescriptionParts, this.isPmdarimaNeeded, envCompat.pmdarima, "pmdarima");
        this.addIncompatibilityPackageInformations(additionalDescriptionParts, this.isClassicalMLTs, envCompat.classicalMLTs, "scikit-learn");
        this.addIncompatibilityPackageInformations(additionalDescriptionParts, this.isBayesianWithSkopt, envCompat.bayesianSearch, "bayesian search");
        this.addIncompatibilityPackageInformations(additionalDescriptionParts, this.hasTextFeatureWithSentenceEmbedding, envCompat.sentenceEmbedding, "code env based text embedding");
        this.addIncompatibilityPackageInformations(additionalDescriptionParts, this.isMLPEnabled, envCompat.deepNeuralNetwork, "deep neural network");
        this.addIncompatibilityPackageInformations(additionalDescriptionParts, this.isHDBScanEnabled, envCompat.hdbscan, "HDBScan");
        this.addIncompatibilityPackageInformations(additionalDescriptionParts, this.isMonotonicConstraintsEnabled, envCompat.monotonicConstraints, "monotonic constraints");
        this.addIncompatibilityPackageInformations(additionalDescriptionParts, this.isKeepMissingNumericalAsNaNEnabled, envCompat.keepMissingNumericalAsNaN, "non-imputed empty numerical input");
        if (!additionalDescriptionParts.isEmpty()) {
            description = (String)description + " using " + MLTaskCodeEnvCompatibilityComputer.reduceAdditionalDescription(additionalDescriptionParts);
        }
        return description;
    }

    private static String reduceAdditionalDescription(Set<String> additionalDescription) {
        if (additionalDescription == null || additionalDescription.isEmpty()) {
            return "";
        }
        if (additionalDescription.size() == 1) {
            return additionalDescription.iterator().next();
        }
        ArrayList<String> sortedList = new ArrayList<String>(additionalDescription);
        int size = sortedList.size();
        return StringUtils.join(sortedList.subList(0, size - 1), (String)", ") + " and " + (String)sortedList.get(size - 1);
    }

    public Set<String> getIncompatibilityReasons_NT(PythonCodeEnvPackagesUtils.CodeEnvVisualMLCompat envCompat) {
        LinkedHashSet<String> incompatibilityReasons = new LinkedHashSet<String>();
        switch (this.task.backendType) {
            case KERAS: {
                this.addIncompatibilityPackageInformations(incompatibilityReasons, true, envCompat.keras, null);
                break;
            }
            case PY_MEMORY: {
                if (this.isTimeseries) {
                    this.addIncompatibilityPackageInformations(incompatibilityReasons, this.isProphetEnabled, envCompat.prophet, null);
                    this.addIncompatibilityPackageInformations(incompatibilityReasons, this.isMxnetTimeseries, envCompat.mxnetTimeseries, null);
                    this.addIncompatibilityPackageInformations(incompatibilityReasons, this.isTorchTimeseries, envCompat.torchTimeseries, null);
                    this.addIncompatibilityPackageInformations(incompatibilityReasons, this.isGluontsNeeded, envCompat.gluonts, null);
                    this.addIncompatibilityPackageInformations(incompatibilityReasons, this.isPmdarimaNeeded, envCompat.pmdarima, null);
                    this.addIncompatibilityPackageInformations(incompatibilityReasons, this.isStatsmodelsNeeded, envCompat.statsmodel, null);
                    this.addIncompatibilityPackageInformations(incompatibilityReasons, this.isStatsforecastNeeded, envCompat.statsforecast, null);
                    this.addIncompatibilityPackageInformations(incompatibilityReasons, this.isClassicalMLTs, envCompat.classicalMLTs, null);
                } else if (this.isCausalPrediction) {
                    this.addIncompatibilityPackageInformations(incompatibilityReasons, true, envCompat.causal, null);
                } else {
                    this.addIncompatibilityPackageInformations(incompatibilityReasons, true, envCompat.regularMl, null);
                    this.addIncompatibilityPackageInformations(incompatibilityReasons, this.isMLPEnabled, envCompat.deepNeuralNetwork, null);
                    this.addIncompatibilityPackageInformations(incompatibilityReasons, this.isMonotonicConstraintsEnabled, envCompat.monotonicConstraints, null);
                    this.addIncompatibilityPackageInformations(incompatibilityReasons, this.isKeepMissingNumericalAsNaNEnabled, envCompat.keepMissingNumericalAsNaN, null);
                }
                this.addIncompatibilityPackageInformations(incompatibilityReasons, this.isBayesianWithSkopt, envCompat.bayesianSearch, null);
                this.addIncompatibilityPackageInformations(incompatibilityReasons, this.isHDBScanEnabled, envCompat.hdbscan, null);
            }
        }
        this.addIncompatibilityPackageInformations(incompatibilityReasons, this.hasTextFeatureWithSentenceEmbedding, envCompat.sentenceEmbedding, null);
        return incompatibilityReasons;
    }
}

