package com.dataiku.dip.plugins.cpr;

import java.io.IOException;
import java.lang.reflect.Field; 
import java.util.Arrays;
import java.util.HashSet;
import java.util.regex.Pattern;
import java.util.Set;
import java.util.stream.Collectors;

import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;

import com.dataiku.dip.connections.AbstractSQLConnection.CustomDatabaseProperty;
import com.dataiku.dip.connections.ConnectionsDAO;
import com.dataiku.dip.connections.DSSConnection;
import com.dataiku.dip.connections.DSSConnection.ConnectionUsableBy;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.ExposedObject;
import com.dataiku.dip.coremodel.SerializedDataset;
import com.dataiku.dip.coremodel.SerializedProject;
import com.dataiku.dip.coremodel.SerializedProject.PermissionItem;
import com.dataiku.dip.coremodel.SerializedProject.ProjectAppType;
import com.dataiku.dip.cuspol.CustomPolicyHooks;
import com.dataiku.dip.dao.DatasetsDAO;
import com.dataiku.dip.dao.SQLNotebooksDAO;
import com.dataiku.dip.dao.StreamingEndpointsDAO;
import com.dataiku.dip.dao.UsersDAO;
import com.dataiku.dip.dao.UsersDAO.User;
import com.dataiku.dip.datasets.fs.BuiltinFSDatasets.UploadedFilesConfig;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.futures.SimpleFutureThread;
import com.dataiku.dip.managedfolder.ManagedFolder;
import com.dataiku.dip.managedfolder.ManagedFolderDAO;
import com.dataiku.dip.plugins.RegularPluginsRegistryService;
import com.dataiku.dip.projects.apps.AppsService;
import com.dataiku.dip.projects.apps.AppsService.AppOrigin;
import com.dataiku.dip.scheduler.ScenariosDAO;
import com.dataiku.dip.scheduler.scenarios.Scenario;
import com.dataiku.dip.scheduler.scenarios.StepBasedScenarioRunner.StepBasedScenarioParams;
import com.dataiku.dip.scheduler.steps.ExecuteSQLStepRunner;
import com.dataiku.dip.scheduler.steps.ExecuteSQLStepRunner.ExecuteSQLStepParams;
import com.dataiku.dip.scheduler.steps.Step;
import com.dataiku.dip.scheduler.triggers.SQLQueryTriggerRunner;
import com.dataiku.dip.scheduler.triggers.SQLQueryTriggerRunner.SQLQueryTriggerParams;
import com.dataiku.dip.scheduler.triggers.Trigger;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.datasets.DatasetSaveService.DatasetCreationContext;
import com.dataiku.dip.server.services.ProjectsService;
import com.dataiku.dip.server.services.TaggableObjectsService.TaggableObject;
import com.dataiku.dip.sqlnotebooks.SQLNotebook;
import com.dataiku.dip.streaming.endpoints.model.StreamingEndpoint;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.google.gson.JsonObject;

public class ConnectionsProjectsRestrictionsHooks extends CustomPolicyHooks {
    @Autowired private ProjectsService projectsService;
    @Autowired private RegularPluginsRegistryService regularPluginsRegistryService;
    @Autowired private ConnectionsDAO connectionsDAO;
    @Autowired private UsersDAO usersDAO;

    /* The list of objects that are checked in a project are:
     *   - Dataset
     *   - Managed folder
     *   - Streaming endpoint
     *   - Scenario
     *   - SQL Notebook
     *   
     * Check that in all parts of the code where we iterate over types, we handle all of them
     */
    @Autowired private DatasetsDAO datasetsDAO;
    @Autowired private ManagedFolderDAO managedFoldersDAO;
    @Autowired private StreamingEndpointsDAO streamingEndpointsDAO;
    @Autowired private SQLNotebooksDAO sqlNotebooksDAO;
    @Autowired private ScenariosDAO scenariosDAO;

    static enum Mode {
        ONLY_IF_FREELY_USABLE_BY_ALL,
        EXPLICIT_PROJECTS_LIST_PER_CONNECTION
    }

    private boolean isConnectionFreelyUsableByProjectPermissionItem(DSSConnection conn, PermissionItem pi) throws IOException {
        if (conn.usableBy == ConnectionUsableBy.ALL) {
            return true;
        } else {
            if (StringUtils.isNotBlank(pi.group)) {
                return conn.allowedGroups.contains(pi.group);
            } else if (StringUtils.isNotBlank(pi.user)) {
                User u = usersDAO.getOrNullUnsafe(pi.user);
                if (u != null) {
                    for (String allowedGroup : conn.allowedGroups) {
                        if (u.groups.contains(allowedGroup)) {
                            return true;
                        }
                    }
                    return false;
                } else {
                    logger.warn("User not found: " + pi.user);
                    return true; // Default allow 
                }
            } else {
                logger.warn("Invalid permission item: no user nor group");
                return true; // Default allow
            }
        }
    }

    private boolean isConnectionCompliant(AuthCtx authCtx, SerializedProject sp, JsonObject pluginSettings, String connectionName) throws IOException, DKUSecurityException {
        DSSConnection conn = connectionsDAO.getConnection(authCtx, connectionName);
        
        if (conn == null) {
            logger.info("Connection " + connectionName + " does not exist, nothing to check");
            return true;
        }

        Mode mode = Mode.ONLY_IF_FREELY_USABLE_BY_ALL;
        if (pluginSettings.has("restriction_mode")) {
            mode = Mode.valueOf(pluginSettings.get("restriction_mode").getAsString());
        }

        switch (mode) {
        case ONLY_IF_FREELY_USABLE_BY_ALL: {
            for (PermissionItem pi: sp.permissions) {
                if (!isConnectionFreelyUsableByProjectPermissionItem(conn, pi)) {
                    logger.info("Connection " + connectionName + " is not usable by project permission item: " + JSON.json(pi));
                    return false;
                }
            }
            return true;
        }
        case EXPLICIT_PROJECTS_LIST_PER_CONNECTION: {
            ProjectAppType projectAppType = sp.projectAppType;
            String generatingAppId = StringUtils.defaultIfEmpty(sp.generatingAppId, "");
            if (projectAppType == ProjectAppType.APP_TEMPLATE) {
                // might be an app getting instantiated. But the first save of the new project
                // is done with the copy of the app template (with a different name) so you have
                // to fight a bit to get the generating appId which is not yet set (the postSaveHook
                // is what fills sp.generatingAppId)
                Thread currentThread = Thread.currentThread();
                if (currentThread instanceof SimpleFutureThread) {
                    // there's a chance it's an instantiation, dig further
                    if ("AppInstantiationFutureThread".equals(currentThread.getClass().getSimpleName())) {
                        try {
                            // use reflection because both class and field are private
                            Field appIdField = currentThread.getClass().getDeclaredField("appId");
                            appIdField.setAccessible(true);
                            generatingAppId = (String) appIdField.get(currentThread);
                            projectAppType = ProjectAppType.APP_INSTANCE;
                        } catch (Throwable t) {
                            logger.warn("Unable to check if current project is an app instance because " + t.getMessage());
                        }
                    }
                }
            }
            AppOrigin appOrigin = null;
            if (StringUtils.isNotBlank(generatingAppId)) {
                try {
                    appOrigin = AppOrigin.fromAppId(generatingAppId);
                } catch (Exception e) {
                    logger.warn("Unable to find app origin of " + generatingAppId);
                }
            }
            String projectKeyToCheck;
            String projectKeyInMessage;
            String allowedProp;
            if (projectAppType == ProjectAppType.APP_INSTANCE && appOrigin != null) {
                String templateProjectKey;
                if (appOrigin == AppOrigin.PROJECT) {
                    projectKeyToCheck = AppsService.getProjectKey(generatingAppId);
                } else {
                    // The actual value of the appId is
                    // "PLUGIN_" + pluginId + "_" + elementId 
                    // as seen in CustomAppTemplatesService.
                    // It can't readily be simplified, so let's use it for
                    // values to put in dku.security.allowedInApps.
                    projectKeyToCheck = generatingAppId;
                }
                projectKeyInMessage = String.format("%s which is an instance of %s ", sp.projectKey, projectKeyToCheck);
                allowedProp = CustomDatabaseProperty.getDkuPropertyOrDefault(conn.getDkuProperties(), "dku.security.allowedInApps", "");
            } else {
                projectKeyToCheck = sp.projectKey;
                projectKeyInMessage = sp.projectKey;
                allowedProp = CustomDatabaseProperty.getDkuPropertyOrDefault(conn.getDkuProperties(), "dku.security.allowedInProjects", "");
            }
            Set<String> allowed = Arrays.stream(allowedProp.split(","))
                    .map(String::trim).filter(s -> !StringUtils.isBlank(s))
                    .collect(Collectors.toSet());
            // using the values in allowed to make regexes is possible because the values are expected to be 
            // projectKey-like, and projectKeys only have [a-zA-Z0-9_], ie no character that could have a meaning in a regex
            if (allowed.stream().anyMatch(s -> Pattern.matches(s.replace("*", ".*"), projectKeyToCheck))) {
                return true;
            } else {
                logger.info("Connection " + connectionName + " is not allowed for project " + projectKeyInMessage + " (allowed: " + allowedProp + ")");
                return false;
            }
        }
        }
        throw new Error("unreachable");
    }

    private boolean isDatasetCompliant(AuthCtx authCtx, SerializedProject sp, JsonObject pluginSettings, SerializedDataset sd) throws IOException, DKUSecurityException {
        logger.info("Checking compliance of dataset " + sd.getFullName() + " in project  " + sp.projectKey);
        Dataset dataset = Dataset.fromSerialized(sd);
        Set<String> connectionsToCheck = new HashSet<>();

        String mainConnection =dataset.getParams().getConnection();
        if (mainConnection != null) {
            connectionsToCheck.add(mainConnection);
        }

        if ("UploadedFiles".equals(dataset.getType())) {
            UploadedFilesConfig upCfg = (UploadedFilesConfig) dataset.getParams();
            if (upCfg.uploadConnection != null) {
                connectionsToCheck.add(upCfg.uploadConnection);
            }
        }

        for (String connection: connectionsToCheck) {
            if (!isConnectionCompliant(authCtx, sp, pluginSettings,  connection)) {
                logger.info("Dataset " + sd.name + " is not compliant because connection " + connection+ " is not");
                return false;
            }
        }
        return true;
    }

    private boolean isManagedFolderCompliant(AuthCtx authCtx, SerializedProject sp, JsonObject pluginSettings, ManagedFolder mf) throws IOException, DKUSecurityException {
        if (mf.getParams().connection != null) {
            if (!isConnectionCompliant(authCtx, sp, pluginSettings, mf.getParams().connection)) {
                logger.info("Managed Folder " + mf.id + " is not compliant because connection " +  mf.getParams().connection+ " is not");
                return false;
            }
        }
        return true;
    }

    private boolean isStreamingEndpointCompliant(AuthCtx authCtx, SerializedProject sp, JsonObject pluginSettings, StreamingEndpoint se) throws IOException, DKUSecurityException {
        if (se.getParams().getConnection() != null) {
            if (!isConnectionCompliant(authCtx, sp, pluginSettings, se.getParams().getConnection())) {
                logger.info("Managed Folder " + se.id + " is not compliant because connection " + se.getParams().getConnection() + " is not");
                return false;
            }
        }
        return true;
    }

    private boolean isSQLNotebookCompliant(AuthCtx authCtx, SerializedProject sp, JsonObject pluginSettings, SQLNotebook notebook)throws IOException, DKUSecurityException {
        if (!isConnectionCompliant(authCtx, sp, pluginSettings, notebook.connection)) {
            logger.info("SQL Notebook " + notebook.id + " is not compliant because connection " + notebook.connection+ " is not");
            return false;
        }
        return true;
    }

    private boolean isScenarioCompliant(AuthCtx authCtx, SerializedProject sp, JsonObject pluginSettings, Scenario scenario) throws IOException, DKUSecurityException {
        for (Trigger trigger : scenario.getTriggers()) {
            if (SQLQueryTriggerRunner.META.getType().equals(trigger.getType())) {
                SQLQueryTriggerParams sqtp = trigger.getParamsAs(SQLQueryTriggerParams.class);
                if (!isConnectionCompliant(authCtx, sp, pluginSettings, sqtp.connection)) {
                    logger.info("Scenario " + scenario.name + " is not compliant because connection " + sqtp.connection+ " is not");
                    return false;
                }
            }
        }

        if (scenario.getParams() instanceof StepBasedScenarioParams) {
            StepBasedScenarioParams sbsp = scenario.getParamsAs(StepBasedScenarioParams.class);

            for (Step step : sbsp.steps) {
                if (ExecuteSQLStepRunner.META.getType().equals(step.getType())) {
                    ExecuteSQLStepParams essp = step.getParamsAs(ExecuteSQLStepParams.class);
                    if (!isConnectionCompliant(authCtx, sp, pluginSettings, essp.connection)) {
                        logger.info("Scenario " + scenario.name + " is not compliant because connection " + essp.connection+ " is not");
                        return false;
                    }
                }
            }
        }

        return true;
    }

    private void checkComplianceOfEachExposedObject(AuthCtx authCtx, SerializedProject sourceProject, JsonObject pluginSettings) throws Exception {
        if (sourceProject.exposedObjects == null) {
            return;
        }

        for (ExposedObject eo: sourceProject.exposedObjects.objects) {
            for (ExposedObject.Rule rule: eo.rules) {
                if (StringUtils.isNotBlank(rule.targetProject)) {
                    SerializedProject targetProject = projectsService.getOrNullUnsafe(rule.targetProject);
                    if (targetProject != null) {
                        switch (eo.type) {
                        case DATASET: {
                            SerializedDataset sd = datasetsDAO.getOrNull(sourceProject.projectKey, eo.localName);
                            if (sd != null && !isDatasetCompliant(authCtx, targetProject, pluginSettings, sd)) {
                                throw new Exception("Save denied: project " + sourceProject.projectKey +  " shares dataset " + eo.localName + " to project " + rule.targetProject+ ", but this target project may not use the connection of this dataset");
                            }
                            break;
                        }
                        case MANAGED_FOLDER: {
                            ManagedFolder mf = managedFoldersDAO.getOrNull(sourceProject.projectKey, eo.localName);
                            if (mf != null && !isManagedFolderCompliant(authCtx, targetProject, pluginSettings, mf)) {
                                throw new Exception("Save denied: project " + sourceProject.projectKey +  " shares folder " + eo.localName + " to project " + rule.targetProject+ ", but this target project may not use the connection of this dataset");
                            }
                            break;
                        }
                        case STREAMING_ENDPOINT: {
                            StreamingEndpoint se = streamingEndpointsDAO.getOrNull(sourceProject.projectKey, eo.localName);
                            if (se != null && !isStreamingEndpointCompliant(authCtx, targetProject, pluginSettings, se)) {
                                throw new Exception("Save denied: project " + sourceProject.projectKey +  " shares streaming endpoint  " + eo.localName + " to project " + rule.targetProject+ ", but this target project may not use the connection of this dataset");
                            }
                            break;
                        }
                        case SQL_NOTEBOOK: {
                            // Not really shareable, no need to handle
                            break;
                        }
                        case SCENARIO: {
                            // Not really shareable, no need to handle
                            break;
                        }

                        /* No need to handle */
                        case ANALYSIS: // Already checked through dataset save
                        case DASHBOARD: // Irrelevant
                        case INSIGHT: // Already checked through dataset save
                        case STATISTICS_WORKSHEET: // Already checked through dataset save
                        case WORKSPACE: // Irrelevant
                        case API_DEPLOYER_DEPLOYMENT: // Not shareable
                        case API_DEPLOYER_INFRA:// Not shareable
                        case API_DEPLOYER_SERVICE:// Not shareable
                        case PROJECT_DEPLOYER_DEPLOYMENT:// Not shareable
                        case PROJECT_DEPLOYER_INFRA:// Not shareable
                        case PROJECT_DEPLOYER_PROJECT:// Not shareable
                        case MODEL_COMPARISON: // Irrelevant
                        case MODEL_EVALUATION_STORE: // Irrelevant
                        case ARTICLE: // Irrelevant
                        case FLOW_ZONE: // Not shareable
                        case WORKSPACE_LINK: // Not shareable
                        case REPORT: // If it uses connection in code, already checked
                        case WEB_APP: // If it uses connection in code, already checked
                        case JUPYTER_NOTEBOOK: // If it uses connection in code, already checked
                        case SAVED_MODEL: // Irrelevant
                        case RECIPE: // Not shareable
                        case LAMBDA_SERVICE: // Not shareable
                        case PROJECT: // Not shareable
                        case CODE_STUDIO: // Not shareable
                        case CODE_STUDIO_TEMPLATE: // Not shareable
                        case LABELING_TASK: // Not shareable
                            break;
                        }
                    }
                }
            }
        }
    }
    
    @Override
    public void onPreDatasetCreation(AuthCtx authCtx, SerializedDataset serializedDataset, DatasetCreationContext context) throws Exception {
        JsonObject pluginSettings = regularPluginsRegistryService.getSettings("connections-projects-restrictions").config;

        if (pluginSettings.has("enabled")) {
            boolean enabled = pluginSettings.get("enabled").getAsBoolean();
            if (!enabled) {
                logger.debug("CPR plugin disabled, not checking anything");
                return;
            }
        }

        SerializedProject sp = projectsService.getMandatory(serializedDataset.getProjectKey());

        if (!isDatasetCompliant(authCtx, sp, pluginSettings, serializedDataset)) {
            throw new Exception("Cannot create this dataset: it uses connections that are forbidden in this project");
        }
    }
    
    // We do not need to check onPreDataExport to ExportDestinationType.DATASET because it would require a non-compliant dataset to be created
    // first and will thus be blocked

    @Override
    public void onPreObjectSave(AuthCtx authCtx, TaggableObject before, TaggableObject after) throws Exception {
        JsonObject pluginSettings = regularPluginsRegistryService.getSettings("connections-projects-restrictions").config;
        
        if (pluginSettings.has("enabled")) {
            boolean enabled = pluginSettings.get("enabled").getAsBoolean();
            if (!enabled) {
                logger.debug("CPR plugin disabled, not checking anything");
                return;
            }
        }

        logger.info("onPreObjectSave: " + after.getTaggableType() + ": " + after.getFullId());

        if (after instanceof SerializedProject) {
            SerializedProject afterProject = (SerializedProject)after;

            /* Project was saved
             *   - Check each of the checked objects in the project
             *   - Check each item that this project shares to other project
             */

            for (SerializedDataset sd: datasetsDAO.listUnsafe(afterProject.projectKey)) {
                if (!isDatasetCompliant(authCtx, afterProject, pluginSettings, sd)) {
                    throw new Exception("Save denied: project " + afterProject.projectKey + " contains dataset " + sd.name + " which is not allowed in this project according to strict connection rules");
                }
            }

            for (ManagedFolder mf: managedFoldersDAO.listUnsafe(afterProject.projectKey)) {
                if (!isManagedFolderCompliant(authCtx, afterProject, pluginSettings, mf)) {
                    throw new Exception("Save denied: project " + afterProject.projectKey + " contains managed folder " + mf.id + " which is not allowed in this project according to strict connection rules");
                }
            }

            for (StreamingEndpoint se: streamingEndpointsDAO.listUnsafe(afterProject.projectKey)) {
                if (!isStreamingEndpointCompliant(authCtx, afterProject, pluginSettings, se)) {
                    throw new Exception("Save denied: project " + afterProject.projectKey + " contains streaming endpoint" + se.id + " which is not allowed in this project according to strict connection rules");
                }
            }

            for (Scenario s: scenariosDAO.listUnsafe(afterProject.projectKey)) {
                if (!isScenarioCompliant(authCtx, afterProject, pluginSettings, s)) {
                    throw new Exception("Save denied: project " + afterProject.projectKey + " contains scenario " + s.id + " which is not allowed in this project according to strict connection rules");
                }
            }

            for (SQLNotebook sn: sqlNotebooksDAO.listUnsafe(afterProject.projectKey)) {
                if (!isSQLNotebookCompliant(authCtx, afterProject, pluginSettings, sn)) {
                    throw new Exception("Save denied: project " + afterProject.projectKey + " contains SQL notebook " + sn.id + " which is not allowed in this project according to strict connection rules");
                }
            }

            checkComplianceOfEachExposedObject(authCtx, afterProject, pluginSettings);

        } else if (after instanceof SerializedDataset) {
            SerializedProject sp = projectsService.getMandatoryUnsafe(after.getProjectKey());
            if (!isDatasetCompliant(authCtx, sp, pluginSettings, (SerializedDataset)after)) {
                throw new Exception("Save denied: project " + after.getProjectKey() + " contains dataset " + after.getId() + " which is not allowed in this project according to strict connection rules");
            }
        } else if (after instanceof ManagedFolder) {
            SerializedProject sp = projectsService.getMandatoryUnsafe(after.getProjectKey());
            if (!isManagedFolderCompliant(authCtx, sp, pluginSettings, (ManagedFolder)after)) {
                throw new Exception("Save denied: project " + after.getProjectKey() + " contains managed folder " + after.getId() + " which is not allowed in this project according to strict connection rules");
            }

        } else if (after instanceof StreamingEndpoint) {
            SerializedProject sp = projectsService.getMandatoryUnsafe(after.getProjectKey());
            if (!isStreamingEndpointCompliant(authCtx, sp, pluginSettings, (StreamingEndpoint)after)) {
                throw new Exception("Save denied: project " + after.getProjectKey() + " contains streaming endpoint " + after.getId() + " which is not allowed in this project according to strict connection rules");
            }
        } else if (after instanceof Scenario) {
            SerializedProject sp = projectsService.getMandatoryUnsafe(after.getProjectKey());
            if (!isScenarioCompliant(authCtx, sp, pluginSettings, (Scenario)after)) {
                throw new Exception("Save denied: project " + after.getProjectKey() + " contains scenario " + after.getId() + " which is not allowed in this project according to strict connection rules");
            }
        } else if (after instanceof SQLNotebook) {
            SerializedProject sp = projectsService.getMandatoryUnsafe(after.getProjectKey());
            if (!isSQLNotebookCompliant(authCtx, sp, pluginSettings, (SQLNotebook)after)) {
                throw new Exception("Save denied: project " + after.getProjectKey() + " contains SQL Notebook " + after.getId() + " which is not allowed in this project according to strict connection rules");
            }
        }
    }

    @Override
    public void onPreSQLConnectionDirectUse(AuthCtx user, String contextProjectKey, String connectionName) throws Exception {
        JsonObject pluginSettings = regularPluginsRegistryService.getSettings("connections-projects-restrictions").config;

        if (pluginSettings.has("enabled")) {
            boolean enabled = pluginSettings.get("enabled").getAsBoolean();
            if (!enabled) {
                logger.debug("CPR plugin disabled, not checking anything");
                return;
            }
        }

        if (StringUtils.isBlank(contextProjectKey)) {
            /* If we don't have a context project key, we can't check anything. Note that using client.sql_query(connection=) will
             * fall in that case and is therefore not "protectable" */
            return;
        }
        SerializedProject sp = projectsService.getMandatoryUnsafe(contextProjectKey);
        if (!isConnectionCompliant(user, sp, pluginSettings, connectionName)) {
            throw new Exception("SQL usage denied, connection " + connectionName + " is not allowed in this project according to strict connection rules");
        }
    }

    private static DKULogger logger = DKULogger.getLogger("dku.plugins.cpr.hooks");
}
