/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.recipes.nlp.finetuning;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.MLFlowUtils;
import com.dataiku.dip.analysis.ml.MLPaths;
import com.dataiku.dip.analysis.ml.llm.LLMSMMgmtService;
import com.dataiku.dip.analysis.ml.llm.LLMSavedModelInfo;
import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.ModelUserMeta;
import com.dataiku.dip.analysis.model.llm.LLMModelSnippetData;
import com.dataiku.dip.code.CodeEnvModel;
import com.dataiku.dip.code.CodeEnvResolutionService;
import com.dataiku.dip.connections.ConnectionsDAO;
import com.dataiku.dip.connections.HuggingFaceLocalConnection;
import com.dataiku.dip.containers.exec.ContainerExecConfigSelector;
import com.dataiku.dip.containers.exec.ContainerExecRuntimeConfig;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.SerializedDataset;
import com.dataiku.dip.coremodel.SerializedRecipe;
import com.dataiku.dip.dao.DatasetsDAO;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.dao.SavedModelsDAO;
import com.dataiku.dip.dataflow.JobActivity;
import com.dataiku.dip.dataflow.JobAuthCtxService;
import com.dataiku.dip.dataflow.RecipeRunnableSubgraph;
import com.dataiku.dip.dataflow.exec.AbortableRecipeRunner;
import com.dataiku.dip.dataflow.exec.FinalCommitable;
import com.dataiku.dip.dataflow.exec.RecipeRunnerWithPayload;
import com.dataiku.dip.dataflow.graph.FlowRecipe;
import com.dataiku.dip.dataflow.graph.FlowSavedModel;
import com.dataiku.dip.dataflow.jobrunner.JobContext;
import com.dataiku.dip.dataflow.utils.FlowJobUtils;
import com.dataiku.dip.license.LicenseRestrictionException;
import com.dataiku.dip.license.LicenseStatusService;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.online.RemoteFineTuningRunner;
import com.dataiku.dip.llm.online.huggingface.HuggingFaceFineTuningRunner;
import com.dataiku.dip.llm.savedmodels.SavedModelVersionDeploymentCRUDService;
import com.dataiku.dip.recipes.RecipeRunner;
import com.dataiku.dip.recipes.nlp.finetuning.FineTuningRecipePayloadParams;
import com.dataiku.dip.recipes.nlp.finetuning.FineTuningRunnerInterface;
import com.dataiku.dip.rpc.TicketBasedIntercomAPIClient;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.impersonation.FilesystemACLUtils;
import com.dataiku.dip.security.tickets.APITicketService;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.server.services.licensing.AbstractLicenseFeaturesStatusBuilder;
import com.dataiku.dip.server.services.licensing.LicenseFeaturesStatusBuilder;
import com.dataiku.dip.transactions.ifaces.IsolationLevel;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.util.AutoDelete;
import com.dataiku.dip.utils.DKUFileUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.google.common.collect.Sets;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Optional;
import java.util.Set;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;

public class FineTuningRecipeRunner
implements RecipeRunner,
AbortableRecipeRunner,
RecipeRunnerWithPayload,
FinalCommitable {
    @Autowired
    private JobAuthCtxService authCtxService;
    @Autowired
    private SavedModelsDAO savedModelsDAO;
    @Autowired
    private DatasetsDAO datasetsDAO;
    @Autowired
    private APITicketService ticketService;
    @Autowired
    private LicenseStatusService licenseStatusService;
    @Autowired
    private CodeEnvResolutionService codeEnvResolutionService;
    @Autowired
    private SavedModelVersionDeploymentCRUDService smvDeploymentCRUDService;
    private FineTuningRecipePayloadParams desc;
    private final JobActivity activity;
    private final FlowRecipe recipe;
    private Dataset trainingDataset;
    private Optional<Dataset> validationDataset;
    private FineTuningRunnerInterface runner;
    private SavedModel sm;
    private LLMStructuredRef llmRef;
    private FullModelId outputFMI;
    private ModelTrainInfo mti;
    private String remoteDeploymentID;
    private AuthCtx authCtx;
    private Set<LLMStructuredRef.LLMType> FINETUNABLE_LLM_TYPES = Sets.newHashSet((Object[])new LLMStructuredRef.LLMType[]{LLMStructuredRef.LLMType.OPENAI, LLMStructuredRef.LLMType.AZURE_OPENAI_MODEL, LLMStructuredRef.LLMType.SAVED_MODEL_FINETUNED_OPENAI, LLMStructuredRef.LLMType.SAVED_MODEL_FINETUNED_AZURE_OPENAI, LLMStructuredRef.LLMType.HUGGINGFACE_TRANSFORMER_LOCAL, LLMStructuredRef.LLMType.SAVED_MODEL_FINETUNED_HUGGINGFACE_TRANSFORMER, LLMStructuredRef.LLMType.BEDROCK, LLMStructuredRef.LLMType.SAVED_MODEL_FINETUNED_BEDROCK});
    private Set<LLMStructuredRef.LLMType> DEPLOYABLE_LLM_TYPES = Sets.newHashSet((Object[])new LLMStructuredRef.LLMType[]{LLMStructuredRef.LLMType.AZURE_OPENAI_MODEL, LLMStructuredRef.LLMType.SAVED_MODEL_FINETUNED_AZURE_OPENAI, LLMStructuredRef.LLMType.BEDROCK, LLMStructuredRef.LLMType.SAVED_MODEL_FINETUNED_BEDROCK});
    static DKULogger logger = DKULogger.getLogger((String)"dku.recipes.nlp.finetuning");

    protected FineTuningRecipeRunner(JobActivity activity) {
        this.activity = activity;
        this.recipe = ((RecipeRunnableSubgraph)activity.getSubgraph()).getRecipe();
    }

    @Override
    public void init() throws Exception {
        this.authCtx = this.authCtxService.getAuthCtx();
    }

    @Override
    public void setPayload(String payload) {
        this.desc = (FineTuningRecipePayloadParams)JSON.parse((String)payload, FineTuningRecipePayloadParams.class);
        this.llmRef = LLMStructuredRef.decodeId(this.desc.llmId);
    }

    private FineTuningRunnerInterface getAbstractFineTuningRunner() throws Exception {
        if (Arrays.asList(LLMStructuredRef.LLMType.OPENAI, LLMStructuredRef.LLMType.AZURE_OPENAI_MODEL, LLMStructuredRef.LLMType.BEDROCK, LLMStructuredRef.LLMType.SAVED_MODEL_FINETUNED_BEDROCK, LLMStructuredRef.LLMType.SAVED_MODEL_FINETUNED_OPENAI, LLMStructuredRef.LLMType.SAVED_MODEL_FINETUNED_AZURE_OPENAI).contains((Object)this.llmRef.type)) {
            try (AutoDelete tmpDir = FlowJobUtils.getTmpFolder("finetuning", "tmp");){
                RemoteFineTuningRunner remoteFineTuningRunner = new RemoteFineTuningRunner(this.authCtx, this.trainingDataset, this.validationDataset, this.desc, this.llmRef, this.mti, this.outputFMI, this.sm, (File)tmpDir);
                return remoteFineTuningRunner;
            }
        }
        if (Arrays.asList(LLMStructuredRef.LLMType.HUGGINGFACE_TRANSFORMER_LOCAL, LLMStructuredRef.LLMType.SAVED_MODEL_FINETUNED_HUGGINGFACE_TRANSFORMER).contains((Object)this.llmRef.type)) {
            ContainerExecRuntimeConfig containerConfig = new ContainerExecConfigSelector().select_autoTXN(this.authCtxService.getAuthCtx(), this.recipe.getProjectKey(), this.desc.containerSelection);
            HuggingFaceLocalConnection hfConnection = (HuggingFaceLocalConnection)ConnectionsDAO.get().getMandatoryConnection(this.authCtx, this.llmRef.connection);
            this.codeEnvResolutionService.checkEnvExists(CodeEnvModel.EnvLang.PYTHON, hfConnection.params.getCodeEnvName());
            hfConnection.ensureDecrypted();
            LLMStructuredRef inputLLMRef = LLMStructuredRef.decodeId(this.desc.llmId);
            HuggingFaceLocalConnection.HuggingFaceHandlingMode handlingMode = null;
            TransactionService ts = (TransactionService)SpringUtils.getBean(TransactionService.class);
            SavedModel inputSm = null;
            AnyLoc loc = null;
            if (this.llmRef.type.equals((Object)LLMStructuredRef.LLMType.SAVED_MODEL_FINETUNED_HUGGINGFACE_TRANSFORMER)) {
                try (Transaction t = ts.retrieveOrBeginRead(IsolationLevel.YOLO);){
                    loc = AnyLoc.resolveSmart(this.sm.projectKey, inputLLMRef.savedModelSmartId);
                    inputSm = (SavedModel)((SavedModelsDAO)SpringUtils.getBean(SavedModelsDAO.class)).getMandatory(loc);
                }
                LLMSMMgmtService.LLMSMVersionHeader vh = LLMSMMgmtService.getStatus_NT((SavedModel)inputSm).versions.stream().filter(v -> v.versionId.equals(inputLLMRef.savedModelVersionId)).findFirst().orElseThrow(() -> new IllegalArgumentException("SMV not found"));
                HuggingFaceLocalConnection.HFLocalModel llmModelFromSMInfo = hfConnection.getLLMModelFromSMInfo(((LLMModelSnippetData)vh.snippet).llmSMInfo);
                handlingMode = llmModelFromSMInfo.handlingMode;
            }
            HuggingFaceLocalConnection.HFLocalModel hfLocalModel = null;
            if (this.llmRef.type.equals((Object)LLMStructuredRef.LLMType.HUGGINGFACE_TRANSFORMER_LOCAL)) {
                hfLocalModel = (HuggingFaceLocalConnection.HFLocalModel)hfConnection.getLLMModel(this.llmRef).getModel();
                handlingMode = hfLocalModel.handlingMode;
            }
            Optional<String> inputHuggingfaceModelId = Optional.ofNullable(hfLocalModel).map(m -> m.huggingFaceId);
            Optional<FullModelId> inputSavedmodelAdaptFmi = this.llmRef.type.equals((Object)LLMStructuredRef.LLMType.SAVED_MODEL_FINETUNED_HUGGINGFACE_TRANSFORMER) ? Optional.of(new FullModelId(inputSm.projectKey, loc.getId(), inputLLMRef.savedModelVersionId)) : Optional.empty();
            String originalLLMId = this.llmRef.type.equals((Object)LLMStructuredRef.LLMType.HUGGINGFACE_TRANSFORMER_LOCAL) ? this.llmRef.id : new FullModelId((String)inputSm.projectKey, (String)loc.getId(), (String)inputLLMRef.savedModelVersionId).getLlmSavedModelInfo().originalLLMId;
            boolean originalLLMIsUnreferenced = this.llmRef.type.equals((Object)LLMStructuredRef.LLMType.SAVED_MODEL_FINETUNED_HUGGINGFACE_TRANSFORMER) ? new FullModelId((String)inputSm.projectKey, (String)loc.getId(), (String)inputLLMRef.savedModelVersionId).getLlmSavedModelInfo().originalLLMIsUnreferenced : false;
            return new HuggingFaceFineTuningRunner(this.authCtx, this.llmRef, this.activity, this.desc, this.outputFMI.getModelFolder(), this.trainingDataset.getFullName(), this.validationDataset.map(Dataset::getFullName), containerConfig, this.outputFMI, inputHuggingfaceModelId, inputSavedmodelAdaptFmi, originalLLMId, originalLLMIsUnreferenced, handlingMode);
        }
        throw new IllegalArgumentException("Wrong LLM type. Fine-tuning not supported for " + String.valueOf((Object)this.llmRef.type));
    }

    @Override
    public void run() throws Exception {
        LicenseStatusService.LicensingStatus ls = this.licenseStatusService.getLicensingStatus();
        AbstractLicenseFeaturesStatusBuilder.LicenseFeaturesStatus lfs = LicenseFeaturesStatusBuilder.getFeaturesStatus(ls);
        if (!lfs.advancedLLMMeshAllowed) {
            throw new LicenseRestrictionException("Fine-tuning requires the \"Advanced LLM Mesh\" add-on");
        }
        String newVersionId = "" + System.currentTimeMillis();
        RecipeRunnableSubgraph subgraph = (RecipeRunnableSubgraph)this.activity.getSubgraph();
        FlowSavedModel fsm = (FlowSavedModel)subgraph.getTargets().get(0);
        this.activity.getTargetStatus((String)fsm.getFullId()).modelVersionId = newVersionId;
        logger.info((Object)"Fine-tuning recipe runner started");
        if (StringUtils.isBlank((CharSequence)this.desc.llmId)) {
            throw new IllegalArgumentException("No LLM was specified");
        }
        LLMStructuredRef llmRef = LLMStructuredRef.decodeId(this.desc.llmId);
        if (!this.FINETUNABLE_LLM_TYPES.contains((Object)llmRef.type)) {
            throw new IllegalArgumentException("Fine-tuning not supported. The fine-tuning recipe supports Azure Open AI, Open AI, Bedrock and Local HuggingFace models.");
        }
        if (StringUtils.isBlank((CharSequence)this.desc.promptColumn)) {
            throw new IllegalArgumentException("No prompt column was specified");
        }
        if (StringUtils.isBlank((CharSequence)this.desc.completionColumn)) {
            throw new IllegalArgumentException("No completion column was specified");
        }
        SerializedRecipe.RecipeOutput ro = (SerializedRecipe.RecipeOutput)this.recipe.getModel().getOutputsForRole("finetuned_model").stream().findFirst().orElseThrow(() -> new IllegalArgumentException("model output not found"));
        this.sm = (SavedModel)this.savedModelsDAO.getMandatory(ro.getLoc(this.recipe.getProjectKey()));
        SerializedRecipe.RecipeInput recipeMainInput = (SerializedRecipe.RecipeInput)this.recipe.getModel().getInputsForRole("main").stream().findFirst().orElseThrow(() -> new IllegalArgumentException("training dataset input not found"));
        this.trainingDataset = Dataset.fromSerialized((SerializedDataset)this.datasetsDAO.getMandatory(recipeMainInput.getLoc(this.recipe.getProjectKey())));
        Optional recipeValidationDSInput = this.recipe.getModel().getInputsForRole("validation_dataset").stream().findFirst();
        this.validationDataset = Optional.empty();
        if (recipeValidationDSInput.isPresent()) {
            this.validationDataset = Optional.of(Dataset.fromSerialized((SerializedDataset)this.datasetsDAO.getMandatory(((SerializedRecipe.RecipeInput)recipeValidationDSInput.get()).getLoc(this.recipe.getProjectKey()))));
        }
        this.mti = new ModelTrainInfo();
        this.mti.startTime = System.currentTimeMillis();
        this.outputFMI = new FullModelId(this.sm.projectKey, this.sm.id, newVersionId);
        MLPaths.createIfNeededSavedModelFolderAndRestrictPermissions(this.sm);
        DKUFileUtils.mkdirs((File)this.outputFMI.getModelFolder());
        FilesystemACLUtils.grantFSReadACLs(this.authCtx, this.sm.projectKey, this.outputFMI.getFolderEnsuringSecurity());
        FilesystemACLUtils.grantFSFullACLs(this.authCtx, this.sm.projectKey, this.outputFMI.getModelFolder());
        JSON.prettyToFile((Object)this.desc, (File)new File(this.outputFMI.getModelFolder(), "desc.json"));
        this.runner = this.getAbstractFineTuningRunner();
        SpringUtils.getInstance().autowire((Object)this.runner);
        this.runner.ensureConnectionAllowsFinetuning();
        LLMSavedModelInfo llmSavedModelInfo = this.runner.runFineTuning();
        llmSavedModelInfo.connection = llmRef.connection;
        llmSavedModelInfo.inputLLMId = this.desc.llmId;
        llmSavedModelInfo.promptColumn = this.desc.promptColumn;
        llmSavedModelInfo.completionColumn = this.desc.completionColumn;
        llmSavedModelInfo.trainingDataset = this.trainingDataset.getName();
        llmSavedModelInfo.validationDataset = this.validationDataset.map(ds -> ds.getName()).orElse(null);
        this.outputFMI.saveLLMInfo(llmSavedModelInfo);
        this.mti.endTime = System.currentTimeMillis();
        this.mti.trainingTime = this.mti.endTime - this.mti.startTime;
        this.mti.state = ModelTrainInfo.ModelTrainState.DONE;
        this.outputFMI.saveModelTrainInfo(this.mti);
        ModelUserMeta mum = new ModelUserMeta();
        mum.name = llmRef.getFinetunedModelMetaName();
        mum.name = mum.name + " - v" + this.sm.lastTrainIndex;
        this.outputFMI.saveUserMeta(mum);
        if (this.desc.deployFinetunedModel && this.DEPLOYABLE_LLM_TYPES.contains((Object)llmRef.type)) {
            this.remoteDeploymentID = this.smvDeploymentCRUDService.deployFineTunedModel(this.authCtx, this.sm, this.outputFMI);
            this.smvDeploymentCRUDService.waitForDeploymentToBeAvailable(this.authCtx, this.sm, this.outputFMI, this.remoteDeploymentID);
        }
        if (this.desc.cleanInactiveSMVDeployments) {
            HashSet<LLMSMMgmtService.LLMSMVersionHeader> versionsToDelete = new HashSet<LLMSMMgmtService.LLMSMVersionHeader>();
            LLMSMMgmtService.LLMSMStatus status = LLMSMMgmtService.getStatus_NT(this.sm);
            for (LLMSMMgmtService.LLMSMVersionHeader versionHeader : status.versions) {
                if (((LLMModelSnippetData)versionHeader.snippet).deployment == null || versionHeader.active && this.sm.publishPolicy == SavedModel.ModelPublishPolicy.MANUAL || ((LLMModelSnippetData)versionHeader.snippet).fullModelId.equals(this.outputFMI.toString())) continue;
                versionsToDelete.add(versionHeader);
            }
            this.smvDeploymentCRUDService.deleteVersionsDeployments(this.authCtx, this.sm, versionsToDelete);
        }
    }

    protected void incrementSavedModelTrainIndex(TicketBasedIntercomAPIClient tClient) throws IOException {
        logger.info((Object)"Increment last train index");
        tClient.postFormToJSON("/dip/api/tintercom/savedmodels/increment-last-train", Void.class, new Object[]{"projectKey", this.recipe.getProjectKey(), "smId", this.sm.id, "jobId", JobContext.getCurrentJobContext().jobId});
        logger.info((Object)"Done Increment last train index");
    }

    @Override
    public void finalCommit() throws Exception {
        String secret = this.ticketService.getSingleTicket().getSecret();
        try (TicketBasedIntercomAPIClient tClient = TicketBasedIntercomAPIClient.forLocalHost(secret);){
            if (this.sm.publishPolicy == SavedModel.ModelPublishPolicy.UNCONDITIONAL || !MLFlowUtils.hasValidActiveVersion(this.sm)) {
                logger.info((Object)"Setting new version as active scoring version");
                String versionToActivate = this.outputFMI.getSavedModelVersionID();
                tClient.postFormToJSON("/dip/api/tintercom/savedmodels/set-active", Void.class, new Object[]{"projectKey", this.recipe.getProjectKey(), "smId", this.sm.id, "versionId", versionToActivate});
                this.sm.activeVersion = versionToActivate;
                this.savedModelsDAO.save(this.sm);
            }
            this.incrementSavedModelTrainIndex(tClient);
        }
    }

    @Override
    public void notifyBeforeAborting() {
        if (this.runner != null) {
            try {
                this.runner.cancelFineTuning();
            }
            catch (UnsupportedOperationException unsupportedOperationException) {
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        if (this.remoteDeploymentID != null) {
            try {
                this.smvDeploymentCRUDService.deleteDeployment(this.authCtx, this.sm.projectKey, this.sm, this.outputFMI);
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
    }
}

