/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.shaker.processors.cleansing;

import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.datalayer.Column;
import com.dataiku.dip.datalayer.Processor;
import com.dataiku.dip.datalayer.memimpl.MemRow;
import com.dataiku.dip.datalayer.memimpl.MemTable;
import com.dataiku.dip.datalineage.RecipeLineage;
import com.dataiku.dip.datasets.Type;
import com.dataiku.dip.futures.FutureProgressState;
import com.dataiku.dip.shaker.ProcessorWithRecordedReport;
import com.dataiku.dip.shaker.facet.CountMap;
import com.dataiku.dip.shaker.facet.FacetUtils;
import com.dataiku.dip.shaker.model.ProcessorScriptStep;
import com.dataiku.dip.shaker.model.StepParams;
import com.dataiku.dip.shaker.processors.AppliesToProcessor;
import com.dataiku.dip.shaker.processors.AppliesToUtils;
import com.dataiku.dip.shaker.processors.Category;
import com.dataiku.dip.shaker.processors.MemTableProcessor;
import com.dataiku.dip.shaker.processors.ProcessorCapabilities;
import com.dataiku.dip.shaker.processors.ProcessorMeta;
import com.dataiku.dip.shaker.processors.ProcessorTag;
import com.dataiku.dip.shaker.server.ProcessorDesc;
import com.dataiku.dip.shaker.sql.ProcessorSQLTranslator;
import com.dataiku.dip.shaker.sql.SQLQueryWithSchema;
import com.dataiku.dip.sql.SQLDialect;
import com.dataiku.dip.sql.queries.ExpressionBuilder;
import com.dataiku.dip.sql.queries.QueryAst;
import com.dataiku.dip.sql.queries.SelectQueryBuilder;
import com.dataiku.dip.util.ParamDesc;
import com.dataiku.dip.util.SecretKeyGenerator;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class MergeLongTailValues {
    public static final ProcessorMeta<MemTableImpl, Parameter> META = new ProcessorMeta<MemTableImpl, Parameter>(){

        @Override
        public String getName() {
            return "MergeLongTailValues";
        }

        @Override
        public String getDocPage() {
            return "merge-long-tail-values";
        }

        @Override
        public Category getCategory() {
            return Category.FILTER;
        }

        @Override
        public Set<ProcessorTag> getTags() {
            return Sets.newHashSet((Object[])new ProcessorTag[]{ProcessorTag.FILTER});
        }

        @Override
        public String getHelp(String language) {
            return this.translate(language, "SHAKER.PROCESSOR.MergeLongTailValues.HELP", "This processor merges values below a certain appearance threshold");
        }

        @Override
        public Class<Parameter> stepParamClass() {
            return Parameter.class;
        }

        @Override
        public Object selfReport(Parameter parameter) {
            JsonObject obj = AppliesToProcessor.selfReport(parameter);
            obj.remove("replacementValue");
            return obj;
        }

        @Override
        public ProcessorMeta.ProcessorCapabilitiesSummary getCapabilities(StepParams sp, ProcessorWithRecordedReport.ProcessorRecordedReport report, SQLDialect dialect) {
            return new ProcessorMeta.ProcessorCapabilitiesSummary().withCan(ProcessorCapabilities.NO_STREAM_IMPL, ProcessorCapabilities.NATIVE_SPARK_IMPL, ProcessorCapabilities.SQL_TRANSLATABLE);
        }

        @Override
        public String getNativeSparkClassname() {
            return "com.dataiku.dip.shaker.processors.cleansing.MergeLongTailValuesNS";
        }

        @Override
        public ProcessorDesc describe(String language) {
            return ProcessorDesc.withGenericForm(this.getName(), this.translate(language, "SHAKER.PROCESSOR.MergeLongTailValues.DESCRIPTION", "Merge long-tail values")).withParam(ParamDesc.advancedSelect("thresholdMode", this.translate(language, "SHAKER.PROCESSOR.MergeLongTailValues.DESCRIPTION.MODE", "Mode"), this.translate(language, "SHAKER.PROCESSOR.MergeLongTailValues.DESCRIPTION.MODE.TOOLTIP", "Threshold mode"), new String[]{"COUNT", "CUM_RATIO"}, new String[]{this.translate(language, "SHAKER.PROCESSOR.MergeLongTailValues.DESCRIPTION.MODE.COUNT", "Count"), this.translate(language, "SHAKER.PROCESSOR.MergeLongTailValues.DESCRIPTION.MODE.CUMULATIVE_RATIO", "Cumulative ratio")})).withParam(ParamDesc.intP("countThreshold", this.translate(language, "SHAKER.PROCESSOR.MergeLongTailValues.DESCRIPTION.COUNT", "Count"), this.translate(language, "SHAKER.PROCESSOR.MergeLongTailValues.DESCRIPTION.COUNT.TOOLTIP", "Max count"), 10)).withParam(ParamDesc.doubleP("cumRatioThreshold", this.translate(language, "SHAKER.PROCESSOR.MergeLongTailValues.DESCRIPTION.CUMULATIVE_RATIO", "Cumulative ratio"), this.translate(language, "SHAKER.PROCESSOR.MergeLongTailValues.DESCRIPTION.CUMULATIVE_RATIO.TOOLTIP", "Max cumulative ratio"), 0.8)).withMNESParam("replacementValue", this.translate(language, "SHAKER.PROCESSOR.MergeLongTailValues.DESCRIPTION.REPLACEMENT_VALUE", "Replacement value"));
        }

        @Override
        public MemTableImpl build(Parameter parameter) {
            return new MemTableImpl(parameter);
        }

        @Override
        public ProcessorSQLTranslator getSQLTranslator(StepParams parameter, ProcessorWithRecordedReport.ProcessorRecordedReport report) throws IOException {
            return new SQLTranslator((Parameter)parameter);
        }

        @Override
        public RecipeLineage getUpdatedRecipeLineage(ProcessorScriptStep pss, RecipeLineage previousRecipeLineage) {
            return previousRecipeLineage;
        }
    };

    public static class MemTableImpl
    implements MemTableProcessor,
    Processor {
        private final Parameter param;

        public MemTableImpl(Parameter param) {
            this.param = param;
        }

        @Override
        public void process(MemTable table) throws InterruptedException {
            List<Column> consideredColumns = AppliesToUtils.getColumnsForMemTable(this.param, table);
            for (Column col : consideredColumns) {
                FutureProgressState.checkInterrupt();
                CountMap<String> counts = new CountMap<String>();
                for (MemRow memRow : table.rows) {
                    if (memRow.isDeleted()) continue;
                    String v = memRow.get(col, "");
                    counts.inc(v);
                }
                List values = new ArrayList<FacetUtils.FacetValue>();
                for (Map.Entry e : counts) {
                    FacetUtils.FacetValue fv = new FacetUtils.FacetValue();
                    fv.key = (String)e.getKey();
                    fv.value = e.getValue().intValue();
                    values.add(fv);
                }
                Collections.sort(values, FacetUtils.Sort.COUNT_FAST);
                switch (this.param.thresholdMode) {
                    case COUNT: {
                        if (values.size() <= this.param.countThreshold) break;
                        values = values.subList(0, this.param.countThreshold);
                        break;
                    }
                    case CUM_RATIO: {
                        double d = 0.0;
                        int lastKept = -1;
                        for (FacetUtils.FacetValue fv : values) {
                            if ((d += (double)fv.value) > (double)table.rows.size() * this.param.cumRatioThreshold) break;
                            ++lastKept;
                        }
                        if (values.size() <= lastKept + 1) break;
                        values = values.subList(0, lastKept + 1);
                    }
                }
                HashSet<String> hashSet = new HashSet<String>();
                for (FacetUtils.FacetValue fv : values) {
                    hashSet.add(fv.key);
                }
                for (MemRow row3 : table.rows) {
                    String v;
                    if (row3.isDeleted() || hashSet.contains(v = row3.get(col, ""))) continue;
                    row3.put(col, this.param.replacementValue);
                }
            }
            table.compact();
        }
    }

    private static class SQLTranslator
    implements ProcessorSQLTranslator {
        private final Parameter parameter;
        private ExpressionBuilder.ExpressionBuilderFactory ebf = new ExpressionBuilder.ExpressionBuilderFactory();

        private SQLTranslator(Parameter parameter) {
            this.parameter = parameter;
        }

        @Override
        public SQLQueryWithSchema translate(SQLQueryWithSchema chain) {
            List<String> appliesToColumns = chain.getAppliesToColumns(this.parameter);
            if (chain.isAnyCreatedOrModifiedByCurrentQuery(appliesToColumns)) {
                chain = chain.makeSubquery();
            }
            SQLQueryWithSchema queryBeforeStep = chain.getCopy("values");
            for (String column : appliesToColumns) {
                chain = this.handle(chain, column, queryBeforeStep);
            }
            return chain;
        }

        private SQLQueryWithSchema handle(SQLQueryWithSchema chain, String column, SQLQueryWithSchema queryBeforeStep) {
            SelectQueryBuilder countsQuery = queryBeforeStep.subQuery("values");
            String randomness = SecretKeyGenerator.generate((int)6);
            String valuesSubQueryName = "values_" + randomness;
            String countColumnName = "cnt_" + randomness;
            String rankColumnName = "rank_" + randomness;
            String isnullColumnName = "isnull_" + randomness;
            countsQuery.group(queryBeforeStep.col(column));
            countsQuery.select(queryBeforeStep.col(column));
            SelectQueryBuilder valuePickingQuery = new SelectQueryBuilder();
            valuePickingQuery.select(queryBeforeStep.col(column));
            valuePickingQuery.select(this.ebf.caseWhen(queryBeforeStep.col(column).isnull(), this.ebf.cst("true"), this.ebf.cst("false")), isnullColumnName);
            switch (this.parameter.thresholdMode) {
                case COUNT: {
                    countsQuery.select(this.ebf.count("*"), countColumnName);
                    QueryAst.Window windowUnbounded = SelectQueryBuilder.window(new ArrayList<ExpressionBuilder>(), Lists.newArrayList((Object[])new ExpressionBuilder[]{this.ebf.col(countColumnName)}), Lists.newArrayList((Object[])new QueryAst.OrderType[]{QueryAst.OrderType.DESC}));
                    windowUnbounded = SelectQueryBuilder.unboundedWindow(windowUnbounded);
                    SelectQueryBuilder rankQuery = new SelectQueryBuilder();
                    rankQuery.from(countsQuery, "counts");
                    rankQuery.select(queryBeforeStep.col(column));
                    rankQuery.select(this.ebf.rowNumber().over(windowUnbounded), rankColumnName);
                    valuePickingQuery.from(rankQuery, "picks");
                    valuePickingQuery.where(this.ebf.col(rankColumnName).lte(this.ebf.cst(this.parameter.countThreshold)));
                    break;
                }
                case CUM_RATIO: {
                    countsQuery.select(this.ebf.count("*").castToFloat(), countColumnName);
                    QueryAst.Window windowUnbounded = SelectQueryBuilder.window(new ArrayList<ExpressionBuilder>(), Lists.newArrayList((Object[])new ExpressionBuilder[]{this.ebf.col(countColumnName)}), Lists.newArrayList((Object[])new QueryAst.OrderType[]{QueryAst.OrderType.DESC}));
                    windowUnbounded.withFrame(QueryAst.WindowFrameMode.ROWS, null, QueryAst.WindowFrameDirection.PRECEDING, null, QueryAst.WindowFrameDirection.FOLLOWING, null);
                    QueryAst.Window window = SelectQueryBuilder.window(new ArrayList<ExpressionBuilder>(), Lists.newArrayList((Object[])new ExpressionBuilder[]{this.ebf.col(countColumnName)}), Lists.newArrayList((Object[])new QueryAst.OrderType[]{QueryAst.OrderType.DESC}));
                    window.withFrame(QueryAst.WindowFrameMode.ROWS, null, QueryAst.WindowFrameDirection.PRECEDING, "0", QueryAst.WindowFrameDirection.FOLLOWING, null);
                    SelectQueryBuilder rankQuery = new SelectQueryBuilder();
                    rankQuery.from(countsQuery, "counts");
                    rankQuery.select(queryBeforeStep.col(column));
                    rankQuery.select(this.ebf.col(countColumnName).sum().over(window).div(this.ebf.col(countColumnName).sum().over(windowUnbounded)), rankColumnName);
                    valuePickingQuery.from(rankQuery, "picks");
                    valuePickingQuery.where(this.ebf.col(rankColumnName).lte(this.ebf.cst(this.parameter.cumRatioThreshold)));
                }
            }
            SchemaColumn schemaColumn = chain.getMandatoryCurrentColumn(column);
            ExpressionBuilder joinOnCol = chain.col(chain.getCurrentMainAlias(), column);
            ExpressionBuilder joinedCol = queryBeforeStep.col(valuesSubQueryName, column);
            ExpressionBuilder nullTweakedJoinOnColumn = (schemaColumn.getType() == Type.STRING ? joinOnCol : joinOnCol.castToString(200)).coalesce(this.ebf.cst("__dku_null__").castToString(200));
            ExpressionBuilder nullTweakedJoinedColumn = (schemaColumn.getType() == Type.STRING ? joinedCol : joinedCol.castToString(200)).coalesce(this.ebf.cst("__dku_null__").castToString(200));
            chain.join(valuePickingQuery, QueryAst.JoinType.LEFT, valuesSubQueryName).on(nullTweakedJoinOnColumn.nullUnsafeEq(nullTweakedJoinedColumn));
            ExpressionBuilder replacement = this.ebf.cst(this.parameter.replacementValue);
            replacement = replacement.cast(schemaColumn.getType(), schemaColumn.maxLength);
            ExpressionBuilder mergedCol = joinedCol.coalesce(replacement).cast(schemaColumn.getType(), schemaColumn.maxLength);
            ExpressionBuilder nullJoinedCol = this.ebf.caseWhen(this.ebf.col(valuesSubQueryName, isnullColumnName).nullUnsafeEq(this.ebf.cst("true")), this.ebf.nullValue(schemaColumn.getType(), schemaColumn.maxLength), mergedCol);
            chain.replaceSelect(column, nullJoinedCol.cast(schemaColumn.getType(), schemaColumn.maxLength), column);
            chain.markColumnModified(column);
            return chain;
        }
    }

    public static class Parameter
    extends AppliesToProcessor.AppliesToParams {
        private static final long serialVersionUID = -1L;
        public LongTailThresholdMode thresholdMode = LongTailThresholdMode.COUNT;
        public int countThreshold = 10;
        public double cumRatioThreshold = 0.8;
        public String replacementValue;
    }

    public static enum LongTailThresholdMode {
        COUNT,
        CUM_RATIO;

    }
}

