/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.kernel;

import com.codahale.metrics.Reservoir;
import com.codahale.metrics.Snapshot;
import com.codahale.metrics.UniformSnapshot;
import com.dataiku.dip.kernel.KernelPool;
import com.dataiku.dip.kernel.KernelScalingStrategy;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dss.shadelib.com.google.common.collect.LinkedHashMultimap;
import com.google.common.annotations.VisibleForTesting;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

public class MovingAverageCapacityStrategy
implements KernelScalingStrategy {
    @VisibleForTesting
    final Map<String, SlidingWindowOrderedReservoir> queuedPerHashMA = new HashMap<String, SlidingWindowOrderedReservoir>();
    @VisibleForTesting
    final Map<String, SlidingWindowOrderedReservoir> activePerHashMA = new HashMap<String, SlidingWindowOrderedReservoir>();
    private static final int LOGGING_FREQUENCY = 30;
    private int logDebugCounter = 1;
    private final DKULogger logger;

    public MovingAverageCapacityStrategy(DKULogger logger) {
        this.logger = logger;
    }

    @VisibleForTesting
    static double calculateRatio(long lastTick, long currentTick, long arrivedAt, long completedAt) {
        assert (lastTick < currentTick);
        long start = Math.max(lastTick, arrivedAt);
        long end = Math.min(currentTick, completedAt);
        double ratio = (double)Math.max(0L, end - start) / (double)(currentTick - lastTick);
        return ratio;
    }

    private Set<String> measureCapacityAndUtilization(List<KernelPool.RequestSnapshot> requests, LinkedHashMultimap<String, KernelPool.KernelSnapshot> kernels, int autoscaleTimeWindow, long lastTick, long currentTick) {
        HashSet<String> allHashes = new HashSet<String>();
        requests.forEach(r -> allHashes.add(r.kernelDescHash));
        allHashes.addAll(kernels.keySet());
        this.queuedPerHashMA.keySet().removeIf(hash -> !allHashes.contains(hash));
        this.activePerHashMA.keySet().removeIf(hash -> !allHashes.contains(hash));
        HashMap<String, Double> queuePerHash = new HashMap<String, Double>();
        HashMap<String, Double> activePerHash = new HashMap<String, Double>();
        for (KernelPool.RequestSnapshot request : requests) {
            double ratio = MovingAverageCapacityStrategy.calculateRatio(lastTick, currentTick, request.arrivedAtTime, request.state == KernelPool.RequestState.COMPLETED ? request.completedAtTime : currentTick);
            if (request.state == KernelPool.RequestState.QUEUED) {
                MovingAverageCapacityStrategy.incrementKey(queuePerHash, request.kernelDescHash, ratio);
                continue;
            }
            if (request.state != KernelPool.RequestState.DISPATCHED && request.state != KernelPool.RequestState.COMPLETED) continue;
            MovingAverageCapacityStrategy.incrementKey(activePerHash, request.kernelDescHash, ratio);
        }
        for (String hash2 : allHashes) {
            double queueValue = queuePerHash.getOrDefault(hash2, 0.0);
            double activeValue = activePerHash.getOrDefault(hash2, 0.0);
            int queuedUpdate = (int)(queueValue > 0.0 ? Math.max(queueValue, 1.0) : queueValue);
            int activeUpdate = (int)(activeValue > 0.0 ? Math.max(activeValue, 1.0) : activeValue);
            MovingAverageCapacityStrategy.updateMovingAverages(this.queuedPerHashMA, hash2, queuedUpdate, autoscaleTimeWindow);
            MovingAverageCapacityStrategy.updateMovingAverages(this.activePerHashMA, hash2, activeUpdate, autoscaleTimeWindow);
        }
        if (this.logger.isDebugEnabled() || this.logger.isTraceEnabled()) {
            for (String hash2 : allHashes) {
                this.log(this.buildMessage(kernels, queuePerHash, activePerHash, hash2));
            }
        }
        return allHashes;
    }

    private void log(String message) {
        if (this.logger.isTraceEnabled()) {
            this.logger.trace((Object)message);
        } else if (this.logger.isDebugEnabled() && this.logDebugCounter % 30 == 0) {
            this.logger.debug((Object)message);
        }
    }

    private void incrementDebugCounter() {
        if (this.logger.isDebugEnabled()) {
            this.logDebugCounter = this.logDebugCounter % 30 == 0 ? 1 : ++this.logDebugCounter;
        }
    }

    private String buildMessage(LinkedHashMultimap<String, KernelPool.KernelSnapshot> kernels, Map<String, Double> queuePerHash, Map<String, Double> activePerHash, String hash) {
        long totalCapacity = kernels.get((Object)hash).stream().filter(k -> k.state == KernelPool.KernelState.READY || k.state == KernelPool.KernelState.STARTING).mapToInt(k -> k.maxRequests).sum();
        return "[" + hash + "] queue=" + String.valueOf(queuePerHash.getOrDefault(hash, 0.0)) + " (~" + (double)Math.round(10.0 * this.queuedPerHashMA.get(hash).getSnapshot().getMean()) / 10.0 + "), active=" + String.valueOf(activePerHash.getOrDefault(hash, 0.0)) + "/" + totalCapacity + " (~" + (double)Math.round(10.0 * this.activePerHashMA.get(hash).getSnapshot().getMean()) / 10.0 + "), kernels=" + kernels.get((Object)hash).stream().map(k -> k.state.toString() + "(" + k.nbProcessingRequests + ")").collect(Collectors.joining(","));
    }

    @Override
    public synchronized KernelScalingStrategy.ScalingResult execute(List<KernelPool.RequestSnapshot> requests, LinkedHashMultimap<String, KernelPool.KernelSnapshot> kernels, int autoscaleTimeWindow, long lastTick, long currentTick) {
        Set<String> allHashes = this.measureCapacityAndUtilization(requests, kernels, autoscaleTimeWindow, lastTick, currentTick);
        HashSet<String> toScaleUp = new HashSet<String>();
        HashSet<String> toScaleDown = new HashSet<String>();
        for (String hash : allHashes) {
            SlidingWindowOrderedReservoir queuedMA = this.queuedPerHashMA.get(hash);
            SlidingWindowOrderedReservoir activeMA = this.activePerHashMA.get(hash);
            double todoMAValues = queuedMA == null ? 0.0 : queuedMA.getSnapshot().getMean();
            double totalCapacity = kernels.get((Object)hash).stream().filter(k -> k.state == KernelPool.KernelState.READY || k.state == KernelPool.KernelState.STARTING).mapToInt(k -> k.maxRequests).sum();
            double activeMAValues = activeMA == null ? 0.0 : activeMA.getSnapshot().getMean();
            KernelPool.KernelSnapshot typicalKernel = kernels.get((Object)hash).stream().findFirst().orElse(null);
            double totalCapacityIfScaledDown = typicalKernel == null ? 0.0 : Math.max(0.0, totalCapacity - (double)typicalKernel.maxRequests);
            this.log("[" + hash + "] Total moving avg=" + (todoMAValues + activeMAValues) + ", Total capacity=" + totalCapacity + ", Total capacity if scaled down=" + totalCapacityIfScaledDown + ", Capacity threshold to scale down=" + totalCapacityIfScaledDown * 0.8);
            if (todoMAValues + activeMAValues > totalCapacity) {
                toScaleUp.add(hash);
                continue;
            }
            if (!(todoMAValues + activeMAValues <= totalCapacityIfScaledDown * 0.8)) continue;
            toScaleDown.add(hash);
        }
        this.incrementDebugCounter();
        KernelScalingStrategy.ScalingResult result = new KernelScalingStrategy.ScalingResult();
        result.up = toScaleUp;
        result.down = toScaleDown;
        return result;
    }

    private static void incrementKey(Map<String, Double> map, String key, double value) {
        map.put(key, map.getOrDefault(key, 0.0) + value);
    }

    @VisibleForTesting
    static void updateMovingAverages(Map<String, SlidingWindowOrderedReservoir> map, String key, int value, int autoscaleTimeWindow) {
        SlidingWindowOrderedReservoir ma = map.get(key);
        if (ma == null) {
            ma = new SlidingWindowOrderedReservoir(autoscaleTimeWindow);
            map.put(key, ma);
        } else if (ma.size() != autoscaleTimeWindow) {
            ma.resize(autoscaleTimeWindow);
        }
        ma.update(value);
    }

    static class SlidingWindowOrderedReservoir
    implements Reservoir {
        private long[] measurements;
        private int lastInsertedIndex;

        public SlidingWindowOrderedReservoir(int size) {
            this.measurements = new long[size];
            this.lastInsertedIndex = size - 1;
        }

        public synchronized int size() {
            return this.measurements.length;
        }

        public synchronized void update(long value) {
            int insertIndex = (this.lastInsertedIndex + 1) % this.measurements.length;
            this.measurements[insertIndex] = value;
            this.lastInsertedIndex = insertIndex;
        }

        public synchronized Snapshot getSnapshot() {
            return new UniformSnapshot(this.measurements);
        }

        public synchronized void resize(int size) {
            long[] newMeasurements = new long[size];
            int valuesCount = Math.min(size, this.measurements.length);
            int startIndex = (this.measurements.length + this.lastInsertedIndex - valuesCount + 1) % this.measurements.length;
            for (int i = 0; i < valuesCount; ++i) {
                newMeasurements[i] = this.measurements[(startIndex + i) % this.measurements.length];
            }
            this.measurements = newMeasurements;
            this.lastInsertedIndex = valuesCount - 1;
        }
    }
}

