/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.code;

import com.dataiku.dip.DKUApp;
import com.dataiku.dip.code.CodeEnvModel;
import com.dataiku.dip.code.DSSInternalCodeEnvsService;
import com.dataiku.dip.code.StandardPythonInterpreter;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.services.licensing.LicenseEnforcementService;
import com.dataiku.dip.utils.DKUFileUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.DKUtils;
import com.google.common.annotations.VisibleForTesting;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

public class DesignNodeCodeEnvPackagePresets {
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.codeenvs.presets");

    public static String updatePackagesRequirements(String[] packages, StandardPythonInterpreter interpreter) {
        return Arrays.stream(packages).map(interpreter::updateRequirement).collect(Collectors.joining("\n"));
    }

    private static String tensorflowPresetDesc(StandardPythonInterpreter interpreter) {
        boolean canUseMacGPU;
        boolean bl = canUseMacGPU = DKUtils.isOnMacSilicon() && Boolean.parseBoolean(DKUApp.getProperty((String)"dku.codeenv.arm.supported", (String)"false"));
        if (interpreter == StandardPythonInterpreter.PYTHON27 || DKUtils.isOsMacOS() && !canUseMacGPU) {
            return "Visual Deep Learning: Tensorflow (CPU)";
        }
        if (canUseMacGPU) {
            return "Visual Deep Learning: Tensorflow (GPU - Mac Metal)";
        }
        return "Visual Deep Learning: Tensorflow (GPU)";
    }

    private static String doctorPresetDesc(StandardPythonInterpreter interpreter) {
        switch (interpreter) {
            case PYTHON27: 
            case PYTHON35: 
            case PYTHON34: {
                return "Visual ML";
            }
            case PYTHON36: {
                return "Visual ML and Timeseries Forecasting (CPU)";
            }
        }
        return "Visual ML, Causal ML, and Timeseries Forecasting (CPU)";
    }

    @VisibleForTesting
    static CodeEnvPackagePreset getLocalHFPreset() {
        return DesignNodeCodeEnvPackagePresets.getPresetFromInternalCodeEnvResources(CodeEnvPreset.LOCAL_HF, "Local Hugging Face models for LLM Mesh", DSSInternalCodeEnvsService.DSSInternalCodeEnvType.HUGGINGFACE_LOCAL_CODE_ENV);
    }

    @VisibleForTesting
    static CodeEnvPackagePreset getTextEmbeddingExtractionPreset() {
        return DesignNodeCodeEnvPackagePresets.getPresetFromInternalCodeEnvResources(CodeEnvPreset.TEXT_EMBEDDING_EXTRACTION, "RAG and Agents", DSSInternalCodeEnvsService.DSSInternalCodeEnvType.RAG_CODE_ENV);
    }

    private static CodeEnvPackagePreset getPresetFromInternalCodeEnvResources(CodeEnvPreset presetType, String description, DSSInternalCodeEnvsService.DSSInternalCodeEnvType internalCodeEnvType) {
        CodeEnvPackagePreset preset = new CodeEnvPackagePreset(presetType, description);
        DSSInternalCodeEnvsService.DSSInternalCodeEnv internalCodeEnv = new DSSInternalCodeEnvsService.DSSInternalCodeEnv(internalCodeEnvType);
        File requirements = DKUFileUtils.getWithin((File)internalCodeEnv.getResourceSpecFolder(), (String[])new String[]{"spec", CodeEnvModel.EnvLang.PYTHON.getPackageFileName()});
        try (BufferedReader reader = new BufferedReader(new FileReader(requirements));){
            String line;
            while ((line = reader.readLine()) != null) {
                if ((line = line.replaceFirst("#.*$", "").trim()).isBlank()) continue;
                preset.packages.add(line);
            }
        }
        catch (IOException e) {
            throw new RuntimeException(String.format("Failed to read requirements for preset %s from resources of internal code env %s", presetType.name(), internalCodeEnv.getCodeEnvName()), e);
        }
        return preset;
    }

    public static List<CodeEnvPackagePreset> forPython(StandardPythonInterpreter interpreter, boolean useConda, String corePackageSet) {
        if (interpreter.isVersionGreaterOrEqual(StandardPythonInterpreter.PYTHON314)) {
            return new ArrayList<CodeEnvPackagePreset>();
        }
        ArrayList<String> corePackages = new ArrayList<String>();
        Collections.addAll(corePackages, interpreter.markupSafeRequirement(), interpreter.jinja2Requirement(), interpreter.cloudpickleRequirement(), interpreter.flaskRequirement(), interpreter.itsdangerousRequirement, interpreter.lightGbmRequirement(), interpreter.scikitLearnRequirement(), interpreter.scikitOptRequirement(), interpreter.scipyRequirement(), interpreter.statsmodelsRequirement(), interpreter.werkzeugRequirement, interpreter.xgboostRequirement());
        interpreter.statsforecastRequirement().ifPresent(corePackages::add);
        interpreter.pyamlPackage().ifPresent(corePackages::add);
        if (interpreter.isVersionGreaterOrEqual(StandardPythonInterpreter.PYTHON37)) {
            corePackages.addAll(Arrays.asList(interpreter.tdigestRequirement(), interpreter.econmlRequirement()));
        }
        ArrayList<String> doctorBasePackages = new ArrayList<String>(corePackages);
        ArrayList<String> mxNetBasePackages = new ArrayList<String>();
        mxNetBasePackages.add("patsy<1.0.2");
        if (interpreter.isVersionGreaterOrEqual(StandardPythonInterpreter.PYTHON36)) {
            if (!corePackageSet.equals("LEGACY_PANDAS023")) {
                doctorBasePackages.add(interpreter.gluontsRequirement());
            }
            if (interpreter.isVersionGreaterOrEqual(StandardPythonInterpreter.PYTHON37) && !DKUtils.isOsWindows()) {
                doctorBasePackages.add(interpreter.cmdstanpyRequirement);
                doctorBasePackages.add(interpreter.prophetRequirement);
            }
            if (!interpreter.isVersionGreaterOrEqual(StandardPythonInterpreter.PYTHON313)) {
                doctorBasePackages.add(interpreter.pmdarimaRequirement());
                mxNetBasePackages.add(interpreter.numpy1Requirement());
                mxNetBasePackages.add(interpreter.mxnetRequirement());
            }
        }
        CodeEnvPackagePreset doctor = new CodeEnvPackagePreset(CodeEnvPreset.DOCTOR, DesignNodeCodeEnvPackagePresets.doctorPresetDesc(interpreter));
        doctor.packages.addAll(doctorBasePackages);
        doctor.packages.addAll(mxNetBasePackages);
        doctor.packages.addAll(interpreter.cpuTorchWithLinkRequirement());
        CodeEnvPackagePreset doctorTimeseriesGPU = new CodeEnvPackagePreset(CodeEnvPreset.DOCTOR_TIMESERIES_GPU_CUDA_LATEST, "Visual Machine Learning and Timeseries Forecasting (GPU)");
        doctorTimeseriesGPU.packages.addAll(doctorBasePackages);
        ArrayList<String> doctorDeepLearningBasePackages = new ArrayList<String>();
        doctorDeepLearningBasePackages.addAll(corePackages);
        Collections.addAll(doctorDeepLearningBasePackages, interpreter.h5pyRequirement(), interpreter.pillowRequirement());
        CodeEnvPackagePreset doctorDeepLearning = new CodeEnvPackagePreset(CodeEnvPreset.DOCTOR_DL, DesignNodeCodeEnvPackagePresets.tensorflowPresetDesc(interpreter));
        doctorDeepLearning.packages.addAll(doctorDeepLearningBasePackages);
        doctorDeepLearning.packages.addAll(interpreter.tensorflowSupportRequirements());
        CodeEnvPackagePreset doctorSentenceEmbedding = new CodeEnvPackagePreset(CodeEnvPreset.DOCTOR_SENTENCE_EMBEDDING, "[Deprecated] Visual Machine Learning with Sentence Embedding (GPU)");
        doctorSentenceEmbedding.packages.addAll(doctorBasePackages);
        doctorSentenceEmbedding.packages.add(interpreter.sentenceTransformersRequirement());
        doctorSentenceEmbedding.packages.add(interpreter.tokenizersRequirement());
        if (interpreter == StandardPythonInterpreter.PYTHON37) {
            doctorSentenceEmbedding.packages.add("safetensors<0.5");
        }
        CodeEnvPackagePreset doctorTimeseriesGPUCuda112 = new CodeEnvPackagePreset(CodeEnvPreset.DOCTOR_TIMESERIES_GPU_CUDA112, "[Deprecated] Visual Machine Learning and Timeseries Forecasting (GPU - CUDA 11)");
        doctorTimeseriesGPUCuda112.packages.addAll(doctorBasePackages);
        doctorTimeseriesGPUCuda112.packages.addAll(mxNetBasePackages.stream().map(pkg -> pkg.startsWith("mxnet") ? pkg.replace("mxnet", "mxnet-cu112") : pkg).toList());
        CodeEnvPackagePreset doctorCausalPredictions = new CodeEnvPackagePreset(CodeEnvPreset.DOCTOR_CAUSAL, "Visual Causal Machine Learning");
        doctorCausalPredictions.packages.addAll(doctorBasePackages);
        if (interpreter == StandardPythonInterpreter.PYTHON36) {
            doctorCausalPredictions.packages.addAll(Arrays.asList(interpreter.tdigestRequirement(), interpreter.econmlRequirement()));
        }
        doctorCausalPredictions.packages.addAll(interpreter.cpuTorchWithLinkRequirement());
        CodeEnvPackagePreset streaming = new CodeEnvPackagePreset(CodeEnvPreset.STREAMING, "Native Streaming Access (Kafka & HTTP SSE)");
        if (useConda) {
            streaming.packages.add("pykafka>=2.8.0,<2.9");
            streaming.packages.add("sseclient>=0.0.26,<0.1");
        } else {
            streaming.packages.add("pykafka==2.8.0");
            streaming.packages.add("sseclient==0.0.26");
        }
        CodeEnvPackagePreset deepNeuralNetwork = new CodeEnvPackagePreset(CodeEnvPreset.DOCTOR_DEEP_NEURAL_NETWORK, "Visual Machine Learning with Deep Neural Network (GPU)");
        deepNeuralNetwork.packages.addAll(doctorBasePackages);
        deepNeuralNetwork.packages.add(interpreter.skorchRequirement());
        deepNeuralNetwork.packages.addAll(interpreter.gpuTorchWithLinkRequirement());
        ArrayList<CodeEnvPackagePreset> basePresets = new ArrayList<CodeEnvPackagePreset>();
        Collections.addAll(basePresets, doctor);
        if (interpreter.isVersionGreaterOrEqual(StandardPythonInterpreter.PYTHON36) && !DKUtils.isOsMacOS() && !DKUtils.isOsWindows()) {
            Collections.addAll(basePresets, doctorTimeseriesGPU);
        }
        if (DSSInternalCodeEnvsService.getSupportedInterpreters(DSSInternalCodeEnvsService.DSSInternalCodeEnvType.HUGGINGFACE_LOCAL_CODE_ENV).contains((Object)interpreter)) {
            try {
                basePresets.add(DesignNodeCodeEnvPackagePresets.getLocalHFPreset());
            }
            catch (Exception e) {
                logger.warn((Object)"Failed to get local HF preset", (Throwable)e);
            }
        }
        if (DSSInternalCodeEnvsService.getSupportedInterpreters(DSSInternalCodeEnvsService.DSSInternalCodeEnvType.RAG_CODE_ENV).contains((Object)interpreter)) {
            try {
                basePresets.add(DesignNodeCodeEnvPackagePresets.getTextEmbeddingExtractionPreset());
            }
            catch (Exception e) {
                logger.warn((Object)"Failed to get text embedding extraction preset", (Throwable)e);
            }
        }
        CodeEnvPackagePreset llmEvaluation = new CodeEnvPackagePreset(CodeEnvPreset.AGENT_AND_LLM_EVALUATION, "Agent and LLM Evaluation");
        llmEvaluation.packages.add(interpreter.ragasRequirement());
        llmEvaluation.packages.add(interpreter.langchainRequirement());
        if (!interpreter.isVersionGreaterOrEqual(StandardPythonInterpreter.PYTHON313)) {
            llmEvaluation.packages.add(interpreter.numpy1Requirement());
        }
        llmEvaluation.packages.add("bert-score");
        if (DKUtils.isOsMacOS()) {
            if (interpreter.isVersionLowerOrEqual(StandardPythonInterpreter.PYTHON38)) {
                llmEvaluation.packages.add("torch==2.2.2");
            } else {
                llmEvaluation.packages.add("torch==2.6.0");
            }
        } else if (interpreter.isVersionLowerOrEqual(StandardPythonInterpreter.PYTHON38)) {
            llmEvaluation.packages.add("torch==2.3.1+cpu");
        } else {
            llmEvaluation.packages.add("torch==2.6.0+cpu");
        }
        llmEvaluation.packages.add(interpreter.cpuTorchLinkRequirement);
        llmEvaluation.packages.add("sacrebleu");
        llmEvaluation.packages.add("rouge-score");
        llmEvaluation.packages.add(interpreter.pydanticRequirement());
        llmEvaluation.packages.add(interpreter.scipyRequirement());
        llmEvaluation.packages.add(interpreter.scikitLearnRequirement());
        if (interpreter.isVersionGreaterOrEqual(StandardPythonInterpreter.PYTHON38) && interpreter.isVersionLowerOrEqual(StandardPythonInterpreter.PYTHON313) && DesignNodeCodeEnvPackagePresets.licenseAllowsLLMEvaluation()) {
            basePresets.add(llmEvaluation);
        }
        if (interpreter == StandardPythonInterpreter.PYTHON36) {
            basePresets.add(doctorCausalPredictions);
        }
        if (interpreter.isVersionGreaterOrEqual(StandardPythonInterpreter.PYTHON38)) {
            basePresets.add(deepNeuralNetwork);
        }
        basePresets.add(streaming);
        if (interpreter.isVersionGreaterOrEqual(StandardPythonInterpreter.PYTHON38)) {
            Collections.addAll(basePresets, doctorDeepLearning);
        }
        if (interpreter == StandardPythonInterpreter.PYTHON27 && !DKUtils.isOsMacOS() && !DKUtils.isOsWindows()) {
            CodeEnvPackagePreset doctorDeepLearningGPU = new CodeEnvPackagePreset(CodeEnvPreset.DOCTOR_DL_GPU, "Visual Deep Learning: Tensorflow (GPU - CUDA 10.0 / cuDNN 7.0)");
            doctorDeepLearningGPU.packages.addAll(doctorDeepLearningBasePackages);
            doctorDeepLearningGPU.packages.addAll(interpreter.tensorflowSupportRequirements());
            doctorDeepLearningGPU.packages = doctorDeepLearningGPU.packages.stream().map(pkg -> pkg.startsWith("tensorflow") ? pkg.replace("tensorflow", "tensorflow-gpu") : pkg).collect(Collectors.toList());
            basePresets.add(doctorDeepLearningGPU);
        }
        if (interpreter.isVersionGreaterOrEqual(StandardPythonInterpreter.PYTHON36)) {
            if (!interpreter.isVersionGreaterOrEqual(StandardPythonInterpreter.PYTHON313) && interpreter.isVersionGreaterOrEqual(StandardPythonInterpreter.PYTHON38)) {
                Collections.addAll(basePresets, doctorSentenceEmbedding);
            }
            if (!DKUtils.isOsMacOS() && !DKUtils.isOsWindows()) {
                Collections.addAll(basePresets, doctorTimeseriesGPUCuda112);
            }
        }
        return basePresets;
    }

    protected static boolean licenseAllowsLLMEvaluation() {
        String presetGeneration = System.getenv("PRESET_GENERATION");
        if (presetGeneration != null && presetGeneration.equals("1")) {
            return true;
        }
        LicenseEnforcementService licenseEnforcementService = (LicenseEnforcementService)SpringUtils.getBean(LicenseEnforcementService.class);
        return licenseEnforcementService.getFeaturesStatus().advancedLLMMeshAllowed;
    }

    public static enum CodeEnvPreset {
        DOCTOR(2, 0),
        DOCTOR_DL(2, 0),
        DOCTOR_SENTENCE_EMBEDDING(2, 1),
        DOCTOR_TIMESERIES_GPU_CUDA112(2, 0),
        DOCTOR_TIMESERIES_GPU_CUDA_LATEST(1, 0),
        DOCTOR_CAUSAL(2, 0),
        STREAMING(1, 0),
        DOCTOR_DEEP_NEURAL_NETWORK(2, 0),
        DOCTOR_DL_GPU(2, 0),
        LOCAL_HF(1, 12),
        TEXT_EMBEDDING_EXTRACTION(1, 8),
        AGENT_AND_LLM_EVALUATION(1, 2);

        public final int major;
        public final int minor;

        private CodeEnvPreset(int major, int minor) {
            this.major = major;
            this.minor = minor;
        }
    }

    public static class CodeEnvPackagePreset {
        String id;
        String description;
        int major;
        int minor;
        List<String> packages;

        public CodeEnvPackagePreset(CodeEnvPreset preset, String description, List<String> packages) {
            this.id = preset.name();
            this.description = description;
            this.major = preset.major;
            this.minor = preset.minor;
            this.packages = packages;
        }

        public CodeEnvPackagePreset(CodeEnvPreset preset, String description) {
            this(preset, description, new ArrayList<String>());
        }

        public String getId() {
            return this.id;
        }

        public String getVersion() {
            return String.format("%s.%s", this.major, this.minor);
        }

        public String getRequirementsString() {
            return String.join((CharSequence)"\n", this.packages);
        }
    }
}

