/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction;

import com.dataiku.dip.DSSMetrics;
import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.MLPaths;
import com.dataiku.dip.analysis.ml.MLTaskLoc;
import com.dataiku.dip.analysis.ml.prediction.AbstractPythonPredictionMLTaskHandler;
import com.dataiku.dip.analysis.ml.prediction.PartitionedExtractService;
import com.dataiku.dip.analysis.ml.prediction.PartitionedModelsService;
import com.dataiku.dip.analysis.ml.prediction.PredictionMLTaskHandlingStrategy;
import com.dataiku.dip.analysis.ml.prediction.StratPredictionTrainAdditionalThread;
import com.dataiku.dip.analysis.ml.prediction.split.PredictionSplitGenerator;
import com.dataiku.dip.analysis.ml.prediction.split.SplitDesc;
import com.dataiku.dip.analysis.ml.prediction.split.SplitGeneratorFactory;
import com.dataiku.dip.analysis.ml.shared.ModelStateHelper;
import com.dataiku.dip.analysis.ml.shared.PRNSTrainThread;
import com.dataiku.dip.analysis.ml.shared.WorkSetPreparator;
import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.AnalysisCoreParams;
import com.dataiku.dip.analysis.model.core.ModelUserMeta;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.PartitionedModelExtract;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.coremodel.Schema;
import com.dataiku.dip.datasets.LatestPartitionsSelector;
import com.dataiku.dip.distributed.metrics.ContainerUsageMetrics;
import com.dataiku.dip.input.DatasetHandlerFactory;
import com.dataiku.dip.partitioning.DimensionType;
import com.dataiku.dip.partitioning.Partition;
import com.dataiku.dip.partitioning.PartitionFactory;
import com.dataiku.dip.partitioning.PartitioningUtils;
import com.dataiku.dip.partitioning.StratifiedModelUtils;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.datasets.DatasetAccessService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.utils.DKUFileUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.Multimap;
import java.io.File;
import java.io.IOException;
import java.lang.invoke.LambdaMetafactory;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.commons.lang.StringUtils;

public class StratifiedPythonPredictionMLTaskHandler
extends AbstractPythonPredictionMLTaskHandler {
    private final PartitionedExtractService partitionedExtractService;
    private final TransactionService transactionService;
    private final DatasetAccessService datasetAccessService;
    private final PartitionedModelsService partitionedModelsService;
    private final Set<FullModelId> baseFmiFilter;
    private List<String> partitionNames;
    private Multimap<String, WorkSet.PreprocessingSet> partitionWorkMap;
    private Queue<WorkSet.PreprocessingSet> workParamsQueue;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.analysis.prediction.strat.prns");

    @Override
    protected void checkSplits(SplitDesc splitDesc) {
    }

    public StratifiedPythonPredictionMLTaskHandler(AnalysisCoreParams acp, MLTaskLoc taskLoc, PredictionMLTask task, PredictionMLTaskHandlingStrategy predictionMLTaskHandlingStrategy, String sessionId, List<FullModelId> fullModelIds, AuthCtx user) {
        super(acp, taskLoc, task, predictionMLTaskHandlingStrategy, sessionId, fullModelIds, user);
        if (fullModelIds != null) {
            this.baseFmiFilter = new HashSet<FullModelId>();
            for (FullModelId fullModelId : fullModelIds) {
                this.baseFmiFilter.add(fullModelId.getPartitionedBaseModel());
            }
        } else {
            this.baseFmiFilter = null;
        }
        this.partitionedExtractService = (PartitionedExtractService)SpringUtils.getBean(PartitionedExtractService.class);
        this.transactionService = (TransactionService)SpringUtils.getBean(TransactionService.class);
        this.datasetAccessService = (DatasetAccessService)SpringUtils.getBean(DatasetAccessService.class);
        this.partitionedModelsService = (PartitionedModelsService)SpringUtils.getBean(PartitionedModelsService.class);
    }

    @Override
    public void init(WorkSetPreparator preparator) throws Exception {
        this.handleDatasetPartitions();
        super.init_();
        this.preparePartitionWorksets(this.ws, preparator);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void train() throws Exception {
        DSSMetrics.registry().meter(this.getDSSMetricName()).mark();
        try {
            logger.info((Object)"Preparing base & partitions splits");
            this.generateSplitsAndDispatchWork();
            this.setBaseModelsState(this.ws, ModelTrainInfo.ModelTrainState.PENDING);
            logger.info((Object)"Launching the training threads");
            this.workParamsQueue = new LinkedList<WorkSet.PreprocessingSet>(this.partitionWorkMap.values());
            if (this.workParamsQueue.isEmpty()) {
                logger.error((Object)"No partition to train");
                this.setBaseModelsState(this.ws, ModelTrainInfo.ModelTrainState.FAILED);
                throw new IllegalStateException("No partition to train");
            }
            this.runTraining(null);
        }
        catch (Exception e) {
            logger.error((Object)"Failure while training main loop", (Throwable)e);
            StratifiedPythonPredictionMLTaskHandler stratifiedPythonPredictionMLTaskHandler = this;
            synchronized (stratifiedPythonPredictionMLTaskHandler) {
                PartitionedModelExtract.PartitionState state = this.isAborting() ? PartitionedModelExtract.PartitionState.ABORTED : PartitionedModelExtract.PartitionState.FAILED;
                this.setPartitionedNonFinalStates(this.ws, state);
                this.fillBaseModelFolders(this.ws);
            }
            throw e;
        }
    }

    private static com.dataiku.scoring.builders.DimensionType toDimensionType(DimensionType dimensionType) {
        switch (dimensionType) {
            case DISCRETE: {
                return com.dataiku.scoring.builders.DimensionType.DISCRETE;
            }
            case YEAR: {
                return com.dataiku.scoring.builders.DimensionType.YEAR;
            }
            case MONTH: {
                return com.dataiku.scoring.builders.DimensionType.MONTH;
            }
            case DAY: {
                return com.dataiku.scoring.builders.DimensionType.DAY;
            }
            case HOUR: {
                return com.dataiku.scoring.builders.DimensionType.HOUR;
            }
        }
        throw new IllegalArgumentException("Unsupported: " + String.valueOf(dimensionType));
    }

    /*
     * Unable to fully structure code
     */
    private void handleDatasetPartitions() throws Exception {
        if (!((PredictionMLTask)this.task).getPartitionedModel().isPresent()) {
            throw new IllegalArgumentException("ML Task is not partitioned.");
        }
        partitionedModel = ((PredictionMLTask)this.task).getPartitionedModel().get();
        datasetSelection = partitionedModel.ssdSelection;
        inputPartitions = new ArrayList<Partition>();
        t = this.transactionService.beginRead();
        try {
            datasetSmartName = StringUtils.defaultIfBlank((String)((PredictionMLTask)this.task).splitParams.ssdDatasetSmartName, (String)this.acp.inputDatasetSmartName);
            dataset = this.datasetAccessService.getMandatoryFromRef(this.acp.projectKey, datasetSmartName);
        }
        finally {
            if (t != null) {
                t.close();
            }
        }
        partitionedModel.dimensionNames = dataset.getPartitioningSchema().getDimensionNames();
        partitionedModel.dimensionTypes = dataset.getPartitioningSchema().getDimensionTypes().stream().map((Function<DimensionType, com.dataiku.scoring.builders.DimensionType>)LambdaMetafactory.metafactory(null, null, null, (Ljava/lang/Object;)Ljava/lang/Object;, toDimensionType(com.dataiku.dip.partitioning.DimensionType ), (Lcom/dataiku/dip/partitioning/DimensionType;)Lcom/dataiku/scoring/builders/DimensionType;)()).collect(Collectors.toList());
        dh = DatasetHandlerFactory.build(this.user, dataset);
        try {
            switch (1.$SwitchMap$com$dataiku$dip$datasets$DatasetSelection$PartitionSelectionMethod[datasetSelection.partitionSelectionMethod.ordinal()]) {
                case 1: {
                    inputPartitions.addAll(dh.listPartitions());
                    ** break;
lbl25:
                    // 1 sources

                    break;
                }
                case 2: {
                    scheme = dataset.getPartitioningSchema();
                    if (datasetSelection.selectedPartitions == null) ** break;
                    for (String p : datasetSelection.selectedPartitions) {
                        inputPartitions.add(PartitionFactory.fromIdentifier(scheme, p));
                    }
                    break;
                }
                case 3: {
                    inputPartitions.addAll(LatestPartitionsSelector.select(this.user, dataset, datasetSelection.latestPartitionsN));
                    break;
                }
                ** default:
lbl39:
                // 1 sources

                break;
            }
        }
        finally {
            if (dh != null) {
                dh.close();
            }
        }
        this.partitionNames = new ArrayList<String>();
        for (Partition partition : PartitioningUtils.sort(inputPartitions, true)) {
            this.partitionNames.add(partition.id());
        }
    }

    private void preparePartitionWorksets(WorkSet ws, WorkSetPreparator preparator) throws IOException {
        preparator.withPreprocessingIdSuffix("-base").prepare(ws);
        for (FullModelId fullModelId : this.filteredFmis(ws.getModelIds())) {
            ModelUserMeta userMeta = fullModelId.parseModelFile("user_meta.json", ModelUserMeta.class);
            userMeta.name = StratifiedModelUtils.addSuffixIfAbsent(userMeta.name);
            JSON.prettyToFile((Object)userMeta, (File)fullModelId.getModelFile("user_meta.json"));
        }
        for (FullModelId baseFmi : this.filteredFmis(ws.getModelIds())) {
            this.partitionedExtractService.createOrRetrieve(baseFmi, this.partitionNames);
        }
        this.partitionWorkMap = ArrayListMultimap.create();
        for (String partitionName : this.partitionNames) {
            WorkSet partitionWorkSet = (WorkSet)JSON.deepCopy((Object)ws);
            preparator.withPreprocessingIdSuffix(StratifiedModelUtils.generatePartitionSuffix(partitionName)).prepare(partitionWorkSet);
            this.partitionWorkMap.putAll((Object)partitionName, partitionWorkSet.preprocessingSets);
        }
    }

    private void generateSplitsAndDispatchWork() throws Exception {
        PredictionSplitGenerator generator = SplitGeneratorFactory.buildForPartitionedMLTask(this.acp, this.taskLoc, ((PredictionMLTask)this.task).splitParams, this.user, null, null);
        String instanceId = generator.getExpectedInstanceId_NT();
        JSON.prettyToFile((Object)new SplitDesc.SplitRef(instanceId), (File)new File(MLPaths.sessionFolder(this.taskLoc, this.sessionId), "split_ref.json"));
        DKUFileUtils.mkdirs((File)this.taskLoc.getSplitsFolder());
        Schema baseSplitSchema = generator.getSplitSchema();
        SplitDesc baseSplitDesc = new SplitDesc(generator.getPolicyId(), instanceId, ((PredictionMLTask)this.task).splitParams, baseSplitSchema);
        for (Map.Entry entry : this.partitionWorkMap.asMap().entrySet()) {
            String partitionName = (String)entry.getKey();
            Collection ppSets = (Collection)entry.getValue();
            SplitDesc partSplitDesc = this.generatePartitionSplit(partitionName, baseSplitSchema);
            for (WorkSet.PreprocessingSet ppSet : ppSets) {
                ppSet.partitionSplitDesc = partSplitDesc;
                File ppRunFolder = MLPaths.preprocessingFolder(this.taskLoc, this.sessionId, ppSet.preprocessingId);
                JSON.prettyToFile((Object)new SplitDesc.SplitRef(partSplitDesc.instanceId), (File)new File(ppRunFolder, "split_ref.json"));
            }
            StratifiedModelUtils.mergeSplitDesc(baseSplitDesc, partSplitDesc);
        }
        JSON.prettyToFile((Object)baseSplitDesc, (File)generator.getSplitDescFile(instanceId));
    }

    private void setBaseModelsState(WorkSet ws, ModelTrainInfo.ModelTrainState state) {
        for (FullModelId globalModelId : this.filteredFmis(ws.getModelIds())) {
            ModelStateHelper.setModelState(globalModelId, state);
        }
    }

    private void setPartitionedNonFinalStates(WorkSet ws, PartitionedModelExtract.PartitionState stateIfNotDone) {
        for (FullModelId baseFmi : this.filteredFmis(ws.getModelIds())) {
            this.partitionedExtractService.setPartitionedNonFinalStates(baseFmi, stateIfNotDone);
        }
    }

    private void fillBaseModelFolders(WorkSet ws) throws IOException {
        for (FullModelId globalModelId : this.filteredFmis(ws.getModelIds())) {
            this.partitionedModelsService.gatherBaseModelInfo(globalModelId);
        }
    }

    private SplitDesc generatePartitionSplit(String partition, Schema baseSplitSchema) throws Exception {
        PredictionSplitGenerator splitGenerator = SplitGeneratorFactory.buildForPartitionedMLTask(this.acp, this.taskLoc, ((PredictionMLTask)this.task).splitParams, this.user, partition, baseSplitSchema);
        splitGenerator.updateSplitIfNeeded_NT();
        return splitGenerator.getUpToDateSplitDesc();
    }

    private List<FullModelId> filteredFmis(List<FullModelId> fmis) {
        if (this.baseFmiFilter == null || this.baseFmiFilter.isEmpty()) {
            return fmis;
        }
        ArrayList<FullModelId> result = new ArrayList<FullModelId>();
        for (FullModelId fmi : fmis) {
            FullModelId baseFmi = fmi.getPartitionedBaseModel();
            if (!this.baseFmiFilter.contains(baseFmi)) continue;
            result.add(fmi);
        }
        return result;
    }

    @Override
    protected PRNSTrainThread createTrainThread(SplitDesc splitDesc) {
        String pyfunc = this.predictionMLTaskHandlingStrategy.getPythonFunction((PredictionMLTask)this.task);
        StratPredictionTrainAdditionalThread thread = new StratPredictionTrainAdditionalThread(this.user, this.workParamsQueue, this.fullModelIds, this.taskLoc, this.rpcp, this.sessionId, (PredictionMLTask)this.task, this, pyfunc);
        thread.forceEnvVars(this.predictionMLTaskHandlingStrategy.getForcedEnvVars());
        return thread;
    }

    @Override
    public Map<FullModelId, ContainerUsageMetrics> getContainerUsageMetricsPerModel() {
        HashMap<FullModelId, ContainerUsageMetrics> containerUsageMetricsPerModel = new HashMap<FullModelId, ContainerUsageMetrics>();
        for (PRNSTrainThread tat : this.processingThreads) {
            ContainerUsageMetrics containerUsageMetricsForThread = tat.getContainerUsageMetrics();
            for (FullModelId fmi : tat.getCurrentFullModelIds()) {
                containerUsageMetricsPerModel.compute(fmi.getPartitionedBaseModel(), (key, oldValue) -> {
                    ContainerUsageMetrics newValue = new ContainerUsageMetrics(0, 0, 0, containerUsageMetricsForThread.isKubernetesEnabled());
                    if (oldValue != null) {
                        newValue = newValue.add(oldValue);
                    }
                    return newValue.add(containerUsageMetricsForThread);
                });
            }
        }
        return containerUsageMetricsPerModel;
    }
}

