/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.interactivemodel;

import com.dataiku.dip.DKUApp;
import com.dataiku.dip.analysis.coreservices.CacheableKernelService;
import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.interactivemodel.InteractiveModelKernel;
import com.dataiku.dip.analysis.ml.interactivemodel.InteractiveModelParams;
import com.dataiku.dip.analysis.ml.interactivemodel.InteractiveModelResponse;
import com.dataiku.dip.analysis.ml.interactivemodel.InteractiveModelResultCache;
import com.dataiku.dip.datalayer.Row;
import com.dataiku.dip.datalayer.memimpl.MemColumn;
import com.dataiku.dip.datalayer.memimpl.MemRow;
import com.dataiku.dip.datalayer.memimpl.MemTable;
import com.dataiku.dip.futures.FutureAborter;
import com.dataiku.dip.futures.FuturePayload;
import com.dataiku.dip.futures.FutureProgress;
import com.dataiku.dip.futures.FutureProgressState;
import com.dataiku.dip.futures.FutureResponse;
import com.dataiku.dip.futures.FutureService;
import com.dataiku.dip.futures.SimpleFutureThread;
import com.dataiku.dip.io.AlivenessCheckableAndClosable;
import com.dataiku.dip.meanings.MeaningsDAO;
import com.dataiku.dip.meanings.model.UserDefinedMeaning;
import com.dataiku.dip.resourceusage.ComputeResourceUsageContext;
import com.dataiku.dip.resourceusage.CurrentComputeResourceUsageContext;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.DSSAuthCtx;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.Pair;
import com.dataiku.dip.variables.VariablesService;
import com.dataiku.lambda.endpoints.PreparationStep;
import com.google.gson.JsonObject;
import com.google.gson.reflect.TypeToken;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class InteractiveModelService
extends CacheableKernelService {
    @Autowired
    FutureService futureService;
    @Autowired
    MeaningsDAO meaningsDAO;
    @Autowired
    TransactionService transactionService;
    @Autowired
    VariablesService variablesService;
    public static final int ASYNC_TIMEOUT = 500;
    private final InteractiveModelResultCache resultCache = new InteractiveModelResultCache();
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.ml.interactivemodel");

    public FutureResponse<InteractiveModelResponse.InteractiveModelResult[]> compute_NT(AuthCtx authCtx, FullModelId fmi, InteractiveModelParams.ComputationParams computationParams, List<JsonObject> records, boolean computeEvenIfCached) throws Exception {
        InteractiveModelFutureThread ft = new InteractiveModelFutureThread((DSSAuthCtx)authCtx, fmi, computeEvenIfCached, computationParams, records);
        return this.futureService.runFuture(ft, 500L, new TypeToken<FutureResponse<InteractiveModelResponse.InteractiveModelResult[]>>(){});
    }

    public void invalidateModel(FullModelId fmi) {
        logger.info((Object)("Model ('" + String.valueOf(fmi) + "') has changed, invalidating caches"));
        this.invalidateCache(pair -> ((String)pair.second).equals(fmi.toString()));
        this.resultCache.invalidateForModel(fmi);
    }

    @Override
    protected int getKernelCacheExpirationTimeInSeconds() {
        return DKUApp.getParams().getIntParam("dku.ml.interactive_scoring.kernel.expirationTimeS", Integer.valueOf(600));
    }

    class InteractiveModelFutureThread
    extends SimpleFutureThread<InteractiveModelResponse.InteractiveModelResult[]>
    implements InteractiveModelComputationRunner {
        final FullModelId fmi;
        final boolean computeEvenIfCached;
        private final InteractiveModelParams.ComputationParams computationParams;
        private final List<JsonObject> records;

        InteractiveModelFutureThread(DSSAuthCtx owner, FullModelId fmi, boolean computeEvenIfCached, InteractiveModelParams.ComputationParams paracomputationParamss, List<JsonObject> records) {
            super((AuthCtx)owner);
            this.fmi = fmi;
            this.computeEvenIfCached = computeEvenIfCached;
            this.computationParams = paracomputationParamss;
            this.records = records;
        }

        public FuturePayload getPayload() {
            return FuturePayload.newSimple((String)"score_interactive_model", (String)"Score interactive model");
        }

        @Override
        protected InteractiveModelResponse.InteractiveModelResult[] compute() throws Exception {
            CachedInteractiveModelComputationRunner cachedRunner = new CachedInteractiveModelComputationRunner(this.fmi, InteractiveModelService.this.resultCache, this.computeEvenIfCached, this);
            return cachedRunner.compute(this.computationParams, this.records);
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public InteractiveModelResponse.InteractiveModelResult[] compute(InteractiveModelParams.ComputationParams computationParams, List<JsonObject> records) throws Exception {
            InteractiveModelResponse.InteractiveModelResult[] results;
            InteractiveModelKernelBuilder builder = new InteractiveModelKernelBuilder(this.fmi);
            try (FutureAborter.AutoCloseableAbortHook ignored = FutureAborter.pushAutoCloseableHook(() -> {
                try {
                    ((InteractiveModelKernel)InteractiveModelService.this.acquireKernel(this.owner, builder)).killKernel();
                }
                catch (Exception e) {
                    logger.warn((Object)"Error while aborting interactive scoring kernel", (Throwable)e);
                }
            });){
                InteractiveModelKernel kernel;
                ComputeResourceUsageContext cruContext = this.fmi.getComputeResourceUsageContext(this.owner);
                CurrentComputeResourceUsageContext.setInCurrentThread((ComputeResourceUsageContext)cruContext);
                try {
                    kernel = (InteractiveModelKernel)InteractiveModelService.this.acquireKernel(this.owner, builder);
                    kernel.startIfNeeded();
                }
                catch (Exception e) {
                    throw new RuntimeException("Could not start Python kernel", e);
                }
                try (FutureProgress.AutocloseableFutureProgressState ignored1 = FutureProgress.pushAutoCloseableState((String)"Computing", (double)1.0, (FutureProgressState.StateUnit)FutureProgressState.StateUnit.NONE);){
                    results = this.actuallyComputeResults(kernel, computationParams, records);
                }
            }
            finally {
                CurrentComputeResourceUsageContext.clear();
            }
            return results;
        }

        private InteractiveModelResponse.InteractiveModelResult[] actuallyComputeResults(InteractiveModelKernel kernel, InteractiveModelParams.ComputationParams computationParams, List<JsonObject> records) throws Exception {
            boolean actuallyApplyPreparationScript;
            MemTable memTable = null;
            boolean bl = actuallyApplyPreparationScript = computationParams.applyPreparationScript && !this.fmi.isExternalMLflowModelVersion();
            if (actuallyApplyPreparationScript) {
                List<UserDefinedMeaning> meaningList;
                memTable = this.tableFromJson(records);
                try (Transaction ignored = InteractiveModelService.this.transactionService.beginRead();){
                    meaningList = InteractiveModelService.this.meaningsDAO.listUnsafe();
                }
                PreparationStep preparationStep = new PreparationStep(this.fmi.getSessionFolder(), meaningList, InteractiveModelService.this.variablesService.getForProject(this.fmi.getProjectKey()));
                preparationStep.process(memTable);
                records = this.tableToJson(memTable);
            }
            InteractiveModelResponse.InteractiveModelResult[] results = kernel.compute(computationParams, records);
            if (actuallyApplyPreparationScript) {
                results = this.realignResultsWithDroppedRows(results, memTable);
            }
            return results;
        }

        private InteractiveModelResponse.InteractiveModelResult[] realignResultsWithDroppedRows(InteractiveModelResponse.InteractiveModelResult[] results, MemTable input) {
            if (input.nrows() == results.length) {
                return results;
            }
            InteractiveModelResponse.InteractiveModelResult[] realignedResults = new InteractiveModelResponse.InteractiveModelResult[input.nrows()];
            int itemsIdx = 0;
            for (int i = 0; i < input.nrows(); ++i) {
                Row row = input.rows.get(i);
                if (row.isDeleted()) {
                    realignedResults[i] = null;
                    continue;
                }
                realignedResults[i] = results[itemsIdx];
                ++itemsIdx;
            }
            return realignedResults;
        }

        private List<JsonObject> tableToJson(MemTable memTable) {
            ArrayList<JsonObject> result = new ArrayList<JsonObject>(memTable.nrows());
            for (MemRow row : memTable.rows) {
                if (row.isDeleted()) continue;
                JsonObject record = new JsonObject();
                for (MemColumn col : memTable.columnsList) {
                    record.addProperty(col.getName(), row.get(col));
                }
                result.add(record);
            }
            return result;
        }

        private MemTable tableFromJson(List<JsonObject> items) {
            MemTable table = new MemTable();
            for (JsonObject item : items) {
                table.addRowFromJsonObject(item);
            }
            table.addDumbRowsIfEmpty();
            return table;
        }
    }

    private static class InteractiveModelKernelBuilder
    implements CacheableKernelService.CacheableKernelBuilder {
        private final FullModelId fmi;

        public InteractiveModelKernelBuilder(FullModelId fmi) {
            this.fmi = fmi;
        }

        @Override
        public AlivenessCheckableAndClosable createKernel(AuthCtx authCtx) throws Exception {
            return new InteractiveModelKernel(authCtx, this.fmi, false, null);
        }

        @Override
        public Pair<String, String> getKey(AuthCtx authCtx) {
            return new Pair((Object)authCtx.toString(), (Object)this.fmi.toString());
        }
    }

    static class CachedInteractiveModelComputationRunner
    implements InteractiveModelComputationRunner {
        private final InteractiveModelComputationRunner actualComputationRunner;
        private final InteractiveModelResultCache cache;
        private final FullModelId fmi;
        private final boolean computeEvenIfCached;

        CachedInteractiveModelComputationRunner(FullModelId fmi, InteractiveModelResultCache cache, boolean computeEvenIfCached, InteractiveModelComputationRunner actualComputationRunner) {
            this.fmi = fmi;
            this.computeEvenIfCached = computeEvenIfCached;
            this.cache = cache;
            this.actualComputationRunner = actualComputationRunner;
        }

        @Override
        public InteractiveModelResponse.InteractiveModelResult[] compute(InteractiveModelParams.ComputationParams computationParams, List<JsonObject> records) throws Exception {
            logger.info((Object)("Will return results from " + records.size() + " record(s)"));
            ArrayList<Pair> remainingRecordsToCompute = new ArrayList<Pair>();
            ArrayList<Pair> cachedResults = new ArrayList<Pair>();
            if (this.computeEvenIfCached) {
                logger.info((Object)"Bypassing cache, computing result with python kernel");
                for (i = 0; i < records.size(); ++i) {
                    remainingRecordsToCompute.add(new Pair((Object)i, (Object)records.get(i)));
                }
            } else {
                i = 0;
                for (JsonObject record : records) {
                    InteractiveModelResultCache.CachedValue cachedValue = this.cache.get(this.fmi, computationParams, record);
                    if (cachedValue != null) {
                        cachedResults.add(new Pair((Object)i, (Object)cachedValue.value));
                    } else {
                        remainingRecordsToCompute.add(new Pair((Object)i, (Object)record));
                    }
                    ++i;
                }
                logger.info((Object)("Got " + cachedResults.size() + " cached result(s), will run a computation on " + remainingRecordsToCompute.size() + " record(s)."));
            }
            InteractiveModelResponse.InteractiveModelResult[] finalResults = new InteractiveModelResponse.InteractiveModelResult[records.size()];
            for (Pair cachedResult : cachedResults) {
                finalResults[((Integer)cachedResult.first).intValue()] = (InteractiveModelResponse.InteractiveModelResult)cachedResult.second;
            }
            if (!remainingRecordsToCompute.isEmpty()) {
                List<JsonObject> remainingRecordsToComputeValues = remainingRecordsToCompute.stream().map(r -> (JsonObject)r.second).collect(Collectors.toList());
                InteractiveModelResponse.InteractiveModelResult[] remainingComputationResults = this.actualComputationRunner.compute(computationParams, remainingRecordsToComputeValues);
                int i = 0;
                for (Pair remainingRecordToCompute : remainingRecordsToCompute) {
                    InteractiveModelResponse.InteractiveModelResult result = remainingComputationResults[i];
                    this.cache.put(this.fmi, computationParams, (JsonObject)remainingRecordToCompute.second, result);
                    finalResults[((Integer)remainingRecordToCompute.first).intValue()] = result;
                    ++i;
                }
            }
            return finalResults;
        }
    }

    static interface InteractiveModelComputationRunner {
        public InteractiveModelResponse.InteractiveModelResult[] compute(InteractiveModelParams.ComputationParams var1, List<JsonObject> var2) throws Exception;
    }
}

