/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.governance.toxicity;

import com.dataiku.dip.connections.ConnectionsDAO;
import com.dataiku.dip.connections.HuggingFaceLocalConnection;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.governance.AbstractLocalHuggingFaceSimpleGuardrailDetector;
import com.dataiku.dip.llm.governance.GuardrailsCodes;
import com.dataiku.dip.llm.governance.GuardrailsPipelineRunner;
import com.dataiku.dip.llm.governance.toxicity.ToxicityDetectionGuardrail;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import org.apache.commons.lang3.StringUtils;

public class LocalHuggingFaceToxicityDetectionPipeline
extends AbstractLocalHuggingFaceSimpleGuardrailDetector {
    public LocalHuggingFaceToxicityDetectionPipeline(AuthCtx authCtx, String projectKey, ToxicityDetectionGuardrail.Params settings) throws Exception {
        super(DKULogger.getLogger((String)"dku.llm.governance.toxicity"));
        if (StringUtils.isBlank((CharSequence)settings.huggingFaceLocalConnectionName)) {
            throw new IllegalArgumentException("This LLM connection uses a toxicity detector that is missing its HuggingFace local connection name");
        }
        if (StringUtils.isBlank((CharSequence)settings.huggingFaceLocalModelId)) {
            throw new IllegalArgumentException("This LLM connection uses a toxicity detector that is missing its HuggingFace local model ID");
        }
        if (settings.huggingFaceLocalThreshold <= 0.0f || settings.huggingFaceLocalThreshold >= 1.0f) {
            throw new IllegalArgumentException("The toxicity detection threshold must be between 0 and 1 (both excluded).");
        }
        HuggingFaceLocalConnection connection = (HuggingFaceLocalConnection)((ConnectionsDAO)SpringUtils.getBean(ConnectionsDAO.class)).getMandatoryConnection(authCtx, settings.huggingFaceLocalConnectionName);
        HuggingFaceLocalConnection.HFLocalModel enabledDetectionModel = connection.getEnabledDetectionModel(settings.huggingFaceLocalModelId);
        this.handlingMode = enabledDetectionModel.handlingMode;
        LLMStructuredRef llmRef = enabledDetectionModel.asStructuredRef(connection.name);
        super.initLLMClient(authCtx, projectKey, llmRef);
        this.detectionThreshold = settings.huggingFaceLocalThreshold;
    }

    @Override
    public GuardrailsPipelineRunner.LLMUsageEnforcerException newExceptionForPipelineStage(boolean isQuery, String message) {
        return new GuardrailsPipelineRunner.LLMUsageEnforcerException(isQuery ? LLMClient.LLMResponseErrorSource.QUERY_TOXICITY_DETECTION : LLMClient.LLMResponseErrorSource.RESPONSE_TOXICITY_DETECTION, isQuery ? GuardrailsCodes.ERR_LLM_QUERY_TOXIC : GuardrailsCodes.ERR_LLM_RESPONSE_TOXIC, message);
    }

    @Override
    public void parseResponseAndDetect(LLMClient.SimpleCompletionResponse response, String contentSource, boolean isQuery) throws GuardrailsPipelineRunner.LLMUsageEnforcerException {
        if (this.isLLamaGuardModel()) {
            if (!response.text.stripLeading().startsWith("safe")) {
                throw this.newExceptionForPipelineStage(isQuery, String.format("LLM %s denied: flagged by content moderation: toxicity", contentSource));
            }
        } else {
            GenericToxicityDetectionResult result;
            try {
                result = (GenericToxicityDetectionResult)JSON.parse((String)response.text, GenericToxicityDetectionResult.class);
            }
            catch (Exception e) {
                this.logger.error((Object)("Failed to parse toxicity detection response: " + e.getMessage()));
                throw this.newExceptionForPipelineStage(isQuery, "Failed to parse toxicity detection response.");
            }
            if (result.toxic == null) {
                throw this.newExceptionForPipelineStage(isQuery, String.format("LLM %s denied: toxicity detection didn't return a valid result", contentSource));
            }
            if (result.toxic < 0.0 || result.toxic >= (double)this.detectionThreshold) {
                throw this.newExceptionForPipelineStage(isQuery, String.format("LLM %s denied: flagged by content moderation: toxicity", contentSource));
            }
        }
    }

    private boolean isLLamaGuardModel() {
        return this.handlingMode == HuggingFaceLocalConnection.HuggingFaceHandlingMode.TEXT_GENERATION_LLAMA_GUARD;
    }

    private static class GenericToxicityDetectionResult {
        public Double toxic;

        private GenericToxicityDetectionResult() {
        }
    }
}

