/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.clustering;

import com.dataiku.dip.analysis.ml.DKUMLUtils;
import com.dataiku.dip.analysis.ml.MLTaskLoc;
import com.dataiku.dip.analysis.ml.PythonMLTaskHandler;
import com.dataiku.dip.analysis.ml.clustering.ClusteringParamsExpander;
import com.dataiku.dip.analysis.ml.clustering.ClusteringTrainAdditionalThread;
import com.dataiku.dip.analysis.ml.clustering.extract.ClusteringSampleExtractor;
import com.dataiku.dip.analysis.ml.prediction.split.SplitDesc;
import com.dataiku.dip.analysis.ml.shared.PRNSTrainThread;
import com.dataiku.dip.analysis.ml.shared.WorkSetPreparator;
import com.dataiku.dip.analysis.model.clustering.ClusteringMLTask;
import com.dataiku.dip.analysis.model.clustering.ResolvedClusteringCoreParams;
import com.dataiku.dip.analysis.model.core.AnalysisCoreParams;
import com.dataiku.dip.security.AuthCtx;
import java.util.Set;

public class PythonClusteringMLTaskHandler
extends PythonMLTaskHandler<ClusteringMLTask> {
    private Set<String> fullModelIds;
    protected ResolvedClusteringCoreParams rccp;

    public PythonClusteringMLTaskHandler(AnalysisCoreParams acp, MLTaskLoc taskLoc, ClusteringMLTask task, String sessionId, AuthCtx user) {
        super(acp, taskLoc, task, sessionId, user);
    }

    @Override
    protected PRNSTrainThread createTrainThread(SplitDesc splitDesc) {
        return new ClusteringTrainAdditionalThread(this.user, this.workQueue, this.fullModelIds, splitDesc, this.taskLoc, this.rccp, this.sessionId, this, (ClusteringMLTask)this.task);
    }

    public PythonClusteringMLTaskHandler(AnalysisCoreParams acp, MLTaskLoc taskLoc, ClusteringMLTask task, String sessionId, Set<String> fullModelIds, AuthCtx user) {
        super(acp, taskLoc, task, sessionId, user);
        this.fullModelIds = fullModelIds;
    }

    @Override
    public void init(WorkSetPreparator preparator) throws Exception {
        this.rccp = ((ClusteringMLTask)this.task).buildResolvedCoreParams(this.acp.projectKey);
        DKUMLUtils.dumpParamsOnDisk(this.acp, this.taskLoc, this.task, this.sessionId, this.rccp);
        this.ws = new ClusteringParamsExpander((ClusteringMLTask)this.task, this.sessionId).expand();
        preparator.prepare(this.ws);
    }

    @Override
    protected SplitDesc prepareSplits() throws Exception {
        ClusteringSampleExtractor cse = new ClusteringSampleExtractor(this.taskLoc, this.acp, ((ClusteringMLTask)this.task).sampling, this.user);
        cse.updateSplitIfNeeded_NT();
        SplitDesc splitDesc = cse.getUpToDateSplitDesc();
        assert (splitDesc.testPath == null);
        return splitDesc;
    }

    @Override
    protected void checkSplits(SplitDesc splitDesc) {
    }

    @Override
    protected String getDSSMetricName() {
        return "dku.ml.clusteringTrain.pyRegularNoSaveTrain";
    }

    @Override
    protected int getThreadCountToRun() {
        return Math.max(1, ((ClusteringMLTask)this.task).maxConcurrentModelTraining);
    }
}

