/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.datalayer.utils;

import com.dataiku.dip.datalayer.Column;
import com.dataiku.dip.datalayer.Row;
import com.dataiku.dip.datalayer.SinkProcessorOutput;
import com.dataiku.dip.datalayer.memimpl.MemRow;
import com.dataiku.dip.datalayer.memimpl.MemTable;
import com.dataiku.dip.datalayer.memimpl.MemTableAppendingOutput;
import com.dataiku.dip.shaker.facet.CountMap;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.warnings.WarningsContext;
import com.dataiku.dip.warnings.WithWarningsContext;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang.mutable.MutableInt;

public class StratifiedReservoirSamplingProcessorOutput
extends SinkProcessorOutput
implements WithWarningsContext {
    private Column cd;
    private WarningsContext wc;
    private final String column;
    private final MemTable table;
    private final Map<String, ModalityReservoir> reservoirs = new HashMap<String, ModalityReservoir>();
    private final Random rnd;
    private long maxMemoryUsed = -1L;
    private long memoryUsed = 0L;
    private static DKULogger logger = DKULogger.getLogger((String)"dku.sampling.stratified");

    public StratifiedReservoirSamplingProcessorOutput(MemTable table, String column, int targetRecords, CountMap<String> map, Long seed) {
        this.table = table;
        this.column = column;
        this.rnd = seed != null ? new Random(seed) : new Random();
        long totalRecords = map.getTotalCount();
        logger.infoV("Stratified reservoir targetRecords=%d totalRecords=%d", new Object[]{targetRecords, totalRecords});
        for (Map.Entry<String, MutableInt> entry : map) {
            ModalityReservoir mr = new ModalityReservoir();
            mr.targetSize = totalRecords == 0L ? 0 : (int)Math.ceil((double)entry.getValue().longValue() * (double)targetRecords / (double)totalRecords);
            logger.infoV("Init reservoir for value=%s targetSize=%d", new Object[]{entry.getKey(), mr.targetSize});
            this.reservoirs.put(entry.getKey(), mr);
        }
    }

    public StratifiedReservoirSamplingProcessorOutput(MemTable table, String column, double targetRatio, CountMap<String> map, Long seed) {
        this.table = table;
        this.column = column;
        this.rnd = seed != null ? new Random(seed) : new Random();
        long totalRecords = map.getTotalCount();
        logger.infoV("Stratified reservoir targetRatio=%.2f totalRecords=%d", new Object[]{targetRatio, totalRecords});
        for (Map.Entry<String, MutableInt> entry : map) {
            ModalityReservoir mr = new ModalityReservoir();
            mr.targetSize = (int)Math.ceil((double)entry.getValue().longValue() * targetRatio);
            logger.infoV("Init reservoir for value=%s targetSize=%d", new Object[]{entry.getKey(), mr.targetSize});
            this.reservoirs.put(entry.getKey(), mr);
        }
    }

    public void setMaxMemoryUsed(long size) {
        this.maxMemoryUsed = size;
    }

    public void emitRow(Row row) throws Exception {
        ModalityReservoir mr;
        String v;
        if (this.cd == null) {
            this.cd = this.table.getColumn(this.column);
            if (this.cd == null) {
                if (this.wc != null) {
                    this.wc.addWarning(WarningsContext.WarningType.SAMPLING_BAD_COLUMN, "Column does not exist: " + this.column, logger);
                }
                return;
            }
        }
        if (StringUtils.isBlank((String)(v = row.get(this.cd)))) {
            v = "__dku_no_value__";
        }
        if ((mr = this.reservoirs.get(v)) == null) {
            logger.warn((Object)("Unknown value: " + v));
            return;
        }
        if (mr.targetSize == 0) {
            logger.warn((Object)("Cannot fill 0-sized reservoir: " + v));
            return;
        }
        MemRow memRow = (MemRow)row;
        if (mr.alreadySent < mr.targetSize) {
            mr.rows.add(memRow);
            this.memoryUsed += memRow.getMemoryUsed();
        } else {
            int rndVal = this.rnd.nextInt();
            if (rndVal < 0) {
                rndVal = -rndVal;
            }
            if ((rndVal %= mr.alreadySent) < mr.targetSize) {
                MemRow oldRow = mr.rows.get(rndVal);
                this.memoryUsed -= oldRow.getMemoryUsed();
                mr.rows.set(rndVal, memRow);
                this.memoryUsed += memRow.getMemoryUsed();
            }
        }
        if (this.maxMemoryUsed > 0L && this.memoryUsed >= this.maxMemoryUsed) {
            for (ModalityReservoir reservoir : this.reservoirs.values()) {
                this.table.rows.addAll(reservoir.rows);
            }
            throw new MemTableAppendingOutput.MemTableSizeLimitReachedException(this.table.rows.size(), this.memoryUsed);
        }
        ++mr.alreadySent;
    }

    public void lastRowEmitted() throws Exception {
        logger.info((Object)"Reservoir sampling done");
        for (ModalityReservoir reservoir : this.reservoirs.values()) {
            this.table.rows.addAll(reservoir.rows);
        }
        super.lastRowEmitted();
    }

    public void setWarningsContext(WarningsContext wc) {
        this.wc = wc;
    }

    private static class ModalityReservoir {
        List<MemRow> rows = new ArrayList<MemRow>();
        int targetSize;
        int alreadySent;

        private ModalityReservoir() {
        }
    }
}

