/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.gh.core.services.python_execution;

import com.dataiku.dip.DKUApp;
import com.dataiku.dip.kernel.KernelPool;
import com.dataiku.dip.utils.DKUCompletableFuture;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.ExceptionUtils;
import com.dataiku.dip.utils.SmartLogTail;
import com.dataiku.gh.core.services.python_execution.GovernPythonExecutionClient;
import com.dataiku.gh.core.services.python_execution.IPythonExecutionKernelPool;
import com.dataiku.gh.core.services.python_execution.PythonExecutionRequest;
import com.dataiku.gh.core.services.python_execution.PythonExecutionResponse;
import com.dataiku.gh.core.services.python_execution.artifact_action.ArtifactActionScriptRequest;
import com.dataiku.gh.core.services.python_execution.autogovernance_script.AutoGovernanceScriptRequest;
import com.dataiku.gh.core.services.python_execution.instance_action.ActionScriptRequest;
import com.dataiku.gh.core.services.python_execution.logical_hooks.LogicalHookRequest;
import com.dataiku.gh.core.services.python_execution.migration_paths.MigrationPathRequest;
import com.google.common.collect.EvictingQueue;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Supplier;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.springframework.stereotype.Service;

@Service
public class PythonExecutionKernelPool
implements IPythonExecutionKernelPool {
    private static int GLOBAL_MAX_KERNEL_COUNT = 50;
    private static long QUEUED_REQUEST_TIMEOUT_MS = 30000L;
    private final ExecutorService executorService = Executors.newCachedThreadPool(new ThreadFactoryBuilder().setNameFormat("python-execution-kernel-pool-startstop-%d").build());
    private final KernelPool<GovernPythonExecutionClient, KernelDesc, KernelDesc> manager = new KernelPool((KernelPool.KernelController)new KernelPool.KernelController<GovernPythonExecutionClient, KernelDesc, KernelDesc>(){

        @Nonnull
        public GovernPythonExecutionClient createKernel(KernelDesc kernelDesc) {
            return new GovernPythonExecutionClient(kernelDesc.kernelType.kernelTypeId);
        }

        @Nonnull
        public CompletableFuture<Void> startKernel(GovernPythonExecutionClient kernel, KernelDesc kernelDesc) {
            return DKUCompletableFuture.runAsync(kernel::init, (Executor)PythonExecutionKernelPool.this.executorService);
        }

        public int getGlobalMaxKernelCount() {
            return KernelType.getGlobalMaxKernelCount();
        }

        public int getAutoscaleTimeWindowSeconds(KernelDesc kernelDesc) {
            return kernelDesc.kernelType.getAutoscaleTimeWindowSeconds();
        }

        public int getHardMaxParallelRequests(KernelDesc kernelDesc) {
            return 1;
        }

        public int getSoftMaxParallelRequests(KernelDesc kernelDesc) {
            return 1;
        }

        @Nonnull
        public CompletableFuture<Void> killKernel(GovernPythonExecutionClient kernel) {
            return DKUCompletableFuture.runAsync(kernel::close, (Executor)PythonExecutionKernelPool.this.executorService);
        }

        public boolean isAlive(GovernPythonExecutionClient kernel) {
            return kernel.getKernel() != null && kernel.getKernel().isAlive();
        }

        public String getKernelId(GovernPythonExecutionClient kernel) {
            return kernel.getKernelId();
        }

        public Long getQueuedRequestTimeoutInNs() {
            return KernelType.getQueuedRequestTimeoutMs() * 1000L * 1000L;
        }

        @Nonnull
        public Collection<KernelPool.ReservedCapacityRequest<KernelDesc, KernelDesc>> getKernelsWithMinCount() {
            ArrayList<KernelPool.ReservedCapacityRequest<KernelDesc, KernelDesc>> kernelsWithMinCount = new ArrayList<KernelPool.ReservedCapacityRequest<KernelDesc, KernelDesc>>();
            for (KernelType type : KernelType.values()) {
                KernelDesc kernelDesc = KernelDesc.build(type);
                kernelsWithMinCount.add((KernelPool.ReservedCapacityRequest<KernelDesc, KernelDesc>)new KernelPool.ReservedCapacityRequest(kernelDesc.kernelType.getMinKernelCount(), (Object)kernelDesc, kernelDesc.kernelType.name(), (Object)kernelDesc, kernelDesc.kernelType.name()));
            }
            return kernelsWithMinCount;
        }

        public Integer getMaxKernelCount(KernelDesc kernelGroup) {
            return kernelGroup.kernelType.getMaxKernelCount();
        }

        public boolean killKernelOnRequestFailure() {
            return true;
        }

        @Nullable
        public SmartLogTail getKernelLog(GovernPythonExecutionClient kernel) {
            if (kernel == null) {
                return null;
            }
            return kernel.getKernel().getSmartLogTailBuilder().get();
        }
    }, "python-execution", logger);
    private static final DKULogger logger = DKULogger.getLogger((String)"gh.services.python-execution.kernel-pool");

    @Override
    public <REQ extends PythonExecutionRequest, RESP extends PythonExecutionResponse> RESP executePythonScript(Supplier<REQ> requestSupplier, Class<RESP> responseClass) throws IOException {
        PythonExecutionRequest request = (PythonExecutionRequest)requestSupplier.get();
        KernelDesc kernelDesc = KernelDesc.build(request.getClass());
        CompletableFuture future = this.manager.handle(kernel -> kernel.getKernel().executeAsyncIO(request, responseClass).orTimeout(kernelDesc.kernelType.computeEffectiveExecutionTimeoutMs(), TimeUnit.MILLISECONDS), (Object)kernelDesc, kernelDesc.kernelType.name(), (Object)request);
        future.whenCompleteAsync((result, throwable) -> kernelDesc.kernelType.timeoutResultForLatestExecutions.add(throwable instanceof TimeoutException), (Executor)this.executorService);
        try {
            return (RESP)((PythonExecutionResponse)DKUCompletableFuture.collectResponse((CompletableFuture)future));
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new IOException("Unexpected error (python execution was interrupted). Ask your administrator to investigate.", e);
        }
        catch (TimeoutException e) {
            throw new IOException("Error while performing the action (timed out). Ask your administrator to investigate.", e);
        }
        catch (Exception e) {
            throw new IOException("Unexpected error (python request failed exceptionally). Ask your administrator to investigate.", e);
        }
    }

    @Override
    public <REQ extends PythonExecutionRequest, RESP extends PythonExecutionResponse> RESP executePythonScriptsSequentially(Class<REQ> requestClass, ExceptionUtils.ThrowingSupplier<REQ, IOException> requestsSupplier, Class<RESP> responseClass, ExceptionUtils.ThrowingConsumer<IPythonExecutionKernelPool.RequestResponseData<REQ, RESP>, IOException> postResponseConsumer, Object executionLoggingObject) throws IOException {
        KernelDesc kernelDesc = KernelDesc.build(requestClass);
        CompletableFuture future = this.manager.handle(kernelHandle -> this.buildNestedSequentialFuture(null, requestsSupplier, (GovernPythonExecutionClient)kernelHandle, responseClass, kernelDesc.kernelType, postResponseConsumer), (Object)kernelDesc, kernelDesc.kernelType.name(), executionLoggingObject);
        try {
            return (RESP)((PythonExecutionResponse)DKUCompletableFuture.collectResponse((CompletableFuture)future));
        }
        catch (UncheckedIOException e) {
            throw e.getCause();
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new IOException("Unexpected error (python execution was interrupted). Ask your administrator to investigate.", e);
        }
        catch (TimeoutException e) {
            throw new IOException("Error while performing the action (timed out). Ask your administrator to investigate.", e);
        }
        catch (Exception e) {
            throw new IOException("Unexpected error (python request failed exceptionally). Ask your administrator to investigate.", e);
        }
    }

    private <REQ extends PythonExecutionRequest, RESP extends PythonExecutionResponse> CompletableFuture<RESP> buildNestedSequentialFuture(@Nullable RESP currentResponse, ExceptionUtils.ThrowingSupplier<REQ, IOException> requestsSupplier, GovernPythonExecutionClient kernelHandle, Class<RESP> responseClass, KernelType kernelType, ExceptionUtils.ThrowingConsumer<IPythonExecutionKernelPool.RequestResponseData<REQ, RESP>, IOException> postResponseConsumer) {
        PythonExecutionRequest request;
        try {
            request = (PythonExecutionRequest)requestsSupplier.get();
        }
        catch (IOException e) {
            throw new UncheckedIOException(e);
        }
        if (request == null) {
            return CompletableFuture.completedFuture(currentResponse);
        }
        return kernelHandle.getKernel().executeAsyncIO(request, responseClass).orTimeout(kernelType.computeEffectiveExecutionTimeoutMs(), TimeUnit.MILLISECONDS).thenCompose(newResponse -> {
            try {
                postResponseConsumer.accept(new IPythonExecutionKernelPool.RequestResponseData<PythonExecutionRequest, PythonExecutionResponse>(request, (PythonExecutionResponse)newResponse));
            }
            catch (IOException e) {
                throw new UncheckedIOException(e);
            }
            return this.buildNestedSequentialFuture(newResponse, requestsSupplier, kernelHandle, responseClass, kernelType, postResponseConsumer);
        });
    }

    @Override
    public KernelPool.PoolDump dump(boolean full) {
        return this.manager.dump(full);
    }

    @Override
    public void killAllKernels(KernelPool.DeathReason deathReason) {
        this.manager.killAllKernels(deathReason);
    }

    @Override
    public void killAllRequests() {
        this.manager.killAllRequests();
    }

    private static class KernelDesc {
        public final KernelType kernelType;
        private static final Map<Class<? extends PythonExecutionRequest>, KernelType> REQ_TO_TYPE_MAPPING = Map.ofEntries(Map.entry(MigrationPathRequest.class, KernelType.MIGRATION_PATH), Map.entry(AutoGovernanceScriptRequest.class, KernelType.AUTO_GOVERNANCE), Map.entry(ActionScriptRequest.class, KernelType.ACTION), Map.entry(ArtifactActionScriptRequest.class, KernelType.ARTIFACT_ACTION), Map.entry(LogicalHookRequest.class, KernelType.LOGICAL_HOOK));

        private KernelDesc(KernelType kernelType) {
            this.kernelType = kernelType;
        }

        public static KernelDesc build(KernelType kernelType) {
            return new KernelDesc(kernelType);
        }

        public static KernelDesc build(Class<? extends PythonExecutionRequest> requestClass) {
            KernelType kernelType = REQ_TO_TYPE_MAPPING.get(requestClass);
            if (kernelType == null) {
                throw new RuntimeException("Unknown request type: " + String.valueOf(requestClass));
            }
            return new KernelDesc(kernelType);
        }
    }

    private static enum KernelType {
        MIGRATION_PATH("migration-path", 1, 0, 10, 5000L, 0.9),
        AUTO_GOVERNANCE("auto-governance", 1, 1, 10, 5000L, 0.9),
        ACTION("action", 1, 1, 10, 60000L, 0.9),
        ARTIFACT_ACTION("artifact-action", 1, 1, 10, 60000L, 0.9),
        LOGICAL_HOOK("logical-hook", 1, 1, 10, 5000L, 0.9);

        private final String kernelTypeId;
        private final int defaultAutoscaleTimeWindowSeconds;
        private final int defaultMinCount;
        private final int defaultMaxKernelCount;
        private final long defaultExecutionTimeoutMs;
        private final double defaultPotentialRecursionDetectionRatio;
        private final Collection<Boolean> timeoutResultForLatestExecutions;

        private KernelType(String kernelTypeId, int defaultAutoscaleTimeWindowSeconds, int defaultMinCount, int defaultMaxKernelCount, long defaultExecutionTimeoutMs, double defaultPotentialRecursionDetectionRatio) {
            this.kernelTypeId = kernelTypeId;
            this.defaultAutoscaleTimeWindowSeconds = defaultAutoscaleTimeWindowSeconds;
            this.defaultMinCount = defaultMinCount;
            this.defaultMaxKernelCount = defaultMaxKernelCount;
            this.defaultExecutionTimeoutMs = defaultExecutionTimeoutMs;
            this.defaultPotentialRecursionDetectionRatio = defaultPotentialRecursionDetectionRatio;
            this.timeoutResultForLatestExecutions = Collections.synchronizedCollection(EvictingQueue.create((int)this.getMaxKernelCount()));
        }

        private int getAutoscaleTimeWindowSeconds() {
            return DKUApp.getParams().getIntParam(this.kernelBasedPropertyName("autoscaleTimeWindowSeconds"), Integer.valueOf(this.defaultAutoscaleTimeWindowSeconds));
        }

        private int getMinKernelCount() {
            return DKUApp.getParams().getIntParam(this.kernelBasedPropertyName("minKernelCount"), Integer.valueOf(this.defaultMinCount));
        }

        private int getMaxKernelCount() {
            return DKUApp.getParams().getIntParam(this.kernelBasedPropertyName("maxKernelCount"), Integer.valueOf(this.defaultMaxKernelCount));
        }

        private double getPotentialRecursionDetectionRatio() {
            return DKUApp.getParams().getDoubleParam(this.kernelBasedPropertyName("potentialRecursionDetectionRatio"), this.defaultPotentialRecursionDetectionRatio);
        }

        private String kernelBasedPropertyName(String propertyName) {
            return "dku.govern.python-execution." + this.kernelTypeId + "." + propertyName;
        }

        private long computeEffectiveExecutionTimeoutMs() {
            long queuedRequestTimeoutMs = KernelType.getQueuedRequestTimeoutMs();
            int maxKernelCount = this.getMaxKernelCount();
            double potentialRecursionDetectionRatio = this.getPotentialRecursionDetectionRatio();
            long executionTimeoutMs = DKUApp.getParams().getLongParam(this.kernelBasedPropertyName("executionTimeoutMs"), DKUApp.getParams().getLongParam(KernelType.nonKernelBasedPropertyName("execution-timeout"), this.defaultExecutionTimeoutMs));
            long countOfLatestTimeout = this.timeoutResultForLatestExecutions.stream().filter(Boolean.TRUE::equals).count();
            long effectiveExecutionTimeoutMs = executionTimeoutMs;
            if ((double)countOfLatestTimeout / (double)maxKernelCount > potentialRecursionDetectionRatio) {
                logger.warnV("Detected %d timeouts in the last %d python execution runs exceeding configured ratio of %f, increasing the execution timeout from %d ms to %d ms to prevent potential recursions", new Object[]{countOfLatestTimeout, maxKernelCount, potentialRecursionDetectionRatio, executionTimeoutMs, effectiveExecutionTimeoutMs += queuedRequestTimeoutMs * (long)maxKernelCount});
            }
            return effectiveExecutionTimeoutMs;
        }

        private static int getGlobalMaxKernelCount() {
            return DKUApp.getParams().getIntParam(KernelType.nonKernelBasedPropertyName("globalMaxKernelCount"), Integer.valueOf(GLOBAL_MAX_KERNEL_COUNT));
        }

        private static long getQueuedRequestTimeoutMs() {
            return DKUApp.getParams().getLongParam(KernelType.nonKernelBasedPropertyName("queuedRequestTimeoutMs"), QUEUED_REQUEST_TIMEOUT_MS);
        }

        private static String nonKernelBasedPropertyName(String propertyName) {
            return "dku.govern.python-execution." + propertyName;
        }
    }
}

