/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.shaker.processors.expr;

import com.dataiku.dip.shaker.model.StepParams;
import com.dataiku.dip.shaker.server.ProcessorDesc;
import com.dataiku.dip.utils.DKULogger;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class RecursiveCharacterTextSplitter {
    private final Parameter params;
    private final Map<String, Pattern> sepToRegex = new HashMap<String, Pattern>();
    private static final DKULogger logger = DKULogger.getLogger((String)"dip.shaker.processors.expr.split");

    public RecursiveCharacterTextSplitter(Parameter params) {
        this.params = params;
    }

    public List<String> splitText(String inputText) {
        List<Separator> enabled = this.params.separators.stream().filter(separator -> separator.enabled).collect(Collectors.toList());
        return this.splitTextInternal(inputText, enabled);
    }

    public static void withParams(ProcessorDesc pd) {
        pd.withMNESParam("separators", "");
        pd.withParam("chunkSize", "int", true, false, "");
        pd.withParam("chunkOverlap", "int", true, false, "");
        pd.withBool("keepSeparator", "");
        pd.withBool("isRegex", "");
    }

    private Pattern getReg(Separator separator) {
        if (this.sepToRegex.containsKey(separator.value)) {
            return this.sepToRegex.get(separator.value);
        }
        Pattern pattern = Pattern.compile(this.params.isRegex ? separator.value : Pattern.quote(separator.value));
        this.sepToRegex.put(separator.value, pattern);
        return pattern;
    }

    private SplitTextSeparators splitTextWithRegex(String text, Pattern separator) {
        if (this.params.keepSeparator) {
            separator = Pattern.compile(String.format("(?=%s)", separator.pattern()));
            List<String> splits = Arrays.stream(separator.split(text)).filter(Predicate.not(String::isEmpty)).collect(Collectors.toList());
            List<String> emptySeparators = Collections.nCopies(splits.size() - 1, "");
            return new SplitTextSeparators(splits, emptySeparators);
        }
        Matcher matcher = separator.matcher(text);
        ArrayList<String> splitTextList = new ArrayList<String>();
        ArrayList<String> separatorsList = new ArrayList<String>();
        int lastMatchEnd = 0;
        while (matcher.find()) {
            splitTextList.add(text.substring(lastMatchEnd, matcher.start()));
            separatorsList.add(matcher.group());
            lastMatchEnd = matcher.end();
        }
        splitTextList.add(text.substring(lastMatchEnd));
        return new SplitTextSeparators(splitTextList, separatorsList);
    }

    private String joinDocs(List<String> docs, List<String> separators) {
        if (docs.size() - 1 != separators.size()) {
            throw new IllegalArgumentException("Number of separators (" + separators.size() + ") must be equal to number of docs (" + (docs.size() - 1) + ") - 1");
        }
        StringBuilder textBuilder = new StringBuilder(docs.get(0));
        for (int i = 1; i < docs.size(); ++i) {
            textBuilder.append(separators.get(i - 1)).append(docs.get(i));
        }
        String text = textBuilder.toString();
        if (this.params.stripWhitespace) {
            text = text.strip();
        }
        if (text.isEmpty()) {
            return null;
        }
        return text;
    }

    private List<String> mergeSplits(List<String> splits, List<String> separators) {
        if (splits.size() - 1 != separators.size()) {
            throw new IllegalArgumentException("Number of separators (" + separators.size() + ") must be equal to number of splits (" + (splits.size() - 1) + ") - 1");
        }
        List sepLengths = separators.stream().map(String::length).collect(Collectors.toList());
        ArrayList<String> docs = new ArrayList<String>();
        ArrayList<String> currentDocs = new ArrayList<String>();
        int currentDocStart = 0;
        int currentDocEnd = 0;
        int total = 0;
        for (int i = 0; i < splits.size(); ++i) {
            String split = splits.get(i);
            int sepLength = i == 0 ? 0 : (Integer)sepLengths.get(i - 1);
            int len = split.length();
            if (total + len + (currentDocs.isEmpty() ? 0 : sepLength) > this.params.chunkSize && !currentDocs.isEmpty()) {
                String doc = this.joinDocs(currentDocs, separators.subList(currentDocStart, currentDocEnd - 1));
                if (doc != null) {
                    docs.add(doc);
                }
                while (!currentDocs.isEmpty() && (total > this.params.chunkOverlap || total + len + sepLength > this.params.chunkSize && total > 0)) {
                    total -= ((String)currentDocs.get(0)).length() + (currentDocs.size() > 1 ? sepLength : 0);
                    currentDocs.remove(0);
                    ++currentDocStart;
                }
            }
            currentDocs.add(split);
            ++currentDocEnd;
            total += len + (currentDocs.size() > 1 ? sepLength : 0);
        }
        String doc = this.joinDocs(currentDocs, separators.subList(currentDocStart, currentDocEnd - 1));
        if (doc != null) {
            docs.add(doc);
        }
        return docs;
    }

    private String fixForSpecialChars(String text) {
        if (!this.params.isRegex) {
            return text;
        }
        return text.replace("\u00a0", " ");
    }

    private List<String> splitTextInternal(String text, List<Separator> separators) {
        ArrayList<String> finalChunks = new ArrayList<String>();
        text = this.fixForSpecialChars(text);
        Separator separator = separators.get(separators.size() - 1);
        List<Object> newSeparators = new ArrayList();
        for (int i = 0; i < separators.size(); ++i) {
            Separator s = separators.get(i);
            if (s.value.isEmpty()) {
                separator = s;
                break;
            }
            Pattern regSep = this.getReg(separators.get(i));
            if (!regSep.matcher(text).find()) continue;
            separator = s;
            newSeparators = separators.subList(i + 1, separators.size());
            break;
        }
        SplitTextSeparators splitText = this.splitTextWithRegex(text, this.getReg(separator));
        ArrayList<String> goodSplits = new ArrayList<String>();
        ArrayList<String> goodSeparators = new ArrayList<String>();
        for (int i = 0; i < splitText.splits.size(); ++i) {
            String split = splitText.splits.get(i);
            String actualSeparator = splitText.separators.get(i);
            if (split.length() < this.params.chunkSize) {
                goodSplits.add(split);
                goodSeparators.add(actualSeparator);
                continue;
            }
            if (!goodSplits.isEmpty()) {
                List<String> mergedText = this.mergeSplits(goodSplits, goodSeparators.subList(0, goodSplits.size() - 1));
                finalChunks.addAll(mergedText);
                goodSplits.clear();
                goodSeparators.clear();
            }
            if (newSeparators.isEmpty()) {
                finalChunks.add(split);
                continue;
            }
            List<String> otherInfo = this.splitTextInternal(split, newSeparators);
            finalChunks.addAll(otherInfo);
        }
        if (!goodSplits.isEmpty()) {
            List<String> mergedText = this.mergeSplits(goodSplits, goodSeparators.subList(0, goodSplits.size() - 1));
            finalChunks.addAll(mergedText);
        }
        return finalChunks;
    }

    public static class Parameter
    implements StepParams {
        public List<Separator> separators = List.of(Separator.asDefault("\\n\\n", "Double new lines"), Separator.asDefault("\\n", "New Lines"), Separator.asDefault(" ", "Spaces"), Separator.asDefault("", "Each character"));
        public int chunkSize = 4000;
        public int chunkOverlap = 200;
        public boolean keepSeparator = true;
        public boolean isRegex = false;
        public boolean stripWhitespace = true;

        public void validate() throws IllegalArgumentException {
            if (this.chunkSize <= 0) {
                throw new IllegalArgumentException("Chunk size must be greater than 0");
            }
            if (this.chunkOverlap >= this.chunkSize) {
                throw new IllegalArgumentException("Chunk overlap must be less than chunk size");
            }
            if (this.separators.isEmpty()) {
                throw new IllegalArgumentException("Separators must not be empty");
            }
            if (this.separators.stream().noneMatch(separator -> separator.enabled)) {
                throw new IllegalArgumentException("Separators must contain at least one enabled separator");
            }
            List active = this.separators.stream().filter(separator -> separator.enabled).collect(Collectors.toList());
            if (active.stream().map(s -> s.value).collect(Collectors.toSet()).size() != active.size()) {
                throw new IllegalArgumentException("All enabled separators must be different");
            }
        }
    }

    public static class Separator
    implements Serializable {
        public String value;
        public boolean isDefault = false;
        public String description = "";
        public boolean enabled = true;

        public Separator(String value) {
            this.value = value;
        }

        public static Separator asDefault(String value, String description) {
            Separator sep = new Separator(value);
            sep.isDefault = true;
            sep.description = description;
            return sep;
        }
    }

    private static class SplitTextSeparators {
        List<String> splits;
        List<String> separators;

        public SplitTextSeparators(List<String> splits, List<String> separators) {
            if (splits.size() - 1 != separators.size()) {
                throw new IllegalArgumentException("Number of separators (" + separators.size() + ") must be equal to number of splits (" + (splits.size() - 1) + ") - 1");
            }
            this.splits = splits;
            this.separators = Stream.concat(separators.stream(), Stream.of("")).collect(Collectors.toList());
        }
    }
}

