/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.agents.tools.datasets;

import com.dataiku.dip.agents.tools.AgentTool;
import com.dataiku.dip.agents.tools.AgentToolMeta;
import com.dataiku.dip.agents.tools.AgentToolParams;
import com.dataiku.dip.agents.tools.AgentToolRunner;
import com.dataiku.dip.agents.tools.filtering.SimpleFilter;
import com.dataiku.dip.agents.tools.utils.JsonSchema;
import com.dataiku.dip.agents.tools.utils.JsonSchemaElement;
import com.dataiku.dip.classpathfix.DKUDoubles;
import com.dataiku.dip.classpathfix.DKULongs;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.coremodel.SerializedDataset;
import com.dataiku.dip.dao.DatasetsDAO;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.datalayer.ProcessorOutput;
import com.dataiku.dip.datalayer.memimpl.MemTable;
import com.dataiku.dip.datalayer.memimpl.MemTableAppendingOutput;
import com.dataiku.dip.datasets.SamplingParam;
import com.dataiku.dip.datasets.StreamableDatasetSelection;
import com.dataiku.dip.datasets.UniversalSingleThreadPusher;
import com.dataiku.dip.exceptions.CodedException;
import com.dataiku.dip.exceptions.ForbiddenObjectException;
import com.dataiku.dip.exceptions.UnauthorizedException;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.tickets.APITicketService;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.services.ITaggingService;
import com.dataiku.dip.server.services.ProjectsService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.TransactionContext;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JF;
import com.dataiku.dip.utils.JSON;
import com.google.common.collect.Lists;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonNull;
import com.google.gson.JsonObject;
import com.google.gson.JsonPrimitive;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;

public class DatasetRowLookupTool {
    public static final AgentToolMeta META = new AgentToolMeta(false){

        @Override
        public String getType() {
            return "DatasetRowLookup";
        }

        @Override
        public Class<? extends AgentToolParams> paramsClass() {
            return Params.class;
        }

        @Override
        public List<SavedModel.AgentDependency> getDependencies(AgentTool tool) {
            Params p = tool.getParamsCopyAs(Params.class);
            if (p.datasetRef != null) {
                return Lists.newArrayList((Object[])new SavedModel.AgentDependency[]{new SavedModel.AgentDependency(ITaggingService.TaggableType.DATASET, p.datasetRef)});
            }
            return new ArrayList<SavedModel.AgentDependency>();
        }

        @Override
        public void checkAccessDependency(AuthCtx authCtx, AgentTool tool) throws IOException, ForbiddenObjectException {
            Params p = tool.getParamsCopyAs(Params.class);
            if (p.datasetRef == null) {
                logger.warn((Object)"No dataset selected. Skipping access check to dependency.");
            } else {
                AnyLoc datasetLoc = AnyLoc.resolveSmart(tool.projectKey, p.datasetRef);
                ((ProjectsService)SpringUtils.getBean(ProjectsService.class)).failIfLocNotAvailableInProject(ITaggingService.TaggableType.DATASET, datasetLoc, tool.projectKey);
            }
        }

        @Override
        public AgentToolMeta.ToolDescriptor getResultingDescriptor(AuthCtx authCtx, String projectKey, AgentTool tool) throws IOException {
            SerializedDataset sd;
            TransactionContext.assertNoAttachedTransaction();
            Params p = tool.getParamsCopyAs(Params.class);
            AgentToolMeta.ToolDescriptor td = new AgentToolMeta.ToolDescriptor(tool.name);
            td.description = p.retrievalMode == AnswerMode.MULTIPLE_RECORDS ? "Get records (up to " + p.maxRecords + ") from the dataset" : "Get a single record from the dataset";
            if (p.datasetRef == null) {
                td.description = td.description + " (error: no dataset selected).";
                return td;
            }
            td.description = td.description + " " + p.datasetRef + "\n\n";
            td.description = td.description + "The columns that are available for you to lookup are:\n\n";
            DatasetsDAO datasetsDAO = (DatasetsDAO)SpringUtils.getBean(DatasetsDAO.class);
            TransactionService transactionService = (TransactionService)SpringUtils.getBean(TransactionService.class);
            try (Transaction t = transactionService.beginRead();){
                AnyLoc datasetLoc = AnyLoc.resolveSmart(tool.projectKey, p.datasetRef);
                sd = (SerializedDataset)datasetsDAO.getMandatoryUnsafe(datasetLoc);
            }
            ArrayList<String> filterColumns = new ArrayList<String>();
            for (SchemaColumn col : sd.getSchema().getColumns()) {
                if (p.usableLookupColumns != null && p.usableLookupColumns.size() > 0 && !p.usableLookupColumns.contains(col.getName())) continue;
                filterColumns.add(col.getName());
                td.description = td.description + "  * " + col.getName() + " (type: " + String.valueOf(col.getType()) + ")";
                if (StringUtils.isNotBlank((String)col.comment)) {
                    td.description = td.description + " description: " + col.comment;
                }
                td.description = td.description + "\n";
            }
            if (StringUtils.isNotBlank((String)tool.additionalDescriptionForLLM)) {
                td.description = td.description + "\n\n" + tool.additionalDescriptionForLLM;
            }
            td.inputSchema = JsonSchema.newObject("https://dataiku.com/agents/tools/datasets/row-lookup/input", "Lookup settings for a row of a dataset");
            List<String> supportedOperators = List.of("EQUALS", "NOT_EQUALS", "GREATER_THAN", "GREATER_OR_EQUAL", "LESS_THAN", "LESS_OR_EQUAL", "DEFINED", "NOT_DEFINED", "CONTAINS", "MATCHES (regex)", "IN_ANY_OF", "IN_NONE_OF", "AND", "OR");
            JsonSchemaElement filter = SimpleFilter.jsonSchemaElement("#/$defs/filter", supportedOperators, filterColumns);
            filter.description = p.retrievalMode == AnswerMode.MULTIPLE_RECORDS ? "The filter to search for the records" : "The filter to search for the record";
            td.inputSchema.$defs = new HashMap();
            td.inputSchema.$defs.put("filter", filter);
            td.inputSchema.properties.put("filter", JsonSchemaElement.ref("#/$defs/filter"));
            return td;
        }

        @Override
        public AgentToolMeta.ToolCallDescription getToolCallDescription_NT(AuthCtx authCtx, String projectKey, AgentTool tool, LLMClient.FunctionTool descriptor, AgentToolRunner.AgentToolInput input) {
            TransactionContext.assertNoAttachedTransaction();
            Params p = tool.getParamsCopyAs(Params.class);
            Object description = String.format("I'm about to lookup in dataset <b>%s</b> with the following filter.%n", p.datasetRef);
            description = (String)description + "\n";
            description = (String)description + "Do you want to proceed?";
            return new AgentToolMeta.ToolCallDescription((String)description);
        }

        @Override
        public AgentToolRunner buildRunner(AuthCtx authCtx, String projectKey, AgentTool tool, boolean devKernel) throws CodedException {
            return new Runner(authCtx, tool.projectKey, tool.getParamsCopyAs(Params.class));
        }
    };
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.agents.tools.datasets.lookup");

    public static class Runner
    implements AgentToolRunner {
        private Dataset dataset;
        private final Params params;
        private final String sourceProjectKey;
        private final AuthCtx authCtx;
        @Autowired
        private DatasetsDAO datasetsDAO;
        @Autowired
        private TransactionService transactionService;
        @Autowired
        private APITicketService apiTicketService;

        public Runner(AuthCtx authCtx, String sourceProjectKey, Params p) {
            this.authCtx = authCtx;
            this.sourceProjectKey = sourceProjectKey;
            this.params = p;
        }

        @Override
        public void init() throws IOException {
            SerializedDataset sd;
            SpringUtils.getInstance().autowire((Object)this);
            if (StringUtils.isBlank((String)this.params.datasetRef)) {
                throw new IllegalArgumentException("Dataset to lookup is not specified in tool");
            }
            try (Transaction t = this.transactionService.beginRead();){
                sd = (SerializedDataset)this.datasetsDAO.getMandatory(AnyLoc.resolveSmart(this.sourceProjectKey, this.params.datasetRef));
            }
            this.dataset = Dataset.fromSerialized(sd);
        }

        private boolean filterContainsNonAllowedColumns(SimpleFilter sf) {
            if (this.params.usableLookupColumns == null || this.params.usableLookupColumns.isEmpty()) {
                return false;
            }
            for (String columnName : sf.listColumnsNamesRecursively()) {
                if (this.params.usableLookupColumns.contains(columnName)) continue;
                return true;
            }
            return false;
        }

        @Override
        public AgentToolRunner.AgentToolOutput run(AgentToolRunner.AgentToolInput input) throws Exception {
            SimpleFilter sf = this.safeParseArgument(input, SimpleFilter.class, "filter");
            if (this.filterContainsNonAllowedColumns(sf)) {
                AgentToolRunner.AgentToolOutput ret = new AgentToolRunner.AgentToolOutput();
                ret.output = JF.obj().with("status", "failed: the filter includes columns that are not allowed to be used for lookup").get();
                return ret;
            }
            StreamableDatasetSelection sds = new StreamableDatasetSelection();
            sds.filter = sf.toRealFilter();
            sds.filter.enabled = true;
            logger.info((Object)("Performing lookup in dataset " + this.dataset.getFullName() + " with  filter: " + JSON.log((Object)sds.filter)));
            sds.samplingMethod = SamplingParam.SamplingMethod.HEAD_SEQUENTIAL;
            sds.maxRecords = this.params.retrievalMode == AnswerMode.SINGLE_RECORD ? 1L : (long)this.params.maxRecords;
            MemTable table = new MemTable();
            MemTableAppendingOutput tableOutput = new MemTableAppendingOutput(table);
            AuthCtx datasetInteractionAuthCtx = null;
            switch (this.params.datasetInteractionUserMode) {
                case AS_TOOL_RUNNER: {
                    datasetInteractionAuthCtx = this.authCtx;
                    break;
                }
                case AS_CALLER_IF_AVAILABLE: {
                    String dkuCallerTicket = this.safeReadStringContextKeyOrNull(input, "dkuCallerTicket");
                    if (dkuCallerTicket == null) {
                        datasetInteractionAuthCtx = this.authCtx;
                        break;
                    }
                    datasetInteractionAuthCtx = this.apiTicketService.getTicketAuthCtx(dkuCallerTicket);
                    if (datasetInteractionAuthCtx != null) break;
                    throw new UnauthorizedException("Invalid caller ticket", "invalid-ticket");
                }
                case AS_CALLER: {
                    String dkuCallerTicket = this.safeReadStringContextKey(input, "dkuCallerTicket");
                    datasetInteractionAuthCtx = this.apiTicketService.getTicketAuthCtx(dkuCallerTicket);
                    if (datasetInteractionAuthCtx != null) break;
                    throw new UnauthorizedException("Invalid caller ticket", "invalid-ticket");
                }
            }
            logger.info((Object)("Performing lookup in dataset with mode: " + String.valueOf((Object)this.params.datasetInteractionUserMode) + " identity: " + datasetInteractionAuthCtx.getIdentifier()));
            UniversalSingleThreadPusher.push(datasetInteractionAuthCtx, this.dataset, sds, (ProcessorOutput)tableOutput, table, table);
            logger.info((Object)("Enumeration found " + table.nrows() + " rows"));
            List<String> fetchedColumns = this.params.returnedColumns == null || this.params.returnedColumns.isEmpty() ? this.dataset.getSchema().columns.stream().map(SchemaColumn::getName).collect(Collectors.toList()) : this.params.returnedColumns;
            AgentToolRunner.AgentToolOutput output = new AgentToolRunner.AgentToolOutput();
            LLMClient.SourceRecords sourceRecords = new LLMClient.SourceRecords();
            sourceRecords.columns = fetchedColumns;
            sourceRecords.data = new ArrayList<List<JsonElement>>();
            JsonObject out = new JsonObject();
            if (this.params.retrievalMode == AnswerMode.SINGLE_RECORD) {
                ArrayList<JsonElement> sourceRow = new ArrayList<JsonElement>();
                if (table.nrows() > 0) {
                    for (String column : fetchedColumns) {
                        String strv = table.rows.get(0).get(column);
                        JsonElement elt = this.eltFromValue(strv);
                        sourceRow.add(elt);
                        out.add(column, elt);
                    }
                    sourceRecords.data.add(sourceRow);
                } else {
                    out.addProperty("response", "NO_MATCHING_RECORDS");
                }
            } else {
                JsonArray rows = new JsonArray();
                for (int i = 0; i < table.nrows(); ++i) {
                    ArrayList<JsonElement> sourceRow = new ArrayList<JsonElement>();
                    JsonObject row = new JsonObject();
                    for (String column : fetchedColumns) {
                        String strv = table.rows.get(i).get(column);
                        JsonElement elt = this.eltFromValue(strv);
                        sourceRow.add(elt);
                        row.add(column, elt);
                    }
                    rows.add((JsonElement)row);
                    sourceRecords.data.add(sourceRow);
                }
                out.add("rows", (JsonElement)rows);
            }
            output.output = out;
            AgentToolRunner.Source source = new AgentToolRunner.Source();
            source.toolCallDescription = "Searched records in " + this.params.datasetRef;
            LLMClient.SourceItem item = new LLMClient.SourceItem();
            item.type = "RECORDS";
            item.records = sourceRecords;
            source.items.add(item);
            output.sources.add(source);
            return output;
        }

        private JsonElement eltFromValue(String strv) {
            Double dv;
            Long lv;
            Object elt = null;
            elt = strv != null ? ((lv = DKULongs.tryParse((String)strv)) != null ? new JsonPrimitive((Number)lv) : ((dv = DKUDoubles.tryParse((String)strv)) != null ? new JsonPrimitive((Number)dv) : new JsonPrimitive(strv))) : JsonNull.INSTANCE;
            return elt;
        }

        @Override
        public void close() throws Exception {
        }
    }

    public static class Params
    implements AgentToolParams {
        public String datasetRef;
        public List<String> usableLookupColumns;
        public List<String> returnedColumns;
        public AnswerMode retrievalMode = AnswerMode.SINGLE_RECORD;
        public int maxRecords = 5;
        public DatasetInteractionUserMode datasetInteractionUserMode = DatasetInteractionUserMode.AS_TOOL_RUNNER;
    }

    public static enum DatasetInteractionUserMode {
        AS_TOOL_RUNNER,
        AS_CALLER_IF_AVAILABLE,
        AS_CALLER;

    }

    public static enum AnswerMode {
        SINGLE_RECORD,
        MULTIPLE_RECORDS;

    }
}

