/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.lambda.mgmt.devserver;

import com.dataiku.common.rpc.InternalAPIClient;
import com.dataiku.common.server.APIError;
import com.dataiku.dip.DKUApp;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.Schema;
import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.coremodel.SerializedDataset;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.datalayer.Column;
import com.dataiku.dip.datalayer.memimpl.MemRow;
import com.dataiku.dip.datalayer.memimpl.MemTable;
import com.dataiku.dip.datasets.DatasetSelection;
import com.dataiku.dip.datasets.DatasetSelectionToMemTable;
import com.dataiku.dip.datasets.PartitionableHandler;
import com.dataiku.dip.datasets.SamplingParam;
import com.dataiku.dip.datasets.SingleThreadPusherToMemTable;
import com.dataiku.dip.datasets.Type;
import com.dataiku.dip.exceptions.DSSInternalErrorException;
import com.dataiku.dip.lambda.mgmt.devserver.LambdaDevServerKernel;
import com.dataiku.dip.lambda.mgmt.devserver.LambdaDevServerService;
import com.dataiku.dip.partitioning.DimensionValue;
import com.dataiku.dip.partitioning.Partition;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.shaker.types.Boolean;
import com.dataiku.dip.shaker.types.DoubleMeaning;
import com.dataiku.dip.shaker.types.LongMeaning;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.Pair;
import com.dataiku.dip.utils.StringTransmogrifier;
import com.dataiku.lambda.model.studioconfig.ApiEndpointQuery;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonNull;
import com.google.gson.JsonObject;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.commons.lang.StringUtils;
import org.apache.directory.api.util.Strings;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class APIDesignerTestQueriesService {
    @Autowired
    private LambdaDevServerService lambdaDevServerService;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.lambda.devserver");

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public TestResponses playTestQueries_NT(String serviceId, String endpointId, List<ApiEndpointQuery> testQueries, LambdaTestType testType, String apiKey, AuthCtx user, String projectKey) throws Exception {
        if (this.lambdaDevServerService.getState((AuthCtx)user, (String)projectKey).status == LambdaDevServerKernel.KernelState.STOPPED) {
            logger.info((Object)"Lambda dev server is stopped. It must be recreated with service deployed on it again.");
            LambdaDevServerService.DeploymentResponse deploymentResponse = this.lambdaDevServerService.deployService_NT(projectKey, serviceId, user);
            if (null != deploymentResponse.error) {
                throw new APIError.SerializedErrorException(deploymentResponse.error);
            }
        }
        String targetUnixUser = this.lambdaDevServerService.getTargetUnixUser(user, projectKey);
        logger.infoV("Lambda dev server is running for user %s, projectKey %s%s", new Object[]{user.toString(), projectKey, Strings.isNotEmpty((String)targetUnixUser) ? ", and target unix user " + targetUnixUser : ""});
        LambdaDevServerKernel kernel = this.lambdaDevServerService.getOrCreateKernel(user, projectKey);
        LambdaDevServerKernel.LambdaDevServerClient client = kernel.getAPIClient(apiKey);
        TestResponses testResponses = new TestResponses();
        for (ApiEndpointQuery tq : testQueries) {
            ResponseOrError reo = new ResponseOrError();
            try (LambdaDevServerKernel.KernelHandle kernelHandle = kernel.startUsing();){
                InternalAPIClient.ResponseStringWithHeaders responseWithHeaders = client.runTestQuery(serviceId, endpointId, testType.getCallName(), tq.q);
                reo.headers = responseWithHeaders.headers;
                reo.response = responseWithHeaders.response;
                if (responseWithHeaders.response != null && responseWithHeaders.headers.stream().anyMatch(h -> ((String)h.first).equalsIgnoreCase("Content-Type") && StringUtils.startsWithIgnoreCase((String)((String)h.second), (String)"application/json"))) {
                    try {
                        reo.response = JSON.parse((String)responseWithHeaders.response, JsonElement.class);
                    }
                    catch (Exception exception) {
                        // empty catch block
                    }
                }
            }
            catch (Exception e) {
                reo.error = new APIError((Throwable)e, !DKUApp.hideErrorStacks(), !DKUApp.hideErrorStacks(), !DKUApp.hideLogTails());
            }
            finally {
                reo.serverState = this.lambdaDevServerService.getState(user, projectKey);
            }
            testResponses.responses.add(reo);
        }
        return testResponses;
    }

    public List<ApiEndpointQuery> getSampleQueriesFromDataset(SerializedDataset serializedDataset, AuthCtx owner, int batchSize, String method, @Nullable SavedModel savedModel, @Nullable Set<String> queriesName, boolean shouldIncludeNulls) throws Exception {
        if (!"HEAD_SEQUENTIAL".equals(method)) {
            throw new DSSInternalErrorException("Unsupported sampling method: " + method);
        }
        Dataset dataset = Dataset.fromSerialized((SerializedDataset)serializedDataset);
        Schema schema = dataset.getSchema();
        MemTable table = new MemTable();
        DatasetSelectionToMemTable selection = new DatasetSelectionToMemTable();
        selection.maxRecords = Math.min(100, batchSize);
        selection.samplingMethod = SamplingParam.SamplingMethod.HEAD_SEQUENTIAL;
        Partition partition = this.readQueryPartitionToSelection(owner, savedModel, dataset, selection);
        SingleThreadPusherToMemTable pusher = new SingleThreadPusherToMemTable(owner, dataset, table);
        pusher.setDatasetSelection(selection);
        pusher.push();
        ArrayList<ApiEndpointQuery> res = new ArrayList<ApiEndpointQuery>();
        Boolean booleanMeaning = new Boolean();
        LongMeaning longMeaning = new LongMeaning();
        DoubleMeaning doubleMeaning = new DoubleMeaning();
        MemTable.RowsIterator rows = table.getInterruptibleRows();
        ArrayList<Object> allFeatures = new ArrayList<Object>();
        for (int i = 0; i < batchSize; ++i) {
            JsonObject features2 = new JsonObject();
            if (!rows.hasNext()) break;
            MemRow row = rows.next();
            for (Column column : table.columns()) {
                String raw = row.get(column);
                if (raw == null) {
                    if (!shouldIncludeNulls) continue;
                    features2.add(column.getName(), (JsonElement)JsonNull.INSTANCE);
                    continue;
                }
                SchemaColumn sc = schema.getColumn(column.getName());
                Type type = sc != null ? sc.getType() : Type.STRING;
                switch (type) {
                    case STRING: 
                    case DATE: 
                    case GEOPOINT: 
                    case GEOMETRY: 
                    case MAP: 
                    case OBJECT: 
                    case ARRAY: {
                        features2.addProperty(column.getName(), raw);
                        break;
                    }
                    case DOUBLE: 
                    case FLOAT: {
                        features2.addProperty(column.getName(), (Number)doubleMeaning.doubleValue(raw));
                        break;
                    }
                    case BIGINT: 
                    case INT: 
                    case SMALLINT: 
                    case TINYINT: {
                        features2.addProperty(column.getName(), (Number)longMeaning.longValue(raw));
                        break;
                    }
                    case BOOLEAN: {
                        features2.addProperty(column.getName(), booleanMeaning.parseNoFail(raw));
                    }
                }
            }
            if (partition != null) {
                for (Map.Entry entry : partition.getDimensionValues().entrySet()) {
                    if (schema.getColumn((String)entry.getKey()) != null) continue;
                    features2.addProperty((String)entry.getKey(), ((DimensionValue)entry.getValue()).id());
                }
            }
            allFeatures.add(features2);
        }
        StringTransmogrifier st = new StringTransmogrifier(" #", 1, Integer.valueOf(0), false, true);
        if (queriesName != null) {
            for (String name : queriesName) {
                st.addAlreadyTransmogrified(name);
            }
        }
        if (savedModel != null && MLTask.MLTaskType.PREDICTION.equals((Object)savedModel.miniTask.taskType) && PredictionMLTask.PredictionType.TIMESERIES_FORECAST.equals((Object)((PredictionMLTask)savedModel.miniTask).predictionType)) {
            JsonObject q = new JsonObject();
            q.add("items", (JsonElement)new JsonArray());
            allFeatures.forEach(features -> q.getAsJsonArray("items").add((JsonElement)features));
            res.add(new ApiEndpointQuery(st.transmogrify(dataset.getName()), q));
        } else {
            for (int i = 0; i < allFeatures.size(); ++i) {
                JsonObject q = new JsonObject();
                q.add("features", (JsonElement)allFeatures.get(i));
                res.add(new ApiEndpointQuery(st.transmogrify(dataset.getName()), q));
            }
        }
        return res;
    }

    private Partition readQueryPartitionToSelection(AuthCtx owner, SavedModel savedModel, Dataset dataset, DatasetSelectionToMemTable selection) throws Exception {
        if (savedModel != null && savedModel.isPartitioned()) {
            List savedModelPartitions;
            try (PartitionableHandler smHandler = savedModel.buildHandler(owner);){
                savedModelPartitions = smHandler.listPartitions();
            }
            if (savedModelPartitions.isEmpty()) {
                return null;
            }
            if (dataset.getPartitioningSchema().isPartitioned() && savedModel.getPartitioningSchema().equals((Object)dataset.getPartitioningSchema())) {
                List datasetPartitions;
                try (PartitionableHandler dsHandler = dataset.buildHandler(owner);){
                    datasetPartitions = dsHandler.listPartitions();
                }
                Sets.SetView commonPartitions = Sets.intersection((Set)Sets.newHashSet((Iterable)savedModelPartitions), (Set)Sets.newHashSet((Iterable)datasetPartitions));
                if (!commonPartitions.isEmpty()) {
                    Partition result = (Partition)commonPartitions.iterator().next();
                    selection.partitionSelectionMethod = DatasetSelection.PartitionSelectionMethod.SELECTED;
                    selection.selectedPartitions = Lists.newArrayList((Object[])new String[]{result.id()});
                    return result;
                }
                return (Partition)savedModelPartitions.iterator().next();
            }
            return (Partition)savedModelPartitions.iterator().next();
        }
        return null;
    }

    public static class TestResponses {
        List<ResponseOrError> responses = new ArrayList<ResponseOrError>();
    }

    public static class ResponseOrError {
        public APIError error;
        public Object response;
        public LambdaDevServerService.DevServerState serverState;
        public List<Pair<String, String>> headers;
    }

    public static enum LambdaTestType {
        PREDICT("predict"),
        PREDICT_EFFECT("predict-effect"),
        FORECAST("forecast"),
        FUNCTION("run"),
        LOOKUP("lookup"),
        QUERY("query"),
        PROMPT("prompt");

        private final String callName;

        private LambdaTestType(String callName) {
            this.callName = callName;
        }

        public String getCallName() {
            return this.callName;
        }
    }
}

