/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.datalayer.utils;

import com.dataiku.dip.datalayer.Column;
import com.dataiku.dip.datalayer.ColumnFactory;
import com.dataiku.dip.datalayer.FilteringProcessorOutput;
import com.dataiku.dip.datalayer.ProcessorOutput;
import com.dataiku.dip.datalayer.Row;
import com.dataiku.dip.shaker.facet.CountMap;
import com.dataiku.dip.utils.DKULogger;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang.mutable.MutableInt;

public class ColumnRebalanceApproximateProcessorOutput
extends FilteringProcessorOutput {
    private final Map<String, Double> probasMap = new HashMap<String, Double>();
    private final String column;
    private final ColumnFactory cf;
    private Random rnd = new Random();
    private Column cd;
    private static DKULogger logger = DKULogger.getLogger((String)"dku.sampling.rebalance");

    public ColumnRebalanceApproximateProcessorOutput(ProcessorOutput downstream, ColumnFactory cf, String column, long targetRecords, CountMap<String> map, Long seed) {
        super(downstream);
        this.cf = cf;
        this.column = column;
        this.setProbasMap(targetRecords, map);
        if (seed != null) {
            this.rnd = new Random(seed);
        }
    }

    public ColumnRebalanceApproximateProcessorOutput(ProcessorOutput downstream, ColumnFactory cf, String column, double targetRatio, CountMap<String> map, Long seed) {
        super(downstream);
        this.cf = cf;
        this.column = column;
        double targetRecords = (double)map.getTotalCount() * targetRatio;
        this.setProbasMap(targetRecords, map);
        if (seed != null) {
            this.rnd = new Random(seed);
        }
    }

    private void setProbasMap(double targetRecords, CountMap<String> map) {
        double targetCountPerModality = targetRecords / (double)map.size();
        long missingRecordsInSmallClasses = 0L;
        long largeClasses = 0L;
        for (Map.Entry<String, MutableInt> entry : map) {
            if ((double)entry.getValue().longValue() < targetCountPerModality) {
                missingRecordsInSmallClasses += (long)targetCountPerModality - entry.getValue().longValue();
                continue;
            }
            ++largeClasses;
        }
        logger.infoV("Column rebalance targetRecords=%f targetCountPerModality=%f missing=%d largeClasses=%d", new Object[]{targetRecords, targetCountPerModality, missingRecordsInSmallClasses, largeClasses});
        if (missingRecordsInSmallClasses > 0L && (double)missingRecordsInSmallClasses < targetRecords && largeClasses > 0L) {
            targetCountPerModality += (double)missingRecordsInSmallClasses / (double)largeClasses;
        }
        logger.infoV("Column rebalance fixed targetCountPerModality=%fd", new Object[]{targetCountPerModality});
        for (Map.Entry<String, MutableInt> entry : map) {
            double proba = targetCountPerModality / (double)entry.getValue().longValue();
            logger.infoV(" value=%s totalCount=%d proba=%f", new Object[]{entry.getKey(), entry.getValue().longValue(), proba});
            this.probasMap.put(entry.getKey(), proba);
        }
    }

    public void emitRow(Row row) throws Exception {
        Double proba;
        if (this.cd == null) {
            this.cd = this.cf.getColumn(this.column);
        }
        if (this.cd == null) {
            logger.warn((Object)("Column " + this.column + " not found, not emitting"));
            return;
        }
        String v = row.get(this.cd);
        if (StringUtils.isBlank((String)v)) {
            v = "__dku_no_value__";
        }
        if ((proba = this.probasMap.get(v)) == null) {
            logger.warn((Object)("Unknown value: " + v));
            return;
        }
        double r = this.rnd.nextDouble();
        if (r < proba) {
            this.downstream.emitRow(row);
        }
    }
}

