/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.governance.formats;

import com.dataiku.dip.llm.governance.GuardrailMeta;
import com.dataiku.dip.llm.governance.GuardrailParams;
import com.dataiku.dip.llm.governance.GuardrailRunner;
import com.dataiku.dip.llm.governance.GuardrailsPipelineSettings;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.google.gson.JsonArray;
import com.google.gson.JsonObject;
import java.util.EnumSet;

public class ResponseFormatChecker {
    public static final GuardrailMeta META = new GuardrailMeta(){

        @Override
        public String getType() {
            return "ResponseFormatChecker";
        }

        @Override
        public Class<? extends GuardrailParams> paramsClass() {
            return Params.class;
        }

        @Override
        public EnumSet<GuardrailMeta.GuardrailFlag> getFlags(GuardrailsPipelineSettings.GuardrailsPipelineElement elt) {
            Params p = elt.getParamsCopyAs(Params.class);
            EnumSet<GuardrailMeta.GuardrailFlag> es = EnumSet.noneOf(GuardrailMeta.GuardrailFlag.class);
            es.add(GuardrailMeta.GuardrailFlag.OPERATES_ON_RESPONSES);
            return es;
        }

        @Override
        public GuardrailRunner buildRunner(AuthCtx authCtx, String projectKey, GuardrailsPipelineSettings.GuardrailsPipelineElement elt, String bypassToken) throws Exception {
            return new Runner(authCtx, projectKey, elt);
        }
    };
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.llm.guardrails.expectedformat");

    private static class Runner
    extends GuardrailRunner {
        private final AuthCtx authCtx;
        private final String projectKey;
        private final GuardrailsPipelineSettings.GuardrailsPipelineElement elt;
        private final Params p;

        Runner(AuthCtx authCtx, String projectKey, GuardrailsPipelineSettings.GuardrailsPipelineElement elt) {
            this.authCtx = authCtx;
            this.projectKey = projectKey;
            this.elt = elt;
            this.p = elt.getParamsCopyAs(Params.class);
        }

        @Override
        public void init() throws Exception {
        }

        @Override
        public void close() {
        }

        @Override
        public GuardrailRunner.CompletionQueryGuardrailResponse processCompletionQuery(GuardrailRunner.GuardrailContext context, LLMClient.SingleCompletionQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
            throw new Error("unreachable");
        }

        @Override
        public GuardrailRunner.CompletionResponseGuardrailResponse processCompletionResponse(GuardrailRunner.GuardrailContext context, LLMClient.SingleCompletionQuery query, LLMClient.SimpleCompletionResponseOrError response, LLMClient.LLMMeshTraceSpan trace) throws Exception {
            switch (this.p.expectedFormat) {
                case JSON_ARRAY: {
                    JsonArray arr = (JsonArray)JSON.parse((String)response.text, JsonArray.class);
                    if (arr != null) break;
                    throw new IllegalArgumentException("Cannot parse to JSON array");
                }
                case JSON_OBJECT: {
                    JsonObject obj = (JsonObject)JSON.parse((String)response.text, JsonObject.class);
                    if (obj != null) break;
                    throw new IllegalArgumentException("Cannot parse to JSON object");
                }
            }
            return GuardrailRunner.CompletionResponseGuardrailResponse.pass(context, response);
        }

        @Override
        public LLMClient.StreamedCompletionResponseConsumer newStreamedCompletionResponseHandler(LLMClient.StreamedCompletionResponseConsumer underlying, GuardrailRunner.GuardrailContext context, LLMClient.SingleCompletionQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
            throw new Error("unreachable");
        }

        @Override
        public GuardrailRunner.EmbeddingQueryGuardrailResponse processEmbeddingQuery(GuardrailRunner.GuardrailContext context, LLMClient.EmbeddingQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
            return GuardrailRunner.EmbeddingQueryGuardrailResponse.pass(context, query);
        }

        @Override
        public GuardrailRunner.ImageGenerationQueryGuardrailResponse processImageGenerationQuery(GuardrailRunner.GuardrailContext context, LLMClient.ImageGenerationQuery query, LLMClient.LLMMeshTraceSpan trace) throws Exception {
            return GuardrailRunner.ImageGenerationQueryGuardrailResponse.pass(context, query);
        }

        @Override
        public GuardrailRunner.ImageGenerationResponseGuardrailResponse processImageGenerationResponse(GuardrailRunner.GuardrailContext context, LLMClient.ImageGenerationQuery query, LLMClient.ImageGenerationResponseOrError response, LLMClient.LLMMeshTraceSpan trace) {
            return GuardrailRunner.ImageGenerationResponseGuardrailResponse.pass(context, response);
        }
    }

    public static class Params
    implements GuardrailParams {
        public ExpectedFormat expectedFormat = ExpectedFormat.NONE;
    }

    public static enum ExpectedFormat {
        JSON_ARRAY,
        JSON_OBJECT,
        NONE;

    }
}

