/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model;

import com.dataiku.dip.analysis.ml.MLDiagnostics;
import com.dataiku.dip.analysis.ml.MLSparkParams;
import com.dataiku.dip.analysis.ml.SparkConstants;
import com.dataiku.dip.analysis.model.MLTaskAdapter;
import com.dataiku.dip.analysis.model.core.GpuConfig;
import com.dataiku.dip.analysis.model.core.ResolvedCoreParams;
import com.dataiku.dip.analysis.model.preprocessing.FeaturePreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.PreprocessingParams;
import com.dataiku.dip.code.CodeEnvModel;
import com.dataiku.dip.code.CodeEnvSelection;
import com.dataiku.dip.containers.exec.ContainerExecSelection;
import com.dataiku.dip.coremodel.SimpleKeyValue;
import com.dataiku.dip.pivot.frontend.model.ChartDef;
import com.dataiku.dip.pivot.frontend.model.CustomMeasure;
import com.dataiku.dip.pivot.frontend.model.DimensionDef;
import com.dataiku.dip.pivot.frontend.model.HierarchyDef;
import com.dataiku.dip.recipes.ParamsWithContainerizable;
import com.dataiku.dip.recipes.ParamsWithSelectableCodeEnv;
import com.dataiku.dip.server.services.TaggableObjectsService;
import com.dataiku.dip.shaker.model.SerializedShakerScript;
import com.dataiku.dip.theming.model.DSSVisualizationTheme;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.warnings.WarningsContext;
import com.google.common.collect.Lists;
import com.google.gson.TypeAdapterFactory;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

public abstract class MLTask
implements ParamsWithSelectableCodeEnv,
ParamsWithContainerizable {
    public String id;
    public String initiator;
    public DiagnosticsSettings diagnosticsSettings = new DiagnosticsSettings();
    public MLTaskType taskType;
    public String name;
    public BackendType backendType = BackendType.PY_MEMORY;
    public int maxConcurrentModelTraining = 2;
    public GpuConfig gpuConfig = new GpuConfig();
    public CodeEnvSelection envSelection = CodeEnvSelection.builtInEnvSelection();
    public ContainerExecSelection containerSelection = new ContainerExecSelection();
    public MLSparkParams sparkParams = new MLSparkParams();
    public SparkConstants.Checkpoint sparkCheckpoint = SparkConstants.Checkpoint.NONE;
    public String sparkCheckpointDir;
    public SerializedShakerScript predictionDisplayScript = new SerializedShakerScript();
    public List<PredictedDataChart> predictionDisplayCharts = new ArrayList<PredictedDataChart>();
    public List<CustomMeasure> customMeasures = new ArrayList<CustomMeasure>();
    public List<DimensionDef> reusableDimensions = new ArrayList<DimensionDef>();
    public List<HierarchyDef> hierarchies = new ArrayList<HierarchyDef>();
    public List<SimpleKeyValue> labels = new ArrayList<SimpleKeyValue>();

    public abstract PreprocessingParams getPreprocessingParams();

    public abstract ResolvedCoreParams buildResolvedCoreParams(String var1) throws IOException;

    @Override
    public CodeEnvSelection getCodeEnvSelection() {
        return this.envSelection;
    }

    @Override
    public void setCodeEnvSelection(CodeEnvSelection envSelection) {
        this.envSelection = envSelection;
    }

    @Override
    public ContainerExecSelection getContainerSelection() {
        return this.containerSelection;
    }

    @Override
    public void setContainerSelection(ContainerExecSelection containerSelection) {
        this.containerSelection = containerSelection;
    }

    @Override
    public List<CodeEnvModel.CodeEnvUsage> collectCodeEnvUsage(TaggableObjectsService.TaggableObject object) {
        ArrayList usages = Lists.newArrayList();
        if (this.backendType.isPythonBased() && this.envSelection.envMode == CodeEnvSelection.EnvMode.EXPLICIT_ENV) {
            usages.add(new CodeEnvModel.CodeEnvUsage(CodeEnvModel.EnvLang.PYTHON, this.envSelection.envName, CodeEnvModel.EnvUsage.MODEL, object.getProjectKey(), object.getId()));
        }
        return usages;
    }

    public Set<String> getUsedConnections() {
        PreprocessingParams preprocessingParams = this.getPreprocessingParams();
        if (preprocessingParams == null) {
            return new HashSet<String>();
        }
        return preprocessingParams.per_feature.values().stream().map(FeaturePreprocessingParams::getUsedConnection).flatMap(Optional::stream).collect(Collectors.toSet());
    }

    public boolean replaceConnections(Map<String, String> replacements) {
        boolean replaced = false;
        for (FeaturePreprocessingParams preprocessingParams : this.getPreprocessingParams().per_feature.values()) {
            if (!preprocessingParams.replaceConnections(replacements)) continue;
            replaced = true;
        }
        return replaced;
    }

    static {
        JSON.registerFactory((TypeAdapterFactory)new MLTaskAdapter());
    }

    public static class DiagnosticsSettings {
        public boolean enabled = true;
        public List<DiagnosticSetting> settings = new ArrayList<DiagnosticSetting>();

        public DiagnosticsSettings() {
            for (MLDiagnostics.DiagnosticsTypes diagnosticsType : MLDiagnostics.DiagnosticsTypes.values()) {
                WarningsContext.WarningType warningType = diagnosticsType.warningType;
                this.settings.add(new DiagnosticSetting(warningType, true));
            }
        }
    }

    public static enum BackendType {
        PY_MEMORY(false, true, true, true),
        MLLIB(true, false, false, false),
        VERTICA(false, false, false, false),
        H2O(true, false, false, false),
        KERAS(false, true, true, false),
        DEEP_HUB(false, true, true, false);

        private final boolean sparkBased;
        private final boolean pythonBased;
        private final boolean supportsDiagnostics;
        private final boolean supportsExplanations;

        public boolean isSparkBased() {
            return this.sparkBased;
        }

        public boolean isPythonBased() {
            return this.pythonBased;
        }

        public boolean supportsDiagnostics() {
            return this.supportsDiagnostics;
        }

        public boolean supportsExplanations() {
            return this.supportsExplanations;
        }

        private BackendType(boolean sparkBased, boolean pythonBased, boolean supportsDiagnostics, boolean supportsExplanation) {
            this.sparkBased = sparkBased;
            this.pythonBased = pythonBased;
            this.supportsDiagnostics = supportsDiagnostics;
            this.supportsExplanations = supportsExplanation;
        }
    }

    public static class PredictedDataChart {
        public ChartDef def;
        public DSSVisualizationTheme theme;
    }

    public static enum MLTaskType {
        PREDICTION,
        CLUSTERING,
        LLM_GENERIC_RAW,
        LLM_GENERIC_PROMPTABLE_COMPLETION,
        LLM_CLASSIFICATION;

    }

    public static class DiagnosticSetting {
        public WarningsContext.WarningType type;
        public boolean enabled;

        public DiagnosticSetting() {
        }

        public DiagnosticSetting(WarningsContext.WarningType type, boolean enabled) {
            this.type = type;
            this.enabled = enabled;
        }
    }
}

