/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.lambda.endpoints.predictcommon;

import com.codahale.metrics.Meter;
import com.codahale.metrics.Timer;
import com.dataiku.dip.DSSMetrics;
import com.dataiku.dip.datalayer.Column;
import com.dataiku.dip.datalayer.Row;
import com.dataiku.dip.datalayer.memimpl.MemTable;
import com.dataiku.dip.security.audit.AuditTrailService;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.lambda.APINodeMetrics;
import com.dataiku.lambda.ServiceGenContext;
import com.dataiku.lambda.controllers.RequestMetadata;
import com.dataiku.lambda.endpoints.LambdaEndpointHandler;
import com.dataiku.lambda.endpoints.pool.PipelinePool;
import com.dataiku.lambda.endpoints.pool.PoolCallbacks;
import com.dataiku.lambda.endpoints.predictcommon.PipelineMessage;
import com.dataiku.lambda.endpoints.predictcommon.PoolablePipelineWithEnrich;
import com.dataiku.lambda.model.api.ExplanationsQuery;
import com.dataiku.lambda.model.api.ForecastQuery;
import com.dataiku.lambda.model.api.MultiplePredictionQuery;
import com.dataiku.lambda.model.api.PredictionResponse;
import com.dataiku.lambda.model.api.ResponseElements;
import com.dataiku.lambda.model.api.SinglePredictionQuery;
import com.dataiku.lambda.model.serverconfig.LambdaEndpointConfig;
import com.dataiku.lambda.model.serverconfig.PredictionEndpointConfig;
import com.dataiku.lambda.model.serverconfig.QueryAPIKey;
import com.dataiku.lambda.services.ServiceManager;
import com.google.common.collect.Lists;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.Nullable;

public abstract class PredictionEndpointHandlerBase<C extends LambdaEndpointConfig, P extends PoolablePipelineWithEnrich>
extends LambdaEndpointHandler<C>
implements PoolCallbacks<P> {
    protected String loggerSuffix;
    protected PipelinePool<P> pool;
    protected Meter requestsMeter;
    protected Meter successRequestMeter;
    protected Timer enrichTimer;
    protected Timer predictTimer;
    protected Timer totalTimer;
    protected boolean demoMode;
    protected boolean destroying = false;
    private AtomicInteger totalNbQueries = new AtomicInteger();
    private static DKULogger logger = DKULogger.getLogger((String)"dku.lambda.prediction.handler");

    public PredictionEndpointHandlerBase(C config) {
        super(config);
    }

    @Override
    public final void init(ServiceGenContext context) throws Exception {
        super.init(context);
        assert (this.lmContext != null);
        this.pool = new PipelinePool(this, context.getServiceId(), this.config.id, this.serverEndpointConfig.pool);
        this.demoMode = this.lmContext.isDemoMode();
        this.loggerSuffix = " (" + context.getServiceId() + "/" + this.config.id + ")";
        this.requestsMeter = APINodeMetrics.endpointMeter(context.getServiceId(), this.config.id, "requests");
        this.successRequestMeter = APINodeMetrics.endpointMeter(context.getServiceId(), this.config.id, "successRequests");
        this.enrichTimer = APINodeMetrics.endpointTimer(context.getServiceId(), this.config.id, "enrich");
        this.predictTimer = APINodeMetrics.endpointTimer(context.getServiceId(), this.config.id, "predict");
        this.totalTimer = APINodeMetrics.endpointTimer(context.getServiceId(), this.config.id, "totalProcessing");
        this._init(context);
        this.pool.init();
    }

    protected void _init(ServiceGenContext context) throws Exception {
    }

    @Override
    public synchronized void destroy() {
        logger.info((Object)("Destroy endpoint" + this.loggerSuffix));
        this.pool.destroy();
    }

    @Override
    public abstract P instantiatePipeline() throws Exception;

    protected abstract EnrichedPredictionResponse predict(long var1, PipelineMessage var3) throws Exception;

    private PredictionResponse predictAndLog(long startTime, ServiceManager.RefcountedEndpoint re, PipelineMessage message, boolean isForecast) throws Exception {
        PredictionResponse predictionResponse;
        block11: {
            if (this.totalNbQueries.incrementAndGet() > 20 && this.demoMode) {
                throw new SecurityException("Your license does not cover usage of the API node. Please contact Dataiku");
            }
            EnrichedPredictionResponse eresp = null;
            ResponseElements.Context apiContext = this.newContext();
            DSSMetrics.TimeCtx tctx = DSSMetrics.timeCtx((Timer)this.totalTimer);
            try {
                eresp = this.predict(startTime, message);
                PredictionResponse resp = eresp.response;
                resp.apiContext = apiContext;
                for (int i = 0; i < resp.results.size(); ++i) {
                    PredictionResponse.PredictionResponseItem result = (PredictionResponse.PredictionResponseItem)resp.results.get(i);
                    JsonObject input = null;
                    MultiplePredictionQuery.Item item = null;
                    if (!isForecast) {
                        item = message.itemsToPredict.get(i);
                        input = item.features;
                    }
                    this.audit(re, eresp, input, result, i, item, apiContext, null);
                }
                predictionResponse = resp;
                if (tctx == null) break block11;
            }
            catch (Throwable throwable) {
                try {
                    if (tctx != null) {
                        try {
                            tctx.close();
                        }
                        catch (Throwable throwable2) {
                            throwable.addSuppressed(throwable2);
                        }
                    }
                    throw throwable;
                }
                catch (Exception exc) {
                    this.audit(re, null, null, null, 0, null, apiContext, exc);
                    throw exc;
                }
            }
            tctx.close();
        }
        return predictionResponse;
    }

    private void audit(ServiceManager.RefcountedEndpoint re, EnrichedPredictionResponse resp, JsonObject input, PredictionResponse.PredictionResponseItem itemResp, int idx, MultiplePredictionQuery.Item item, ResponseElements.Context apiContext, Exception error) {
        try {
            LambdaEndpointHandler.AuditResult output = itemResp != null ? this.createResult((JsonObject)JSON.parse((String)JSON.json((Object)itemResp), JsonObject.class)) : null;
            String timing = resp != null && resp.response != null ? JSON.json((Object)resp.response.timing) : null;
            Optional<AuditTrailService.EmittableAuditObj> geo = this.createEvent(re, logger, input, output, error, timing, JSON.json((Object)apiContext));
            if (geo.isPresent()) {
                AuditTrailService.EmittableAuditObj geoObj = geo.get();
                if (resp != null && resp.response != null) {
                    geoObj.with("batchSize", "" + resp.response.results.size());
                }
                if (resp != null) {
                    geoObj.with("batchIdx", "" + idx).withAll(resp.globalAdditionalAuditElements);
                }
                if (item != null) {
                    geoObj.with("userContext", item.context);
                }
                if (itemResp != null) {
                    geoObj.withAll(itemResp.additionalAuditElements);
                }
                geoObj.emit();
            }
        }
        catch (Exception e) {
            logger.warn((Object)"failed to log queries", (Throwable)e);
        }
    }

    @Override
    protected String getMessageType() {
        return "prediction-query";
    }

    @Override
    protected String getQueryType() {
        return "prediction";
    }

    @Override
    protected String getInputPropertyName() {
        return "features";
    }

    public PredictionResponse predict(long startTimeN, ServiceManager.RefcountedEndpoint re, SinglePredictionQuery query, QueryAPIKey apiKey, @Nullable RequestMetadata requestMetadata) throws Exception {
        MemTable table = this.tableFromItems(Lists.newArrayList((Object[])new MultiplePredictionQuery.Item[]{query}));
        PipelineMessage message = new PipelineMessage();
        message.apiKey = apiKey;
        message.httpRequestMetadata = requestMetadata;
        message.itemsToPredict.add((MultiplePredictionQuery.Item)query);
        message.table = table;
        message.explanations = this.getExplanationsParams(query.explanations);
        message.pyPredictionAdvancedOptions = query.pyPredictionAdvancedOptions;
        for (int i = 0; i < table.rows.size(); ++i) {
            message.prePredictIgnoreReasons.add(null);
        }
        return this.predictAndLog(startTimeN, re, message, false);
    }

    public PredictionResponse predict(long startTimeN, ServiceManager.RefcountedEndpoint re, MultiplePredictionQuery query, QueryAPIKey apiKey, @Nullable RequestMetadata requestMetadata) throws Exception {
        MemTable table = this.tableFromItems(query.items);
        PipelineMessage message = new PipelineMessage();
        message.apiKey = apiKey;
        message.httpRequestMetadata = requestMetadata;
        message.itemsToPredict.addAll(query.items);
        message.table = table;
        message.explanations = this.getExplanationsParams(query.explanations);
        message.pyPredictionAdvancedOptions = query.pyPredictionAdvancedOptions;
        for (int i = 0; i < table.rows.size(); ++i) {
            message.prePredictIgnoreReasons.add(null);
        }
        return this.predictAndLog(startTimeN, re, message, false);
    }

    public PredictionResponse forecast(long startTimeN, ServiceManager.RefcountedEndpoint re, ForecastQuery query, QueryAPIKey apiKey) throws Exception {
        MemTable table = this.tableFromItemsToForecast(query.items);
        PipelineMessage message = new PipelineMessage();
        message.apiKey = apiKey;
        message.itemsToForecast.addAll(query.items);
        message.table = table;
        message.pyPredictionAdvancedOptions = query.pyPredictionAdvancedOptions;
        for (int i = 0; i < table.rows.size(); ++i) {
            message.prePredictIgnoreReasons.add(null);
        }
        return this.predictAndLog(startTimeN, re, message, true);
    }

    private ExplanationsQuery getExplanationsParams(ExplanationsQuery explanationsQuery) {
        if (this.config instanceof PredictionEndpointConfig) {
            ExplanationsQuery resolvedQuery;
            PredictionEndpointConfig predictionEndpointConfig = (PredictionEndpointConfig)this.config;
            ExplanationsQuery explanationsQuery2 = resolvedQuery = explanationsQuery != null ? explanationsQuery : new ExplanationsQuery();
            if (predictionEndpointConfig.useJava && Boolean.TRUE.equals(resolvedQuery.enabled)) {
                throw new IllegalArgumentException("Explanations are not compatible with java scoring");
            }
            if (predictionEndpointConfig.outputExplanations || Boolean.TRUE.equals(resolvedQuery.enabled)) {
                if (resolvedQuery.enabled == null) {
                    resolvedQuery.enabled = true;
                }
                if (resolvedQuery.method == null) {
                    resolvedQuery.method = predictionEndpointConfig.individualExplanationParams.method;
                }
                if (resolvedQuery.nExplanations == null) {
                    resolvedQuery.nExplanations = predictionEndpointConfig.individualExplanationParams.nbExplanations;
                }
                if (resolvedQuery.nMonteCarloSteps == null) {
                    resolvedQuery.nMonteCarloSteps = predictionEndpointConfig.individualExplanationParams.shapleyBackgroundSize;
                }
            }
            return resolvedQuery;
        }
        return null;
    }

    private MemTable tableFromItems(List<MultiplePredictionQuery.Item> queryItems) {
        MemTable table = new MemTable();
        for (MultiplePredictionQuery.Item item : queryItems) {
            table.addRowFromJsonObject(item.features);
        }
        table.addDumbRowsIfEmpty();
        return table;
    }

    private MemTable tableFromItemsToForecast(List<JsonObject> queryItems) {
        MemTable table = new MemTable();
        for (JsonObject item : queryItems) {
            table.addRowFromJsonObject(item);
        }
        table.addDumbRowsIfEmpty();
        return table;
    }

    protected List<JsonObject> collectPostEnrichDataIfNeeded(PipelineMessage pm, boolean auditIt, boolean returnIt) {
        ArrayList<JsonObject> postEnrich = null;
        if (auditIt || returnIt) {
            postEnrich = new ArrayList<JsonObject>();
            for (int i = 0; i < pm.itemsToPredict.size(); ++i) {
                if (pm.prePredictIgnoreReasons.get(i) != null) {
                    postEnrich.add(null);
                    continue;
                }
                JsonObject postEnrichI = new JsonObject();
                Row r = (Row)pm.table.rows.get(i);
                for (Column col : pm.table.columns()) {
                    postEnrichI.addProperty(col.getName(), r.get(col));
                }
                postEnrich.add(postEnrichI);
            }
        }
        return postEnrich;
    }

    protected void enrichPredictionResponseWithPostEnrichData(PredictionResponse resp, List<JsonObject> postEnrich, boolean auditIt, boolean returnIt) {
        if (!returnIt && !auditIt) {
            return;
        }
        assert (resp.results.size() == postEnrich.size());
        for (int i = 0; i < resp.results.size(); ++i) {
            if (returnIt) {
                ((PredictionResponse.PredictionResponseItem)resp.results.get((int)i)).postEnrich = postEnrich.get(i);
            }
            if (!auditIt) continue;
            ((PredictionResponse.PredictionResponseItem)resp.results.get((int)i)).additionalAuditElements.add("postEnrich", (JsonElement)postEnrich.get(i));
        }
    }

    public static class EnrichedPredictionResponse {
        public PredictionResponse response;
        public JsonObject globalAdditionalAuditElements = new JsonObject();
    }
}

