/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.recipes.nlp.classification.user_provided;

import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.Schema;
import com.dataiku.dip.dataflow.JobActivity;
import com.dataiku.dip.datasets.Type;
import com.dataiku.dip.recipes.RecipeSchemaComputer;
import com.dataiku.dip.recipes.nlp.classification.user_provided.NLPLLMUserProvidedClassificationRecipePayloadParams;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.utils.JSON;
import com.google.common.base.Preconditions;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import org.springframework.beans.factory.annotation.Autowired;

public class NLPLLMUserProvidedClassificationRecipeSchemaComputer
extends RecipeSchemaComputer
implements RecipeSchemaComputer.RecipeSchemaComputerWithPayload {
    @Autowired
    private TransactionService transactionService;
    private NLPLLMUserProvidedClassificationRecipePayloadParams desc;

    public NLPLLMUserProvidedClassificationRecipeSchemaComputer(AuthCtx authCtx, JobActivity activity) {
        super(authCtx, activity);
    }

    @Override
    public void setPayload(String payload) {
        this.desc = (NLPLLMUserProvidedClassificationRecipePayloadParams)JSON.parse((String)payload, NLPLLMUserProvidedClassificationRecipePayloadParams.class);
    }

    @Override
    public List<Schema> getSchemasForOutputRole_NT(String role) throws Exception {
        Preconditions.checkArgument((boolean)role.equals("main"), (Object)("Recipe only has output role 'main', got: " + role));
        Schema schema = new Schema();
        if (!this.recipe.getInputsUnsafe().get((Object)role).items.isEmpty()) {
            schema = this.getInputSchemaCopy();
        }
        for (NLPLLMUserProvidedClassificationRecipeColumn column : NLPLLMUserProvidedClassificationRecipeColumn.values()) {
            if (!this.desc.explainOutput && Objects.equals(column.name, NLPLLMUserProvidedClassificationRecipeColumn.EXPLANATION.name)) continue;
            schema.withColumn(column.name, column.type);
        }
        return Collections.singletonList(schema);
    }

    private Schema getInputSchemaCopy() throws Exception {
        Dataset inputDataset;
        try (Transaction t = this.transactionService.beginRead();){
            this.recipesValidationService.checkComplianceWithRecipeDesc(this.authCtx, this.recipe);
            AnyLoc inputDatasetLoc = this.recipe.getSingleInput("main").getLoc(this.recipe.getProjectKey());
            inputDataset = this.datasetAccessService.getMandatoryUnsafe(inputDatasetLoc);
        }
        return inputDataset.getSchema().getCopy();
    }

    public static enum NLPLLMUserProvidedClassificationRecipeColumn {
        OUTPUT_CLASS("prediction", Type.STRING),
        EXPLANATION("prediction_explanation", Type.STRING),
        LLM_RAW_OUTPUT("llm_raw_response", Type.STRING),
        LLM_ERROR_MSG("llm_error_message", Type.STRING);

        final String name;
        final Type type;

        private NLPLLMUserProvidedClassificationRecipeColumn(String name, Type type) {
            this.name = name;
            this.type = type;
        }
    }
}

