/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.agents.tools.datasets;

import com.dataiku.dip.agents.tools.AgentTool;
import com.dataiku.dip.agents.tools.AgentToolMeta;
import com.dataiku.dip.agents.tools.AgentToolParams;
import com.dataiku.dip.agents.tools.AgentToolRunner;
import com.dataiku.dip.agents.tools.utils.JsonSchema;
import com.dataiku.dip.agents.tools.utils.JsonSchemaElement;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.coremodel.SerializedDataset;
import com.dataiku.dip.dao.DatasetsDAO;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.datalayer.Column;
import com.dataiku.dip.datalayer.ColumnFactory;
import com.dataiku.dip.datalayer.Row;
import com.dataiku.dip.datalayer.streamimpl.StreamColumnFactory;
import com.dataiku.dip.datalayer.streamimpl.StreamRowFactory;
import com.dataiku.dip.datasets.DatasetHandler;
import com.dataiku.dip.datasets.DatasetInspector;
import com.dataiku.dip.exceptions.CodedException;
import com.dataiku.dip.exceptions.ForbiddenObjectException;
import com.dataiku.dip.input.DatasetHandlerFactory;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.output.Output;
import com.dataiku.dip.output.OutputWriter;
import com.dataiku.dip.partitioning.Partition;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.notifications.backend.DatasetChangedEvent;
import com.dataiku.dip.server.services.ITaggingService;
import com.dataiku.dip.server.services.NeverBuiltComputablesCacheService;
import com.dataiku.dip.server.services.ProjectsService;
import com.dataiku.dip.server.services.PubSubService;
import com.dataiku.dip.server.services.TaggableObjectsService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.TransactionContext;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.AnyLoc;
import com.dataiku.dip.util.DatasetLocUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JF;
import com.dataiku.dip.warnings.WarningsContext;
import com.google.common.collect.Lists;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonPrimitive;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;

public class DatasetRowAppendTool {
    public static final AgentToolMeta META = new AgentToolMeta(true){

        @Override
        public String getType() {
            return "DatasetRowAppend";
        }

        @Override
        public Class<? extends AgentToolParams> paramsClass() {
            return Params.class;
        }

        @Override
        public List<SavedModel.AgentDependency> getDependencies(AgentTool tool) {
            Params p = tool.getParamsCopyAs(Params.class);
            if (p.datasetRef != null) {
                return Lists.newArrayList((Object[])new SavedModel.AgentDependency[]{new SavedModel.AgentDependency(ITaggingService.TaggableType.DATASET, p.datasetRef)});
            }
            return new ArrayList<SavedModel.AgentDependency>();
        }

        @Override
        public void checkAccessDependency(AuthCtx authCtx, AgentTool tool) throws IOException, ForbiddenObjectException {
            Params p = tool.getParamsCopyAs(Params.class);
            if (p.datasetRef == null) {
                logger.warn((Object)"No dataset selected. Skipping access check to dependency.");
            } else {
                AnyLoc datasetLoc = AnyLoc.resolveSmart(tool.projectKey, p.datasetRef);
                ((ProjectsService)SpringUtils.getBean(ProjectsService.class)).failIfLocNotAvailableInProject(ITaggingService.TaggableType.DATASET, datasetLoc, tool.projectKey);
            }
        }

        @Override
        public AgentToolMeta.ToolDescriptor getResultingDescriptor(AuthCtx authCtx, String projectKey, AgentTool tool) throws IOException {
            TransactionContext.assertNoAttachedTransaction();
            Params p = tool.getParamsCopyAs(Params.class);
            AgentToolMeta.ToolDescriptor td = new AgentToolMeta.ToolDescriptor(tool.name);
            td.description = "Writes a single record to the dataset";
            if (p.datasetRef == null) {
                td.description = td.description + " (error: no dataset selected).";
                return td;
            }
            td.description = td.description + " " + p.datasetRef + "\n\n";
            td.description = td.description + "Provide the record to write as a single JSON dictionary called \"record\", with one key per column.\nThe columns of the dataset are:\n";
            JsonSchemaElement record = JsonSchemaElement.object("The record to append");
            SerializedDataset sd = this.getDataset(tool.projectKey, p.datasetRef);
            for (SchemaColumn col : sd.getSchema().getColumns()) {
                if (this.columnIsExcluded(p.restrictColumns, col.getName())) continue;
                td.description = td.description + "  * " + col.getName() + " (type: " + String.valueOf(col.getType()) + ")";
                if (StringUtils.isNotBlank((String)col.comment)) {
                    td.description = td.description + " description: " + col.comment;
                }
                td.description = td.description + "\n";
                if (col.getType().isNumeric()) {
                    record.properties.put(col.getName(), JsonSchemaElement.number("Value for column " + col.getName()));
                    continue;
                }
                record.properties.put(col.getName(), JsonSchemaElement.string("Value for column " + col.getName()));
            }
            if (StringUtils.isNotBlank((String)tool.additionalDescriptionForLLM)) {
                td.description = td.description + "\n\n" + tool.additionalDescriptionForLLM;
            }
            td.inputSchema = JsonSchema.newObject("https://dataiku.com/agents/tools/datasets/append/input", "Write a record to a dataset");
            td.inputSchema.properties.put("record", record);
            return td;
        }

        @Override
        public AgentToolMeta.ToolCallDescription getToolCallDescription_NT(AuthCtx authCtx, String projectKey, AgentTool tool, LLMClient.FunctionTool descriptor, AgentToolRunner.AgentToolInput input) {
            TransactionContext.assertNoAttachedTransaction();
            Params p = tool.getParamsCopyAs(Params.class);
            Object description = String.format("I'm about to append the following record to dataset <b>%s</b>.%n", p.datasetRef);
            description = (String)description + "\n";
            description = (String)description + "Do you want to proceed?";
            return new AgentToolMeta.ToolCallDescription((String)description);
        }

        @Override
        public JsonObject loadSampleQuery(AuthCtx authCtx, String projectKey, AgentTool tool) throws Exception {
            Params p = tool.getParamsCopyAs(Params.class);
            if (p.datasetRef == null) {
                throw new IllegalArgumentException("No dataset selected.");
            }
            SerializedDataset sd = this.getDataset(tool.projectKey, p.datasetRef);
            JsonObject quickTestRecord = new JsonObject();
            for (SchemaColumn col : sd.getSchema().getColumns()) {
                if (this.columnIsExcluded(p.restrictColumns, col.getName())) continue;
                if (col.getType().isNumeric()) {
                    quickTestRecord.addProperty(col.getName(), (Number)1);
                    continue;
                }
                quickTestRecord.addProperty(col.getName(), "<Your value here>");
            }
            JsonObject sampleInput = new JsonObject();
            sampleInput.add("record", (JsonElement)quickTestRecord);
            return sampleInput;
        }

        private SerializedDataset getDataset(String sourceProjectKey, String datasetRef) throws IOException {
            SerializedDataset sd;
            try (Transaction t = ((TransactionService)SpringUtils.getBean(TransactionService.class)).beginRead();){
                AnyLoc datasetLoc = AnyLoc.resolveSmart(sourceProjectKey, datasetRef);
                sd = (SerializedDataset)((DatasetsDAO)SpringUtils.getBean(DatasetsDAO.class)).getMandatoryUnsafe(datasetLoc);
            }
            DatasetRowAppendTool.checkDatasetCompatibility(sd);
            return sd;
        }

        private boolean columnIsExcluded(List<String> restrictColumns, String columnName) {
            if (restrictColumns != null && !restrictColumns.isEmpty()) {
                boolean included = restrictColumns.contains(columnName);
                return !included;
            }
            return false;
        }

        @Override
        public AgentToolRunner buildRunner(AuthCtx authCtx, String projectKey, AgentTool tool, boolean devKernel) throws CodedException {
            return new Runner(authCtx, tool.projectKey, tool.getParamsCopyAs(Params.class));
        }
    };
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.agents.tools.datasets.append");

    private static void checkDatasetCompatibility(SerializedDataset sd) {
        Dataset dataset = Dataset.fromSerializedUnsafe(sd.getFullName(), sd);
        DatasetInspector.NotAppendableReason reason = DatasetInspector.isAppendableDataset(dataset);
        if (reason.noAppend) {
            Object msg = "This dataset cannot be appended to.";
            if (reason.datasetFormat != null) {
                msg = (String)msg + String.format(" The dataset format '%s' does not allow appending.", reason.datasetFormat);
            } else if (reason.customPythonDataset) {
                msg = (String)msg + " The dataset is a custom python dataset and does not allow appending.";
            }
            throw new IllegalArgumentException((String)msg);
        }
        if ("Inline".equals(sd.type)) {
            throw new IllegalArgumentException("Editable datasets are not supported by DatasetRowAppendTool");
        }
    }

    public static class Runner
    implements AgentToolRunner {
        private Dataset dataset;
        private final Params params;
        private final String sourceProjectKey;
        private final AuthCtx authCtx;
        @Autowired
        private DatasetsDAO datasetsDAO;
        @Autowired
        private NeverBuiltComputablesCacheService neverBuiltComputablesCacheService;
        @Autowired
        private PubSubService pubSub;
        @Autowired
        private TransactionService transactionService;

        public Runner(AuthCtx authCtx, String sourceProjectKey, Params p) {
            this.authCtx = authCtx;
            this.sourceProjectKey = sourceProjectKey;
            this.params = p;
        }

        @Override
        public void init() throws IOException {
            SerializedDataset sd;
            SpringUtils.getInstance().autowire((Object)this);
            if (StringUtils.isBlank((String)this.params.datasetRef)) {
                throw new IllegalArgumentException("Dataset to lookup is not specified in tool");
            }
            try (Transaction t = this.transactionService.beginRead();){
                sd = (SerializedDataset)this.datasetsDAO.getMandatory(AnyLoc.resolveSmart(this.sourceProjectKey, this.params.datasetRef));
            }
            this.dataset = Dataset.fromSerialized(sd);
            DatasetRowAppendTool.checkDatasetCompatibility(sd);
            if (this.params.restrictColumns != null && !this.params.restrictColumns.isEmpty()) {
                boolean anyRestrictedColumnsUsed = false;
                for (SchemaColumn col : sd.getSchema().getColumns()) {
                    if (!this.params.restrictColumns.contains(col.getName())) continue;
                    anyRestrictedColumnsUsed = true;
                    break;
                }
                if (!anyRestrictedColumnsUsed) {
                    throw new IllegalArgumentException("All dataset columns are excluded. Either restrict the tool to existing columns or don't restrict it at all.");
                }
            }
        }

        @Override
        public AgentToolRunner.AgentToolOutput run(AgentToolRunner.AgentToolInput input) throws Exception {
            JsonObject record = this.safeReadObjectArgument(input, "record");
            JF.ObjectBuilder response = JF.obj();
            try (DatasetHandler dh = DatasetHandlerFactory.build(this.authCtx, this.dataset);){
                Row row;
                OutputWriter wr;
                block16: {
                    WarningsContext wc = new WarningsContext();
                    Output output = dh.buildOutput(Partition.newNP(), 0, 1, wc);
                    StreamColumnFactory cf = new StreamColumnFactory();
                    StreamRowFactory rf = new StreamRowFactory();
                    wr = output.getWriter(Output.WriteMode.APPEND);
                    try {
                        wr.init((ColumnFactory)cf);
                        row = rf.row();
                        boolean sawSupportedColumn = false;
                        ArrayList<String> ignoredColumns = new ArrayList<String>();
                        for (Map.Entry entry : record.entrySet()) {
                            String columnName = (String)entry.getKey();
                            if (cf.getColumn(columnName) == null) {
                                wc.addWarning(WarningsContext.WarningType.LLM_TOOL_UNKNOWN_COLUMN, "Not appending unknown column: " + columnName, logger);
                                ignoredColumns.add(columnName);
                                continue;
                            }
                            if (this.params.restrictColumns != null && !this.params.restrictColumns.isEmpty() && !this.params.restrictColumns.contains(columnName) || !((JsonElement)entry.getValue()).isJsonPrimitive()) continue;
                            sawSupportedColumn = true;
                            Column c2 = cf.column(columnName);
                            JsonPrimitive pv = ((JsonElement)entry.getValue()).getAsJsonPrimitive();
                            if (pv.isNumber()) {
                                try {
                                    row.put(c2, pv.getAsLong());
                                }
                                catch (Exception e) {
                                    row.put(c2, pv.getAsDouble());
                                }
                                continue;
                            }
                            row.put(c2, pv.getAsString());
                        }
                        if (!ignoredColumns.isEmpty()) {
                            response = response.with("warning", "ignored columns: " + String.join((CharSequence)", ", ignoredColumns));
                        }
                        if (sawSupportedColumn) break block16;
                        AgentToolRunner.AgentToolOutput ret = new AgentToolRunner.AgentToolOutput();
                        ret.output = response.with("status", "failed: the new row must have at least one existing and non-restricted column").get();
                        wr.cancel();
                        AgentToolRunner.AgentToolOutput agentToolOutput = ret;
                        return agentToolOutput;
                    }
                    catch (Throwable t) {
                        wr.cancel();
                        throw new RuntimeException("Failed to write", t);
                    }
                }
                wr.emitRow(row);
                wr.lastRowEmitted();
                this.invalidateCaches(dh);
            }
            AgentToolRunner.AgentToolOutput ret = new AgentToolRunner.AgentToolOutput();
            ret.output = response.with("status", "ok").get();
            return ret;
        }

        @Override
        public void close() throws Exception {
        }

        private void invalidateCaches(DatasetHandler dh) throws Exception {
            DatasetLocUtils.DatasetLoc loc = DatasetLocUtils.resolveFull(this.dataset.getFullName());
            SerializedDataset sd = this.dataset.getModel();
            TaggableObjectsService.TaggableObjectRef ref = new TaggableObjectsService.TaggableObjectRef(loc.getProjectKey(), sd.getTaggableType(), sd.getId());
            if (this.neverBuiltComputablesCacheService.contains(ref)) {
                this.neverBuiltComputablesCacheService.remove(ref);
            }
            ArrayList<String> partitionsOrNP = new ArrayList();
            if (this.dataset.getPartitioningSchema().isPartitioned()) {
                for (Partition p : dh.listPartitions()) {
                    partitionsOrNP.add(p.id());
                }
            }
            if (partitionsOrNP.isEmpty()) {
                partitionsOrNP = List.of("NP");
            }
            this.pubSub.publish(new DatasetChangedEvent(loc, partitionsOrNP));
        }
    }

    public static class Params
    implements AgentToolParams {
        public String datasetRef;
        public List<String> restrictColumns;
    }
}

