/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.recipes.nlp.common;

import com.dataiku.dip.DKUApp;
import com.dataiku.dip.connections.AbstractSQLConnection;
import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.coremodel.SerializedDataset;
import com.dataiku.dip.coremodel.SerializedRecipe;
import com.dataiku.dip.dataflow.JobActivity;
import com.dataiku.dip.dataflow.JobAuthCtxService;
import com.dataiku.dip.dataflow.RecipeRunnableSubgraph;
import com.dataiku.dip.dataflow.exec.AbstractInitializedRunner;
import com.dataiku.dip.dataflow.exec.FlowRunnable;
import com.dataiku.dip.dataflow.exec.PreRunSchemaPropagationHandler;
import com.dataiku.dip.dataflow.exec.RecipeRunnerWithPayload;
import com.dataiku.dip.dataflow.jobrunner.JobContext;
import com.dataiku.dip.license.LicenseStatusService;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.governance.GuardrailsPipelineSettings;
import com.dataiku.dip.llm.governance.GuardrailsPipelineUtils;
import com.dataiku.dip.llm.online.CompletionRecipeLLMMeshClient;
import com.dataiku.dip.llm.online.ParallelLLMClient;
import com.dataiku.dip.llm.online.RedirectCompletionRecipeLLMMeshClient;
import com.dataiku.dip.recipes.RecipeRunner;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.resourceusage.ComputeResourceUsageReportingService;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.services.licensing.AbstractLicenseFeaturesStatusBuilder;
import com.dataiku.dip.server.services.licensing.LicenseFeaturesStatusBuilder;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.Params;
import com.dataiku.dip.utils.StringTransmogrifier;
import com.google.common.collect.Maps;
import java.io.IOException;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.springframework.beans.factory.annotation.Autowired;

public abstract class NLPLLMRecipeRunnerBase
extends AbstractInitializedRunner
implements RecipeRunner,
FlowRunnable,
RecipeRunnerWithPayload {
    @Autowired
    protected JobAuthCtxService authCtxService;
    @Autowired
    protected ComputeResourceUsageReportingService cruReportingService;
    @Autowired
    protected LicenseStatusService licenseStatusService;
    protected AuthCtx authCtx;
    private String llmId;
    protected EnrichedLLMStructuredRef enrichedLLMRef;
    protected CompletionRecipeLLMMeshClient.CompletionsStreamer plcStream;
    protected boolean advancedLLMMeshLicensed;
    static DKULogger logger = DKULogger.getLogger((String)"dku.recipes.nlp.base");

    public NLPLLMRecipeRunnerBase(JobActivity activity) {
        super(activity);
        this.activity.initStatus();
    }

    protected void setLlmId(String llmId) {
        this.llmId = llmId;
    }

    @Override
    public void init() throws Exception {
        this.authCtx = this.authCtxService.getAuthCtx();
        this.recipe = ((RecipeRunnableSubgraph)this.activity.getSubgraph()).getRecipe();
        this.outputs = Maps.newHashMap();
        new PreRunSchemaPropagationHandler(this.activity, this.recipe).propagateIfNeeded();
        LicenseStatusService.LicensingStatus ls = this.licenseStatusService.getLicensingStatus();
        AbstractLicenseFeaturesStatusBuilder.LicenseFeaturesStatus featuresStatus = LicenseFeaturesStatusBuilder.getFeaturesStatus(ls);
        this.advancedLLMMeshLicensed = featuresStatus.advancedLLMMeshAllowed;
        this.initOutputs();
    }

    protected StringTransmogrifier getOutputColTransmogrifier() throws IOException {
        StringTransmogrifier transmogrifier = new StringTransmogrifier("_");
        List inputColumns = ((SerializedDataset)this.datasetsDAO.getMandatoryUnsafe((AnyLoc)this.getMainInputLoc())).getSchema().columns;
        for (SchemaColumn inputColumn : inputColumns) {
            transmogrifier.addAlreadyTransmogrified(inputColumn.getName());
        }
        return transmogrifier;
    }

    public AnyLoc getMainInputLoc() {
        for (Map.Entry<String, SerializedRecipe.InputRole> e : this.recipe.getModel().getInputsUnsafe().entrySet()) {
            Iterator<SerializedRecipe.RecipeInput> iterator = e.getValue().items.iterator();
            if (!iterator.hasNext()) continue;
            SerializedRecipe.RecipeInput recipeInput = iterator.next();
            return recipeInput.getLoc(this.recipe.getProjectKey());
        }
        return null;
    }

    @Override
    public void notifyBeforeAborting() {
    }

    protected CompletionRecipeLLMMeshClient buildCompletionRecipeClient(GuardrailsPipelineSettings usageTimeGuardrails) throws Exception {
        String projectKey = this.recipe.getProjectKey();
        AnyLoc usedDataset = this.getMainInputLoc();
        LLMStructuredRef llmRef = LLMStructuredRef.decodeId(this.llmId);
        Params p = AbstractSQLConnection.CustomDatabaseProperty.toParams(this.recipe.getModel().dkuProperties);
        boolean llmBackendCentralizationEnabled = p.getBoolParam("dku.llm.recipe.backendCentralization", DKUApp.getParams().getBoolParam("dku.llm.recipes.backendCentralization", true));
        if (llmBackendCentralizationEnabled) {
            return new RedirectCompletionRecipeLLMMeshClient(this.authCtx, llmRef, projectKey, usedDataset, JobContext.getCurrentJobContext(), usageTimeGuardrails);
        }
        logger.info((Object)"LLM backend centralization bypassed");
        GuardrailsPipelineSettings connectionGuardrailsPipelineSettings = GuardrailsPipelineUtils.getConnectionAndLLMLevelSettings(this.authCtx, projectKey, llmRef);
        GuardrailsPipelineSettings guardrailsPipelineSettings = GuardrailsPipelineUtils.mergeEnforcementSettings(connectionGuardrailsPipelineSettings, usageTimeGuardrails);
        return new ParallelLLMClient(this.authCtx, llmRef, guardrailsPipelineSettings, projectKey, usedDataset, Integer.MAX_VALUE);
    }

    protected void handleCRU(CompletionRecipeLLMMeshClient plc) {
        ComputeResourceUsage totalCRU = plc.getTotalCRU(ComputeResourceUsage.LLMUsageType.COMPLETION);
        if (totalCRU != null) {
            this.cruReportingService.reportComplete(totalCRU);
        }
    }
}

