/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.distributed.workers;

import com.dataiku.dip.analysis.coreservices.MLBaseService;
import com.dataiku.dip.analysis.ml.distributed.workers.Worker;
import com.dataiku.dip.analysis.ml.distributed.workers.WorkerFactory;
import com.dataiku.dip.analysis.ml.distributed.workers.WorkerInfos;
import com.dataiku.dip.analysis.ml.distributed.workers.WorkerStatus;
import com.dataiku.dip.dataflow.jobrunner.JobContext;
import com.dataiku.dip.distributed.metrics.ContainerUsageMetrics;
import com.dataiku.dip.exceptions.CodedException;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.resourceusage.ComputeResourceUsageContext;
import com.dataiku.dip.resourceusage.CurrentComputeResourceUsageContext;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.utils.DKULogger;
import com.google.common.base.Preconditions;
import java.io.Closeable;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.regex.Pattern;
import javax.annotation.Nullable;

public class WorkerPool
implements Closeable {
    private static final Pattern ID_PATTERN = Pattern.compile("^[a-z0-9][a-z0-9_-]*[a-z0-9]$");
    public final String workerPoolId;
    @Nullable
    public final String projectKey;
    public final String envName;
    public final AuthCtx authCtx;
    @Nullable
    private final MLBaseService.MLTaskWorkThread.MLTaskContext mlTaskContext;
    @Nullable
    private final JobContext jobContext;
    @Nullable
    public final ComputeResourceUsageContext cruContext;
    private final Map<String, Worker> workers = new HashMap<String, Worker>();
    private final Set<String> releasedWorkerIds = new HashSet<String>();
    private final Runnable closeCallback;
    private final WorkerFactory workerFactory;
    private boolean closed;
    public final boolean impersonated;
    protected static final DKULogger logger = DKULogger.getLogger((String)"com.dataiku.dip.analysis.ml.distributed.workers");

    public WorkerPool(String projectKey, String envName, AuthCtx authCtx, String workerPoolId, ComputeResourceUsageContext cruContext, @Nullable MLBaseService.MLTaskWorkThread.MLTaskContext mlTaskContext, @Nullable JobContext jobContext, Runnable closeCallback, WorkerFactory workerFactory, boolean impersonated) {
        this.projectKey = projectKey;
        this.envName = envName;
        this.authCtx = (AuthCtx)Preconditions.checkNotNull((Object)authCtx);
        this.workerPoolId = WorkerPool.validateId(workerPoolId);
        this.closeCallback = closeCallback;
        this.workerFactory = workerFactory;
        this.cruContext = cruContext;
        this.jobContext = jobContext;
        this.mlTaskContext = mlTaskContext;
        this.impersonated = impersonated;
    }

    private Worker createWorker(String workerId) throws IOException, CodedException, DKUSecurityException {
        final Worker worker = this.workerFactory.createWorker(workerId, this);
        new Thread(new Runnable(){

            @Override
            public void run() {
                try {
                    if (WorkerPool.this.mlTaskContext != null) {
                        MLBaseService.MLTaskWorkThread.mlTaskContext.set(WorkerPool.this.mlTaskContext);
                    }
                    if (WorkerPool.this.jobContext != null) {
                        JobContext.setJobContext(WorkerPool.this.jobContext);
                    }
                    if (WorkerPool.this.cruContext != null) {
                        CurrentComputeResourceUsageContext.setInCurrentThread((ComputeResourceUsageContext)WorkerPool.this.cruContext);
                    }
                    worker.start();
                }
                catch (Exception e) {
                    logger.error((Object)("Worker " + worker.getId() + " failed to start"), (Throwable)e);
                }
            }
        }).start();
        logger.info((Object)("Created " + worker.workerType + " worker " + worker.getId()));
        return worker;
    }

    public String getId() {
        return this.workerPoolId;
    }

    public synchronized WorkerInfos requestWorker(String workerId) throws IOException, CodedException, DKUSecurityException {
        this.assertIsNotClosed();
        WorkerPool.validateId(workerId);
        if (this.releasedWorkerIds.contains(workerId)) {
            logger.error((Object)("Can't acquire released worker " + workerId));
            return new WorkerInfos(WorkerStatus.DEAD, null, null, null);
        }
        Worker worker = this.workers.get(workerId);
        if (worker == null) {
            worker = this.createWorker(workerId);
            this.workers.put(workerId, worker);
        }
        return worker.getWorkerInfos();
    }

    public synchronized void releaseWorker(String workerId) {
        this.assertIsNotClosed();
        WorkerPool.validateId(workerId);
        if (this.releasedWorkerIds.contains(workerId)) {
            return;
        }
        this.releasedWorkerIds.add(workerId);
        Worker worker = this.workers.get(workerId);
        if (worker != null) {
            worker.close();
        }
        this.workers.remove(workerId);
        logger.info((Object)("Release worker " + workerId));
    }

    public ContainerUsageMetrics getWorkersUsageMetrics() {
        ContainerUsageMetrics remoteWorkersStatusCount = new ContainerUsageMetrics(0, 0, 0, true);
        for (Map.Entry<String, Worker> entry : this.workers.entrySet()) {
            String workerId = entry.getKey();
            WorkerStatus status = entry.getValue().getWorkerInfos().status;
            switch (status) {
                case DEAD: {
                    if (this.releasedWorkerIds.contains(workerId)) break;
                    remoteWorkersStatusCount.incrementDead();
                    break;
                }
                case PENDING: {
                    remoteWorkersStatusCount.incrementPending();
                    break;
                }
                case READY: {
                    remoteWorkersStatusCount.incrementReady();
                }
            }
        }
        this.releasedWorkerIds.forEach(releasedWorkerId -> remoteWorkersStatusCount.incrementDead());
        return remoteWorkersStatusCount;
    }

    private void assertIsNotClosed() {
        if (this.closed) {
            throw new IllegalStateException("Worker pool " + this.workerPoolId + " is closed");
        }
    }

    private static String validateId(String id) {
        if (id == null || !ID_PATTERN.matcher(id).matches()) {
            throw new IllegalArgumentException("Invalid ID: " + id);
        }
        return id;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void close() {
        WorkerPool workerPool = this;
        synchronized (workerPool) {
            if (this.closed) {
                return;
            }
            this.closed = true;
        }
        logger.info((Object)("Closing worker pool " + this.workerPoolId));
        if (this.closeCallback != null) {
            this.closeCallback.run();
        }
        for (Worker worker : this.workers.values()) {
            worker.close();
        }
        this.workers.clear();
    }
}

