/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.governance;

import com.dataiku.dip.connections.HuggingFaceLocalConnection;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.governance.GuardrailsPipelineRunner;
import com.dataiku.dip.llm.governance.GuardrailsPipelineUtils;
import com.dataiku.dip.llm.governance.SimpleNonModifyingGuardrailDetector;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMClientFactory;
import com.dataiku.dip.llm.online.LLMTracingUtils;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.utils.DKULogger;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;

public abstract class AbstractLocalHuggingFaceSimpleGuardrailDetector
implements SimpleNonModifyingGuardrailDetector {
    private LLMClient llmClient;
    protected float detectionThreshold;
    private Pattern tensorExceptionPattern = Pattern.compile("The size of tensor a \\(\\d+\\) must match the size of tensor b \\((\\d+)\\)");
    protected HuggingFaceLocalConnection.HuggingFaceHandlingMode handlingMode;
    protected DKULogger logger;
    private GuardrailsPipelineUtils.LLMClientAuditReporter llmClientAuditReporter;

    public AbstractLocalHuggingFaceSimpleGuardrailDetector(DKULogger logger) {
        SpringUtils.getInstance().autowire((Object)this);
        this.logger = logger;
    }

    protected void initLLMClient(AuthCtx authCtx, String projectKey, LLMStructuredRef llmRef) throws Exception {
        this.llmClient = LLMClientFactory.get(authCtx, projectKey, llmRef);
        this.llmClientAuditReporter = new GuardrailsPipelineUtils.LLMClientAuditReporter(authCtx, llmRef, this.llmClient);
    }

    @Override
    public void processCompletionQuery(LLMClient.SingleCompletionQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
        String content = query.messages.stream().map(m -> m.getTextEvenIfNotTextOnly()).collect(Collectors.joining("\n"));
        this.checkDetection(content, "query", trace);
    }

    @Override
    public void processCompletionResponse(LLMClient.SingleCompletionQuery query, LLMClient.SimpleCompletionResponseOrError response, LLMClient.LLMMeshTraceSpan trace) throws Exception {
        this.checkDetection(response.text, "response", trace);
    }

    @Override
    public void processEmbeddingQuery(LLMClient.EmbeddingQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
        this.checkDetection(query.text, "query", trace);
    }

    @Override
    public void processImageGenerationQuery(LLMClient.ImageGenerationQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
        this.checkDetection(query.getConcatenatedPrompts(), "query", trace);
        this.checkDetection(query.getConcatenatedNegativePrompts(), "query", trace);
    }

    protected void checkDetection(String content, String contentSource, LLMClient.LLMMeshTraceSpan trace) throws Exception {
        List<LLMClient.SimpleCompletionResponse> responses;
        if (StringUtils.isBlank((CharSequence)content)) {
            this.logger.info((Object)String.format("Skipping detection as %s is empty", contentSource));
            return;
        }
        LLMClient.SingleCompletionQuery completionQuery = new LLMClient.SingleCompletionQuery();
        completionQuery.messages.add(new LLMClient.ChatMessage("user", content));
        LLMClient.CompletionSettings completionSettings = new LLMClient.CompletionSettings();
        completionSettings.textClassificationOutputMode = LLMClient.CompletionSettings.ClassificationOutputMode.ALL;
        completionSettings.temperature = 0.1;
        LLMClient.LLMMeshTraceSpan callSpan = trace.withChildSpan("DKU_LLM_MESH_LLM_CALL");
        try {
            LLMTracingUtils.addIdentifiersAndSetCompletionInput(callSpan, this.llmClient.getEnrichedRef().id, this.llmClient, completionQuery, completionSettings);
            try {
                responses = this.llmClient.completeBatch(Collections.singletonList(completionQuery), completionSettings);
            }
            catch (RuntimeException e) {
                this.llmClientAuditReporter.emitLLMCompletionAuditIfNeeded(completionQuery, LLMClient.SimpleCompletionResponseOrError.fromError(e));
                Matcher matcher = this.tensorExceptionPattern.matcher(e.getMessage());
                if (!matcher.find()) {
                    throw e;
                }
                int chunkSize = Integer.parseInt(matcher.group(1));
                this.checkDetectionInChunks(content, contentSource, chunkSize, trace);
                if (callSpan != null) {
                    callSpan.close();
                }
                return;
            }
            assert (responses.size() == 1);
            LLMClient.SimpleCompletionResponseOrError scre = LLMClient.SimpleCompletionResponseOrError.fromSuccess(responses.get(0));
            this.llmClientAuditReporter.emitLLMCompletionAuditIfNeeded(completionQuery, scre);
            LLMTracingUtils.addUsageMetadataAndSetCompletionOutput(callSpan, scre);
        }
        finally {
            if (callSpan != null) {
                try {
                    callSpan.close();
                }
                catch (Throwable throwable) {
                    Throwable scre;
                    scre.addSuppressed(throwable);
                }
            }
        }
        boolean isQuery = Objects.equals(contentSource, "query");
        if (responses.size() != 1) {
            this.logger.error((Object)("Invalid detection response, a single value was expected instead of " + responses.size()));
            throw this.newExceptionForPipelineStage(isQuery, "Failed to run detection.");
        }
        LLMClient.SimpleCompletionResponse response = responses.get(0);
        this.parseResponseAndDetect(response, contentSource, isQuery);
    }

    public abstract void parseResponseAndDetect(LLMClient.SimpleCompletionResponse var1, String var2, boolean var3) throws GuardrailsPipelineRunner.LLMUsageEnforcerException;

    public abstract GuardrailsPipelineRunner.LLMUsageEnforcerException newExceptionForPipelineStage(boolean var1, String var2);

    private void checkDetectionInChunks(String content, String contentSource, int chunkSize, LLMClient.LLMMeshTraceSpan trace) throws Exception {
        int currentIndex = 0;
        int contentLength = content.length();
        while (currentIndex < contentLength) {
            int endIndex = Math.min(currentIndex + chunkSize, contentLength);
            int spaceIndex = content.lastIndexOf(" ", endIndex);
            if (spaceIndex == -1 || spaceIndex < currentIndex) {
                spaceIndex = endIndex;
            }
            String chunk = content.substring(currentIndex, spaceIndex);
            this.checkDetection(chunk, contentSource, trace);
            currentIndex = spaceIndex + 1;
        }
    }

    @Override
    public void close() throws IOException {
        try {
            this.llmClientAuditReporter.close();
        }
        catch (Exception e) {
            this.logger.warn((Object)"Failed to close llmClientAuditReporter", (Throwable)e);
        }
        try {
            this.llmClient.close();
        }
        catch (Exception e) {
            throw new IOException("Failed to close LLM client", e);
        }
    }
}

