/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.pivot.backend.dss.aggregators;

import com.dataiku.dip.pivot.backend.dss.AxisHandler;
import com.dataiku.dip.pivot.backend.dss.LongDataTensor;
import com.dataiku.dip.pivot.backend.dss.aggregators.AbstractAggregator;
import com.dataiku.dip.pivot.backend.dss.aggregators.DoubleAggregator;
import com.dataiku.dip.pivot.backend.model.Aggregation;
import com.dataiku.dip.pivot.backend.model.AxisElt;
import com.dataiku.dip.pivot.backend.model.PivotTableTensorRequest;
import java.io.IOException;

public class VarianceAggregator
extends DoubleAggregator {
    private static final String ERROR_MESSAGE = "Cannot compute variance or standard deviation of non numeric values.";
    private double[] squaredOutDT;
    private double[][] axeSquaredOutDT;
    private double[] squaredMergeDT;
    private boolean isPopulationAggregation = false;
    private boolean isSTDEV = false;

    public VarianceAggregator(Aggregation req) {
        super(req, ERROR_MESSAGE);
        this.init();
    }

    public VarianceAggregator(Aggregation req, int[] numBins) {
        super(req, numBins, ERROR_MESSAGE);
        this.init();
    }

    public VarianceAggregator(Aggregation req, int bins) {
        super(req, bins, ERROR_MESSAGE);
        this.init();
    }

    private void init() {
        this.isPopulationAggregation = this.req.function == Aggregation.Function.STDEV_POPULATION || this.req.function == Aggregation.Function.VARIANCE_POPULATION;
        this.isSTDEV = this.req.function == Aggregation.Function.STDEV || this.req.function == Aggregation.Function.STDEV_POPULATION;
        this.squaredOutDT = new double[this.getOutDT().tensorSize];
        if (this.getOutDT().getDoubleAxes() != null) {
            double[][] axes = this.getOutDT().getDoubleAxes();
            this.axeSquaredOutDT = new double[axes.length][axes[0].length];
            for (int i = 0; i < axes.length; ++i) {
                this.axeSquaredOutDT[i] = new double[axes[i].length];
            }
        }
    }

    @Override
    protected void handle(double value, int[] coords, boolean fillAxes) throws IOException {
        if (Double.isFinite(value)) {
            if (fillAxes) {
                this.getOutDT().incrementAxesNonNullCount(coords);
                for (int i = 0; i < this.getOutDT().numAxes; ++i) {
                    WelfordState update = this.welfordUpdate(value, this.getOutDT().getAxisAsDouble(i, coords[i]), this.axeSquaredOutDT[i][coords[i]], this.getOutDT().getAxesNonNullCount(i, coords[i]));
                    this.getOutDT().setAxisAsDouble(i, coords[i], update.mean);
                    this.axeSquaredOutDT[i][coords[i]] = update.squaredDistanceFromMean;
                }
            }
            this.getOutDT().incrementNonNullCount(coords);
            WelfordState update = this.welfordUpdate(value, this.getOutDT().getAsDouble(coords), this.squaredOutDT[this.getOutDT().loc(coords)], this.getOutDT().getNonNullCount(coords));
            this.getOutDT().setAsDouble(coords, update.mean);
            this.squaredOutDT[this.getOutDT().loc((int[])coords)] = update.squaredDistanceFromMean;
        }
        this.getOutDT().hasNullValues = true;
    }

    @Override
    public void initMerge(int[] axisLengths) {
        super.initMerge(axisLengths);
        this.squaredMergeDT = new double[this.getMergeDT().tensorSize];
    }

    @Override
    public AbstractAggregator.OtherCategoryProperties<Double> retrieveOthersCategoryProperties(AxisHandler.Axis axis) {
        AbstractAggregator.OtherCategoryProperties<Double> otherCategoryProperties = new AbstractAggregator.OtherCategoryProperties<Double>();
        WelfordState ws = new WelfordState(0.0, 0.0, 0.0);
        for (int x = 0; x < axis.elts.size(); ++x) {
            if (!axis.elts.get((int)x).cutoffed) continue;
            int origBin = axis.elts.get((int)x).binIndex;
            ws = this.welfordMerge(ws, new WelfordState(this.getOutDT().getAsDouble(origBin), this.squaredOutDT[origBin], this.getOutDT().nonNullCounts[origBin]));
        }
        otherCategoryProperties.nonNullCountOfOthersColumn = (long)ws.count;
        otherCategoryProperties.aggrOfOthersColumn = this.computeAggregation(ws.count, ws.squaredDistanceFromMean);
        return otherCategoryProperties;
    }

    @Override
    public void mergeEnd(AxisHandler.Axis[] axes, LongDataTensor countTensor) throws IOException {
        super.mergeEnd(axes, countTensor);
        for (int i = 0; i < this.getMergeDT().tensorSize; ++i) {
            this.getMergeDT().set(i, Double.valueOf(this.computeAggregation(this.getMergeDT().getNonNullCount(i), this.squaredMergeDT[i])));
        }
    }

    @Override
    public boolean mergeTensorAndAxes(PivotTableTensorRequest request, int[] origCoordinates, int[] targetCoordinates, AxisHandler.Axis[] axes) {
        double mean = this.getOutDT().getAsDouble(origCoordinates);
        for (int i = 0; i < axes.length; ++i) {
            if (targetCoordinates[i] != axes[i].nbNotCutoff) continue;
            if (!request.axes[i].sortPrune.generateOthersCategory) {
                return false;
            }
            this.getMergeDT().setAxisAsDouble(i, targetCoordinates[i], this.mergeDoubleValues(this.getMergeDT().getAxisAsDouble(i, targetCoordinates[i]), mean));
        }
        WelfordState update = this.welfordMerge(new WelfordState(mean, this.squaredOutDT[this.getOutDT().loc(origCoordinates)], this.getOutDT().getNonNullCount(origCoordinates)), new WelfordState(this.getMergeDT().getAsDouble(targetCoordinates), this.squaredMergeDT[this.getMergeDT().loc(targetCoordinates)], this.getMergeDT().getNonNullCount(targetCoordinates)));
        this.getMergeDT().incrementNonNullCount(targetCoordinates, this.getOutDT().getNonNullCount(origCoordinates));
        this.getMergeDT().setAsDouble(targetCoordinates, update.mean);
        this.squaredMergeDT[this.getMergeDT().loc((int[])targetCoordinates)] = update.squaredDistanceFromMean;
        return true;
    }

    @Override
    public double getDoubleValue(int loc, long[] counts) {
        return this.computeAggregation(counts[loc], this.squaredOutDT[loc]);
    }

    @Override
    public Double getValue(int loc, long[] counts, boolean toRealType) {
        return this.getDoubleValue(loc, counts);
    }

    @Override
    public void sortAxis(AxisHandler.Axis ret, int asc, int axisIdx) {
        double offset = !this.isPopulationAggregation && this.getOutDT().axesNonNullCounts.length > 1 ? 1.0 : 0.0;
        long[] nnc = this.getOutDT().axesNonNullCounts[axisIdx];
        ret.elts.sort((o1, o2) -> {
            double v1 = this.computeAggregation((double)nnc[o1.binIndex] - offset, this.axeSquaredOutDT[axisIdx][o1.binIndex]);
            double v2 = this.computeAggregation((double)nnc[o2.binIndex] - offset, this.axeSquaredOutDT[axisIdx][o2.binIndex]);
            return this.compareAxisDoubleValues(v1, v2, (AxisElt)o1, (AxisElt)o2, asc, nnc);
        });
    }

    @Override
    protected int compareAxisDoubleValues(double v1, double v2, AxisElt o1, AxisElt o2, int asc, long[] nnc) {
        return this.compareAxisDoubleValues(v1, v2, o1, o2, asc, nnc, 1);
    }

    private double computeAggregation(double n, double squaredSum) {
        if (squaredSum == 0.0) {
            return 0.0;
        }
        if (this.isPopulationAggregation) {
            return this.isSTDEV ? Math.sqrt(squaredSum / n) : squaredSum / n;
        }
        return this.isSTDEV ? Math.sqrt(squaredSum / (n - 1.0)) : squaredSum / (n - 1.0);
    }

    private WelfordState welfordUpdate(double value, double mean, double squaredDistanceFromMean, double count) {
        double delta = value - mean;
        double delta2 = value - (mean += delta / count);
        return new WelfordState(mean, squaredDistanceFromMean += delta * delta2, count);
    }

    private WelfordState welfordMerge(WelfordState state1, WelfordState state2) {
        double delta = state2.mean - state1.mean;
        double count = state1.count + state2.count;
        double mean = (state1.count * state1.mean + state2.count * state2.mean) / count;
        double m2 = state1.squaredDistanceFromMean + state2.squaredDistanceFromMean + delta * delta * (state1.count * state2.count) / count;
        return new WelfordState(Double.isNaN(mean) ? 0.0 : mean, Double.isNaN(m2) ? 0.0 : m2, Double.isNaN(count) ? 0.0 : count);
    }

    private static class WelfordState {
        double squaredDistanceFromMean;
        double mean;
        double count;

        public WelfordState(double mean, double squaredDistanceFromMean, double count) {
            this.mean = mean;
            this.squaredDistanceFromMean = squaredDistanceFromMean;
            this.count = count;
        }
    }
}

