/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.server.controllers.analysis;

import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.analysis.coreservices.AnalysisCRUDService;
import com.dataiku.dip.analysis.coreservices.MLBaseService;
import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.MLBackendsService;
import com.dataiku.dip.analysis.ml.MLTaskLoc;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.Privileges;
import com.dataiku.dip.security.audit.AuditTrailService;
import com.dataiku.dip.security.auth.UIAuthService;
import com.dataiku.dip.server.controllers.AuditInline;
import com.dataiku.dip.server.controllers.AuditNotNeeded;
import com.dataiku.dip.server.controllers.AuditedCall;
import com.dataiku.dip.server.controllers.DIPInternalControllerBase;
import com.dataiku.dip.server.services.ProjectsService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.shaker.model.SerializedShakerScript;
import com.dataiku.dip.transactions.ifaces.RWTransaction;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.utils.DKUFileUtils;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelib.org.apache.commons.io.IOUtils;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;

@Controller
public class AnalysisMLCommonController
extends DIPInternalControllerBase {
    @Autowired
    private TransactionService transactionService;
    @Autowired
    private ProjectsService projectsService;
    @Autowired
    private AnalysisCRUDService analysisCRUDService;
    @Autowired
    private MLBackendsService mlBackendsService;
    @Autowired
    private MLBaseService mlBaseService;
    @Autowired
    private AuditTrailService auditTrailService;
    @Autowired
    private UIAuthService authService;

    @AuditNotNeeded
    @RequestMapping(value={"/api/analysis/mlcommon/list-backends"})
    public void listMLBackends(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String datasetSmartName, @RequestParam String taskType) throws Exception {
        List<MLBackendsService.MLBackendDesc> backends;
        try (Transaction t = this.transactionService.beginRead();){
            AuthCtx authCtx = this.authService.getMandatoryUser(req);
            backends = this.mlBackendsService.listMLBackends(authCtx, projectKey, datasetSmartName, taskType);
        }
        AnalysisMLCommonController.writeJSON((HttpServletResponse)resp, backends);
    }

    @AuditNotNeeded
    @RequestMapping(value={"/api/analysis/mlcommon/forget-feature-selection"})
    public void forgetFeatureSelection(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String analysisId, @RequestParam String mlTaskId) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            this.projectsService.checkPerm(req, projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
            MLTaskLoc loc = new MLTaskLoc(projectKey, analysisId, mlTaskId);
            this.mlBaseService.forgetFeatureSelection(loc);
        }
    }

    @AuditedCall(value={"msgType", "analysis-ml-get-settings", "projectKey", "${projectKey}", "analysisId", "${analysisId}"})
    @RequestMapping(value={"/api/analysis/mlcommon/get-current-settings"})
    public void getMLCurrentSettings(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String analysisId, @RequestParam String mlTaskId) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            this.projectsService.checkPerm(req, projectKey, Privileges.ProjectLevelPrivilegeType.READ_CONF);
            MLTaskLoc loc = new MLTaskLoc(projectKey, analysisId, mlTaskId);
            AnalysisMLCommonController.writeJSON((HttpServletResponse)resp, (Object)this.analysisCRUDService.getMLTask(loc));
        }
    }

    @AuditedCall(value={"msgType", "analysis-ml-train-abort", "fullModelIds", "${fullModelIds}"})
    @RequestMapping(value={"/api/analysis/mlcommon/stop-grid-search"})
    public void stopGridSearch(HttpServletRequest req, HttpServletResponse resp, @RequestParam String[] fullModelIds) throws Exception {
        ArrayList<FullModelId> fullModelIdList = new ArrayList<FullModelId>();
        HashSet<String> projectKeys = new HashSet<String>();
        for (String fullModelId : fullModelIds) {
            FullModelId fmi = FullModelId.parse(fullModelId);
            fullModelIdList.add(fmi);
            projectKeys.add(fmi.getProjectKey());
        }
        try (Transaction t = this.transactionService.beginRead();){
            for (String projectKey : projectKeys) {
                this.projectsService.checkPerm(req, projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
            }
        }
        this.mlBaseService.stopGridSearch_NT(fullModelIdList);
    }

    @AuditedCall(value={"msgType", "analysis-ml-train-abort", "analysisId", "${analysisId}", "mlTaskId", "${mlTaskId}", "sessionId", "${sessionId}"})
    @RequestMapping(value={"/api/analysis/mlcommon/stop-grid-search-session"})
    public void stopGridSearchSession(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String analysisId, @RequestParam String mlTaskId, @RequestParam String sessionId) throws Exception {
        MLTaskLoc mlTaskLoc = new MLTaskLoc(projectKey, analysisId, mlTaskId);
        try (Transaction t = this.transactionService.beginRead();){
            this.projectsService.checkPerm(req, mlTaskLoc.analysisProjectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
        }
        this.mlBaseService.stopGridSearchSession_NT(mlTaskLoc, sessionId);
    }

    @AuditedCall(value={"msgType", "analysis-ml-train-abort", "projectKey", "${projectKey}", "analysisId", "${analysisId}"})
    @RequestMapping(value={"/api/analysis/mlcommon/train-abort"})
    public void trainAbort(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String analysisId, @RequestParam String mlTaskId, @RequestParam(required=false) boolean pauseQueue) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            this.projectsService.checkPerm(req, projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
        }
        this.mlBaseService.abort_NT(new MLTaskLoc(projectKey, analysisId, mlTaskId), pauseQueue);
    }

    @AuditInline
    @RequestMapping(value={"/api/analysis/mlcommon/train-abort-partial"})
    public void trainAbort(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String analysisId, @RequestParam String mlTaskId, @RequestParam String[] fullModelIds, @RequestParam(required=false) boolean pauseQueue) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            this.projectsService.checkPerm(req, projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
        }
        MLTaskLoc mlTaskLoc = new MLTaskLoc(projectKey, analysisId, mlTaskId);
        ArrayList<FullModelId> fullModelIdList = new ArrayList<FullModelId>();
        for (String fullModelId : fullModelIds) {
            FullModelId fmi = FullModelId.parse(fullModelId);
            if (!fmi.getTaskLoc().equals(mlTaskLoc)) {
                throw new Exception("FullModelId doesnt match specified TaskLoc : " + fullModelId);
            }
            fullModelIdList.add(fmi);
        }
        this.auditTrailService.generic("analysis-ml-train-abort-partial").with("projectKey", projectKey).with("analysisId", analysisId).with("mlTaskId", mlTaskId).with("fullModelIds", StringUtils.join(fullModelIdList, (String)",")).emit();
        this.mlBaseService.abort_NT(mlTaskLoc, fullModelIdList, pauseQueue);
    }

    @AuditedCall(value={"msgType", "analysis-mltask-delete", "projectKey", "${projectKey}", "analysisId", "${analysisId}", "mlTaskId", "${mlTaskId}"})
    @RequestMapping(value={"/api/analysis/mlcommon/delete-mltask"})
    public void deleteTask(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String analysisId, @RequestParam String mlTaskId) throws Exception {
        try (RWTransaction t = this.transactionService.beginWriteForUI(req);){
            this.projectsService.checkPerm(req, projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
            this.mlBaseService.deleteMLTask(new MLTaskLoc(projectKey, analysisId, mlTaskId));
            t.commit("Delete ML Task " + mlTaskId + " in " + projectKey + "." + analysisId);
        }
    }

    @AuditedCall(value={"msgType", "analysis-mltask-list-queued-sessions", "projectKey", "${projectKey}", "analysisId", "${analysisId}"})
    @RequestMapping(value={"/api/analysis/mlcommon/list-queued-sessions"})
    public void listQueuedSessions(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String analysisId, @RequestParam String mlTaskId) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            this.projectsService.checkPerm(req, projectKey, Privileges.ProjectLevelPrivilegeType.READ_CONF);
        }
        MLTaskLoc loc = new MLTaskLoc(projectKey, analysisId, mlTaskId);
        AnalysisMLCommonController.writeJSON((HttpServletResponse)resp, this.mlBaseService.listQueuedSessions(loc));
    }

    @AuditedCall(value={"msgType", "analysis-mltask-pause-queue-sessions", "projectKey", "${projectKey}", "analysisId", "${analysisId}"})
    @RequestMapping(value={"/api/analysis/mlcommon/pause-queue"})
    public void pauseQueue(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String analysisId, @RequestParam String mlTaskId) throws Exception {
        MLTaskLoc loc = new MLTaskLoc(projectKey, analysisId, mlTaskId);
        try (Transaction t = this.transactionService.beginRead();){
            this.projectsService.checkPerm(req, projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
        }
        this.mlBaseService.setQueueState(loc, false);
    }

    @AuditedCall(value={"msgType", "analysis-mltask-delete-queued-sessions", "projectKey", "${projectKey}", "analysisId", "${analysisId}"})
    @RequestMapping(value={"/api/analysis/mlcommon/delete-queued-sessions"})
    public void deleteQueuedSessions(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String analysisId, @RequestParam String mlTaskId, @RequestParam String sessionIds) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            this.projectsService.checkPerm(req, projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
        }
        MLTaskLoc loc = new MLTaskLoc(projectKey, analysisId, mlTaskId);
        List sessionList = (List)JSON.parse((String)sessionIds, List.class);
        this.mlBaseService.deleteQueuedSessionsById(loc, sessionList);
    }

    @AuditedCall(value={"msgType", "analysis-ml-get-log", "fullModelId", "${fullModelId}"})
    @RequestMapping(value={"/api/analysis/mlcommon/cat-activity-log"})
    public void catActivityLog(HttpServletRequest req, HttpServletResponse resp, @RequestParam String fullModelId) throws Exception {
        FullModelId parsedFullModelId = FullModelId.parse(fullModelId);
        try (Transaction t = this.transactionService.beginRead();){
            AuthCtx authCtx = this.authService.getMandatoryUserNoXSRF(req);
            this.projectsService.checkPerm(authCtx, parsedFullModelId.getProjectKey(), Privileges.ProjectLevelPrivilegeType.READ_CONF);
            if (!AnalysisMLCommonController.checkCanAccessLogs(authCtx, resp)) {
                return;
            }
        }
        File ppLogFile = DKUFileUtils.getWithinFollowLink((File)parsedFullModelId.getPreprocessingFolder(), (String[])new String[]{"train.log"});
        if (ppLogFile.exists()) {
            resp.setContentType("text/plain; charset=UTF-8");
            try (FileInputStream fis = new FileInputStream(ppLogFile);){
                IOUtils.copy((InputStream)fis, (OutputStream)resp.getOutputStream());
            }
        } else {
            resp.setStatus(404);
            resp.getWriter().write("Analysis ML log file not found");
        }
    }

    @AuditedCall(value={"msgType", "analysis-ml-download-train-diagnosis", "fullModelId", "${fullModelId}", "includeTrainingData", "${includeTrainingData}"})
    @RequestMapping(value={"/api/analysis/mlcommon/download-train-diagnosis"})
    public void downloadTrainDiagnosis(HttpServletRequest req, HttpServletResponse resp, @RequestParam String fullModelId, @RequestParam(defaultValue="false") boolean includeTrainingData) throws Exception {
        AuthCtx authCtx;
        FullModelId fmi = FullModelId.parse(fullModelId);
        try (Transaction t = this.transactionService.beginRead();){
            authCtx = this.authService.getMandatoryUserNoXSRF(req);
            this.projectsService.checkPerm(authCtx, fmi.getProjectKey(), Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
            if (includeTrainingData) {
                this.projectsService.checkPerm(authCtx, fmi.getProjectKey(), Privileges.ProjectLevelPrivilegeType.EXPORT_DATASETS_DATA);
            }
        }
        if (!fmi.exists()) {
            resp.setStatus(404);
            resp.getWriter().write("Model dir not found");
        } else if (fmi.getType() != FullModelId.Type.ANALYSIS) {
            resp.setStatus(400);
            resp.getWriter().write("Supports only ANALYSIS models");
        } else {
            String filename = "dss-train-diag-" + String.valueOf(fmi) + ".zip";
            resp.setContentType("application/zip");
            resp.setHeader("Content-Disposition", "attachment; filename=\"" + filename + "\"");
            resp.setStatus(200);
            this.mlBaseService.generateTrainDiagnosis(authCtx, fmi, includeTrainingData, (OutputStream)resp.getOutputStream());
        }
    }

    @AuditedCall(value={"msgType", "analysis-ml-copy-features", "projectKeyFrom", "${projectKeyFrom}", "analysisIdFrom", "${analysisIdFrom}", "mlTaskIdFrom", "${mlTaskIdFrom}", "projectKeyTo", "${projectKeyTo}", "analysisIdTo", "${analysisIdTo}", "mlTaskIdTo", "${mlTaskIdTo}"})
    @RequestMapping(value={"/api/analysis/mlcommon/copy-features-handling"})
    public void copyFeaturesHandling(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKeyFrom, @RequestParam String analysisIdFrom, @RequestParam String mlTaskIdFrom, @RequestParam String projectKeyTo, @RequestParam String analysisIdTo, @RequestParam String mlTaskIdTo) throws Exception {
        MLTask taskTo;
        MLTaskLoc locTo;
        MLTask taskFrom;
        AuthCtx u;
        try (Transaction t = this.transactionService.beginRead();){
            u = this.authService.getMandatoryUser(req);
            this.projectsService.checkPerm(req, projectKeyFrom, Privileges.ProjectLevelPrivilegeType.READ_CONF);
            MLTaskLoc locFrom = new MLTaskLoc(projectKeyFrom, analysisIdFrom, mlTaskIdFrom);
            taskFrom = this.analysisCRUDService.getMLTask(locFrom);
            this.projectsService.checkPerm(req, projectKeyTo, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
            locTo = new MLTaskLoc(projectKeyTo, analysisIdTo, mlTaskIdTo);
            taskTo = this.analysisCRUDService.getMLTask(locTo);
        }
        this.mlBaseService.copyFeaturesHandling(taskFrom, taskTo);
        try (RWTransaction rw = this.transactionService.beginWriteAsLoggedInUser(u);){
            this.analysisCRUDService.saveMLTask(locTo, taskTo, true);
            rw.commit("Copied features handling from " + taskFrom.id + " to " + taskTo.id);
        }
        AnalysisMLCommonController.writeJSON((HttpServletResponse)resp, (Object)taskTo);
    }

    @AuditedCall(value={"msgType", "analysis-ml-revert-script-to-session", "projectKey", "${projectKey}", "analysisId", "${analysisId}", "mlTaskId", "${mlTaskId}", "sessionId", "${sessionId}"})
    @RequestMapping(value={"/api/analysis/mlcommon/revert-script-to-session"})
    public void revertScriptToMLSession(HttpServletRequest req, HttpServletResponse resp, @RequestParam String projectKey, @RequestParam String analysisId, @RequestParam String mlTaskId, @RequestParam String sessionId) throws Exception {
        SerializedShakerScript script;
        try (RWTransaction t = this.transactionService.beginWriteForUI(req);){
            this.projectsService.checkPerm(req, projectKey, Privileges.ProjectLevelPrivilegeType.WRITE_CONF);
            script = this.analysisCRUDService.revertScriptToMLSession_T(projectKey, analysisId, mlTaskId, sessionId, t);
        }
        AnalysisMLCommonController.writeJSON((HttpServletResponse)resp, (Object)script);
    }

    private static boolean checkCanAccessLogs(AuthCtx authCtx, HttpServletResponse resp) throws IOException {
        if (ApplicationConfigurator.hideLogs() && !authCtx.isAdmin()) {
            resp.setStatus(403);
            resp.setContentType("text/plain; charset=UTF-8");
            resp.getWriter().write("Logs visibility is restricted. Contact your administrator if you need to access these logs");
            return false;
        }
        return true;
    }
}

