/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.huggingface;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.hf.IModelCacheDownloadService;
import com.dataiku.dip.analysis.ml.llm.LLMSavedModelInfo;
import com.dataiku.dip.code.CodeEnvModel;
import com.dataiku.dip.code.CodeEnvResolutionService;
import com.dataiku.dip.connections.HuggingFaceLocalConnection;
import com.dataiku.dip.containers.exec.ContainerExecRuntimeConfig;
import com.dataiku.dip.containers.exec.KubernetesExecUtils;
import com.dataiku.dip.coremodel.InfoMessage;
import com.dataiku.dip.dataflow.JobActivity;
import com.dataiku.dip.dataflow.exec.AbstractPythonRecipeRunner;
import com.dataiku.dip.dataflow.utils.FlowJobUtils;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMClientFactory;
import com.dataiku.dip.recipes.consistency.RecipeCodes;
import com.dataiku.dip.recipes.nlp.finetuning.FineTuningRecipePayloadParams;
import com.dataiku.dip.recipes.nlp.finetuning.FineTuningRunnerInterface;
import com.dataiku.dip.remoterun.RemoteRunsRegistry;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.impersonation.FilesystemACLUtils;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.google.gson.JsonObject;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
import java.util.Optional;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;

public class HuggingFaceFineTuningRunner
extends AbstractPythonRecipeRunner
implements FineTuningRunnerInterface {
    @Autowired
    private CodeEnvResolutionService codeEnvResolutionService;
    private final AuthCtx authCtx;
    private final LLMStructuredRef llmRef;
    private final FineTuningRecipePayloadParams desc;
    protected String baseRecipeJobDir;
    private final File outModelFolder;
    private final String inputDatasetName;
    private final Optional<String> validationDatasetName;
    private final ContainerExecRuntimeConfig containerConfig;
    private final FullModelId fmi;
    private final Optional<String> inputHuggingfaceModelId;
    private final Optional<FullModelId> inputSavedmodelAdaptFmi;
    private final String originalLLMId;
    private final boolean originalLLMIsUnreferenced;
    private final HuggingFaceLocalConnection.HuggingFaceHandlingMode handlingMode;
    protected static final DKULogger logger = DKULogger.getLogger((String)"dku.recipes.llm.huggingface.finetune");

    public HuggingFaceFineTuningRunner(AuthCtx authCtx, LLMStructuredRef llmRef, JobActivity activity, FineTuningRecipePayloadParams desc, File outModelFolder, String inputDatasetName, Optional<String> validationDatasetName, ContainerExecRuntimeConfig containerConfig, FullModelId fmi, Optional<String> inputHuggingfaceModelId, Optional<FullModelId> inputSavedmodelAdaptFmi, String originalLLMId, boolean originalLLMIsUnreferenced, HuggingFaceLocalConnection.HuggingFaceHandlingMode handlingMode) {
        super(activity);
        this.authCtx = authCtx;
        this.llmRef = llmRef;
        this.desc = desc;
        this.outModelFolder = outModelFolder;
        this.inputDatasetName = inputDatasetName;
        this.validationDatasetName = validationDatasetName;
        this.containerConfig = containerConfig;
        this.fmi = fmi;
        this.inputHuggingfaceModelId = inputHuggingfaceModelId;
        this.inputSavedmodelAdaptFmi = inputSavedmodelAdaptFmi;
        this.originalLLMId = originalLLMId;
        this.originalLLMIsUnreferenced = originalLLMIsUnreferenced;
        this.handlingMode = handlingMode;
        this.baseRecipeJobDir = "fine-tuning-recipe";
    }

    @Override
    public void init() throws Exception {
    }

    /*
     * Unable to fully structure code
     */
    @Override
    public LLMSavedModelInfo runFineTuning() throws Exception {
        block19: {
            hfConnection = (HuggingFaceLocalConnection)this.getLLMClient().getConnection();
            envName = hfConnection.params.getCodeEnvName();
            this.codeEnvResolutionService.checkEnvExists(CodeEnvModel.EnvLang.PYTHON, envName);
            baseModelInCache = null;
            if (hfConnection.params.useDSSModelCache) {
                if (this.inputHuggingfaceModelId.isEmpty()) {
                    HuggingFaceFineTuningRunner.logger.debug((Object)"DSS model cache can't be used to fine-tune already fine-tuned models");
                } else {
                    originalLLMRef = LLMStructuredRef.decodeId(this.originalLLMId);
                    baseModelInCache = this.inputHuggingfaceModelId.get();
                    modelCacheService = (IModelCacheDownloadService)SpringUtils.getBean(IModelCacheDownloadService.class);
                    modelCacheService.downloadModelIfNeeded_Check(baseModelInCache, this.authCtx, hfConnection.name);
                    HuggingFaceFineTuningRunner.logger.info((Object)("Fine-tuning will use the base model from DSS model cache at " + modelCacheService.getLocalModelPath(baseModelInCache)));
                }
            }
            hfConnection.ensureDecrypted();
            hfToken = hfConnection.params.apiKey;
            envVariables = StringUtils.isBlank((String)hfToken) == false ? Map.of("HF_TOKEN", hfToken) : Collections.emptyMap();
            llmSmi = new LLMSavedModelInfo();
            llmSmi.llmType = LLMStructuredRef.LLMType.SAVED_MODEL_FINETUNED_HUGGINGFACE_TRANSFORMER;
            llmSmi.quantizationMode = HuggingFaceLocalConnection.HuggingFaceLocalConnectionParams.QuantizationMode.valueOf(this.desc.hyperparameters.localHuggingFace.quantization.toString());
            llmSmi.handlesSystemMessage = true;
            llmSmi.originalLLMId = this.originalLLMId;
            llmSmi.originalLLMIsUnreferenced = this.originalLLMIsUnreferenced;
            llmSmi.huggingFaceHandlingMode = this.handlingMode;
            this.activity.setStatusMessage("Executing Python script");
            llmSmi.codeEnv = codeEnvRef = new CodeEnvModel.UsedCodeEnvRef(CodeEnvModel.EnvLang.PYTHON, envName);
            HuggingFaceFineTuningRunner.logger.info((Object)("Run training in code env " + StringUtils.defaultIfBlank((String)envName, (String)"built-in") + " (set at deploy-time)"));
            additionalLogsDir = FlowJobUtils.getJobMadeDir("training-recipe", "additional-logs");
            mainLogFile = FlowJobUtils.getJobTouchedFile("training-recipe", "python.log");
            payload = new JsonObject();
            payload.addProperty("inputDatasetName", this.inputDatasetName);
            if (this.validationDatasetName.isPresent()) {
                payload.addProperty("validationDatasetName", this.validationDatasetName.get());
            }
            if (this.inputHuggingfaceModelId.isPresent()) {
                payload.addProperty("inputLlmHfId", this.inputHuggingfaceModelId.get());
            }
            if (this.inputSavedmodelAdaptFmi.isPresent()) {
                payload.addProperty("inputSavedModelFolderContext", this.inputSavedmodelAdaptFmi.get().getModelFolder().getAbsolutePath());
            }
            if (baseModelInCache != null) {
                payload.addProperty("baseModelInCache", baseModelInCache);
            }
            HuggingFaceFineTuningRunner.logger.info((Object)("Run recipe in code env " + StringUtils.defaultIfBlank((String)envName, (String)"built-in")));
            outputTmpDir = FlowJobUtils.getTmpFolder("fine-tuning-recipe", "pyrun");
            try {
                JSON.prettyToFile((Object)this.desc, (File)new File((File)outputTmpDir, "desc.json"));
                writablePaths = Arrays.asList(new String[]{this.outModelFolder.getAbsolutePath(), this.outModelFolder.getParentFile().getAbsolutePath()});
                readablePaths = new ArrayList<String>(writablePaths);
                readablePaths.addAll(Arrays.asList(new String[]{outputTmpDir.getAbsolutePath()}));
                if (this.inputSavedmodelAdaptFmi.isPresent()) {
                    FilesystemACLUtils.grantFSReadACLs(this.authCtx, this.projectKey, new File[]{this.inputSavedmodelAdaptFmi.get().getFolderEnsuringSecurity()});
                    readablePaths.addAll(Arrays.asList(new String[]{this.inputSavedmodelAdaptFmi.get().getModelFolder().getAbsolutePath()}));
                }
                if (this.containerConfig == null) {
                    arguments = new String[]{this.outModelFolder.getAbsolutePath(), this.inputDatasetName, new File((File)outputTmpDir, "desc.json").getAbsolutePath(), this.validationDatasetName.isPresent() != false ? this.validationDatasetName.get() : "", this.inputHuggingfaceModelId.isPresent() != false ? this.inputHuggingfaceModelId.get() : "", this.inputSavedmodelAdaptFmi.isPresent() != false ? this.inputSavedmodelAdaptFmi.get().getModelFolder().getAbsolutePath() : "", baseModelInCache != null ? baseModelInCache : ""};
                    this.executeModule(envName, (File)outputTmpDir, "dataiku.huggingface.fine_tuning.fine_tuning_recipe", true, envVariables, arguments);
                    break block19;
                }
                switch (2.$SwitchMap$com$dataiku$dip$containers$exec$ContainerExecRuntimeConfig$Container[this.containerConfig.type.ordinal()]) {
                    case 1: {
                        this.executeDockerCodeRecipe(codeEnvRef, this.containerConfig, this.outModelFolder, mainLogFile, outputTmpDir, RemoteRunsRegistry.ExecutionType.RECIPE_FINE_TUNING_LLM, payload.toString(), envVariables, readablePaths, writablePaths);
                        ** break;
lbl59:
                        // 1 sources

                        break;
                    }
                    case 2: {
                        this.executeKubernetesCodeRecipe(codeEnvRef, this.containerConfig, this.outModelFolder, mainLogFile, additionalLogsDir, outputTmpDir, RemoteRunsRegistry.ExecutionType.RECIPE_FINE_TUNING_LLM, payload.toString(), envVariables, readablePaths, writablePaths, new KubernetesExecUtils.KubernetesFailureCodeProvider(){

                            @Override
                            public InfoMessage.MessageCode codeForOOMKilled() {
                                return RecipeCodes.ERR_RECIPE_ML_TRAINING_K8S_OOM;
                            }
                        });
                        ** break;
lbl63:
                        // 1 sources

                        break;
                    }
                    default: {
                        throw new UnsupportedOperationException("Unknown execution container: " + String.valueOf((Object)this.containerConfig.type));
                    }
                }
            }
            finally {
                if (outputTmpDir != null) {
                    outputTmpDir.close();
                }
            }
        }
        localLLMInfo = this.fmi.parseModelFile("llm_info.json", LLMSavedModelInfo.class);
        llmSmi.quantizationMode = localLLMInfo.quantizationMode;
        llmSmi.checkpointMode = localLLMInfo.checkpointMode;
        llmSmi.loraRank = localLLMInfo.loraRank;
        llmSmi.loraAlpha = localLLMInfo.loraAlpha;
        llmSmi.loraDropout = localLLMInfo.loraDropout;
        llmSmi.neftuneNoiseAlpha = localLLMInfo.neftuneNoiseAlpha;
        llmSmi.initialLearningRate = localLLMInfo.initialLearningRate;
        llmSmi.nbEpochs = localLLMInfo.nbEpochs;
        llmSmi.batchSize = localLLMInfo.batchSize;
        llmSmi.totalSteps = localLLMInfo.totalSteps;
        return llmSmi;
    }

    @Override
    public LLMClient getLLMClient() throws Exception {
        LLMClient client = LLMClientFactory.get(this.authCtx, this.projectKey, this.llmRef);
        return client;
    }

    @Override
    public void cancelFineTuning() throws IOException {
        this.notifyBeforeAborting();
    }

    @Override
    public void run() throws Exception {
    }
}

