/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.llm.LLMSavedModelInfo;
import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.datalayer.Column;
import com.dataiku.dip.datalayer.ColumnFactory;
import com.dataiku.dip.datalayer.ProcessorOutput;
import com.dataiku.dip.datalayer.Row;
import com.dataiku.dip.datalayer.RowFactory;
import com.dataiku.dip.datalayer.streamimpl.StreamColumnFactory;
import com.dataiku.dip.datalayer.streamimpl.StreamRowFactory;
import com.dataiku.dip.datasets.UniversalSingleThreadPusher;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.online.DatasetRowJsonifier;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMClientFactory;
import com.dataiku.dip.llm.online.RemoteFineTuningClient;
import com.dataiku.dip.recipes.nlp.finetuning.FineTuningRecipePayloadParams;
import com.dataiku.dip.recipes.nlp.finetuning.FineTuningRunnerInterface;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.utils.DKUFileUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.Pair;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.io.File;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.Optional;

public class RemoteFineTuningRunner
implements FineTuningRunnerInterface {
    private final AuthCtx authCtx;
    private final Dataset trainingDataset;
    private final Optional<Dataset> validationDataset;
    private final FineTuningRecipePayloadParams desc;
    private final LLMStructuredRef llmRef;
    private final ModelTrainInfo mti;
    private final FullModelId outputFMI;
    private final SavedModel savedModel;
    private final File tmpDir;
    private String remoteJobId;
    private boolean remoteJobRunning = false;
    private RemoteFineTuningClient ftClient;
    private static DKULogger logger = DKULogger.getLogger((String)"dku.llm.finetuning");

    public RemoteFineTuningRunner(AuthCtx authCtx, Dataset trainingDataset, Optional<Dataset> validationDataset, FineTuningRecipePayloadParams desc, LLMStructuredRef llmRef, ModelTrainInfo mti, FullModelId outputFMI, SavedModel savedModel, File tmpDir) {
        SpringUtils.getInstance().autowire((Object)this);
        this.authCtx = authCtx;
        this.trainingDataset = trainingDataset;
        this.validationDataset = validationDataset;
        this.desc = desc;
        this.llmRef = llmRef;
        this.mti = mti;
        this.outputFMI = outputFMI;
        this.savedModel = savedModel;
        this.tmpDir = tmpDir;
    }

    @Override
    public LLMSavedModelInfo runFineTuning() throws Exception {
        String runId = new SimpleDateFormat("yyyy-MM-dd-HH-mm-ss").format(new Date());
        LLMClient client = this.getLLMClient();
        EnrichedLLMStructuredRef enrichedLLMRef = client.getEnrichedRef();
        this.ftClient = client.newFineTuningClient();
        logger.info((Object)"Starting fine-tuning");
        logger.info((Object)("runId: " + runId));
        logger.info((Object)("llmId: " + this.desc.llmId));
        logger.info((Object)("llmRef: " + JSON.json((Object)enrichedLLMRef)));
        Pair<File, Long> fileHolder = this.getTrainingJSONL(runId, this.ftClient.getJsonifyFunction());
        this.mti.trainRows = (Long)fileHolder.second;
        logger.info((Object)("Built fine-tuning training file: " + String.valueOf(fileHolder.first)));
        logger.info((Object)"Validating fine-tuning training file");
        this.ftClient.assertTrainingFileIsValid((File)fileHolder.first);
        logger.info((Object)"Training file is valid");
        logger.info((Object)"Uploading fine-tuning training file");
        String remoteTrainingFileURI = this.ftClient.uploadFile((File)fileHolder.first, runId, true);
        logger.info((Object)("File uploaded: " + remoteTrainingFileURI));
        DKUFileUtils.delete((File)((File)fileHolder.first));
        Optional<String> remoteValidationFileURI = Optional.empty();
        if (this.validationDataset.isPresent()) {
            fileHolder = this.getValidationJSONL(runId, this.ftClient.getJsonifyFunction());
            this.mti.testRows = (Long)fileHolder.second;
            logger.info((Object)("Built fine-tuning validation file: " + String.valueOf(fileHolder)));
            logger.info((Object)"Uploading fine-tuning validation file");
            remoteValidationFileURI = Optional.of(this.ftClient.uploadFile((File)fileHolder.first, runId, false));
            logger.info((Object)("File uploaded: " + String.valueOf(remoteValidationFileURI)));
            DKUFileUtils.delete((File)((File)fileHolder.first));
        }
        String baseModel = this.ftClient.getBaseModel();
        logger.info((Object)("Starting a fine-tuning job based on " + baseModel));
        this.remoteJobRunning = true;
        this.remoteJobId = this.ftClient.startFineTuning(baseModel, remoteTrainingFileURI, remoteValidationFileURI, runId, this.desc.hyperparameters, this.desc.hyperparameters.useDefaults);
        logger.info((Object)("Remote fine-tuning job started: " + this.remoteJobId + ", waiting for completion"));
        this.ftClient.waitForJobCompletion(this.remoteJobId);
        logger.info((Object)"Fine-tuning job completed");
        this.remoteJobRunning = false;
        logger.info((Object)"Storing new saved model info");
        LLMSavedModelInfo llmSmi = new LLMSavedModelInfo();
        llmSmi.handlesSystemMessage = this.ftClient.isChatModel();
        llmSmi.inputLLMName = enrichedLLMRef.friendlyNameShort;
        llmSmi.promptColumn = this.desc.promptColumn;
        llmSmi.completionColumn = this.desc.completionColumn;
        llmSmi.embeddingSize = enrichedLLMRef.embeddingSize;
        llmSmi.maxTokensLimit = enrichedLLMRef.maxTokensLimit;
        llmSmi.validationDataset = this.validationDataset.map(ds -> ds.getName()).orElse(null);
        llmSmi.originalLLMId = enrichedLLMRef.baseModelId == null ? this.desc.llmId : enrichedLLMRef.baseModelId;
        llmSmi.originalLLMIsUnreferenced = enrichedLLMRef.originalLLMIsUnreferenced;
        logger.info((Object)"Adding post fine-tuning metadata");
        llmSmi.remoteJobId = this.remoteJobId;
        JsonObject remoteJob = this.ftClient.getJob(llmSmi.remoteJobId);
        this.ftClient.enrichSavedModelPostFinetuning(llmSmi.remoteJobId, enrichedLLMRef, remoteJob, llmSmi, this.outputFMI);
        logger.info((Object)("Deleting remote fine-tuning training file: " + remoteTrainingFileURI));
        this.ftClient.deleteFile(remoteTrainingFileURI);
        if (remoteValidationFileURI.isPresent()) {
            logger.info((Object)("Deleting remote fine-tuning validation file: " + remoteValidationFileURI.get()));
            this.ftClient.deleteFile(remoteValidationFileURI.get());
        }
        return llmSmi;
    }

    @Override
    public LLMClient getLLMClient() throws Exception {
        LLMClient client = LLMClientFactory.get(this.authCtx, this.savedModel.projectKey, this.llmRef);
        return client;
    }

    @Override
    public void cancelFineTuning() throws Exception {
        if (this.remoteJobId != null && this.ftClient != null && this.remoteJobRunning) {
            logger.info((Object)("Canceling fine-tuning job: " + this.remoteJobId));
            this.ftClient.cancelJob(this.remoteJobId);
        }
    }

    private static JsonObject jsonifyForChatModel(String prompt, String completion, String systemMessage) {
        JsonObject res = new JsonObject();
        JsonArray messagesArray = new JsonArray();
        res.add("messages", (JsonElement)messagesArray);
        if (systemMessage != null) {
            JsonObject systemMessageObject = new JsonObject();
            systemMessageObject.addProperty("role", "system");
            systemMessageObject.addProperty("content", systemMessage);
            messagesArray.add((JsonElement)systemMessageObject);
        }
        JsonObject userMessage = new JsonObject();
        userMessage.addProperty("role", "user");
        userMessage.addProperty("content", prompt);
        messagesArray.add((JsonElement)userMessage);
        JsonObject assistantMessage = new JsonObject();
        assistantMessage.addProperty("role", "assistant");
        assistantMessage.addProperty("content", completion);
        messagesArray.add((JsonElement)assistantMessage);
        return res;
    }

    private static JsonObject jsonifyForPromptModel(String prompt, String completion) {
        JsonObject res = new JsonObject();
        res.addProperty("prompt", prompt);
        res.addProperty("completion", completion);
        return res;
    }

    private Pair<String, Long> getJSONLFromDataset(Dataset dataset, final DatasetRowJsonifier processFunc) throws Exception {
        final StringBuilder sb = new StringBuilder();
        StreamColumnFactory scf = new StreamColumnFactory();
        StreamRowFactory srf = new StreamRowFactory();
        final Column promptCD = scf.column(this.desc.promptColumn);
        final Column completionCD = scf.column(this.desc.completionColumn);
        final Column systemMessageCD = this.desc.systemMessageMode == FineTuningRecipePayloadParams.SystemMessageMode.DYNAMIC ? scf.column(this.desc.systemMessageColumn) : null;
        UniversalSingleThreadPusher threadPush = new UniversalSingleThreadPusher(this.authCtx, dataset, new ProcessorOutput(){

            public void setMaxMemoryUsed(long size) {
            }

            public void lastRowEmitted() throws Exception {
            }

            public void emitRow(Row row) throws Exception {
                String prompt = row.get(promptCD);
                String completion = row.get(completionCD);
                if (prompt == null) {
                    throw new IllegalArgumentException("Prompt column `" + promptCD.getName() + "` is empty in dataset row: " + String.valueOf(row));
                }
                if (completion == null) {
                    throw new IllegalArgumentException("Completion column `" + completionCD.getName() + "` is empty in dataset row: " + String.valueOf(row));
                }
                String systemMessageforRow = null;
                if (RemoteFineTuningRunner.this.desc.systemMessageMode == FineTuningRecipePayloadParams.SystemMessageMode.DYNAMIC) {
                    systemMessageforRow = row.get(systemMessageCD);
                } else if (RemoteFineTuningRunner.this.desc.systemMessageMode == FineTuningRecipePayloadParams.SystemMessageMode.STATIC) {
                    systemMessageforRow = RemoteFineTuningRunner.this.desc.systemMessage;
                }
                JsonObject jo = processFunc.apply(prompt, completion, Optional.ofNullable(systemMessageforRow));
                sb.append(JSON.json((Object)jo));
                sb.append("\n");
            }

            public void cancel() throws Exception {
            }
        }, (ColumnFactory)scf, (RowFactory)srf, true);
        threadPush.push();
        return new Pair((Object)sb.toString(), (Object)threadPush.getProcessedRowCount());
    }

    private Pair<File, Long> getTrainingJSONL(String runId, DatasetRowJsonifier rowFunc) throws Exception {
        File f = new File(this.tmpDir, "training-" + runId + ".jsonl");
        Pair<String, Long> res = this.getJSONLFromDataset(this.trainingDataset, rowFunc);
        DKUFileUtils.writeFileUTF8((File)f, (String)((String)res.first));
        return new Pair((Object)f, res.second);
    }

    private Pair<File, Long> getValidationJSONL(String runId, DatasetRowJsonifier rowFunc) throws Exception {
        if (this.validationDataset.isEmpty()) {
            throw new IllegalArgumentException("No validation file was provided");
        }
        File f = new File(this.tmpDir, "validation-" + runId + ".jsonl");
        Pair<String, Long> res = this.getJSONLFromDataset(this.validationDataset.get(), rowFunc);
        DKUFileUtils.writeFileUTF8((File)f, (String)((String)res.first));
        return new Pair((Object)f, res.second);
    }
}

