/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.kernel;

import com.codahale.metrics.Reservoir;
import com.codahale.metrics.Snapshot;
import com.codahale.metrics.UniformSnapshot;
import com.dataiku.dip.kernel.KernelPool;
import com.dataiku.dip.kernel.KernelScalingStrategy;
import com.dataiku.dip.utils.DKULogger;
import com.google.common.annotations.VisibleForTesting;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.math3.util.Precision;

public class MovingAverageCapacityStrategy
implements KernelScalingStrategy {
    @VisibleForTesting
    final SlidingWindowOrderedReservoir queuedMA = new SlidingWindowOrderedReservoir(1);
    @VisibleForTesting
    final SlidingWindowOrderedReservoir activeMA = new SlidingWindowOrderedReservoir(1);
    private static final int LOGGING_FREQUENCY = 30;
    private final String kernelDescHash;
    private int logDebugCounter = 1;
    private final DKULogger logger;

    public MovingAverageCapacityStrategy(String kernelDescHash, DKULogger logger) {
        this.logger = logger;
        this.kernelDescHash = kernelDescHash;
    }

    @VisibleForTesting
    static double calculateRatio(long lastTick, long currentTick, long arrivedAt, long completedAt) {
        assert (lastTick < currentTick);
        long start = Math.max(lastTick, arrivedAt);
        long end = Math.min(currentTick, completedAt);
        return (double)Math.max(0L, end - start) / (double)(currentTick - lastTick);
    }

    private void measureCapacityAndUtilization(List<KernelPool.RequestSnapshot> requests, List<KernelPool.KernelSnapshot> kernels, int autoscaleTimeWindow, long lastTick, long currentTick) {
        double nbQueuedRequests = 0.0;
        double nbActiveRequests = 0.0;
        for (KernelPool.RequestSnapshot request : requests) {
            double ratio = MovingAverageCapacityStrategy.calculateRatio(lastTick, currentTick, request.arrivedAtTime, request.state == KernelPool.RequestState.COMPLETED ? request.completedAtTime : currentTick);
            if (request.state == KernelPool.RequestState.QUEUED) {
                nbQueuedRequests += ratio;
                continue;
            }
            if (request.state != KernelPool.RequestState.DISPATCHED && request.state != KernelPool.RequestState.COMPLETED) continue;
            nbActiveRequests += ratio;
        }
        int queuedUpdate = (int)(nbQueuedRequests > 0.0 ? Math.max(nbQueuedRequests, 1.0) : nbQueuedRequests);
        int activeUpdate = (int)(nbActiveRequests > 0.0 ? Math.max(nbActiveRequests, 1.0) : nbActiveRequests);
        MovingAverageCapacityStrategy.updateMovingAverages(this.queuedMA, queuedUpdate, autoscaleTimeWindow);
        MovingAverageCapacityStrategy.updateMovingAverages(this.activeMA, activeUpdate, autoscaleTimeWindow);
        if (this.logger.isDebugEnabled() || this.logger.isTraceEnabled()) {
            this.log(this.buildMessage(kernels, nbQueuedRequests, nbActiveRequests));
        }
    }

    private void log(String message) {
        if (this.logger.isTraceEnabled()) {
            this.logger.trace((Object)message);
        } else if (this.logger.isDebugEnabled() && this.logDebugCounter % 30 == 0) {
            this.logger.debug((Object)message);
        }
    }

    private void incrementDebugCounter() {
        if (this.logger.isDebugEnabled()) {
            this.logDebugCounter = this.logDebugCounter % 30 == 0 ? 1 : ++this.logDebugCounter;
        }
    }

    private String buildMessage(List<KernelPool.KernelSnapshot> kernels, double queuePerHash, double activePerHash) {
        long totalCapacity = kernels.stream().filter(k -> k.state == KernelPool.KernelState.READY || k.state == KernelPool.KernelState.STARTING).mapToInt(k -> k.maxRequests).sum();
        return "[" + this.kernelDescHash + "] queue=" + queuePerHash + " (~" + (double)Math.round(10.0 * this.queuedMA.getSnapshot().getMean()) / 10.0 + "), active=" + activePerHash + "/" + totalCapacity + " (~" + (double)Math.round(10.0 * this.activeMA.getSnapshot().getMean()) / 10.0 + "), kernels=" + kernels.stream().map(k -> k.state.toString() + "(" + k.nbProcessingRequests + ")").collect(Collectors.joining(","));
    }

    @Override
    public synchronized int execute(List<KernelPool.RequestSnapshot> requests, List<KernelPool.KernelSnapshot> kernels, int autoscaleTimeWindow, long lastTick, long currentTick) {
        this.measureCapacityAndUtilization(requests, kernels, autoscaleTimeWindow, lastTick, currentTick);
        double todoMAValues = this.queuedMA.getSnapshot().getMean();
        long totalCapacity = kernels.stream().filter(k -> k.state == KernelPool.KernelState.READY || k.state == KernelPool.KernelState.STARTING).mapToInt(k -> k.maxRequests).sum();
        double activeMAValues = this.activeMA.getSnapshot().getMean();
        KernelPool.KernelSnapshot typicalKernel = kernels.stream().findFirst().orElse(null);
        long totalCapacityIfScaledDown = typicalKernel == null ? 0L : Math.max(0L, totalCapacity - (long)typicalKernel.maxRequests);
        this.log("[" + this.kernelDescHash + "] Total moving avg=" + Precision.round((double)(todoMAValues + activeMAValues), (int)2) + ", Total capacity=" + totalCapacity + ", Total capacity if scaled down=" + totalCapacityIfScaledDown + ", Capacity threshold to scale down=" + Precision.round((double)((double)totalCapacityIfScaledDown * 0.8), (int)1));
        this.incrementDebugCounter();
        if (todoMAValues + activeMAValues > (double)totalCapacity) {
            return 1;
        }
        if (todoMAValues + activeMAValues <= (double)totalCapacityIfScaledDown * 0.8) {
            return -1;
        }
        return 0;
    }

    @VisibleForTesting
    static void updateMovingAverages(SlidingWindowOrderedReservoir ma, int value, int autoscaleTimeWindow) {
        if (ma.size() != autoscaleTimeWindow) {
            ma.resize(autoscaleTimeWindow);
        }
        ma.update(value);
    }

    static class SlidingWindowOrderedReservoir
    implements Reservoir {
        private long[] measurements;
        private int lastInsertedIndex;

        public SlidingWindowOrderedReservoir(int size) {
            this.measurements = new long[size];
            this.lastInsertedIndex = size - 1;
        }

        public synchronized int size() {
            return this.measurements.length;
        }

        public synchronized void update(long value) {
            int insertIndex = (this.lastInsertedIndex + 1) % this.measurements.length;
            this.measurements[insertIndex] = value;
            this.lastInsertedIndex = insertIndex;
        }

        public synchronized Snapshot getSnapshot() {
            return new UniformSnapshot(this.measurements);
        }

        public synchronized void resize(int size) {
            long[] newMeasurements = new long[size];
            int valuesCount = Math.min(size, this.measurements.length);
            int startIndex = (this.measurements.length + this.lastInsertedIndex - valuesCount + 1) % this.measurements.length;
            for (int i = 0; i < valuesCount; ++i) {
                newMeasurements[i] = this.measurements[(startIndex + i) % this.measurements.length];
            }
            this.measurements = newMeasurements;
            this.lastInsertedIndex = valuesCount - 1;
        }
    }
}

