/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.recipes.eda.univariate;

import com.dataiku.dip.datalayer.ColumnFactory;
import com.dataiku.dip.datalayer.Row;
import com.dataiku.dip.datalayer.RowFactory;
import com.dataiku.dip.eda.compute.computations.Computation;
import com.dataiku.dip.eda.compute.computations.ComputationResult;
import com.dataiku.dip.eda.compute.computations.common.Count;
import com.dataiku.dip.eda.compute.computations.common.GroupedComputation;
import com.dataiku.dip.eda.compute.computations.common.MultiComputation;
import com.dataiku.dip.eda.compute.filtering.AnumFilter;
import com.dataiku.dip.eda.compute.filtering.Filter;
import com.dataiku.dip.eda.compute.grouping.AnumGrouping;
import com.dataiku.dip.eda.worksheets.models.Variable;
import com.dataiku.dip.recipes.eda.UnivariateRecipePayloadParams;
import com.dataiku.dip.recipes.eda.univariate.UnivariateAnalysis;
import com.dataiku.dip.utils.ErrorContext;
import com.dataiku.dip.utils.JSON;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Objects;

public class UnivariateAnalysisFrequencyTable
extends UnivariateAnalysis {
    public static final String TYPE = "FREQUENCY_TABLE";
    public int maxValues = 10;

    public UnivariateAnalysisFrequencyTable(Variable column) {
        this.column = column;
    }

    private UnivariateAnalysisFrequencyTable() {
    }

    @Override
    public Computation getComputationPlan(UnivariateRecipePayloadParams payloadParams) {
        if (this.maxValues <= 0) {
            return new MultiComputation(new Computation[0]);
        }
        return new MultiComputation(new Count(), new GroupedComputation(new Count(), new AnumGrouping(this.column.name, this.maxValues, true)));
    }

    @Override
    public List<Row> extractRows(ComputationResult result, UnivariateRecipePayloadParams payloadParams, ColumnFactory cf, RowFactory rf) {
        if (this.maxValues <= 0) {
            return new ArrayList<Row>();
        }
        if (!result.isAvailable() || !result.asMulti().get(1).isAvailable()) {
            List<String> warnings = result.collectWarnings();
            Collection<String> errors = result.collectErrors();
            ArrayList<Row> rows = new ArrayList<Row>();
            rows.add(this.rowBase(this.column.name, "COUNT_MODALITY", warnings, errors, cf, rf));
            rows.add(this.rowBase(this.column.name, "COUNT_OTHERS", warnings, errors, cf, rf));
            rows.add(this.rowBase(this.column.name, "FREQUENCY_MODALITY", warnings, errors, cf, rf));
            rows.add(this.rowBase(this.column.name, "FREQUENCY_OTHERS", warnings, errors, cf, rf));
            return rows;
        }
        MultiComputation.MultiComputationResult mcr = result.asMulti();
        ComputationResult countAllResult = mcr.get(0);
        List<GroupedComputation.GroupResult> groupResults = mcr.get(1).asGrouped().getGroupedResults();
        List<Row> rows = this.extractRows(groupResults, countAllResult, false, cf, rf);
        rows.addAll(this.extractRows(groupResults, countAllResult, true, cf, rf));
        return rows;
    }

    private List<Row> extractRows(List<GroupedComputation.GroupResult> groupResults, ComputationResult countAllResult, boolean normalized, ColumnFactory cf, RowFactory rf) {
        ArrayList<Row> rows = new ArrayList<Row>();
        for (GroupedComputation.GroupResult groupResult : groupResults) {
            Row row = this.extractRow(groupResult, countAllResult, normalized, cf, rf);
            rows.add(row);
        }
        return rows;
    }

    private Row extractRow(GroupedComputation.GroupResult groupResult, ComputationResult countAllResult, boolean normalized, ColumnFactory cf, RowFactory rf) {
        Row row;
        String statType;
        Filter filter = groupResult.filter;
        ComputationResult result = groupResult.result;
        List<String> warnings = result.collectWarnings();
        Collection<String> errors = result.collectErrors();
        if (normalized) {
            warnings.addAll(countAllResult.collectWarnings());
            errors.addAll(countAllResult.collectErrors());
        }
        String string = statType = normalized ? "FREQUENCY" : "COUNT";
        if (filter instanceof AnumFilter) {
            AnumFilter anumFilter = (AnumFilter)filter;
            ArrayList<String> modalities = new ArrayList<String>(anumFilter.values);
            if (modalities.size() != 1) {
                throw new IllegalArgumentException(String.format("Expected exactly 1 modality, got %d", modalities.size()));
            }
            row = this.rowBase(this.column.name, String.format("%s_MODALITY", statType), warnings, errors, cf, rf);
            row.put(cf.column("key"), (String)modalities.get(0));
        } else if (Objects.equals(filter.name, "Others")) {
            row = this.rowBase(this.column.name, String.format("%s_OTHERS", statType), warnings, errors, cf, rf);
        } else {
            throw new IllegalArgumentException(String.format("Unexpected filter for frequency table: %s", ((Object)((Object)filter)).getClass().getSimpleName()));
        }
        Long count = null;
        if (result.isAvailable()) {
            count = result.as(Count.CountResult.class).count;
        }
        Long totalCount = null;
        if (countAllResult.isAvailable()) {
            totalCount = countAllResult.as(Count.CountResult.class).count;
        }
        if (normalized) {
            if (count != null && totalCount != null && totalCount > 0L) {
                double frequency = (double)count.longValue() / (double)totalCount.longValue();
                row.put(cf.column("value"), frequency);
            }
        } else if (count != null) {
            row.put(cf.column("value"), count.longValue());
        }
        return row;
    }

    private Row rowBase(String columnName, String statsType, Collection<String> warnings, Collection<String> errors, ColumnFactory cf, RowFactory rf) {
        Row row = rf.row();
        row.put(cf.column("variable"), columnName);
        row.put(cf.column("origin"), TYPE);
        row.put(cf.column("statistic_type"), statsType);
        if (!warnings.isEmpty()) {
            row.put(cf.column("warnings"), JSON.json(warnings));
        }
        if (!errors.isEmpty()) {
            row.put(cf.column("errors"), JSON.json(errors));
        }
        return row;
    }

    @Override
    public void validate(UnivariateRecipePayloadParams payloadParams) {
        if (this.maxValues < 0) {
            throw ErrorContext.iae((String)"The frequency table max values parameter must be positive");
        }
    }
}

