/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.server.api;

import com.dataiku.common.server.DKUControllerBase;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.dao.SavedModelsDAO;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.exceptions.ForbiddenObjectException;
import com.dataiku.dip.kernel.KernelPool;
import com.dataiku.dip.llm.langchain.PythonLLMServerKernelPool;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.DSSAuthCtx;
import com.dataiku.dip.security.IPermissionsService;
import com.dataiku.dip.security.Privileges;
import com.dataiku.dip.security.audit.AuditTrailService;
import com.dataiku.dip.security.auth.MetaAuthService;
import com.dataiku.dip.server.api.PublicAPIControllerBase;
import com.dataiku.dip.server.controllers.AuditInline;
import com.dataiku.dip.server.controllers.AuditedCall;
import com.dataiku.dip.server.controllers.NotFoundException;
import com.dataiku.dip.server.services.ITaggingService;
import com.dataiku.dip.server.services.ProjectsService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.AnyLoc;
import jakarta.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.time.Instant;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import javax.annotation.Nullable;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpStatus;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.ResponseBody;
import org.springframework.web.bind.annotation.ResponseStatus;

@Controller
public class PublicAPIAgentsController
extends PublicAPIControllerBase {
    @Autowired
    private IPermissionsService permissionsService;
    @Autowired
    private TransactionService transactionService;
    @Autowired
    private MetaAuthService authService;
    @Autowired
    private SavedModelsDAO savedModelsDAO;
    @Autowired
    private PythonLLMServerKernelPool agentPool;
    @Autowired
    private AuditTrailService auditTrailService;
    @Autowired
    private ProjectsService projectsService;

    private void checkAgentAccess(AuthCtx authCtx, String projectKey, AnyLoc agentRef, boolean needWrite) throws IOException, DKUSecurityException {
        Privileges.ProjectLevelPrivilegeType privilege = needWrite ? Privileges.ProjectLevelPrivilegeType.WRITE_CONF : Privileges.ProjectLevelPrivilegeType.READ_CONF;
        this.permissionsService.checkProjectPrivileges(authCtx, projectKey, new Privileges.ProjectLevelPrivilegeType[]{privilege});
        if (!Objects.equals(agentRef.getProjectKey(), projectKey)) {
            if (!this.projectsService.hasExposedObjectAccess(ITaggingService.TaggableType.SAVED_MODEL, authCtx, agentRef, projectKey)) {
                throw new ForbiddenObjectException("Agent %s from project %s is not shared to project %s.".formatted(agentRef.getId(), agentRef.getProjectKey(), projectKey));
            }
            if (needWrite) {
                this.permissionsService.checkProjectPrivileges(authCtx, agentRef.getProjectKey(), new Privileges.ProjectLevelPrivilegeType[]{Privileges.ProjectLevelPrivilegeType.WRITE_CONF});
            }
        }
    }

    @AuditInline
    @RequestMapping(value={"/publicapi/projects/{projectKey}/agents/{agentId}/actions/shutdown"}, method={RequestMethod.POST})
    @ResponseStatus(value=HttpStatus.NO_CONTENT)
    public void shutdownAgentKernel(HttpServletRequest req, @PathVariable String projectKey, @PathVariable String agentId, @RequestBody ShutdownRequest body) throws Exception {
        try {
            SavedModel model;
            AnyLoc loc = AnyLoc.resolveSmart((String)projectKey, (String)agentId);
            try (Transaction t = this.transactionService.beginRead();){
                AuthCtx authCtx = this.authService.getTicketOrKey(req);
                this.checkAgentAccess(authCtx, projectKey, loc, true);
                model = (SavedModel)this.savedModelsDAO.getMandatoryUnsafe(loc);
            }
            this.ensureIsAgent(model);
            String parsedVersionId = body.versionId == null ? model.activeVersion : body.versionId;
            model.getVersion(parsedVersionId).orElseThrow(() -> new NotFoundException("Agent version '%s' does not exist".formatted(parsedVersionId)));
            this.agentPool.shutdownKernels(loc.getProjectKey(), loc.getId(), parsedVersionId, body.force);
            this.auditTrailService.generic("agent-shutdown").with("projectKey", projectKey).with("agentId", agentId).with("versionId", body.versionId).emit();
        }
        catch (Exception e) {
            this.auditTrailService.failure("agent-shutdown", (Throwable)e).with("projectKey", projectKey).with("agentId", agentId).with("versionId", body.versionId).emit();
            throw e;
        }
    }

    @AuditedCall(value={"msgType", "agent-status", "projectKey", "${projectKey}", "agentId", "${agentId}"})
    @RequestMapping(value={"/publicapi/projects/{projectKey}/agents/{agentId}/status"}, method={RequestMethod.GET})
    @ResponseBody
    public StatusResponse getAgentStatus(HttpServletRequest req, @PathVariable String projectKey, @PathVariable String agentId, @RequestBody StatusRequest body) throws Exception {
        SavedModel model;
        AnyLoc loc = AnyLoc.resolveSmart((String)projectKey, (String)agentId);
        try (Transaction t = this.transactionService.beginRead();){
            AuthCtx authCtx = this.authService.getTicketOrKey(req);
            this.checkAgentAccess(authCtx, projectKey, loc, false);
            model = (SavedModel)this.savedModelsDAO.getMandatoryUnsafe(loc);
        }
        this.ensureIsAgent(model);
        String parsedVersionId = body.versionId == null ? model.activeVersion : body.versionId;
        model.getVersion(parsedVersionId).orElseThrow(() -> new NotFoundException("Agent version '%s' does not exist".formatted(parsedVersionId)));
        List<StatusResponse.KernelStatus> kernels = this.agentPool.getKernelDump().streamKernels().filter(dump -> {
            if (dump.state == KernelPool.KernelState.DEAD) return false;
            Object patt9447$temp = dump.kernelDesc;
            if (!(patt9447$temp instanceof PythonLLMServerKernelPool.KernelDesc)) return false;
            PythonLLMServerKernelPool.KernelDesc desc = (PythonLLMServerKernelPool.KernelDesc)patt9447$temp;
            if (!Objects.equals(loc.getProjectKey(), desc.getProjectKey())) return false;
            if (!Objects.equals(loc.getId(), desc.getSavedModelId())) return false;
            if (!Objects.equals(parsedVersionId, desc.getSavedModelVersionId())) return false;
            return true;
        }).map(dump -> new StatusResponse.KernelStatus(PublicAPIAgentsController.toEpochMillis(dump.startInstant), PublicAPIAgentsController.toEpochMillis(dump.readyInstant), dump.nbActiveRequests, dump.nbCancelledRequests, dump.nbFailedRequests, dump.nbSuccessfulRequests)).toList();
        return new StatusResponse(model.getId(), parsedVersionId, kernels);
    }

    @Nullable
    private static Long toEpochMillis(@Nullable Instant instant) {
        return instant == null ? null : Long.valueOf(instant.toEpochMilli());
    }

    @AuditInline
    @RequestMapping(value={"/publicapi/projects/{projectKey}/agents/{agentId}/actions/wakeup"}, method={RequestMethod.POST})
    @ResponseStatus(value=HttpStatus.NO_CONTENT)
    public void wakeUpAgent(HttpServletRequest req, @PathVariable String projectKey, @PathVariable String agentId, @RequestBody WakeUpRequest body) throws Exception {
        try {
            SavedModel model;
            AuthCtx authCtx;
            AnyLoc loc = AnyLoc.resolveSmart((String)projectKey, (String)agentId);
            try (Transaction t = this.transactionService.beginRead();){
                authCtx = this.authService.getTicketOrKey(req);
                this.checkAgentAccess(authCtx, projectKey, loc, false);
                model = (SavedModel)this.savedModelsDAO.getMandatoryUnsafe(loc);
            }
            this.ensureIsAgent(model);
            String parsedVersionId = body.versionId == null ? model.activeVersion : body.versionId;
            SavedModel.SavedModelInlineVersion smiv = (SavedModel.SavedModelInlineVersion)model.getVersion(parsedVersionId).orElseThrow(() -> new NotFoundException("Agent version '%s' does not exist".formatted(parsedVersionId)));
            this.agentPool.wakeUp((DSSAuthCtx)authCtx, loc.getProjectKey(), model, smiv);
            this.auditTrailService.generic("agent-wake-up").with("projectKey", projectKey).with("agentId", agentId).with("versionId", body.versionId).emit();
        }
        catch (Exception e) {
            this.auditTrailService.failure("agent-wake-up", (Throwable)e).with("projectKey", projectKey).with("agentId", agentId).with("versionId", body.versionId).emit();
            throw e;
        }
    }

    private void ensureIsAgent(SavedModel sm) {
        Set<SavedModel.SavedModelType> agentTypes = Set.of(SavedModel.SavedModelType.PYTHON_AGENT, SavedModel.SavedModelType.TOOLS_USING_AGENT, SavedModel.SavedModelType.PLUGIN_AGENT);
        if (!agentTypes.contains(sm.savedModelType)) {
            throw new DKUControllerBase.MalformedRequestException("Saved model is not an agent, found type %s".formatted(sm.savedModelType));
        }
    }

    public static class ShutdownRequest {
        String versionId;
        boolean force;
    }

    public static class StatusRequest {
        String versionId;
    }

    public static class StatusResponse {
        public String agentId;
        public String agentVersionId;
        public List<KernelStatus> kernels;

        private StatusResponse() {
        }

        public StatusResponse(String agentId, String agentVersionId, List<KernelStatus> kernels) {
            this.agentId = agentId;
            this.agentVersionId = agentVersionId;
            this.kernels = kernels;
        }

        public static class KernelStatus {
            public Long startedOn;
            public Long readyOn;
            public long nbActiveRequests;
            public int nbCancelledRequests;
            public int nbFailedRequests;
            public int nbSuccessfulRequests;

            private KernelStatus() {
            }

            public KernelStatus(Long startedOn, Long readyOn, long nbActiveRequests, int nbCancelledRequests, int nbFailedRequests, int nbSuccessfulRequests) {
                this.startedOn = startedOn;
                this.readyOn = readyOn;
                this.nbActiveRequests = nbActiveRequests;
                this.nbCancelledRequests = nbCancelledRequests;
                this.nbFailedRequests = nbFailedRequests;
                this.nbSuccessfulRequests = nbSuccessfulRequests;
            }
        }
    }

    public static class WakeUpRequest {
        String versionId;
    }
}

