/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.server.api;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.ModelLikeId;
import com.dataiku.dip.analysis.ml.prediction.ModelComparisonSampleDataService;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.mec.AbstractModelEvaluation;
import com.dataiku.dip.mec.FullModelEvaluationId;
import com.dataiku.dip.mec.ModelComparison;
import com.dataiku.dip.mec.ModelComparisonsCRUDService;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.IPermissionsService;
import com.dataiku.dip.security.Privileges;
import com.dataiku.dip.security.audit.AuditTrailService;
import com.dataiku.dip.security.auth.MetaAuthService;
import com.dataiku.dip.server.api.PublicAPIControllerBase;
import com.dataiku.dip.server.controllers.AuditInline;
import com.dataiku.dip.server.controllers.AuditNotNeeded;
import com.dataiku.dip.server.controllers.AuditedCall;
import com.dataiku.dip.server.services.ProjectsService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.ifaces.RWTransaction;
import com.dataiku.dip.transactions.ifaces.Transaction;
import jakarta.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.ResponseBody;

@Controller
@RequestMapping(value={"/publicapi/projects/{projectKey}/modelcomparisons"})
public class PublicAPIModelComparisonsController
extends PublicAPIControllerBase {
    @Autowired
    private TransactionService transactionService;
    @Autowired
    private MetaAuthService authService;
    @Autowired
    private IPermissionsService permissionsService;
    @Autowired
    private ModelComparisonsCRUDService modelComparisonsCRUDService;
    @Autowired
    private AuditTrailService auditTrailService;
    @Autowired
    private ProjectsService projectsService;
    @Autowired
    private ModelComparisonSampleDataService modelComparisonSampleDataService;

    @AuditedCall(value={"msgType", "modelcomparisons-list", "projectKey", "${projectKey}"})
    @ResponseBody
    @RequestMapping(value={"/"}, method={RequestMethod.GET})
    public List<ModelComparison> list(HttpServletRequest req, @PathVariable String projectKey) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            AuthCtx authCtx = this.authService.getTicketOrKey(req);
            this.permissionsService.checkProjectPrivileges(authCtx, projectKey, new Privileges.ProjectLevelPrivilegeType[]{Privileges.ProjectLevelPrivilegeType.READ_CONF});
            List list = this.modelComparisonsCRUDService.list(projectKey);
            return list;
        }
    }

    @AuditInline
    @ResponseBody
    @RequestMapping(value={"/"}, method={RequestMethod.POST})
    public PublicAPIControllerBase.ResponseMessageWithId add(HttpServletRequest req, @PathVariable String projectKey, @RequestBody ModelComparison modelComparison) throws Exception {
        Transaction t;
        this.require(StringUtils.isNotBlank((String)modelComparison.displayName), "Required field 'displayName' is missing.");
        this.require(null != modelComparison.modelTaskType, "Required field 'modelTaskType' is missing or its value is not supported.");
        this.require(StringUtils.isNotBlank((String)modelComparison.projectKey), "Required field 'projectKey' is missing.");
        if (StringUtils.isNotBlank((String)modelComparison.projectKey)) {
            this.require(StringUtils.equals((String)modelComparison.projectKey, (String)projectKey), "Model Comparison projectKey field does not match request param projectKey.");
        } else {
            modelComparison.projectKey = projectKey;
        }
        AuthCtx authCtx = this.authService.getTicketOrKey_NT(req);
        ModelComparison created = new ModelComparison(projectKey);
        this.copyMECFields(modelComparison, created);
        String commitMessage = "Created model comparison " + created.id + " (named " + created.displayName + ") in " + created.projectKey;
        try {
            t = this.transactionService.beginRead();
            try {
                for (ModelComparison.ComparedModel cm : modelComparison.comparedModels) {
                    ModelLikeId mli = this.checkValidRefId(cm, modelComparison.projectKey);
                    if (!mli.getMainFolder().exists()) {
                        throw new IllegalArgumentException("Comparated item " + cm.refId + " not found.");
                    }
                    this.projectsService.failIfNoReadAccess(authCtx, mli.getUnderlyingStore(), projectKey);
                    this.checkPredictionTypesConsistency(modelComparison, cm, mli);
                    this.checkNoPartitionedModel(cm, mli);
                    this.checkNoEnsembledModel(cm, mli);
                }
                this.permissionsService.checkProjectPrivileges(authCtx, projectKey, new Privileges.ProjectLevelPrivilegeType[]{Privileges.ProjectLevelPrivilegeType.WRITE_CONF});
                if (this.modelComparisonsCRUDService.getOrNull(created.projectKey, created.id) != null) {
                    throw new IllegalArgumentException("The model comparator '" + created.id + "' already exists in project " + created.projectKey);
                }
            }
            finally {
                if (t != null) {
                    t.close();
                }
            }
        }
        catch (Exception e) {
            this.auditTrailService.failure("modelcomparisons-create", (Throwable)e).with("projectKey", modelComparison.projectKey).with("mcId", modelComparison.id).with("displayName", modelComparison.displayName).emit();
            throw e;
        }
        try {
            t = this.transactionService.beginWriteAsLoggedInUser(authCtx);
            try {
                this.modelComparisonsCRUDService.save(authCtx, created, true, false);
                t.commit(commitMessage);
                this.auditTrailService.generic("modelcomparisons-create").with("projectKey", modelComparison.projectKey).with("mcId", modelComparison.id).with("displayName", modelComparison.displayName).emit();
            }
            finally {
                if (t != null) {
                    t.close();
                }
            }
        }
        catch (Exception e) {
            this.auditTrailService.failure("modelcomparisons-create", (Throwable)e).with("projectKey", modelComparison.projectKey).with("mcId", modelComparison.id).with("displayName", modelComparison.displayName).emit();
            throw e;
        }
        return new PublicAPIControllerBase.ResponseMessageWithId(created.id, commitMessage);
    }

    private void checkPredictionTypesConsistency(ModelComparison modelComparison, ModelComparison.ComparedModel cm, ModelLikeId mli) {
        try {
            AbstractModelEvaluation me;
            ModelComparison.ModelTaskType modelTaskType = mli instanceof FullModelEvaluationId ? ((me = ((FullModelEvaluationId)mli).getModelEvaluation()).isLLM() ? ModelComparison.ModelTaskType.LLM : (me.isAgent() ? ModelComparison.ModelTaskType.AGENT : ModelComparison.ModelTaskType.from((PredictionMLTask.PredictionType)mli.getPredictionType()))) : ModelComparison.ModelTaskType.from((PredictionMLTask.PredictionType)mli.getPredictionType());
            this.require(modelTaskType == modelComparison.modelTaskType, "Model Task types of the model comparator (" + String.valueOf(modelComparison.modelTaskType) + ") and item " + cm.refId + " (" + String.valueOf(modelTaskType) + ") are not consistent.");
        }
        catch (IOException e) {
            throw new IllegalArgumentException("Could not determine prediction type of item " + cm.refId);
        }
    }

    private void copyMECFields(ModelComparison source, ModelComparison target) {
        target.displayName = source.displayName;
        target.checklists = source.checklists;
        target.creationTag = source.creationTag;
        target.versionTag = source.versionTag;
        target.tags = source.tags;
        target.description = source.description;
        target.shortDesc = source.shortDesc;
        target.customFields = source.customFields;
        target.modelTaskType = source.modelTaskType;
        HashSet alreadySpecifiedRefs = new HashSet();
        List previousModels = target.comparedModels;
        target.comparedModels = new ArrayList();
        source.comparedModels.forEach(cm -> {
            if (alreadySpecifiedRefs.contains(cm.refId)) {
                throw new IllegalArgumentException("The reference '" + cm.refId + "' is duplicated in model comparator configuration");
            }
            alreadySpecifiedRefs.add(cm.refId);
            ModelComparison.ComparedModel newCm = null;
            if (null != previousModels) {
                newCm = previousModels.stream().filter(cmOld -> StringUtils.equals((String)cmOld.refId, (String)cm.refId)).findFirst().orElse(null);
            }
            if (null == newCm) {
                newCm = new ModelComparison.ComparedModel();
                newCm.refId = cm.refId;
            }
            target.comparedModels.add(newCm);
        });
    }

    @AuditedCall(value={"msgType", "modelcomparisons-get", "projectKey", "${projectKey}", "mcId", "${mcId}"})
    @ResponseBody
    @RequestMapping(value={"/{mcId}"}, method={RequestMethod.GET})
    public ModelComparison get(HttpServletRequest req, @PathVariable String projectKey, @PathVariable String mcId) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            AuthCtx authCtx = this.authService.getTicketOrKey(req);
            this.permissionsService.checkProjectPrivileges(authCtx, projectKey, new Privileges.ProjectLevelPrivilegeType[]{Privileges.ProjectLevelPrivilegeType.READ_CONF});
            ModelComparison modelComparison = this.modelComparisonsCRUDService.getMandatory(projectKey, mcId);
            return modelComparison;
        }
    }

    @AuditedCall(value={"msgType", "modelcomparisons-update", "projectKey", "${projectKey}", "mcId", "${mcId}"})
    @ResponseBody
    @RequestMapping(value={"/{mcId}"}, method={RequestMethod.PUT})
    public String update(HttpServletRequest req, @PathVariable String projectKey, @PathVariable String mcId) throws Exception {
        ModelComparison existingMEC;
        ModelComparison modelComparison = (ModelComparison)this.getRequestBodyAs(req, ModelComparison.class);
        this.require(StringUtils.isNotBlank((String)modelComparison.projectKey), "Required field 'projectKey' is missing.");
        this.require(StringUtils.isNotBlank((String)modelComparison.displayName), "Required field 'displayName' is missing.");
        this.require(StringUtils.isNotBlank((String)modelComparison.id), "Required field 'id' is missing.");
        this.require(modelComparison.projectKey.equals(projectKey), "Model comparator projectKey does not match the requested URL.");
        this.require(modelComparison.id.equals(mcId), "Model comparator id does not match the requested URL.");
        AuthCtx authCtx = this.authService.getTicketOrKey_NT(req);
        try (Transaction t = this.transactionService.beginRead();){
            for (ModelComparison.ComparedModel cm : modelComparison.comparedModels) {
                ModelLikeId mli = this.checkValidRefId(cm, modelComparison.projectKey);
                if (!mli.getMainFolder().exists()) {
                    throw new IllegalArgumentException("Comparated item " + cm.refId + " not found.");
                }
                this.projectsService.failIfNoReadAccess(authCtx, mli.getUnderlyingStore(), projectKey);
                this.checkPredictionTypesConsistency(modelComparison, cm, mli);
                this.checkNoPartitionedModel(cm, mli);
                this.checkNoEnsembledModel(cm, mli);
            }
            this.permissionsService.checkProjectPrivileges(authCtx, projectKey, new Privileges.ProjectLevelPrivilegeType[]{Privileges.ProjectLevelPrivilegeType.WRITE_CONF});
            existingMEC = this.modelComparisonsCRUDService.getMandatory(modelComparison.projectKey, modelComparison.id);
        }
        this.copyMECFields(modelComparison, existingMEC);
        String commitMessage = "Updated model comparison " + modelComparison.getFullId();
        try (RWTransaction t = this.transactionService.beginWriteAsLoggedInUser(authCtx);){
            existingMEC.displayParams.rowByRowShakerScript = null;
            existingMEC.displayParams.rowByRowAdditionalColumns = null;
            this.modelComparisonsCRUDService.save(authCtx, existingMEC, false, false);
            t.commit(commitMessage);
        }
        return commitMessage;
    }

    private ModelLikeId checkValidRefId(ModelComparison.ComparedModel cm, String projectKey) {
        try {
            return ModelLikeId.parse((String)cm.refId, (String)projectKey);
        }
        catch (IllegalArgumentException iae) {
            throw new IllegalArgumentException("A compared model has an invalid refId", iae);
        }
    }

    private void checkNoPartitionedModel(ModelComparison.ComparedModel cm, ModelLikeId mli) {
        switch (mli.getModelLikeType()) {
            case DOCTOR_MODEL: {
                FullModelId fmi = (FullModelId)mli;
                if (!fmi.isPartitionedBaseModel()) break;
                throw new IllegalArgumentException("Model " + cm.refId + " is a partitioned model. Not supported by model comparator.");
            }
            case MODEL_EVALUATION: {
                break;
            }
            default: {
                throw new IllegalArgumentException("Model type  " + String.valueOf(mli.getModelLikeType()) + " of model " + cm.refId + " is not supported by model comparator.");
            }
        }
    }

    private void checkNoEnsembledModel(ModelComparison.ComparedModel cm, ModelLikeId mli) {
        try {
            String preTrainPredictionModelingParamsFileName = "rmodeling_params.json";
            switch (mli.getModelLikeType()) {
                case DOCTOR_MODEL: {
                    FullModelId fmi = (FullModelId)mli;
                    if (fmi.checkModelFileExists("rmodeling_params.json")) {
                        PreTrainPredictionModelingParams modeling = (PreTrainPredictionModelingParams)fmi.parseModelFile("rmodeling_params.json", PreTrainPredictionModelingParams.class);
                        if (null != modeling.ensemble_params) {
                            throw new IllegalArgumentException("Model " + cm.refId + " is an ensembled model. Not supported by model comparator.");
                        }
                    }
                    break;
                }
                case MODEL_EVALUATION: {
                    FullModelEvaluationId fme = (FullModelEvaluationId)mli;
                    if (fme.checkEvaluationFileExists("rmodeling_params.json")) {
                        PreTrainPredictionModelingParams modeling = (PreTrainPredictionModelingParams)((FullModelEvaluationId)mli).parseEvaluationFile("rmodeling_params.json", PreTrainPredictionModelingParams.class);
                        if (null != modeling.ensemble_params) {
                            throw new IllegalArgumentException("Model evaluation " + cm.refId + " is an ensembled model evaluation. Not supported by model comparator.");
                        }
                    }
                    break;
                }
                default: {
                    throw new IllegalArgumentException("Model type  " + String.valueOf(mli.getModelLikeType()) + " of model " + cm.refId + " is not supported by model comparator.");
                }
            }
        }
        catch (IOException e) {
            throw new IllegalArgumentException("Error checking if " + cm.refId + " is an ensembled model", e);
        }
    }

    @AuditedCall(value={"msgType", "modelcomparisons-delete", "projectKey", "${projectKey}", "mcId", "${mcId}"})
    @ResponseBody
    @RequestMapping(value={"/{mcId}"}, method={RequestMethod.DELETE})
    public String delete(HttpServletRequest req, @PathVariable String projectKey, @PathVariable String mcId) throws Exception {
        String commitMessage = "Deleted model comparator " + projectKey + "." + mcId;
        AuthCtx authCtx = this.authService.getTicketOrKey_NT(req);
        try (RWTransaction t = this.transactionService.beginWriteAsLoggedInUser(authCtx);){
            this.permissionsService.checkProjectPrivileges(authCtx, projectKey, new Privileges.ProjectLevelPrivilegeType[]{Privileges.ProjectLevelPrivilegeType.WRITE_CONF});
            this.modelComparisonsCRUDService.delete(authCtx, projectKey, mcId);
            t.commit(commitMessage);
        }
        return commitMessage;
    }

    @AuditNotNeeded
    @RequestMapping(value={"/llm/list-samples"}, method={RequestMethod.GET})
    @ResponseBody
    public List<ModelComparisonSampleDataService.SampleBasicInfo> listSamples(HttpServletRequest req, @PathVariable String projectKey, @RequestParam(defaultValue="false", required=false) boolean allProjects) throws Exception {
        try (Transaction t = this.transactionService.beginRead();){
            AuthCtx authCtx = this.authService.getTicketOrKey(req);
            if (allProjects) {
                this.permissionsService.checkAdmin(authCtx, "You need to be admin to list model comparison samples from all projects.");
            } else {
                this.projectsService.checkPerm(authCtx, projectKey, new Privileges.ProjectLevelPrivilegeType[]{Privileges.ProjectLevelPrivilegeType.READ_CONF});
            }
        }
        return this.modelComparisonSampleDataService.listSamples(projectKey, allProjects);
    }

    @AuditNotNeeded
    @RequestMapping(value={"/llm/clean-sample-cache"}, method={RequestMethod.DELETE})
    @ResponseBody
    public int deleteOldSamples(HttpServletRequest req, @PathVariable String projectKey, @RequestParam(defaultValue="false", required=false) boolean allProjects, @RequestParam long minDays) throws Exception {
        try (RWTransaction t = this.transactionService.beginWriteForTicket(req);){
            AuthCtx authCtx = this.authService.getTicketOrKey(req);
            if (allProjects) {
                this.permissionsService.checkAdmin(authCtx, "You need to be admin to clear model comparison samples from all projects.");
            } else {
                this.projectsService.checkPerm(authCtx, projectKey, new Privileges.ProjectLevelPrivilegeType[]{Privileges.ProjectLevelPrivilegeType.WRITE_CONF});
            }
        }
        return this.modelComparisonSampleDataService.clearSamples(projectKey, minDays, allProjects);
    }
}

