/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.recipes.nlp.common;

import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.Schema;
import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.coremodel.SerializedDataset;
import com.dataiku.dip.coremodel.SerializedRecipe;
import com.dataiku.dip.dao.DatasetsDAO;
import com.dataiku.dip.dataflow.JobActivity;
import com.dataiku.dip.dataflow.graph.FlowDataset;
import com.dataiku.dip.dataflow.graph.FlowRecipe;
import com.dataiku.dip.datalayer.Column;
import com.dataiku.dip.datalayer.ColumnFactory;
import com.dataiku.dip.datalayer.ProcessorOutput;
import com.dataiku.dip.datalayer.Row;
import com.dataiku.dip.datalayer.RowFactory;
import com.dataiku.dip.datasets.DatasetSelection;
import com.dataiku.dip.datasets.StreamableDatasetSelection;
import com.dataiku.dip.datasets.UniversalSingleThreadPusher;
import com.dataiku.dip.llm.online.CompletionRecipeLLMMeshClient;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.ParallelLLMClient;
import com.dataiku.dip.partitioning.Partition;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.utils.DKULogger;
import com.google.common.collect.Lists;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.springframework.beans.factory.annotation.Autowired;

public abstract class NLPRecipeParallelRunInputFeedThread
extends Thread {
    private final Map<String, Column> inputColumns = new HashMap<String, Column>();
    private Dataset dataset;
    private StreamableDatasetSelection ds;
    protected final ColumnFactory cf;
    protected final RowFactory rf;
    private final AuthCtx authCtx;
    private final CompletionRecipeLLMMeshClient.CompletionsStreamer plcStream;
    @Autowired
    private DatasetsDAO datasetsDAO;
    private Throwable exception;
    private int submittedRows = 0;
    private static DKULogger logger = DKULogger.getLogger((String)"dku.recipes.nlp.inputfeeder");

    public NLPRecipeParallelRunInputFeedThread(AuthCtx authCtx, FlowRecipe recipe, JobActivity activity, CompletionRecipeLLMMeshClient.CompletionsStreamer plcStream, ColumnFactory cf, RowFactory rf) throws IOException {
        this.authCtx = authCtx;
        this.cf = cf;
        this.rf = rf;
        this.plcStream = plcStream;
        SpringUtils.getInstance().autowire((Object)this);
        for (SerializedRecipe.RecipeInput recipeInput : recipe.getModel().getInputsForRole("main")) {
            this.dataset = Dataset.fromSerialized((SerializedDataset)this.datasetsDAO.getMandatory(recipeInput.getLoc(recipe.getProjectKey())));
            FlowDataset inputFD = activity.getSubgraph().getSourceDataset(this.dataset.getFullName());
            Schema schema = this.dataset.getSchema();
            if (schema == null) continue;
            this.ds = StreamableDatasetSelection.full();
            if (!this.dataset.getPartitioningSchema().isPartitioned()) break;
            this.ds.partitionSelectionMethod = DatasetSelection.PartitionSelectionMethod.SELECTED;
            this.ds.selectedPartitions = Lists.newArrayList();
            for (Partition p : activity.getSubgraph().getSourcePartitions(inputFD)) {
                this.ds.selectedPartitions.add(p.id());
            }
        }
        Schema schema = new Schema(this.dataset.getSchema());
        for (SchemaColumn sc : schema.getColumns()) {
            this.inputColumns.put(sc.getName(), cf.column(sc.getName()));
        }
    }

    public abstract LLMClient.SingleCompletionQuery buildCompletionQuery(Row var1);

    @Override
    public void run() {
        logger.info((Object)"Start input feed thread");
        try {
            UniversalSingleThreadPusher.push(this.authCtx, this.dataset, this.ds, new ProcessorOutput(){

                public void emitRow(Row row) throws Exception {
                    LLMClient.SingleCompletionQuery query = NLPRecipeParallelRunInputFeedThread.this.buildCompletionQuery(row);
                    ParallelLLMClient.SingleCompletionQueryWithTrace queryWithTrace = new ParallelLLMClient.SingleCompletionQueryWithTrace(query);
                    HashMap<String, String> rowContext = new HashMap<String, String>();
                    for (Column inputColumn : NLPRecipeParallelRunInputFeedThread.this.inputColumns.values()) {
                        rowContext.put(inputColumn.getName(), row.get(inputColumn));
                    }
                    NLPRecipeParallelRunInputFeedThread.this.plcStream.submit(queryWithTrace, rowContext);
                    if (++NLPRecipeParallelRunInputFeedThread.this.submittedRows % 100 == 0) {
                        logger.info((Object)("Input Feed Thread sent " + NLPRecipeParallelRunInputFeedThread.this.submittedRows + " records to PLCS"));
                    }
                }

                public void lastRowEmitted() {
                }

                public void cancel() {
                }

                public void setMaxMemoryUsed(long size) {
                }
            }, this.cf, this.rf);
            logger.info((Object)"Done USTP");
            this.plcStream.done().join();
        }
        catch (Throwable e) {
            logger.error((Object)"Input feeding failed", e);
            this.exception = e;
            this.plcStream.done().join();
        }
    }

    public Throwable getException() {
        return this.exception;
    }
}

