/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.governance.promptinjection;

import com.dataiku.dip.connections.ConnectionsDAO;
import com.dataiku.dip.connections.HuggingFaceLocalConnection;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.governance.AbstractLocalHuggingFaceSimpleGuardrailDetector;
import com.dataiku.dip.llm.governance.GuardrailsCodes;
import com.dataiku.dip.llm.governance.GuardrailsPipelineRunner;
import com.dataiku.dip.llm.governance.promptinjection.PromptInjectionDetectionGuardrail;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;

public class LocalHuggingFacePromptInjectionDetectionPipeline
extends AbstractLocalHuggingFaceSimpleGuardrailDetector {
    public LocalHuggingFacePromptInjectionDetectionPipeline(AuthCtx authCtx, String projectKey, PromptInjectionDetectionGuardrail.Params settings) throws Exception {
        super(DKULogger.getLogger((String)"dku.llm.governance.promptInjection"));
        if (StringUtils.isBlank((CharSequence)settings.huggingFaceLocalConnectionName)) {
            throw new IllegalArgumentException("This LLM connection uses a prompt injection detector that is missing its HuggingFace local connection name");
        }
        if (StringUtils.isBlank((CharSequence)settings.huggingFaceLocalModelId)) {
            throw new IllegalArgumentException("This LLM connection uses a prompt injection detector that is missing its HuggingFace local model ID");
        }
        if (settings.huggingFaceLocalThreshold <= 0.0f || settings.huggingFaceLocalThreshold >= 1.0f) {
            throw new IllegalArgumentException("The prompt injection detection threshold must be between 0 and 1 (both excluded).");
        }
        HuggingFaceLocalConnection connection = (HuggingFaceLocalConnection)((ConnectionsDAO)SpringUtils.getBean(ConnectionsDAO.class)).getMandatoryConnection(authCtx, settings.huggingFaceLocalConnectionName);
        HuggingFaceLocalConnection.HFLocalModel enabledDetectionModel = connection.getEnabledDetectionModel(settings.huggingFaceLocalModelId);
        this.handlingMode = enabledDetectionModel.handlingMode;
        LLMStructuredRef llmRef = enabledDetectionModel.asStructuredRef(connection.name);
        super.initLLMClient(authCtx, projectKey, llmRef);
        this.detectionThreshold = settings.huggingFaceLocalThreshold;
    }

    @Override
    public void processCompletionQuery(LLMClient.SingleCompletionQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
        String content = query.messages.stream().filter(m -> !m.role.equals("system")).map(m -> m.getTextEvenIfNotTextOnly()).collect(Collectors.joining("\n"));
        super.checkDetection(content, "query", trace);
    }

    @Override
    public GuardrailsPipelineRunner.LLMUsageEnforcerException newExceptionForPipelineStage(boolean isQuery, String message) {
        return new GuardrailsPipelineRunner.LLMUsageEnforcerException(LLMClient.LLMResponseErrorSource.QUERY_PROMPT_INJECTION_DETECTION, GuardrailsCodes.ERR_LLM_QUERY_PROMPT_INJECTION, message);
    }

    @Override
    public void parseResponseAndDetect(LLMClient.SimpleCompletionResponse response, String contentSource, boolean isQuery) throws GuardrailsPipelineRunner.LLMUsageEnforcerException {
        PromptGuardPromptInjectionDetectionResult result;
        try {
            result = (PromptGuardPromptInjectionDetectionResult)JSON.parse((String)response.text, PromptGuardPromptInjectionDetectionResult.class);
        }
        catch (Exception e) {
            this.logger.error((Object)("Failed to parse prompt injection detection response: " + e.getMessage()));
            throw this.newExceptionForPipelineStage(isQuery, "Failed to prompt injection detection response.");
        }
        if (result.JAILBREAK == null && result.SAFE == null) {
            throw this.newExceptionForPipelineStage(isQuery, String.format("LLM %s denied: prompt injection detection didn't return a valid result", contentSource));
        }
        double jailBreakScore = result.JAILBREAK != null ? result.JAILBREAK : 1.0 - result.SAFE;
        if (jailBreakScore < 0.0 || jailBreakScore >= (double)this.detectionThreshold) {
            throw this.newExceptionForPipelineStage(isQuery, String.format("LLM %s denied: flagged by content moderation: prompt injection", contentSource));
        }
    }

    private static class PromptGuardPromptInjectionDetectionResult {
        public Double JAILBREAK;
        public Double SAFE;

        private PromptGuardPromptInjectionDetectionResult() {
        }
    }
}

