/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml;

import com.dataiku.dip.DSSTempUtils;
import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.prediction.CustomPythonPredictionAlgoService;
import com.dataiku.dip.analysis.model.core.PreTrainModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.exceptions.UnavailableTypeException;
import com.dataiku.dip.plugins.IPluginsRegistryService;
import com.dataiku.dip.plugins.model.PluginDesc;
import com.dataiku.dip.remoterun.RemoteRunFileExchangeService;
import com.dataiku.dip.transactions.fs.NativeFS;
import com.dataiku.dip.transactions.fs.RelFile;
import com.dataiku.dip.transactions.fs.ifaces.ReadWriteFS;
import com.dataiku.dip.transactions.fs.utils.FSUtils;
import com.dataiku.dip.util.AutoDelete;
import com.dataiku.dip.utils.DKUFileUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dss.shadelib.org.apache.commons.io.filefilter.TrueFileFilter;
import jakarta.servlet.http.HttpServletResponse;
import java.io.File;
import java.io.FilenameFilter;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class MLPluginsService {
    @Autowired
    private IPluginsRegistryService pluginsService;
    @Autowired
    private CustomPythonPredictionAlgoService customPyAlgorithmsService;
    private static final String PLUGINS_FOLDER_NAME = "dku-ml-plugins";
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.ml.plugins");

    public void downloadMlPlugins(String pluginId, String elementId, HttpServletResponse resp) throws IOException {
        Map<String, PluginInfo> usedPlugins = this.getUsedPlugins(pluginId, elementId);
        try (AutoDelete file = DSSTempUtils.getTempFolder((String)"ml-plugins-lib");){
            this.copyNecessaryPluginFiles((File)file, usedPlugins);
            resp.setStatus(200);
            resp.setContentType("application/x-tar");
            resp.setHeader("Content-Disposition", "attachment; filename=\"" + file.getName() + "\"");
            RemoteRunFileExchangeService.archiveDirectoryToOutputStream((File)file, (FilenameFilter)TrueFileFilter.INSTANCE, (OutputStream)resp.getOutputStream());
        }
    }

    private Map<String, PluginInfo> getUsedPlugins(String pluginId, String elementId) {
        if (!this.customPyAlgorithmsService.exists(pluginId, elementId)) {
            throw UnavailableTypeException.fromTypeAndPlugin(elementId, pluginId);
        }
        PluginDesc pluginDesc = this.pluginsService.getDesc(pluginId);
        PluginInfo pluginInfo = new PluginInfo(pluginDesc);
        pluginInfo.addCustomAlgo(elementId);
        HashMap<String, PluginInfo> ret = new HashMap<String, PluginInfo>();
        ret.put(pluginId, pluginInfo);
        return ret;
    }

    public Map<String, PluginInfo> getUsedPlugins(PreTrainModelingParams modeling) throws IOException {
        if (modeling instanceof PreTrainPredictionModelingParams) {
            PreTrainPredictionModelingParams predictionModeling = (PreTrainPredictionModelingParams)modeling;
            return this.getUsedPlugins(predictionModeling);
        }
        return new HashMap<String, PluginInfo>();
    }

    public Map<String, PluginInfo> getUsedPlugins(FullModelId fmi) throws IOException {
        if (fmi.isExternalMLflowModelVersion()) {
            return new HashMap<String, PluginInfo>();
        }
        PreTrainModelingParams modeling = fmi.parseModelFile("rmodeling_params.json", PreTrainModelingParams.class);
        return this.getUsedPlugins(modeling);
    }

    private Map<String, PluginInfo> getUsedPlugins(PreTrainPredictionModelingParams modelingParams) {
        switch (modelingParams.algorithm) {
            case CUSTOM_PLUGIN: {
                if (modelingParams.plugin_python_grid == null) {
                    logger.warn((Object)"Cannot retrieve used plugin, 'plugin_python_grid' is undefined");
                    return new HashMap<String, PluginInfo>();
                }
                if (!this.customPyAlgorithmsService.exists(modelingParams.plugin_python_grid.pluginId, modelingParams.plugin_python_grid.elementId)) {
                    throw UnavailableTypeException.fromTypeAndPlugin(modelingParams.plugin_python_grid.elementId, modelingParams.plugin_python_grid.pluginId);
                }
                PluginDesc pluginDesc = this.pluginsService.getDesc(modelingParams.plugin_python_grid.pluginId);
                if (!pluginDesc.version.equals(modelingParams.plugin_python_grid.pluginVersion)) {
                    logger.warn((Object)("Using a plugin algorithm ('" + modelingParams.plugin_python_grid.name + "') for which version has changed. It was created with version '" + modelingParams.plugin_python_grid.pluginVersion + "' and now it's '" + pluginDesc.version + "'"));
                }
                HashMap<String, PluginInfo> ret = new HashMap<String, PluginInfo>();
                PluginInfo pluginInfo = new PluginInfo(pluginDesc);
                pluginInfo.addCustomAlgo(modelingParams.plugin_python_grid.elementId);
                ret.put(modelingParams.plugin_python_grid.pluginId, pluginInfo);
                return ret;
            }
            case PYTHON_ENSEMBLE: {
                return PluginInfo.mergeMapsPluginInfos(modelingParams.ensemble_params.modeling_params.stream().map(this::getUsedPlugins).collect(Collectors.toList()));
            }
        }
        return new HashMap<String, PluginInfo>();
    }

    private static String getMLPluginResourceFolderEnvVar(String pluginId) {
        return "DKU_CUSTOM_ML_RESOURCE_FOLDER_" + pluginId;
    }

    private void copyNecessaryPluginFiles(File workDir, Map<String, PluginInfo> usedPlugins) throws IOException {
        if (!workDir.exists()) {
            DKUFileUtils.mkdirs((File)workDir);
        }
        this.copyNecessaryPluginFiles((ReadWriteFS)NativeFS.from((File)workDir).build(), usedPlugins);
    }

    public void copyNecessaryPluginFiles(ReadWriteFS workDirFS, Map<String, PluginInfo> usedPlugins) throws IOException {
        if (usedPlugins == null || usedPlugins.isEmpty()) {
            return;
        }
        RelFile pluginsFolder = RelFile.fromPath((String)PLUGINS_FOLDER_NAME);
        if (!workDirFS.exists(pluginsFolder)) {
            workDirFS.makeDirectory(pluginsFolder);
            this.createEmptyInitPyFile(workDirFS, pluginsFolder);
        }
        for (PluginInfo pluginInfo : usedPlugins.values()) {
            File resourceDir;
            RelFile pluginFolder = pluginsFolder.append(new String[]{pluginInfo.pluginId});
            if (workDirFS.exists(pluginFolder)) {
                logger.info((Object)("ml plugin files for plugin: " + pluginInfo.pluginId + " have already been copied, not copying them again"));
                continue;
            }
            logger.info((Object)("Adding ml plugin files for plugin: " + pluginInfo.pluginId));
            workDirFS.makeDirectory(pluginFolder);
            this.createEmptyInitPyFile(workDirFS, pluginFolder);
            File libDir = this.pluginsService.getPluginPythonlibFolder(pluginInfo.pluginId);
            if (libDir != null && libDir.isDirectory()) {
                RelFile newLibDir = pluginFolder.append(new String[]{"python-lib"});
                FSUtils.newRecursiveCopy().from(libDir).to(workDirFS, newLibDir).run();
            }
            if ((resourceDir = this.pluginsService.getPluginResourceFolder(pluginInfo.pluginId)) != null && resourceDir.isDirectory()) {
                RelFile newResourceDir = pluginFolder.append(new String[]{"resource"});
                FSUtils.newRecursiveCopy().from(resourceDir).to(workDirFS, newResourceDir).run();
            }
            if (pluginInfo.customAlgos.isEmpty()) continue;
            RelFile customAlgosFolder = pluginFolder.append(new String[]{this.customPyAlgorithmsService.getFolderName()});
            workDirFS.makeDirectory(customAlgosFolder);
            this.createEmptyInitPyFile(workDirFS, customAlgosFolder);
            for (String customAlgo : pluginInfo.customAlgos) {
                RelFile customAlgoFolder = customAlgosFolder.append(new String[]{customAlgo});
                FSUtils.newRecursiveCopy().from(DKUFileUtils.getWithin((File)this.pluginsService.getActualPluginFolder(pluginInfo.pluginId), (String[])new String[]{this.customPyAlgorithmsService.getFolderName(), customAlgo})).to(workDirFS, customAlgoFolder).run();
                this.createEmptyInitPyFile(workDirFS, customAlgoFolder);
            }
        }
    }

    private void createEmptyInitPyFile(ReadWriteFS readWriteFS, RelFile directory) throws IOException {
        readWriteFS.writeStringUTF8(directory.append(new String[]{"__init__.py"}), "");
    }

    public static MLPluginEnvInfo getPluginsLibsPath(File mlPluginsOuterFolder) {
        ArrayList<String> additionalPythonPaths = new ArrayList<String>();
        HashMap<String, String> resourcePaths = new HashMap<String, String>();
        if (mlPluginsOuterFolder.isDirectory()) {
            additionalPythonPaths.add(mlPluginsOuterFolder.getAbsolutePath());
            File pluginsFolder = DKUFileUtils.getWithin((File)mlPluginsOuterFolder, (String[])new String[]{PLUGINS_FOLDER_NAME});
            if (pluginsFolder.isDirectory()) {
                for (File pluginDir : pluginsFolder.listFiles()) {
                    File resourceDir;
                    if (!pluginDir.isDirectory()) continue;
                    File pluginLibDir = DKUFileUtils.getWithin((File)pluginDir, (String[])new String[]{"python-lib"});
                    if (pluginLibDir.isDirectory()) {
                        additionalPythonPaths.add(pluginLibDir.getAbsolutePath());
                    }
                    if (!(resourceDir = DKUFileUtils.getWithin((File)pluginDir, (String[])new String[]{"resource"})).isDirectory()) continue;
                    resourcePaths.put(MLPluginsService.getMLPluginResourceFolderEnvVar(pluginDir.getName()), resourceDir.getAbsolutePath());
                }
            }
        }
        return new MLPluginEnvInfo(additionalPythonPaths, resourcePaths);
    }

    public static class PluginInfo {
        String pluginId;
        String pluginVersion;
        public Set<String> customAlgos = new HashSet<String>();

        public PluginInfo(PluginDesc pluginDesc) {
            this.pluginId = pluginDesc.id;
            this.pluginVersion = pluginDesc.version;
        }

        public PluginInfo(String pluginId, String pluginVersion) {
            this.pluginId = pluginId;
            this.pluginVersion = pluginVersion;
        }

        void addCustomAlgo(String customAlgo) {
            if (StringUtils.isNotBlank((String)customAlgo)) {
                this.customAlgos.add(customAlgo);
            }
        }

        void addCustomAlgos(Set<String> customAlgos) {
            if (customAlgos != null) {
                this.customAlgos.addAll(customAlgos);
            }
        }

        public static PluginInfo merge(PluginInfo pluginInfo1, PluginInfo pluginInfo2) {
            if (pluginInfo1 == null) {
                return pluginInfo2;
            }
            if (pluginInfo2 == null) {
                return pluginInfo1;
            }
            assert (pluginInfo1.pluginId.equals(pluginInfo2.pluginId));
            PluginInfo ret = new PluginInfo(pluginInfo1.pluginId, pluginInfo1.pluginVersion);
            ret.addCustomAlgos(pluginInfo1.customAlgos);
            ret.addCustomAlgos(pluginInfo2.customAlgos);
            return ret;
        }

        static Map<String, PluginInfo> mergeMapsPluginInfos(List<Map<String, PluginInfo>> usedPluginsList) {
            HashMap<String, PluginInfo> ret = new HashMap<String, PluginInfo>();
            if (usedPluginsList == null || usedPluginsList.isEmpty()) {
                return ret;
            }
            for (Map<String, PluginInfo> usedPlugins : usedPluginsList) {
                if (usedPlugins == null) continue;
                for (Map.Entry<String, PluginInfo> pluginInfoEntry : usedPlugins.entrySet()) {
                    String pluginId = pluginInfoEntry.getKey();
                    PluginInfo pluginInfo = pluginInfoEntry.getValue();
                    if (!ret.containsKey(pluginId)) {
                        ret.put(pluginId, pluginInfo);
                        continue;
                    }
                    ret.put(pluginId, PluginInfo.merge(pluginInfo, (PluginInfo)ret.get(pluginId)));
                }
            }
            return ret;
        }
    }

    public static class MLPluginEnvInfo {
        public List<String> additionalPythonPaths;
        public Map<String, String> pluginResourcesEnv;

        MLPluginEnvInfo(List<String> additionalPythonPaths, Map<String, String> pluginResourcesEnv) {
            this.additionalPythonPaths = additionalPythonPaths;
            this.pluginResourcesEnv = pluginResourcesEnv;
        }
    }
}

