/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.promptstudio;

import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.SerializedDataset;
import com.dataiku.dip.dao.DatasetsDAO;
import com.dataiku.dip.datalayer.Column;
import com.dataiku.dip.datalayer.Row;
import com.dataiku.dip.datalayer.memimpl.MemColumn;
import com.dataiku.dip.datalayer.memimpl.MemTable;
import com.dataiku.dip.datasets.DatasetSelectionToMemTable;
import com.dataiku.dip.datasets.SamplingParam;
import com.dataiku.dip.datasets.SingleThreadPusherToMemTable;
import com.dataiku.dip.exceptions.DSSInternalErrorException;
import com.dataiku.dip.futures.FutureProgress;
import com.dataiku.dip.futures.FutureProgressState;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMAuditHelper;
import com.dataiku.dip.llm.LLMRefEnricherService;
import com.dataiku.dip.llm.PromptChatClient;
import com.dataiku.dip.llm.PromptDef;
import com.dataiku.dip.llm.governance.GuardrailsPipelineSettings;
import com.dataiku.dip.llm.governance.GuardrailsPipelineUtils;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMMeshClient;
import com.dataiku.dip.llm.online.LLMMeshClientFactory;
import com.dataiku.dip.llm.prompts.PromptExpander;
import com.dataiku.dip.llm.promptstudio.PromptResponse;
import com.dataiku.dip.llm.promptstudio.PromptResponsePreview;
import com.dataiku.dip.llm.promptstudio.PromptStudio;
import com.dataiku.dip.llm.promptstudio.PromptStudiosCRUDService;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.resourceusage.ComputeResourceUsageReportingService;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.audit.AuditTrailService;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.streaming.endpoints.httpsse.MiniSSEEmitter;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.springframework.beans.factory.annotation.Autowired;

public class PromptExecutionEngine {
    @Autowired
    private TransactionService transactionService;
    @Autowired
    private DatasetsDAO datasetsDAO;
    @Autowired
    private AuditTrailService auditTrailService;
    @Autowired
    private ComputeResourceUsageReportingService cruReportingService;
    @Autowired
    private PromptStudiosCRUDService promptStudiosCrudService;
    @Autowired
    private LLMRefEnricherService llmRefEnricherService;
    private final AuthCtx authCtx;
    private final PromptStudio promptStudio;
    private final PromptStudio.PromptStudioPrompt promptStudioPrompt;
    private final PromptDef promptDef;
    private String promptRunId;
    private MemTable forcedRecordsForPromptTemplate;
    private static final String DKU_SINGLE_INLINE_INPUT = "__dku_single_inline_input__";
    private static DKULogger logger = DKULogger.getLogger((String)"dku.llm.promptstudio.engine");

    public PromptExecutionEngine(AuthCtx authCtx, PromptStudio promptStudio, PromptStudio.PromptStudioPrompt promptStudioPrompt) {
        this.authCtx = authCtx;
        this.promptStudioPrompt = promptStudioPrompt;
        this.promptDef = promptStudioPrompt.prompt;
        this.promptStudio = promptStudio;
        SpringUtils.getInstance().autowire((Object)this);
        logger.info((Object)"PEE initialized");
    }

    public void setPromptRunId(String promptRunId) {
        this.promptRunId = promptRunId;
    }

    public void setForcedRecordsForPromptTemplate(MemTable mt) {
        this.forcedRecordsForPromptTemplate = mt;
    }

    public PromptResponse respond() throws Exception {
        EnrichedLLMStructuredRef enrichedLLMRef = this.llmRefEnricherService.getEnrichedLLMRef(this.promptStudioPrompt.llmId, this.authCtx, this.promptStudio.getProjectKey());
        logger.info((Object)"Starting to respond to prompt");
        ArrayList<LLMClient.SingleCompletionQuery> in = new ArrayList<LLMClient.SingleCompletionQuery>();
        MemTable recordsFromPromptTemplate = null;
        PromptExpander pe = new PromptExpander(this.authCtx, this.promptStudio.projectKey, enrichedLLMRef.promptDriven, this.promptDef);
        switch (this.promptDef.promptMode) {
            case PROMPT_TEMPLATE_TEXT: 
            case PROMPT_TEMPLATE_STRUCTURED: {
                if (!enrichedLLMRef.promptDriven) {
                    if (this.promptDef.promptTemplateQueriesSource == PromptStudio.PromptTemplateQueriesSource.INLINE) {
                        pe.setSingleColumnForUnpromptMode(DKU_SINGLE_INLINE_INPUT);
                    } else {
                        pe.setSingleColumnForUnpromptMode(this.promptDef.singleInputColumn);
                    }
                }
                recordsFromPromptTemplate = this.getRecordsForPromptTemplate();
                logger.info((Object)("Columns in records: " + recordsFromPromptTemplate.columns.values().stream().map(c2 -> c2.getName()).collect(Collectors.joining(","))));
                for (int recordIdx = 0; recordIdx < recordsFromPromptTemplate.nrows(); ++recordIdx) {
                    in.add(pe.expand(recordsFromPromptTemplate, recordsFromPromptTemplate.rows.get(recordIdx)));
                }
                break;
            }
            case RAW_PROMPT: {
                if (!enrichedLLMRef.promptDriven) {
                    throw new IllegalArgumentException("Cannot run a single-shot prompt with a non-promptable LLM");
                }
                LLMClient.SingleCompletionQuery recordQuery = new LLMClient.SingleCompletionQuery();
                recordQuery.messages.add(new LLMClient.ChatMessage("user", this.promptDef.rawPromptText));
                pe.expandVariables(recordQuery.messages);
                in.add(recordQuery);
                break;
            }
            case CHAT: {
                throw new DSSInternalErrorException("Prompt mode CHAT is not supported");
            }
        }
        AnyLoc usedDataset = null;
        if (this.promptDef.promptMode.canUseDataset && this.promptDef.promptTemplateQueriesSource == PromptStudio.PromptTemplateQueriesSource.DATASET) {
            usedDataset = AnyLoc.resolveSmartNullSafe(this.promptStudio.projectKey, this.promptStudioPrompt.dataset);
        }
        GuardrailsPipelineSettings connectionGuardrailsPipelineSettings = GuardrailsPipelineUtils.getConnectionAndLLMLevelSettings(this.authCtx, this.promptStudio.projectKey, enrichedLLMRef);
        GuardrailsPipelineSettings guardrailsPipelineSettings = GuardrailsPipelineUtils.mergeEnforcementSettings(connectionGuardrailsPipelineSettings, this.promptDef.guardrailsPipelineSettings);
        try (LLMMeshClient llmMeshClient = LLMMeshClientFactory.get(this.authCtx, this.promptStudio.projectKey, enrichedLLMRef, guardrailsPipelineSettings, usedDataset, in.size());){
            List<LLMClient.SimpleCompletionResponseOrError> responses = null;
            try (FutureProgress.AutocloseableFutureProgressState _ignored = FutureProgress.pushAutoCloseableState((String)"Querying LLM", (double)in.size(), (FutureProgressState.StateUnit)FutureProgressState.StateUnit.RECORDS);){
                responses = llmMeshClient.completeQueries(in, this.promptStudioPrompt.llmSettings.toFullSettings());
            }
            ComputeResourceUsage cru = llmMeshClient.getTotalCRU(ComputeResourceUsage.LLMUsageType.COMPLETION);
            if (cru != null) {
                this.cruReportingService.reportComplete(cru);
            }
            PromptResponse promptResponse = new PromptResponse();
            promptResponse.promptId = this.promptStudioPrompt.id;
            promptResponse.runBy = this.authCtx.getIdentifier();
            promptResponse.runOn = System.currentTimeMillis();
            promptResponse.runId = this.promptRunId;
            promptResponse.querySource = this.promptDef.promptTemplateQueriesSource;
            switch (this.promptDef.promptMode) {
                case PROMPT_TEMPLATE_TEXT: 
                case PROMPT_TEMPLATE_STRUCTURED: {
                    assert (recordsFromPromptTemplate != null);
                    promptResponse.mainPromptTemplateInputs = this.getMainPromptTemplateInputsFromPromptDef(enrichedLLMRef.promptDriven);
                    for (int recordIdx = 0; recordIdx < recordsFromPromptTemplate.nrows(); ++recordIdx) {
                        PromptResponse.SingleInputPromptResponse sipr = new PromptResponse.SingleInputPromptResponse();
                        sipr.mainInputs = this.getMainInputsFromRecords(recordsFromPromptTemplate, recordIdx, enrichedLLMRef.promptDriven, pe.getSingleColumnForUnpromptedMode());
                        this.fillSingleInputPromptResponse(sipr, responses.get(recordIdx));
                        promptResponse.responses.add(sipr);
                    }
                    break;
                }
                case RAW_PROMPT: {
                    if (!enrichedLLMRef.promptDriven) {
                        throw new IllegalArgumentException("Cannot run a single-shot prompt with a non-promptable LLM");
                    }
                    PromptResponse.SingleInputPromptResponse sipr = new PromptResponse.SingleInputPromptResponse();
                    this.fillSingleInputPromptResponse(sipr, responses.get(0));
                    promptResponse.responses.add(sipr);
                }
            }
            for (int recordIdx = 0; recordIdx < promptResponse.responses.size(); ++recordIdx) {
                LLMAuditHelper.emitLLMCompletionAuditFromBackendIfNeeded(this.auditTrailService, enrichedLLMRef, llmMeshClient.getConnection(), (LLMClient.SingleCompletionQuery)in.get(recordIdx), responses.get(recordIdx));
            }
            promptResponse.stats.hasNoValidation = this.promptDef.resultValidation.hasNoSetRules();
            promptResponse.stats.testedRecords = promptResponse.responses.size();
            promptResponse.stats.validRecords = promptResponse.responses.stream().filter(pr -> !pr.error && pr.validationStatus == PromptResponse.ResponseValidationStatus.VALID).count();
            promptResponse.stats.invalidRecords = promptResponse.responses.stream().filter(pr -> !pr.error && pr.validationStatus == PromptResponse.ResponseValidationStatus.INVALID).count();
            promptResponse.stats.failedRecords = promptResponse.responses.stream().filter(pr -> pr.error).count();
            List validCosts = promptResponse.responses.stream().map(pr -> pr.estimatedCost).filter(Objects::nonNull).collect(Collectors.toList());
            promptResponse.stats.estimatedCostPer1KRecords = validCosts.stream().mapToDouble(Double::doubleValue).sum() / (double)validCosts.size() * 1000.0;
            if (Double.isNaN(promptResponse.stats.estimatedCostPer1KRecords)) {
                promptResponse.stats.estimatedCostPer1KRecords = 0.0;
            }
            PromptResponse promptResponse2 = promptResponse;
            return promptResponse2;
        }
    }

    public PromptResponse streamChatResponse(MiniSSEEmitter emitter) throws Exception {
        assert (this.promptDef.promptMode == PromptStudio.PromptMode.CHAT);
        PromptChatClient pcc = new PromptChatClient(this.authCtx, this.getPromptChat());
        PromptResponse.SingleInputPromptResponse chatResponse = pcc.streamChatResponse(emitter);
        return this.buildChatPromptResponse(chatResponse);
    }

    private PromptResponse buildChatPromptResponse(PromptResponse.SingleInputPromptResponse chatResponse) {
        PromptResponse promptResponse = new PromptResponse();
        promptResponse.promptId = this.promptStudioPrompt.id;
        promptResponse.runBy = this.authCtx.getIdentifier();
        promptResponse.runOn = System.currentTimeMillis();
        promptResponse.runId = this.promptRunId;
        this.validateResponse(chatResponse);
        PromptStudio.ConversationMessage assistantMessage = this.promptDef.chatMessages.get(chatResponse.lastMessageId);
        assistantMessage.validationStatus = chatResponse.validationStatus;
        assistantMessage.validationMessage = chatResponse.validationMessage;
        promptResponse.responses.add(chatResponse);
        return promptResponse;
    }

    public void forkResponse(PromptResponse promptResponse, String sourceUserMessageId) {
        promptResponse.promptId = this.promptStudioPrompt.id;
        promptResponse.runBy = this.authCtx.getIdentifier();
        promptResponse.runOn = System.currentTimeMillis();
        promptResponse.runId = this.promptRunId;
        List<PromptResponse.SingleInputPromptResponse> responses = promptResponse.responses;
        if (!responses.isEmpty()) {
            String parentId;
            PromptResponse.SingleInputPromptResponse sipr = responses.get(0);
            HashMap<String, PromptStudio.ConversationMessage> newChatMessages = new HashMap<String, PromptStudio.ConversationMessage>();
            sipr.lastMessageId = parentId = sipr.chatMessages.get((Object)sourceUserMessageId).parentId;
            while (parentId != null) {
                PromptStudio.ConversationMessage parentMessage = sipr.chatMessages.get(parentId);
                parentMessage.version = 0;
                newChatMessages.put(parentId, parentMessage);
                parentId = parentMessage.parentId;
            }
            sipr.chatMessages = newChatMessages;
        }
    }

    private void fillSingleInputPromptResponse(PromptResponse.SingleInputPromptResponse sipr, LLMClient.SimpleCompletionResponseOrError resp) {
        this.promptStudiosCrudService.fillSingleInputPromptResponse(sipr, resp);
        this.validateResponse(sipr);
    }

    private void validateResponse(PromptResponse.SingleInputPromptResponse sipr) {
        logger.info((Object)("Validating prompt result with rules: " + JSON.json((Object)this.promptDef.resultValidation)));
        sipr.validate(this.promptDef.resultValidation);
    }

    public PromptResponsePreview getPromptDatasetPreview() throws Exception {
        PromptResponsePreview promptResponsePreview = new PromptResponsePreview();
        if (this.promptDef.promptTemplateQueriesSource == PromptStudio.PromptTemplateQueriesSource.DATASET && this.promptDef.promptMode.canUseDataset) {
            EnrichedLLMStructuredRef enrichedLLMRef = this.llmRefEnricherService.getEnrichedLLMRef(this.promptStudioPrompt.llmId, this.authCtx, this.promptStudio.getProjectKey());
            MemTable recordsFromPromptTemplate = this.getRecordsForPromptTemplate();
            promptResponsePreview.promptId = this.promptStudioPrompt.id;
            promptResponsePreview.mainPromptTemplateInputs = this.getMainPromptTemplateInputsFromPromptDef(enrichedLLMRef.promptDriven);
            for (int recordIdx = 0; recordIdx < recordsFromPromptTemplate.nrows(); ++recordIdx) {
                PromptResponsePreview.SingleInputPromptResponsePreview siprp = new PromptResponsePreview.SingleInputPromptResponsePreview();
                siprp.mainInputs = this.getMainInputsFromRecords(recordsFromPromptTemplate, recordIdx, enrichedLLMRef.promptDriven, this.promptDef.singleInputColumn);
                promptResponsePreview.responses.add(siprp);
            }
        }
        return promptResponsePreview;
    }

    private List<PromptStudio.PromptTemplateInput> getMainPromptTemplateInputsFromPromptDef(boolean isPromptDriven) {
        ArrayList<PromptStudio.PromptTemplateInput> mainPromptTemplateInputs = new ArrayList<PromptStudio.PromptTemplateInput>();
        if (isPromptDriven) {
            for (PromptStudio.PromptTemplateInput input : this.promptDef.getInputs()) {
                PromptStudio.PromptTemplateInput newInput = new PromptStudio.PromptTemplateInput();
                newInput.name = this.promptDef.getInputName(input);
                newInput.type = input.type;
                mainPromptTemplateInputs.add(newInput);
            }
        } else {
            PromptStudio.PromptTemplateInput input = new PromptStudio.PromptTemplateInput();
            input.name = "Single input";
            mainPromptTemplateInputs.add(input);
        }
        return mainPromptTemplateInputs;
    }

    private List<String> getMainInputsFromRecords(MemTable recordsFromPromptTemplate, int recordIndex, boolean isPromptDriven, String singleInputColumn) {
        ArrayList<String> mainInputs = new ArrayList<String>();
        if (isPromptDriven) {
            for (PromptStudio.PromptTemplateInput input : this.promptDef.getInputs()) {
                String inputName = this.promptDef.getInputName(input);
                String mainInput = null;
                if (inputName != null) {
                    MemColumn mc = recordsFromPromptTemplate.getColumn(inputName);
                    mainInput = recordsFromPromptTemplate.rows.get(recordIndex).get(mc);
                }
                mainInputs.add(mainInput);
            }
        } else {
            MemColumn mc = recordsFromPromptTemplate.getColumn(singleInputColumn);
            mainInputs.add(recordsFromPromptTemplate.rows.get(recordIndex).get(mc));
        }
        return mainInputs;
    }

    private PromptChatClient.PromptChat getPromptChat() throws Exception {
        PromptChatClient.PromptChat promptChat = new PromptChatClient.PromptChat();
        promptChat.projectKey = this.promptStudio.projectKey;
        promptChat.llmSettings = this.promptStudioPrompt.llmSettings;
        promptChat.enrichedLLMRef = this.llmRefEnricherService.getEnrichedLLMRef(this.promptStudioPrompt.llmId, this.authCtx, this.promptStudio.projectKey);
        promptChat.guardrailsPipelineSettings = this.promptDef.guardrailsPipelineSettings;
        promptChat.messages = this.promptDef.chatMessages;
        promptChat.lastUserMessage = this.promptDef.lastUserMessage;
        promptChat.systemMessage = this.promptDef.chatSystemMessage;
        promptChat.context = this.promptDef.chatContext;
        promptChat.streamingDisabled = this.promptDef.streamingDisabled;
        return promptChat;
    }

    public MemTable getRecordsForPromptTemplate() throws Exception {
        EnrichedLLMStructuredRef llmRef = this.llmRefEnricherService.getEnrichedLLMRef(this.promptStudioPrompt.llmId, this.authCtx, this.promptStudio.getProjectKey());
        if (this.forcedRecordsForPromptTemplate != null) {
            return this.forcedRecordsForPromptTemplate;
        }
        switch (this.promptDef.promptTemplateQueriesSource) {
            case DATASET: {
                return this.datasetToMemTable(this.promptStudioPrompt.dataset);
            }
            case INLINE: {
                MemTable mt = new MemTable();
                if (llmRef.promptDriven) {
                    for (PromptStudio.PromptTemplateInput pti : this.promptDef.getInputs()) {
                        mt.column(pti.name);
                    }
                    for (PromptStudio.InlinePromptTemplateQuery iptq : this.promptStudioPrompt.inlinePromptTemplateQueries) {
                        Row r = mt.row();
                        for (Map.Entry<String, String> e : iptq.data.entrySet()) {
                            r.put((Column)mt.column(e.getKey()), e.getValue());
                        }
                        mt.appendRow(r);
                    }
                } else {
                    mt.column(DKU_SINGLE_INLINE_INPUT);
                    for (PromptStudio.InlinePromptTemplateQuery iptq : this.promptStudioPrompt.inlinePromptTemplateQueries) {
                        Row r = mt.row();
                        r.put((Column)mt.column(DKU_SINGLE_INLINE_INPUT), iptq.singleInputData);
                        mt.appendRow(r);
                    }
                }
                return mt;
            }
        }
        throw new Error("unreachable");
    }

    private MemTable datasetToMemTable(String datasetRef) throws Exception {
        MemTable mt = new MemTable();
        SerializedDataset sd = null;
        try (Transaction t = this.transactionService.retrieveOrBeginRead();){
            sd = (SerializedDataset)this.datasetsDAO.getMandatory(AnyLoc.resolveSmart(this.promptStudio.getProjectKey(), datasetRef));
        }
        Dataset dataset = Dataset.fromSerialized(sd);
        SingleThreadPusherToMemTable stmt = new SingleThreadPusherToMemTable(this.authCtx, dataset, mt);
        DatasetSelectionToMemTable dsmt = new DatasetSelectionToMemTable();
        dsmt.samplingMethod = SamplingParam.SamplingMethod.HEAD_SEQUENTIAL;
        dsmt.maxRecords = this.promptStudioPrompt.nbRows;
        stmt.setDatasetSelection(dsmt);
        stmt.push();
        return mt;
    }
}

