/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.openai;

import com.dataiku.common.rpc.ExternalJSONAPIClient;
import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.llm.LLMSavedModelInfo;
import com.dataiku.dip.analysis.ml.llm.LLMStepwiseTrainingMetrics;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.online.DatasetRowJsonifier;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.RemoteFineTuningClient;
import com.dataiku.dip.llm.online.openai.RawOpenAIClient;
import com.dataiku.dip.recipes.nlp.finetuning.FineTuningRecipePayloadParams;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Optional;

public abstract class GenericOpenAIFineTuningClient
implements RemoteFineTuningClient {
    LLMClient llmClient;
    private static DKULogger logger = DKULogger.getLogger((String)"dku.llm.openai");

    public GenericOpenAIFineTuningClient(LLMClient llmClient) {
        this.llmClient = llmClient;
    }

    protected abstract RawOpenAIClient getRawClient();

    @Override
    public String uploadFile(File f, String runId, boolean isTrainingFile) throws Exception {
        String fileId = this.getRawClient().uploadFile(f, isTrainingFile ? "training.jsonl" : "validation.jsonl", "fine-tune");
        while (Arrays.asList("uploaded", "pending").contains(this.getFileStatus(fileId))) {
            logger.info((Object)("File " + fileId + " is still being processed by OpenAI, retrying in 5 seconds"));
            Thread.sleep(5000L);
        }
        return fileId;
    }

    @Override
    public void deleteFile(String fileId) throws Exception {
        this.getRawClient().deleteFile(fileId);
    }

    private String getFileStatus(String fileId) throws IOException {
        return this.getRawClient().getFileStatus(fileId);
    }

    @Override
    public String startFineTuning(String baseModel, String trainingFileURI, Optional<String> validationFileURI, String runId, FineTuningRecipePayloadParams.FineTuningHyperparameters hyperparameters, boolean useDefaults) throws Exception {
        return this.getRawClient().fineTuneStart(trainingFileURI, validationFileURI, baseModel, hyperparameters, useDefaults);
    }

    @Override
    public void waitForJobCompletion(String jobId) throws Exception {
        RawOpenAIClient rawClient = this.getRawClient();
        while (true) {
            Thread.sleep(30000L);
            try {
                JsonObject ftGetResp = rawClient.fineTuneGet(jobId);
                String ftJobStatus = ftGetResp.get("status").getAsString();
                logger.info((Object)("fine tune job status: " + ftJobStatus));
                if (!"succeeded".equals(ftJobStatus)) {
                    if ("failed".equals(ftJobStatus)) {
                        throw new RemoteFineTuningClient.FinetuningFailedException("fine-tuning failed: " + ftGetResp.get("error").getAsJsonObject().get("message").getAsString());
                    }
                    if (!"cancelled".equals(ftJobStatus)) continue;
                    throw new RemoteFineTuningClient.FinetuningFailedException("fine-tuning cancelled.");
                }
                logger.info((Object)"Done !");
            }
            catch (ExternalJSONAPIClient.JSONAPIClientException e) {
                throw e;
            }
            catch (IOException e) {
                if (!e.getMessage().contains("Connection reset")) continue;
                logger.warn((Object)"Connection reset, retrying");
                continue;
            }
            break;
        }
    }

    @Override
    public JsonObject getJob(String jobId) throws Exception {
        return this.getRawClient().fineTuneGet(jobId);
    }

    @Override
    public void cancelJob(String jobId) throws Exception {
        if (jobId != null) {
            this.getRawClient().fineTuneCancel(jobId);
        }
    }

    @Override
    public void deleteFinetunedModel(String modelId, String jobId) throws Exception {
        if (jobId != null) {
            this.getRawClient().deleteFinetunedModel(modelId, jobId);
        }
    }

    protected void enrichSavedModel(String jobId, EnrichedLLMStructuredRef enrichedLLMRef, JsonObject fineTuningResponse, LLMSavedModelInfo llmSmi, FullModelId fmi) throws Exception {
        String finalModelId = fineTuningResponse.get("fine_tuned_model").getAsString();
        try {
            llmSmi.remoteModelId = this.getRawClient().getBestCheckpointModel(jobId, finalModelId);
            if (llmSmi.remoteModelId.equals(finalModelId)) {
                logger.info((Object)"There is no better checkpoint version, defaulting to the final model");
            } else {
                logger.info((Object)("Chose checkpoint model " + llmSmi.remoteModelId + " instead of final model " + finalModelId + " because it has a better train loss"));
            }
        }
        catch (Exception e) {
            llmSmi.remoteModelId = finalModelId;
            logger.warn((Object)"Failed to get a better checkpoint version, defaulting to the final model");
        }
        JsonObject hyperparameters = fineTuningResponse.get("hyperparameters").getAsJsonObject();
        LLMStepwiseTrainingMetrics llmStepwiseTrainingMetrics = new LLMStepwiseTrainingMetrics();
        llmSmi.nbEpochs = hyperparameters.get("n_epochs").getAsInt();
        llmSmi.batchSize = hyperparameters.get("batch_size").getAsInt();
        llmSmi.learningRateMultiplier = Float.valueOf(hyperparameters.get("learning_rate_multiplier").getAsFloat());
        this.getRawClient().fillLLMStepwiseTrainingMetrics(jobId, llmStepwiseTrainingMetrics, llmSmi);
        JSON.prettyToFile((Object)llmStepwiseTrainingMetrics, (File)fmi.getLLMStepwiseTrainingMetricsFile());
    }

    private JsonObject createMessage(String role, String content) {
        JsonObject message = new JsonObject();
        message.addProperty("role", role);
        message.addProperty("content", content);
        return message;
    }

    @Override
    public DatasetRowJsonifier getJsonifyFunction() {
        if (!this.isChatModel()) {
            return this.getJsonifyForPromptModelFunction();
        }
        return (prompt, completion, systemMessage) -> {
            JsonObject res = new JsonObject();
            JsonArray messagesArray = new JsonArray();
            res.add("messages", (JsonElement)messagesArray);
            if (systemMessage.isPresent()) {
                messagesArray.add((JsonElement)this.createMessage("system", (String)systemMessage.get()));
            }
            messagesArray.add((JsonElement)this.createMessage("user", prompt));
            messagesArray.add((JsonElement)this.createMessage("assistant", completion));
            return res;
        };
    }
}

