/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.MLTaskLoc;
import com.dataiku.dip.analysis.ml.prediction.PartitionedExtractService;
import com.dataiku.dip.analysis.ml.prediction.PartitionedModelsService;
import com.dataiku.dip.analysis.ml.prediction.PredictionResultsReader;
import com.dataiku.dip.analysis.ml.prediction.PredictionTrainAdditionalThread;
import com.dataiku.dip.analysis.ml.prediction.split.SplitDesc;
import com.dataiku.dip.analysis.ml.shared.ModelStateHelper;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.ClassicalPredictionModelDetails;
import com.dataiku.dip.analysis.model.prediction.PartitionedModelExtract;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelDetails;
import com.dataiku.dip.analysis.model.prediction.PredictionModelSnippetData;
import com.dataiku.dip.analysis.model.prediction.ResolvedPredictionCoreParams;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import java.io.File;
import java.io.IOException;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Queue;

public class StratPredictionTrainAdditionalThread
extends PredictionTrainAdditionalThread {
    private final PartitionedExtractService extractCacheService;
    private final PartitionedModelsService partitionedModelsService;
    private final Queue<WorkSet.PreprocessingSet> partitionsWorkQueue;
    private WorkSet.PreprocessingSet currentSet;
    private boolean partialAbort = false;
    private static DKULogger logger = DKULogger.getLogger((String)"dku.analysis.prediction.strat");

    public StratPredictionTrainAdditionalThread(AuthCtx authCtx, Queue<WorkSet.PreprocessingSet> partitionsWorkQueue, List<FullModelId> fullModelIds, MLTaskLoc taskLoc, ResolvedPredictionCoreParams rpcp, String sessionId, PredictionMLTask task, Object sync, String command) {
        super(authCtx, null, fullModelIds, null, taskLoc, rpcp, sessionId, task, sync, command);
        this.partitionsWorkQueue = partitionsWorkQueue;
        this.extractCacheService = (PartitionedExtractService)SpringUtils.getBean(PartitionedExtractService.class);
        this.partitionedModelsService = (PartitionedModelsService)SpringUtils.getBean(PartitionedModelsService.class);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public WorkSet.PreprocessingSet pollPreprocessingSet() {
        this.partialAbort = false;
        Queue<WorkSet.PreprocessingSet> queue = this.partitionsWorkQueue;
        synchronized (queue) {
            this.currentSet = this.partitionsWorkQueue.poll();
            if (this.currentSet == null) {
                logger.info((Object)"StratPredictionTrainAdditionalThread done");
                return null;
            }
            return this.currentSet;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void partialAbort(List<FullModelId> fmis) throws IOException {
        logger.info((Object)"Partially aborting partitioned python train thread");
        Queue<WorkSet.PreprocessingSet> queue = this.partitionsWorkQueue;
        synchronized (queue) {
            this.removePPSMatchingFMIs(this.partitionsWorkQueue, fmis);
        }
        if (this.kernel != null && this.isPreprocessingIdInFullModelIdSet(this.currentPreprocessingSet, fmis)) {
            logger.infoV("Aborting kernel ...%s", new Object[]{this.kernel.getId()});
            try {
                this.partialAbort = true;
                this.kernel.abort();
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                logger.info((Object)"Kernel was interrupted", (Throwable)e);
            }
        }
        for (FullModelId fmi : fmis) {
            this.extractCacheService.setPartitionedNonFinalStates(fmi, PartitionedModelExtract.PartitionState.ABORTED);
            this.partitionedModelsService.gatherBaseModelInfo(fmi);
        }
    }

    private void removePPSMatchingFMIs(Collection<WorkSet.PreprocessingSet> preprocessingSets, List<FullModelId> fmis) {
        Iterator<WorkSet.PreprocessingSet> iterator = preprocessingSets.iterator();
        while (iterator.hasNext()) {
            WorkSet.PreprocessingSet pps = iterator.next();
            if (!this.isPreprocessingIdInFullModelIdSet(pps, fmis)) continue;
            logger.infoV("Discarding preprocessingset %s", new Object[]{pps.preprocessingId});
            iterator.remove();
        }
    }

    @Override
    protected boolean isPreprocessingIdInFullModelIdSet(WorkSet.PreprocessingSet pps, List<FullModelId> fmis) {
        if (pps == null) {
            return false;
        }
        for (WorkSet.ModelingSet ms : pps.modelingSets) {
            FullModelId baseModelingFmi = ms.fullId.getPartitionedBaseModel();
            for (FullModelId fmi : fmis) {
                FullModelId baseFmi = fmi.getPartitionedBaseModel();
                if (!baseModelingFmi.equals(baseFmi)) continue;
                return true;
            }
        }
        return false;
    }

    @Override
    public SplitDesc getSplitDesc() {
        return this.currentSet.partitionSplitDesc;
    }

    @Override
    public void process(WorkSet.PreprocessingSet pps) throws Exception {
        if (this.shallNotProcess(pps)) {
            return;
        }
        for (WorkSet.ModelingSet ms : pps.modelingSets) {
            FullModelId partitionFmi = ms.fullId;
            FullModelId baseFmi = partitionFmi.getPartitionedBaseModel();
            final File mtiFile = baseFmi.getModelInfoFile();
            ModelStateHelper.updateModelTrainInfoAtomically(mtiFile, new Runnable(){

                @Override
                public void run() {
                    try {
                        ModelTrainInfo mti = (ModelTrainInfo)JSON.parseFile((File)mtiFile, ModelTrainInfo.class);
                        if (mti.state == ModelTrainInfo.ModelTrainState.PENDING) {
                            mti.startTime = System.currentTimeMillis();
                            mti.state = ModelTrainInfo.ModelTrainState.RUNNING;
                            JSON.prettyToFile((Object)mti, (File)mtiFile);
                        }
                    }
                    catch (IOException e) {
                        logger.warn((Object)"Failed to record running state", (Throwable)e);
                    }
                }
            });
            PartitionedModelExtract extract = this.extractCacheService.read(baseFmi);
            if (extract == null || extract.summaries.get((Object)partitionFmi.getPartitionName()).state != PartitionedModelExtract.PartitionState.PENDING) continue;
            this.extractCacheService.updateStates(baseFmi, PartitionedModelExtract.PartitionState.RUNNING, partitionFmi.getPartitionName());
        }
        super.process(pps);
    }

    @Override
    public void postProcess(WorkSet.PreprocessingSet pps) throws IOException {
        super.postProcess(pps);
        this.retrieveSnippetsFromPartitions(pps);
    }

    @Override
    protected boolean shallNotProcess(WorkSet.PreprocessingSet pps) {
        FullModelId currentFmi = new FullModelId(this.taskLoc, this.sessionId, pps.preprocessingId, pps.modelingSets.get((int)0).modelId);
        if (this.fullModelIds != null && !this.fullModelIds.isEmpty()) {
            return !this.fullModelIds.contains(currentFmi.toString()) && !this.fullModelIds.contains(currentFmi.getPartitionedBaseModel().toString());
        }
        return false;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void handleError(WorkSet.PreprocessingSet pps, Throwable e) {
        ModelTrainInfo.ModelTrainState state = this.abort || this.partialAbort ? ModelTrainInfo.ModelTrainState.ABORTED : ModelTrainInfo.ModelTrainState.FAILED;
        Object object = this.sync;
        synchronized (object) {
            ModelStateHelper.markAllNotFinalAsState(pps, state, e);
        }
        this.retrieveSnippetsFromPartitions(pps);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void retrieveSnippetsFromPartitions(WorkSet.PreprocessingSet pps) {
        for (WorkSet.ModelingSet ms : pps.modelingSets) {
            FullModelId partitionFmi = ms.fullId;
            FullModelId baseFmi = partitionFmi.getPartitionedBaseModel();
            PartitionedModelExtract.PartitionState newState = PartitionedModelExtract.PartitionState.FAILED;
            try {
                PredictionModelDetails details = PredictionResultsReader.makeModelDetails(partitionFmi);
                newState = PartitionedModelExtract.PartitionState.fromModelTrainState(details.trainInfo.state);
                MLTask mlTask = partitionFmi.getHeadMLTask();
                if (mlTask instanceof PredictionMLTask.ClassicalPredictionMLTask) {
                    ((ClassicalPredictionModelDetails)details).headTaskCMW = ((PredictionMLTask.ClassicalPredictionMLTask)mlTask).modeling.metrics.costMatrixWeights;
                }
                PredictionModelSnippetData snippet = PredictionResultsReader.makeSnippet(details);
                PredictionResultsReader.addPartitionedModelInfo(snippet, partitionFmi);
                this.extractCacheService.updateSnippet(baseFmi, snippet);
            }
            catch (IOException ex) {
                try {
                    logger.warn((Object)"Failed to retrieve partitioned model snippet", (Throwable)ex);
                }
                catch (Throwable throwable) {
                    this.extractCacheService.updateStates(baseFmi, newState, partitionFmi.getPartitionName());
                    throw throwable;
                }
                this.extractCacheService.updateStates(baseFmi, newState, partitionFmi.getPartitionName());
            }
            this.extractCacheService.updateStates(baseFmi, newState, partitionFmi.getPartitionName());
            try {
                boolean isLastPartition = this.extractCacheService.getNumPartitionsWithStates(baseFmi, PartitionedModelExtract.PartitionState.RUNNING, PartitionedModelExtract.PartitionState.PENDING) <= 0;
                if (!isLastPartition) continue;
                this.partitionedModelsService.gatherBaseModelInfo(baseFmi);
            }
            catch (IOException e) {
                logger.warn((Object)("Failed to read extract of following fmi: '" + String.valueOf(baseFmi) + "'"), (Throwable)e);
            }
        }
    }
}

