/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.shared;

import com.dataiku.common.server.APIError;
import com.dataiku.common.server.SerializedError;
import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.DSSTempUtils;
import com.dataiku.dip.analysis.coreservices.AnalysisMLContainerKernel;
import com.dataiku.dip.analysis.coreservices.AnalysisMLKernel;
import com.dataiku.dip.analysis.coreservices.IAnalysisMLKernel;
import com.dataiku.dip.analysis.coreservices.MLBaseService;
import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.MLTaskLoc;
import com.dataiku.dip.analysis.ml.prediction.split.SplitDesc;
import com.dataiku.dip.analysis.ml.shared.ModelStateHelper;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.containers.exec.ContainerExecConfigSelector;
import com.dataiku.dip.containers.exec.ContainerExecRuntimeConfig;
import com.dataiku.dip.distributed.metrics.ContainerUsageMetrics;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.io.PortRangeParams;
import com.dataiku.dip.io.SingleCommandKernelLink;
import com.dataiku.dip.resourceusage.ComputeResourceUsageContext;
import com.dataiku.dip.resourceusage.CurrentComputeResourceUsageContext;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.impersonation.FilesystemACLUtils;
import com.dataiku.dip.security.rpc.EncryptedRPC;
import com.dataiku.dip.util.AutoDelete;
import com.dataiku.dip.util.SecretKeyGenerator;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.google.common.collect.Lists;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Queue;

public abstract class PRNSTrainThread
extends Thread {
    protected IAnalysisMLKernel kernel;
    private SingleCommandKernelLink link;
    private int workingPid = 0;
    protected boolean abort = false;
    protected final Object sync;
    protected WorkSet.PreprocessingSet currentPreprocessingSet;
    private boolean runsInContainer = false;
    private final SplitDesc splitDesc;
    protected final MLTaskLoc taskLoc;
    protected final List<String> fullModelIds;
    protected final String sessionId;
    private final Queue<WorkSet.PreprocessingSet> workQueue;
    protected Map<String, String> envVars = Collections.emptyMap();
    private final AuthCtx authCtx;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.analysis.ml.python");

    protected PRNSTrainThread(AuthCtx authCtx, Queue<WorkSet.PreprocessingSet> workQueue, SplitDesc splitDesc, MLTaskLoc taskLoc, List<String> fullModelIds, String sessionId, Object sync) {
        this.authCtx = authCtx;
        this.workQueue = workQueue;
        this.splitDesc = splitDesc;
        this.taskLoc = taskLoc;
        this.fullModelIds = fullModelIds;
        this.sessionId = sessionId;
        this.sync = sync;
    }

    public abstract void process(WorkSet.PreprocessingSet var1) throws Exception;

    public abstract void postProcess(WorkSet.PreprocessingSet var1) throws IOException;

    public int getWorkingPid() {
        return this.workingPid;
    }

    public abstract String getProjectKey();

    public abstract String getEnvName() throws IOException;

    public abstract MLTask getTask();

    protected abstract boolean shallNotProcess(WorkSet.PreprocessingSet var1);

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void handleError(WorkSet.PreprocessingSet pps, Throwable e) {
        Object object = this.sync;
        synchronized (object) {
            ModelStateHelper.markAllNotDoneAsFailed(pps, e);
        }
    }

    public void forceEnvVars(Map<String, String> envVars) {
        this.envVars = envVars;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public final void run() {
        Thread.currentThread().setName("MRT-" + Thread.currentThread().getId());
        ComputeResourceUsageContext cruContext = ComputeResourceUsageContext.forAnalysisMLTrain((AuthCtx)this.authCtx, (String)this.taskLoc.analysisProjectKey, (String)this.taskLoc.analysisId, (String)this.taskLoc.mlTaskId, (String)this.sessionId);
        CurrentComputeResourceUsageContext.setInCurrentThread((ComputeResourceUsageContext)cruContext);
        while (!this.abort) {
            Exception error = null;
            WorkSet.PreprocessingSet pps = this.pollPreprocessingSet();
            if (pps == null) {
                logger.info((Object)"TrainAdditionalThread done");
                return;
            }
            if (this.shallNotProcess(pps)) {
                logger.info((Object)("Not processing preprocessing set: " + pps.preprocessingId));
                continue;
            }
            ArrayList fmis = Lists.newArrayList();
            for (WorkSet.ModelingSet ms : pps.modelingSets) {
                if (ms.fullId == null) continue;
                fmis.add(ms.fullId);
            }
            MLBaseService.MLTaskWorkThread.mlTaskContext.set(new MLBaseService.MLTaskWorkThread.MLTaskContext(fmis));
            logger.info((Object)("Running a preprocessing set: " + pps.preprocessingId + " in " + pps.run_folder));
            try {
                this.currentPreprocessingSet = pps;
                PortRangeParams dssPortRange = ApplicationConfigurator.getPortRangeParams();
                ContainerExecRuntimeConfig containerConfig = new ContainerExecConfigSelector().selectForML_autoTXN(this.authCtx, this.getProjectKey(), this.getTask().containerSelection, this.getTask().backendType);
                this.runsInContainer = containerConfig != null;
                this.link = new SingleCommandKernelLink(SecretKeyGenerator.generate((int)16), dssPortRange, this.runsInContainer ? EncryptedRPC.getSSLContext() : null);
                File runFolderFile = new File(pps.run_folder);
                this.kernel = !this.runsInContainer ? new AnalysisMLKernel(this.link, pps, this.getProjectKey(), this.getEnvName(), this.authCtx, runFolderFile) : new AnalysisMLContainerKernel(this.link, pps, this.getProjectKey(), this.getEnvName(), this.authCtx, runFolderFile, this.taskLoc.getDataFolder(), containerConfig, "doctor-preprocessing-");
                this.kernel.forceEnvVars(this.envVars);
                this.kernel.start();
                Object object = this.sync;
                synchronized (object) {
                    pps.kernel = this.kernel;
                    pps.link = this.link;
                }
                this.workingPid = this.kernel.getPid();
                this.process(pps);
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                logger.error((Object)"Processing failed", (Throwable)e);
                error = e;
            }
            catch (Exception e) {
                logger.warn((Object)"Training failed", (Throwable)e);
                error = e;
            }
            finally {
                this.workingPid = 0;
                if (this.kernel != null) {
                    try {
                        if (!this.abort && !this.kernel.isAborted()) {
                            if (error == null) {
                                try {
                                    this.kernel.waitForResults();
                                    this.postProcess(pps);
                                }
                                catch (Exception e) {
                                    error = e;
                                }
                            } else {
                                SerializedError serializedError = this.kernel.waitForError();
                                if (serializedError != null) {
                                    error = new APIError.SerializedErrorException(serializedError);
                                }
                            }
                        }
                        this.kernel.cleanup();
                        this.kernel.killWithoutMercy();
                    }
                    catch (Exception e) {
                        logger.error((Object)"Failure while destroying ml kernel", (Throwable)e);
                    }
                }
                if (this.link != null) {
                    try {
                        this.link.close();
                    }
                    catch (Exception e) {
                        logger.error((Object)"Failure while closing link to kernel", (Throwable)e);
                    }
                }
                this.kernel = null;
                if (error != null) {
                    logger.error((Object)"Processing failed", (Throwable)error);
                    this.handleError(pps, error);
                }
                MLBaseService.MLTaskWorkThread.mlTaskContext.set(null);
            }
        }
    }

    protected WorkSet.PreprocessingSet remapPreprocessingSet(WorkSet.PreprocessingSet pps, IAnalysisMLKernel kernel) throws IOException, DKUSecurityException, InterruptedException {
        if (this.runsInContainer) {
            pps = (WorkSet.PreprocessingSet)JSON.deepCopy((Object)pps);
            if (pps.needsTmpFolder) {
                pps.tmp_folder = ".";
            }
        } else if (pps.needsTmpFolder) {
            AutoDelete tmpFolder = DSSTempUtils.getTempFolder((String)"ml-training-tmp-folder");
            FilesystemACLUtils.grantFSFullACLs(this.authCtx, this.getProjectKey(), new File[]{tmpFolder});
            pps.tmp_folder = tmpFolder.getAbsolutePath();
            kernel.registerOnShutDownRunnable(() -> ((AutoDelete)tmpFolder).close());
        }
        return pps;
    }

    public WorkSet.PreprocessingSet pollPreprocessingSet() {
        return this.workQueue.poll();
    }

    public SplitDesc getSplitDesc() {
        return this.splitDesc;
    }

    public static void join(List<? extends PRNSTrainThread> threads) throws InterruptedException {
        for (PRNSTrainThread pRNSTrainThread : threads) {
            logger.info((Object)"Joining processing thread ...");
            pRNSTrainThread.join();
            logger.info((Object)"Processing thread joined ...");
        }
    }

    public static void abort(List<? extends PRNSTrainThread> threads) {
        for (PRNSTrainThread pRNSTrainThread : threads) {
            pRNSTrainThread.abort = true;
            if (pRNSTrainThread.kernel != null) {
                logger.info((Object)("Aborting kernel ..." + pRNSTrainThread.kernel.getId()));
                pRNSTrainThread.kernel.killNoWaitNoException(true);
            }
            pRNSTrainThread.interrupt();
        }
    }

    protected boolean isPreprocessingIdInFullModelIdSet(WorkSet.PreprocessingSet pps, List<FullModelId> fullModelIdSet) {
        for (WorkSet.ModelingSet ms : pps.modelingSets) {
            for (FullModelId fmi : fullModelIdSet) {
                if (!ms.fullId.equals(fmi)) continue;
                return true;
            }
        }
        return false;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void partialAbort(List<FullModelId> fullModelIdSet) throws IOException {
        logger.info((Object)"Partially aborting python train thread");
        Queue<WorkSet.PreprocessingSet> queue = this.workQueue;
        synchronized (queue) {
            ArrayList<WorkSet.PreprocessingSet> ppss = new ArrayList<WorkSet.PreprocessingSet>();
            while (this.workQueue.peek() != null) {
                WorkSet.PreprocessingSet pps = this.workQueue.remove();
                if (!this.isPreprocessingIdInFullModelIdSet(pps, fullModelIdSet)) {
                    ppss.add(pps);
                    continue;
                }
                logger.info((Object)("Discarding preprocessingset " + pps.preprocessingId));
            }
            this.workQueue.addAll(ppss);
        }
        if (this.kernel != null && this.isPreprocessingIdInFullModelIdSet(this.currentPreprocessingSet, fullModelIdSet)) {
            logger.info((Object)("Aborting kernel ..." + this.kernel.getId()));
            try {
                this.kernel.abort();
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                logger.info((Object)"Kernel was interrupted", (Throwable)e);
            }
        }
    }

    public ContainerUsageMetrics getContainerUsageMetrics() {
        if (this.kernel instanceof AnalysisMLContainerKernel) {
            AnalysisMLContainerKernel analysisMLContainerKernel = (AnalysisMLContainerKernel)this.kernel;
            return analysisMLContainerKernel.getContainerUsageMetrics();
        }
        return new ContainerUsageMetrics(0, 0, 0, false);
    }

    public List<FullModelId> getCurrentFullModelIds() {
        ArrayList<FullModelId> currentFullModelIds = new ArrayList<FullModelId>();
        if (this.currentPreprocessingSet != null) {
            for (WorkSet.ModelingSet modelingSet : this.currentPreprocessingSet.modelingSets) {
                currentFullModelIds.add(modelingSet.fullId);
            }
        }
        return currentFullModelIds;
    }
}

