/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.governance.pii;

import com.dataiku.dip.llm.governance.GuardrailMeta;
import com.dataiku.dip.llm.governance.GuardrailParams;
import com.dataiku.dip.llm.governance.GuardrailRunner;
import com.dataiku.dip.llm.governance.GuardrailsPipelineSettings;
import com.dataiku.dip.llm.online.LLMChatMessageUtils;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.pii.PIIClient;
import com.dataiku.dip.llm.pii.PIIKernelPool;
import com.dataiku.dip.llm.pii.PresidioBasedPIIHandlingServer;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.utils.DKUCompletableFuture;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JF;
import com.dataiku.dip.utils.JSON;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.util.EnumSet;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.springframework.beans.factory.annotation.Autowired;

public class PIIDetectionGuardrail {
    public static final GuardrailMeta META = new GuardrailMeta(){

        @Override
        public String getType() {
            return "PIIDetector";
        }

        @Override
        public Class<? extends GuardrailParams> paramsClass() {
            return Params.class;
        }

        @Override
        public EnumSet<GuardrailMeta.GuardrailFlag> getFlags(GuardrailsPipelineSettings.GuardrailsPipelineElement elt) {
            Params p = elt.getParamsCopyAs(Params.class);
            EnumSet<GuardrailMeta.GuardrailFlag> es = EnumSet.noneOf(GuardrailMeta.GuardrailFlag.class);
            if (p.filterQueries) {
                es.add(GuardrailMeta.GuardrailFlag.OPERATES_ON_QUERIES);
            }
            if (p.filterResponses) {
                es.add(GuardrailMeta.GuardrailFlag.OPERATES_ON_RESPONSES);
            }
            return es;
        }

        @Override
        public GuardrailRunner buildRunner(AuthCtx authCtx, String projectKey, GuardrailsPipelineSettings.GuardrailsPipelineElement elt, String bypassToken) throws Exception {
            return new Runner(authCtx, projectKey, elt);
        }
    };
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.guardrails.pii");

    private static class Runner
    extends GuardrailRunner {
        private final AuthCtx authCtx;
        private final String projectKey;
        private final GuardrailsPipelineSettings.GuardrailsPipelineElement elt;
        private final Params p;
        @Autowired
        private PIIKernelPool piiKernelPool;
        private PIIClient piiClient;

        Runner(AuthCtx authCtx, String projectKey, GuardrailsPipelineSettings.GuardrailsPipelineElement elt) {
            this.authCtx = authCtx;
            this.projectKey = projectKey;
            this.elt = elt;
            this.p = elt.getParamsCopyAs(Params.class);
            SpringUtils.getInstance().autowire((Object)this);
        }

        @Override
        public void init() throws Exception {
            this.piiClient = this.piiKernelPool.getClient(this.authCtx, this.projectKey, this.p);
        }

        @Override
        public void close() {
        }

        @Override
        public GuardrailRunner.CompletionQueryGuardrailResponse processCompletionQuery(GuardrailRunner.GuardrailContext context, LLMClient.SingleCompletionQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
            if (this.p.filterQueries) {
                trace.setInput("checkedText", LLMChatMessageUtils.completionQueryToStringsOfTexts(query));
                CompletableFuture<PIIClient.CompletionQueryPIIDetectionResponse> future = this.piiClient.processAsync(query);
                PIIClient.CompletionQueryPIIDetectionResponse pdcr = (PIIClient.CompletionQueryPIIDetectionResponse)DKUCompletableFuture.collectResponse(future);
                if (pdcr.redactedQuery != null) {
                    query.messages = pdcr.redactedQuery.messages;
                }
                if (pdcr.recognizedEntities != null && pdcr.recognizedEntities.size() > 0) {
                    trace.outputs = new JsonObject();
                    trace.outputs.add("detectedEntities", (JsonElement)JSON.toJsonArray(pdcr.recognizedEntities));
                    return GuardrailRunner.CompletionQueryGuardrailResponse.passWithAudit(context, query, JF.obj().with("piiDetectedEntities", (JsonElement)JSON.toJsonArray(pdcr.recognizedEntities)).get());
                }
            }
            return GuardrailRunner.CompletionQueryGuardrailResponse.pass(context, query);
        }

        @Override
        public GuardrailRunner.CompletionResponseGuardrailResponse processCompletionResponse(GuardrailRunner.GuardrailContext context, LLMClient.SingleCompletionQuery query, LLMClient.SimpleCompletionResponseOrError response, LLMClient.LLMMeshTraceSpan trace) throws Exception {
            if (this.p.filterResponses) {
                trace.setInput("checkedText", response.text);
                CompletableFuture<PIIClient.CompletionResponsePIIDetectionResponse> future = this.piiClient.processAsync(response);
                PIIClient.CompletionResponsePIIDetectionResponse pdcr = (PIIClient.CompletionResponsePIIDetectionResponse)DKUCompletableFuture.collectResponse(future);
                if (pdcr.redactedResponse != null) {
                    response.text = pdcr.redactedResponse.text;
                }
                if (pdcr.recognizedEntities != null && pdcr.recognizedEntities.size() > 0) {
                    trace.outputs = new JsonObject();
                    trace.outputs.add("detectedEntities", (JsonElement)JSON.toJsonArray(pdcr.recognizedEntities));
                    return GuardrailRunner.CompletionResponseGuardrailResponse.passWithAudit(context, response, JF.obj().with("piiDetectedEntities", (JsonElement)JSON.toJsonArray(pdcr.recognizedEntities)).get());
                }
            }
            return GuardrailRunner.CompletionResponseGuardrailResponse.pass(context, response);
        }

        @Override
        public LLMClient.StreamedCompletionResponseConsumer newStreamedCompletionResponseHandler(LLMClient.StreamedCompletionResponseConsumer underlying, GuardrailRunner.GuardrailContext context, LLMClient.SingleCompletionQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
            throw new Error("unreachable");
        }

        @Override
        public GuardrailRunner.EmbeddingQueryGuardrailResponse processEmbeddingQuery(GuardrailRunner.GuardrailContext context, LLMClient.EmbeddingQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
            if (this.p.filterQueries) {
                trace.setInput("checkedText", query.text);
                CompletableFuture<PIIClient.EmbeddingQueryPIIDetectionResponse> future = this.piiClient.processAsync(query);
                PIIClient.EmbeddingQueryPIIDetectionResponse processedQuery = (PIIClient.EmbeddingQueryPIIDetectionResponse)DKUCompletableFuture.collectResponse(future);
                query.text = processedQuery.redactedQuery.text;
                if (processedQuery.recognizedEntities != null && processedQuery.recognizedEntities.size() > 0) {
                    trace.outputs = new JsonObject();
                    trace.outputs.add("detectedEntities", (JsonElement)JSON.toJsonArray(processedQuery.recognizedEntities));
                    return GuardrailRunner.EmbeddingQueryGuardrailResponse.passWithAudit(context, query, JF.obj().with("piiDetectedEntities", (JsonElement)JSON.toJsonArray(processedQuery.recognizedEntities)).get());
                }
            }
            return GuardrailRunner.EmbeddingQueryGuardrailResponse.pass(context, query);
        }

        @Override
        public GuardrailRunner.ImageGenerationQueryGuardrailResponse processImageGenerationQuery(GuardrailRunner.GuardrailContext context, LLMClient.ImageGenerationQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
            if (this.p.filterQueries) {
                trace.setInput("checkedTexts", Stream.concat(query.prompts.stream(), query.negativePrompts.stream()).map(q -> q.prompt).collect(Collectors.joining("\n")));
                CompletableFuture<PIIClient.ImageGenerationQueryPIIDetectionResponse> future = this.piiClient.processAsync(query);
                PIIClient.ImageGenerationQueryPIIDetectionResponse processedQuery = (PIIClient.ImageGenerationQueryPIIDetectionResponse)DKUCompletableFuture.collectResponse(future);
                query.prompts = processedQuery.redactedQuery.prompts;
                query.negativePrompts = processedQuery.redactedQuery.negativePrompts;
                if (processedQuery.recognizedEntities != null && processedQuery.recognizedEntities.size() > 0) {
                    trace.outputs = new JsonObject();
                    trace.outputs.add("detectedEntities", (JsonElement)JSON.toJsonArray(processedQuery.recognizedEntities));
                    return GuardrailRunner.ImageGenerationQueryGuardrailResponse.passWithAudit(context, query, JF.obj().with("piiDetectedEntities", (JsonElement)JSON.toJsonArray(processedQuery.recognizedEntities)).get());
                }
            }
            return GuardrailRunner.ImageGenerationQueryGuardrailResponse.pass(context, query);
        }

        @Override
        public GuardrailRunner.ImageGenerationResponseGuardrailResponse processImageGenerationResponse(GuardrailRunner.GuardrailContext context, LLMClient.ImageGenerationQuery query, LLMClient.ImageGenerationResponseOrError response, LLMClient.LLMMeshTraceSpan trace) {
            return GuardrailRunner.ImageGenerationResponseGuardrailResponse.pass(context, response);
        }
    }

    public static class Params
    extends PresidioBasedPIIHandlingServer.PresidioBasedPIIHandlingSettings
    implements GuardrailParams {
        public boolean filterQueries = true;
        public boolean filterResponses = false;
    }
}

