/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.recipes.nlp.classification.user_provided;

import com.dataiku.dip.dataflow.JobActivity;
import com.dataiku.dip.dataflow.exec.AbstractInitializedRunner;
import com.dataiku.dip.datalayer.Column;
import com.dataiku.dip.datalayer.ColumnFactory;
import com.dataiku.dip.datalayer.ProcessorOutputToSIP;
import com.dataiku.dip.datalayer.Row;
import com.dataiku.dip.datalayer.RowFactory;
import com.dataiku.dip.llm.LLMAuditHelper;
import com.dataiku.dip.llm.PromptDef;
import com.dataiku.dip.llm.online.CompletionRecipeLLMMeshClient;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.prompts.PromptExpander;
import com.dataiku.dip.llm.promptstudio.PromptStudio;
import com.dataiku.dip.recipes.nlp.classification.user_provided.NLPLLMUserProvidedClassificationRecipePayloadParams;
import com.dataiku.dip.recipes.nlp.classification.user_provided.NLPLLMUserProvidedClassificationRecipeSchemaComputer;
import com.dataiku.dip.recipes.nlp.common.NLPLLMRecipeRunnerBase;
import com.dataiku.dip.recipes.nlp.common.NLPRecipeParallelRunInputFeedThread;
import com.dataiku.dip.server.services.AuditPrivilegedClient;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.StringTransmogrifier;
import com.dataiku.dip.warnings.WarningsContext;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;

public class NLPLLMUserProvidedClassificationRecipeRunner
extends NLPLLMRecipeRunnerBase {
    private NLPLLMUserProvidedClassificationRecipePayloadParams desc;
    static DKULogger logger = DKULogger.getLogger((String)"dku.recipes.nlp.classification");

    public NLPLLMUserProvidedClassificationRecipeRunner(JobActivity activity) {
        super(activity);
    }

    @Override
    public void setPayload(String payload) {
        this.desc = (NLPLLMUserProvidedClassificationRecipePayloadParams)JSON.parse((String)payload, NLPLLMUserProvidedClassificationRecipePayloadParams.class);
        this.setLlmId(this.desc.llmId);
    }

    private static int getPossibleClassIdx(String name, List<String> possibleClassNames) {
        return possibleClassNames.indexOf(name.toLowerCase().trim());
    }

    private static String generateOutput(String className, int classIdx, @Nullable String explanation) {
        JsonObject outputObj = new JsonObject();
        if (explanation != null) {
            outputObj.addProperty("explanation", explanation);
        }
        outputObj.addProperty("class_name", className);
        outputObj.addProperty("class_id", (Number)classIdx);
        return outputObj.toString();
    }

    public static List<String> getPossibleClasses(NLPLLMUserProvidedClassificationRecipePayloadParams desc) {
        ArrayList<String> possibleClasses = new ArrayList<String>();
        for (NLPLLMUserProvidedClassificationRecipePayloadParams.ClassificationClass possibleClass : desc.possibleClasses) {
            String normalizedName = possibleClass.name.toLowerCase().trim();
            if (possibleClasses.contains(normalizedName)) {
                throw new IllegalArgumentException("Some of the given classes are duplicates");
            }
            if (StringUtils.isBlank((CharSequence)normalizedName)) {
                throw new IllegalArgumentException("One of the given classes is empty");
            }
            possibleClasses.add(normalizedName);
        }
        return possibleClasses;
    }

    @Override
    public void run() throws Exception {
        logger.info((Object)"Classification recipe API runner started");
        if (StringUtils.isBlank((CharSequence)this.desc.llmId)) {
            throw new IllegalArgumentException("No LLM was specified");
        }
        if (StringUtils.isBlank((CharSequence)this.desc.inputColumn)) {
            throw new IllegalArgumentException("No input column was specified");
        }
        List<String> possibleClasses = NLPLLMUserProvidedClassificationRecipeRunner.getPossibleClasses(this.desc);
        AbstractInitializedRunner.Output output = (AbstractInitializedRunner.Output)((List)this.outputs.get("main")).get(0);
        StringTransmogrifier transmogrifier = this.getOutputColTransmogrifier();
        Columns columns = new Columns();
        columns.outputClass = output.cf.column(transmogrifier.transmogrify(NLPLLMUserProvidedClassificationRecipeSchemaComputer.NLPLLMUserProvidedClassificationRecipeColumn.OUTPUT_CLASS.name));
        if (this.desc.explainOutput) {
            columns.explanation = output.cf.column(transmogrifier.transmogrify(NLPLLMUserProvidedClassificationRecipeSchemaComputer.NLPLLMUserProvidedClassificationRecipeColumn.EXPLANATION.name));
        }
        columns.rawLLMOutput = output.cf.column(transmogrifier.transmogrify(NLPLLMUserProvidedClassificationRecipeSchemaComputer.NLPLLMUserProvidedClassificationRecipeColumn.LLM_RAW_OUTPUT.name));
        columns.errorMessage = output.cf.column(transmogrifier.transmogrify(NLPLLMUserProvidedClassificationRecipeSchemaComputer.NLPLLMUserProvidedClassificationRecipeColumn.LLM_ERROR_MSG.name));
        try (AuditPrivilegedClient auditClient = new AuditPrivilegedClient();){
            ProcessorOutputToSIP processorOutput = new ProcessorOutputToSIP(output.out);
            try (CompletionRecipeLLMMeshClient meshClient = this.buildCompletionRecipeClient(null);){
                this.enrichedLLMRef = meshClient.getEnrichedRef();
                this.plcStream = meshClient.completeQueriesAsyncStream(this.buildCompletionSettings());
                InputFeedThread ift = new InputFeedThread((ColumnFactory)output.cf, (RowFactory)output.rf, possibleClasses);
                ift.start();
                while (true) {
                    LLMClient.SimpleCompletionResponseOrError scr;
                    Row outputRow;
                    Optional o;
                    block34: {
                        logger.info((Object)"Fetching next response from PLCS");
                        o = this.plcStream.fetchNextResponse();
                        logger.info((Object)"Fetched response from PLCS");
                        if (!o.isPresent()) break;
                        outputRow = output.rf.row();
                        for (Map.Entry e : ((Map)((CompletionRecipeLLMMeshClient.SimpleCompletionResponseOrErrorWithContext)o.get()).context).entrySet()) {
                            outputRow.put((Column)output.cf.column((String)e.getKey()), (String)e.getValue());
                        }
                        scr = ((CompletionRecipeLLMMeshClient.SimpleCompletionResponseOrErrorWithContext)o.get()).scr;
                        if (scr.ok) {
                            if (scr.predictedClass != null) {
                                outputRow.put(columns.outputClass, scr.predictedClass);
                            } else if (this.enrichedLLMRef.promptDriven) {
                                int classIdx;
                                NLPLLMUserProvidedClassificationRecipePayloadParams.ClassificationClass possibleClass = null;
                                try {
                                    JsonObject jo = (JsonObject)JSON.parse((String)scr.text, JsonObject.class);
                                    if (jo.has("class_id")) {
                                        classIdx = jo.get("class_id").getAsInt();
                                        if (classIdx == -1) {
                                            possibleClass = new NLPLLMUserProvidedClassificationRecipePayloadParams.ClassificationClass();
                                            possibleClass.name = "";
                                            outputRow.put(columns.rawLLMOutput, scr.text);
                                        } else if (classIdx < this.desc.possibleClasses.size()) {
                                            possibleClass = this.desc.possibleClasses.get(classIdx);
                                        }
                                    }
                                    if (possibleClass == null && jo.has("class_name") && (classIdx = NLPLLMUserProvidedClassificationRecipeRunner.getPossibleClassIdx(jo.get("class_name").getAsString(), possibleClasses)) > -1) {
                                        possibleClass = this.desc.possibleClasses.get(classIdx);
                                    }
                                    if (possibleClass != null) {
                                        outputRow.put(columns.outputClass, possibleClass.name);
                                    } else {
                                        outputRow.put(columns.outputClass, "");
                                        outputRow.put(columns.rawLLMOutput, scr.text);
                                    }
                                    if (this.desc.explainOutput && jo.has("explanation")) {
                                        outputRow.put(columns.explanation, jo.get("explanation").getAsString());
                                    }
                                }
                                catch (Exception e) {
                                    classIdx = NLPLLMUserProvidedClassificationRecipeRunner.getPossibleClassIdx(scr.text, possibleClasses);
                                    if (classIdx > -1) {
                                        possibleClass = this.desc.possibleClasses.get(classIdx);
                                        outputRow.put(columns.outputClass, possibleClass.name);
                                        break block34;
                                    }
                                    outputRow.put(columns.rawLLMOutput, scr.text);
                                }
                            } else {
                                outputRow.put(columns.outputClass, scr.text);
                            }
                        } else {
                            outputRow.put(columns.errorMessage, scr.errorMessage);
                            this.activity.warnContext.addWarning(WarningsContext.WarningType.LLM_QUERY_ERROR, scr.errorMessage, logger);
                        }
                    }
                    processorOutput.emitRow(outputRow);
                    LLMAuditHelper.emitLLMCompletionAuditFromJobIfNeeded(this.authCtx, auditClient, this.enrichedLLMRef, meshClient.getConnection(), ((CompletionRecipeLLMMeshClient.SimpleCompletionResponseOrErrorWithContext)o.get()).completionQuery, scr);
                }
                logger.info((Object)"Terminated");
                processorOutput.lastRowEmitted();
                ift.join();
                if (ift.getException() != null) {
                    throw new IOException("Input feeding failed", ift.getException());
                }
                this.handleCRU(meshClient);
            }
        }
    }

    public static PromptDef getPrompt(NLPLLMUserProvidedClassificationRecipePayloadParams desc) {
        return NLPLLMUserProvidedClassificationRecipeRunner.getPrompt(desc, NLPLLMUserProvidedClassificationRecipeRunner.getPossibleClasses(desc));
    }

    public static PromptDef getPrompt(NLPLLMUserProvidedClassificationRecipePayloadParams desc, List<String> possibleClasses) {
        PromptDef prompt = PromptDef.forRecipe();
        prompt.structuredPromptPrefix = "You are a helpful assistant that classifies the following text into one of these classes ";
        if (desc.outputDescription != null) {
            prompt.structuredPromptPrefix = prompt.structuredPromptPrefix + "(" + desc.outputDescription + ")";
        }
        prompt.structuredPromptPrefix = prompt.structuredPromptPrefix + ": ";
        prompt.structuredPromptPrefix = prompt.structuredPromptPrefix + desc.possibleClasses.stream().map(possibleClass -> NLPLLMUserProvidedClassificationRecipeRunner.generateOutput(possibleClass.name, NLPLLMUserProvidedClassificationRecipeRunner.getPossibleClassIdx(possibleClass.name, possibleClasses), null)).collect(Collectors.joining(", "));
        prompt.structuredPromptPrefix = prompt.structuredPromptPrefix + "\n\n";
        prompt.structuredPromptPrefix = prompt.structuredPromptPrefix + "No other classes are allowed. If you think no classes are a proper match for the text, predict an empty string (\"\") for the class_name, and -1 for the class_id\n";
        if (desc.explainOutput) {
            prompt.structuredPromptPrefix = prompt.structuredPromptPrefix + "You must classify the document into exactly one of the listed classes above and give a very short explanation. Answer using only this JSON format:\n";
            prompt.structuredPromptPrefix = prompt.structuredPromptPrefix + "{\"explanation\": \"<short explanation>\", \"class_name\": \"<predicted class_name>\", \"class_id\": <predicted class_id>}\n\n";
        } else {
            prompt.structuredPromptPrefix = prompt.structuredPromptPrefix + "You must classify the document into exactly one of the listed classes above. Answer using only this JSON format:\n";
            prompt.structuredPromptPrefix = prompt.structuredPromptPrefix + "{\"class_name\": \"<predicted class_name>\", \"class_id\": <predicted class_id>}\n\n";
        }
        PromptStudio.PromptTemplateInput pti = new PromptStudio.PromptTemplateInput();
        pti.name = "text to classify";
        pti.datasetColumnName = desc.inputColumn;
        prompt.getInputs().add(pti);
        prompt.structuredPromptExamples = desc.explainOutput ? desc.examples.stream().filter(ie -> NLPLLMUserProvidedClassificationRecipeRunner.getPossibleClassIdx(ie.outputClass, possibleClasses) != -1).map(ie -> PromptStudio.StructuredPromptTemplateExample.newSingleInput(ie.input, NLPLLMUserProvidedClassificationRecipeRunner.generateOutput(ie.outputClass, NLPLLMUserProvidedClassificationRecipeRunner.getPossibleClassIdx(ie.outputClass, possibleClasses), ie.explanation))).collect(Collectors.toList()) : desc.examples.stream().filter(ie -> NLPLLMUserProvidedClassificationRecipeRunner.getPossibleClassIdx(ie.outputClass, possibleClasses) != -1).map(ie -> PromptStudio.StructuredPromptTemplateExample.newSingleInput(ie.input, NLPLLMUserProvidedClassificationRecipeRunner.generateOutput(ie.outputClass, NLPLLMUserProvidedClassificationRecipeRunner.getPossibleClassIdx(ie.outputClass, possibleClasses), null))).collect(Collectors.toList());
        return prompt;
    }

    private LLMClient.CompletionSettings buildCompletionSettings() {
        LLMClient.CompletionSettings cs2 = new LLMClient.CompletionSettings();
        if (!this.enrichedLLMRef.promptDriven) {
            cs2.classLabels = this.desc.possibleClasses.stream().map(pc -> pc.name).collect(Collectors.toList());
        }
        if (this.enrichedLLMRef.customClassificationRequiresHypothesisTemplate) {
            cs2.hypothesisTemplate = NLPLLMUserProvidedClassificationRecipeRunner.sanitizeHypothesisTemplate(this.desc.hypothesisTemplate);
        }
        return cs2;
    }

    public static String sanitizeHypothesisTemplate(String template) {
        if (!(template = template.replace("{class}", "{}")).matches("^(?:[^{]|\\{\\}){0,300}$")) {
            throw new IllegalArgumentException("Invalid hypothesis template, it should be a sentence containing the {class} placeholder.");
        }
        return template;
    }

    private static class Columns {
        Column outputClass;
        Column explanation;
        Column rawLLMOutput;
        Column errorMessage;

        private Columns() {
        }
    }

    private class InputFeedThread
    extends NLPRecipeParallelRunInputFeedThread {
        private PromptExpander promptExpander;

        InputFeedThread(ColumnFactory cf, RowFactory rf, List<String> possibleClasses) throws IOException {
            super(NLPLLMUserProvidedClassificationRecipeRunner.this.authCtx, NLPLLMUserProvidedClassificationRecipeRunner.this.recipe, NLPLLMUserProvidedClassificationRecipeRunner.this.activity, NLPLLMUserProvidedClassificationRecipeRunner.this.plcStream, cf, rf);
            if (((NLPLLMUserProvidedClassificationRecipeRunner)NLPLLMUserProvidedClassificationRecipeRunner.this).enrichedLLMRef.promptDriven) {
                this.promptExpander = new PromptExpander(NLPLLMUserProvidedClassificationRecipeRunner.this.enrichedLLMRef, NLPLLMUserProvidedClassificationRecipeRunner.getPrompt(NLPLLMUserProvidedClassificationRecipeRunner.this.desc, possibleClasses), NLPLLMUserProvidedClassificationRecipeRunner.this.recipe.getProjectKey());
            }
        }

        @Override
        public LLMClient.SingleCompletionQuery buildCompletionQuery(Row row) {
            if (this.promptExpander != null) {
                return this.promptExpander.expand(this.cf, row);
            }
            String v = row.get(this.cf.column(NLPLLMUserProvidedClassificationRecipeRunner.this.desc.inputColumn));
            LLMClient.SingleCompletionQuery cq = new LLMClient.SingleCompletionQuery();
            cq.messages.add(new LLMClient.ChatMessage("user", v));
            return cq;
        }
    }
}

