/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.server.api.ml;

import com.dataiku.dip.analysis.coreservices.AnalysisCRUDService;
import com.dataiku.dip.analysis.coreservices.AnalysisDataService;
import com.dataiku.dip.analysis.coreservices.ClusteringService;
import com.dataiku.dip.analysis.coreservices.PredictionService;
import com.dataiku.dip.analysis.ml.MLSparkParams;
import com.dataiku.dip.analysis.ml.clustering.guess.ClusteringGuessPolicy;
import com.dataiku.dip.analysis.ml.prediction.guess.PredictionGuessPolicy;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.core.AnalysisCoreParams;
import com.dataiku.dip.cluster.ClusterSelector;
import com.dataiku.dip.cluster.SparkSettings;
import com.dataiku.dip.coremodel.InfoMessage;
import com.dataiku.dip.datalayer.utils.RecipeCreationUtils;
import com.dataiku.dip.exceptions.APIIllegalArgumentException;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.IPermissionsService;
import com.dataiku.dip.security.Privileges;
import com.dataiku.dip.security.audit.AuditTrailService;
import com.dataiku.dip.security.auth.MetaAuthService;
import com.dataiku.dip.server.api.PublicAPICodes;
import com.dataiku.dip.server.api.PublicAPIControllerBase;
import com.dataiku.dip.server.api.ml.PublicAPIMLLabController;
import com.dataiku.dip.server.controllers.AuditInline;
import com.dataiku.dip.server.controllers.AuditedCall;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.shaker.server.MemScriptRunner;
import com.dataiku.dip.transactions.ifaces.RWTransaction;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.Id;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.ErrorContext;
import com.google.common.collect.Lists;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.util.ArrayList;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;

@Controller
@RequestMapping(value={"/publicapi/projects/{projectKey}/lab"})
public class PublicAPILabController
extends PublicAPIControllerBase {
    @Autowired
    private TransactionService transactionService;
    @Autowired
    private MetaAuthService authService;
    @Autowired
    private IPermissionsService permissionsService;
    @Autowired
    private PredictionService predictionService;
    @Autowired
    private ClusteringService clusteringService;
    @Autowired
    private AuditTrailService auditTrailService;
    @Autowired
    private AnalysisCRUDService analysisCRUDService;
    @Autowired
    private AnalysisDataService dataService;
    static DKULogger logger = DKULogger.getLogger((String)"dku.api.analysis.controller");

    @AuditedCall(value={"msgType", "analysis-list", "projectKey", "${projectKey}"})
    @RequestMapping(value={"/"}, method={RequestMethod.GET})
    public void list(HttpServletRequest req, HttpServletResponse resp, @PathVariable String projectKey) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            AuthCtx authCtx = this.authService.getTicketOrKey(req);
            this.permissionsService.checkProjectPrivileges(authCtx, projectKey, new Privileges.ProjectLevelPrivilegeType[]{Privileges.ProjectLevelPrivilegeType.READ_CONF});
            ArrayList ret = Lists.newArrayList();
            for (AnalysisCoreParams.AnalysisListItem ali : this.analysisCRUDService.listHeadsUnsafe(projectKey, null, true)) {
                AnalysisRef ref = new AnalysisRef();
                ref.analysisId = ali.id;
                ref.analysisName = ali.name;
                ref.inputDataset = ali.inputDatasetSmartName;
                ret.add(ref);
            }
            PublicAPILabController.writeJSON((HttpServletResponse)resp, (Object)ret);
        }
    }

    @AuditInline
    @RequestMapping(value={"/"}, method={RequestMethod.POST})
    public void create(HttpServletRequest req, HttpServletResponse resp, @PathVariable String projectKey) throws Exception {
        String analysisId;
        AuthCtx authCtx;
        try (Transaction t = this.transactionService.beginRead();){
            authCtx = this.authService.getTicketOrKey(req);
            this.permissionsService.checkProjectPrivileges(authCtx, projectKey, new Privileges.ProjectLevelPrivilegeType[]{Privileges.ProjectLevelPrivilegeType.WRITE_CONF});
        }
        AnalysisCreationInfo creationInfo = (AnalysisCreationInfo)this.getRequestBodyAs(req, AnalysisCreationInfo.class);
        this.checkNotBlank(creationInfo.inputDataset, "Dataset not specified", new Object[0]);
        if (StringUtils.isBlank((String)creationInfo.analysisName)) {
            creationInfo.analysisName = "API-created visual analysis on " + creationInfo.inputDataset;
        }
        try (RWTransaction t = this.transactionService.beginWriteAsLoggedInUser(authCtx);){
            analysisId = this.analysisCRUDService.create(projectKey, creationInfo.inputDataset, creationInfo.analysisName);
            t.commit("Created analysis for API-created model on " + creationInfo.inputDataset);
        }
        this.auditTrailService.generic("analysis-create").with("projectKey", projectKey).with("datasetRef", creationInfo.inputDataset).with("analysisId", analysisId).emit();
        PublicAPILabController.writeJSON((HttpServletResponse)resp, (Object)new Id(analysisId));
    }

    @AuditedCall(value={"msgType", "analysis-get", "projectKey", "${projectKey}", "analysisId", "${analysisId}"})
    @RequestMapping(value={"/{analysisId}/"}, method={RequestMethod.GET})
    public void get(HttpServletRequest req, HttpServletResponse resp, @PathVariable String projectKey, @PathVariable String analysisId) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            AuthCtx authCtx = this.authService.getTicketOrKey(req);
            this.permissionsService.checkProjectPrivileges(authCtx, projectKey, new Privileges.ProjectLevelPrivilegeType[]{Privileges.ProjectLevelPrivilegeType.READ_CONF});
            AnalysisCoreParams analysis = this.analysisCRUDService.getCoreMandatory(projectKey, analysisId);
            PublicAPILabController.writeJSON((HttpServletResponse)resp, (Object)analysis);
        }
    }

    @AuditedCall(value={"msgType", "analysis-save", "projectKey", "${projectKey}", "analysisId", "${analysisId}"})
    @RequestMapping(value={"/{analysisId}/"}, method={RequestMethod.PUT})
    public void save(HttpServletRequest req, HttpServletResponse resp, @PathVariable String projectKey, @PathVariable String analysisId) throws Exception {
        AnalysisCoreParams acp = (AnalysisCoreParams)this.getRequestBodyAs(req, AnalysisCoreParams.class);
        this.require(StringUtils.isNotBlank((String)acp.projectKey), "Required field 'projectKey' is missing.");
        this.require(StringUtils.isNotBlank((String)acp.name), "Required field 'name' is missing.");
        this.require(StringUtils.isNotBlank((String)acp.id), "Required field 'id' is missing.");
        this.require(acp.projectKey.equals(projectKey), "Analysis projectKey does not match the requested URL");
        this.require(acp.id.equals(analysisId), "Analysis id does not match the requested URL");
        AuthCtx authCtx = null;
        try (Transaction t = this.transactionService.beginRead();){
            authCtx = this.authService.getTicketOrKey(req);
            this.permissionsService.checkProjectPrivileges(authCtx, projectKey, new Privileges.ProjectLevelPrivilegeType[]{Privileges.ProjectLevelPrivilegeType.WRITE_CONF});
        }
        String commitMessage = "Saved analysis settings for " + projectKey + " " + analysisId;
        try (RWTransaction t = this.transactionService.beginWriteAsLoggedInUser(authCtx);){
            this.analysisCRUDService.saveCore(acp, false);
            t.commit(commitMessage);
        }
        this.writeMessage(resp, commitMessage, new Object[0]);
    }

    @AuditedCall(value={"msgType", "analysis-delete", "projectKey", "${projectKey}", "analysisId", "${analysisId}"})
    @RequestMapping(value={"/{analysisId}/"}, method={RequestMethod.DELETE})
    public void delete(HttpServletRequest req, HttpServletResponse resp, @PathVariable String projectKey, @PathVariable String analysisId) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            AuthCtx authCtx = this.authService.getTicketOrKey(req);
            this.permissionsService.checkProjectPrivileges(authCtx, projectKey, new Privileges.ProjectLevelPrivilegeType[]{Privileges.ProjectLevelPrivilegeType.WRITE_CONF});
            this.analysisCRUDService.getCoreMandatoryUnsafe(projectKey, analysisId);
        }
        String commitMessage = "Deleted analysis " + analysisId + " in " + projectKey;
        try (RWTransaction t = this.transactionService.beginWriteForAPI(req);){
            this.analysisCRUDService.delete(projectKey, analysisId);
            t.commit(commitMessage);
        }
        this.writeMessage(resp, commitMessage, new Object[0]);
    }

    @AuditedCall(value={"msgType", "analysis-mltasks-list", "projectKey", "${projectKey}", "analysisId", "${analysisId}"})
    @RequestMapping(value={"/{analysisId}/models/"}, method={RequestMethod.GET})
    public void listMLLabTasks(HttpServletRequest req, HttpServletResponse resp, @PathVariable String projectKey, @PathVariable String analysisId) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            AuthCtx authCtx = this.authService.getTicketOrKey(req);
            this.permissionsService.checkProjectPrivileges(authCtx, projectKey, new Privileges.ProjectLevelPrivilegeType[]{Privileges.ProjectLevelPrivilegeType.READ_CONF});
            PublicAPIMLLabController.LabRefList ret = new PublicAPIMLLabController.LabRefList();
            AnalysisCoreParams.AnalysisListItem ali = this.analysisCRUDService.getHeadMandatoryUnsafe(projectKey, analysisId, true);
            for (AnalysisCRUDService.MLTaskHead mth : ali.mlTasks) {
                PublicAPIMLLabController.NamedLabRef nlr = new PublicAPIMLLabController.NamedLabRef();
                nlr.analysisId = ali.id;
                nlr.mlTaskId = mth.mlTaskId;
                nlr.mlTaskName = mth.name;
                nlr.taskType = mth.taskType.toString();
                nlr.inputDataset = ali.inputDatasetSmartName;
                ret.mlTasks.add(nlr);
            }
            PublicAPILabController.writeJSON((HttpServletResponse)resp, (Object)ret);
        }
    }

    @AuditInline
    @RequestMapping(value={"/{analysisId}/models/"}, method={RequestMethod.POST})
    public void createMLTask(HttpServletRequest req, HttpServletResponse resp, @PathVariable String projectKey, @PathVariable String analysisId) throws Exception {
        MLSparkParams sparkParams;
        AnalysisCoreParams acp;
        AuthCtx authCtx;
        try (Transaction t = this.transactionService.beginRead();){
            authCtx = this.authService.getTicketOrKey(req);
            this.permissionsService.checkProjectPrivileges(authCtx, projectKey, new Privileges.ProjectLevelPrivilegeType[]{Privileges.ProjectLevelPrivilegeType.WRITE_CONF});
        }
        PublicAPIMLLabController.AnalysisAndMLTaskCreationInfo creationInfo = (PublicAPIMLLabController.AnalysisAndMLTaskCreationInfo)this.getRequestBodyAs(req, PublicAPIMLLabController.AnalysisAndMLTaskCreationInfo.class);
        PublicAPIMLLabController.getAndCheckMLTaskCreation(creationInfo);
        try (Transaction t = this.transactionService.beginRead();){
            acp = this.analysisCRUDService.getCoreMandatory(projectKey, analysisId);
            if (creationInfo.backendType == null) {
                throw ErrorContext.iae((String)"Invalid ML Backend type");
            }
            if (creationInfo.backendType.isSparkBased()) {
                SparkSettings sparkSettings = new ClusterSelector().selectForProject(authCtx, projectKey).getSparkSettings();
                if (StringUtils.isBlank((String)creationInfo.sparkConfig)) {
                    creationInfo.sparkConfig = sparkSettings.getDefault().name;
                }
                sparkSettings.getByName(creationInfo.sparkConfig);
                sparkParams = RecipeCreationUtils.setupMLSparkParams((AuthCtx)authCtx, (String)projectKey, (String)creationInfo.sparkConfig, (Boolean)creationInfo.useGlobalMetastore);
            } else {
                sparkParams = null;
            }
            if (creationInfo.backendType == MLTask.BackendType.KERAS) {
                authCtx.failIfNoSafeCode("use a custom Keras architecture");
            }
            if (creationInfo.backendType == MLTask.BackendType.VERTICA) {
                throw new APIIllegalArgumentException((InfoMessage.MessageCode)PublicAPICodes.ERR_PUBLICAPI_INVALID_PARAMETER, "Cannot create new analysis and ML task: Vertica ML backend is no longer supported");
            }
        }
        if ("PREDICTION".equals(creationInfo.taskType)) {
            MemScriptRunner.TableWithReport dataTable = this.dataService.getCachedUnfiltered_NOTRANSACTION(acp, authCtx);
            if (!dataTable.table.columns.containsKey(creationInfo.targetVariable)) {
                throw new APIIllegalArgumentException((InfoMessage.MessageCode)PublicAPICodes.ERR_PUBLICAPI_INVALID_PARAMETER, "Invalid target variable '" + creationInfo.targetVariable + "', not present in analysis");
            }
        }
        String mlTaskId = null;
        if ("PREDICTION".equals(creationInfo.taskType)) {
            pgp = PredictionGuessPolicy.valueOf((String)creationInfo.guessPolicy);
            mlTaskId = this.predictionService.createAndGuess_NT(authCtx, acp, creationInfo.targetVariable, null, creationInfo.backendType, sparkParams, pgp, PublicAPIMLLabController.getAndCheckPredictionType(creationInfo.predictionType), creationInfo.timeVariable, creationInfo.timeseriesIdentifiers, creationInfo.treatmentVariable);
            this.auditTrailService.generic("analysis-prediction-task-create").with("projectKey", projectKey).with("analysisId", analysisId).with("inputDatasetRef", creationInfo.inputDataset).with("mlTaskId", mlTaskId).with("targetVariable", creationInfo.targetVariable).emit();
        } else {
            pgp = ClusteringGuessPolicy.valueOf((String)creationInfo.guessPolicy);
            mlTaskId = this.clusteringService.createAndGuess_NT(authCtx, acp, creationInfo.backendType, sparkParams, (ClusteringGuessPolicy)pgp);
            this.auditTrailService.generic("analysis-clustering-task-create").with("projectKey", projectKey).with("analysisId", analysisId).with("inputDatasetRef", creationInfo.inputDataset).with("mlTaskId", mlTaskId).emit();
        }
        PublicAPILabController.writeJSON((HttpServletResponse)resp, (Object)new PublicAPIMLLabController.LabRef(analysisId, mlTaskId));
    }

    static class AnalysisRef {
        String analysisId;
        String analysisName;
        String inputDataset;

        AnalysisRef() {
        }
    }

    static class AnalysisCreationInfo {
        public String analysisName;
        public String inputDataset;

        AnalysisCreationInfo() {
        }
    }
}

