/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.shaker.processors.cleansing;

import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.datalayer.Column;
import com.dataiku.dip.datalayer.Processor;
import com.dataiku.dip.datalayer.memimpl.MemRow;
import com.dataiku.dip.datalayer.memimpl.MemTable;
import com.dataiku.dip.datasets.Type;
import com.dataiku.dip.futures.FutureProgressState;
import com.dataiku.dip.shaker.ProcessorWithRecordedReport;
import com.dataiku.dip.shaker.facet.CountMap;
import com.dataiku.dip.shaker.model.StepParams;
import com.dataiku.dip.shaker.processors.AppliesToProcessor;
import com.dataiku.dip.shaker.processors.AppliesToUtils;
import com.dataiku.dip.shaker.processors.Category;
import com.dataiku.dip.shaker.processors.MemTableProcessor;
import com.dataiku.dip.shaker.processors.ProcessorCapabilities;
import com.dataiku.dip.shaker.processors.ProcessorMeta;
import com.dataiku.dip.shaker.processors.ProcessorTag;
import com.dataiku.dip.shaker.server.ProcessorDesc;
import com.dataiku.dip.shaker.sql.ProcessorSQLTranslator;
import com.dataiku.dip.shaker.sql.SQLQueryWithSchema;
import com.dataiku.dip.sql.SQLDialect;
import com.dataiku.dip.sql.queries.ExpressionBuilder;
import com.dataiku.dip.sql.queries.QueryAst;
import com.dataiku.dip.sql.queries.QueryUtils;
import com.dataiku.dip.sql.queries.SelectQueryBuilder;
import com.dataiku.dip.util.ParamDesc;
import com.dataiku.dip.util.SecretKeyGenerator;
import com.dataiku.dip.utils.Pair;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Set;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang.mutable.MutableInt;

public class FillEmptyWithComputedValue {
    public static final ProcessorMeta<MemTableImpl, Parameter> META = new ProcessorMeta<MemTableImpl, Parameter>(){

        @Override
        public String getName() {
            return "FillEmptyWithComputedValue";
        }

        @Override
        public String getDocPage() {
            return "fill-empty-with-computed-value";
        }

        @Override
        public Category getCategory() {
            return Category.CLEANSING;
        }

        @Override
        public Set<ProcessorTag> getTags() {
            return Sets.newHashSet((Object[])new ProcessorTag[]{ProcessorTag.CLEANSING});
        }

        @Override
        public String getHelp(String language) {
            return this.translate(language, "SHAKER.PROCESSOR.FillEmptyWithComputedValue.HELP", "This processor imputes missing values with mean or median");
        }

        @Override
        public Class<Parameter> stepParamClass() {
            return Parameter.class;
        }

        @Override
        public ProcessorMeta.ProcessorCapabilitiesSummary getCapabilities(StepParams sp, ProcessorWithRecordedReport.ProcessorRecordedReport report, SQLDialect dialect) {
            Parameter params = (Parameter)sp;
            ProcessorMeta.ProcessorCapabilitiesSummary ret = new ProcessorMeta.ProcessorCapabilitiesSummary().withCan(ProcessorCapabilities.NO_STREAM_IMPL, ProcessorCapabilities.NATIVE_SPARK_IMPL);
            if (params.mode == ComputedValueMode.MEDIAN && dialect != null && dialect.getOperator(QueryUtils.OperatorType.PERCENTILE_APPROX_AGG) == null && dialect.getOperator(QueryUtils.OperatorType.PERCENTILE_APPROX_WIN) == null) {
                ret.withCould(ProcessorCapabilities.SQL_TRANSLATABLE, "Cannot use SQL engine, the database doesn't support percentile function");
            } else {
                ret.withCan(ProcessorCapabilities.SQL_TRANSLATABLE);
            }
            return ret;
        }

        @Override
        public String getNativeSparkClassname() {
            return "com.dataiku.dip.shaker.processors.cleansing.FillEmptyWithComputedValueNS";
        }

        @Override
        public ProcessorDesc describe(String language) {
            return ProcessorDesc.withGenericForm(this.getName(), this.translate(language, "SHAKER.PROCESSOR.FillEmptyWithComputedValue.DESCRIPTION", "Impute with computed value")).withParam(ParamDesc.advancedSelect("mode", this.translate(language, "SHAKER.PROCESSOR.FillEmptyWithComputedValue.DESCRIPTION.COMPUTED_VALUE", "Computed value"), this.translate(language, "SHAKER.PROCESSOR.FillEmptyWithComputedValue.DESCRIPTION.COMPUTED_VALUE.TOOLTIP", "Computation mode for the replacement value"), new String[]{"MEAN", "MEDIAN", "MODE"}, new String[]{this.translate(language, "SHAKER.PROCESSOR.FillEmptyWithComputedValue.DESCRIPTION.COMPUTED_VALUE.MEAN", "Mean"), this.translate(language, "SHAKER.PROCESSOR.FillEmptyWithComputedValue.DESCRIPTION.COMPUTED_VALUE.MEDIAN", "Median"), this.translate(language, "SHAKER.PROCESSOR.FillEmptyWithComputedValue.DESCRIPTION.COMPUTED_VALUE.MODE", "Mode")})).deprecate().doNotDisplayInLibrary();
        }

        @Override
        public MemTableImpl build(Parameter parameter) {
            return new MemTableImpl(parameter);
        }

        @Override
        public ProcessorSQLTranslator getSQLTranslator(StepParams parameter, ProcessorWithRecordedReport.ProcessorRecordedReport report) throws IOException {
            return new SQLTranslator((Parameter)parameter);
        }

        @Override
        public Object selfReport(Parameter parameter) {
            return AppliesToProcessor.selfReport(parameter);
        }
    };

    public static class MemTableImpl
    implements MemTableProcessor,
    Processor {
        private final Parameter param;

        public MemTableImpl(Parameter param) {
            this.param = param;
        }

        @Override
        public void process(MemTable table) throws InterruptedException {
            List<Column> consideredColumns = AppliesToUtils.getColumnsForMemTable(this.param, table);
            for (Column col : consideredColumns) {
                FutureProgressState.checkInterrupt();
                ArrayList values = Lists.newArrayList();
                for (MemRow row : table.rows) {
                    String v;
                    if (row.isDeleted() || !StringUtils.isNotEmpty((String)(v = row.get(col, "")))) continue;
                    values.add(v);
                }
                String replacementValue = switch (this.param.mode) {
                    case ComputedValueMode.MEAN -> this.getMean(values);
                    case ComputedValueMode.MEDIAN -> this.getMedian(values);
                    case ComputedValueMode.MODE -> this.getMode(values);
                    default -> throw new Error("unreachable");
                };
                for (MemRow row : table.rows) {
                    String v;
                    if (row.isDeleted() || !StringUtils.isEmpty((String)(v = row.get(col, "")))) continue;
                    row.put(col, replacementValue);
                }
            }
            table.compact();
        }

        private String getMode(List<String> values) {
            if (values.size() == 0) {
                return null;
            }
            CountMap<String> counts = new CountMap<String>();
            for (String v : values) {
                counts.inc(v);
            }
            return (String)Collections.max(counts.getMap().entrySet(), Comparator.comparingInt(a -> ((MutableInt)a.getValue()).intValue())).getKey();
        }

        private String getMean(List<String> values) {
            if (values.size() == 0) {
                return "";
            }
            double sum = 0.0;
            for (String v : values) {
                sum += Double.parseDouble(v);
            }
            return Double.toString(sum / (double)values.size());
        }

        private String getMedian(List<String> values) {
            if (values.size() == 0) {
                return "";
            }
            ArrayList comparables = Lists.newArrayList();
            for (String v : values) {
                comparables.add(new Pair((Object)v, (Object)Double.parseDouble(v)));
            }
            Collections.sort(comparables, new Comparator<Pair<String, Double>>(){

                @Override
                public int compare(Pair<String, Double> a, Pair<String, Double> b) {
                    double d = (Double)a.second - (Double)b.second;
                    return d < 0.0 ? -1 : (d > 0.0 ? 1 : 0);
                }
            });
            return (String)((Pair)comparables.get((int)(values.size() / 2))).first;
        }
    }

    private static class SQLTranslator
    implements ProcessorSQLTranslator {
        private final Parameter parameter;
        private ExpressionBuilder.ExpressionBuilderFactory ebf = new ExpressionBuilder.ExpressionBuilderFactory();

        private SQLTranslator(Parameter parameter) {
            this.parameter = parameter;
        }

        @Override
        public SQLQueryWithSchema translate(SQLQueryWithSchema chain) {
            List<String> appliesToColumns = chain.getAppliesToColumns(this.parameter);
            if (chain.isAnyCreatedOrModifiedByCurrentQuery(appliesToColumns)) {
                chain = chain.makeSubquery();
            }
            SQLQueryWithSchema queryBeforeStep = chain.getCopy("values");
            for (String column : appliesToColumns) {
                chain = this.handle(chain, column, queryBeforeStep);
            }
            return chain;
        }

        private void filterOutNullOrEmpty(SelectQueryBuilder query, SQLQueryWithSchema queryBeforeStep, String column) {
            if (queryBeforeStep.getMandatoryCurrentColumn(column).getType() == Type.STRING) {
                query.where(queryBeforeStep.col(column).isNullOrEmptyString().not());
            } else {
                query.where(queryBeforeStep.col(column).isnotnull());
            }
        }

        private ExpressionBuilder castToNumberIfNeeded(SQLQueryWithSchema queryBeforeStep, String column) {
            ExpressionBuilder colExpr = queryBeforeStep.col(column);
            if (!queryBeforeStep.getMandatoryCurrentColumn(column).getType().isNumeric()) {
                colExpr = colExpr.castToFloat();
            }
            return colExpr;
        }

        private SQLQueryWithSchema handle(SQLQueryWithSchema chain, String column, SQLQueryWithSchema queryBeforeStep) {
            SelectQueryBuilder valuePickingQuery;
            String randomness = SecretKeyGenerator.generate((int)6);
            String valuesSubQueryName = "value_" + randomness;
            String countColumnName = "cnt_" + randomness;
            String rankColumnName = "rank_" + randomness;
            SchemaColumn schemaColumn = chain.getMandatoryCurrentColumn(column);
            Type outputType = schemaColumn.getType();
            switch (this.parameter.mode) {
                case MEAN: {
                    valuePickingQuery = queryBeforeStep.subQuery("values");
                    this.filterOutNullOrEmpty(valuePickingQuery, queryBeforeStep, column);
                    valuePickingQuery.select(this.castToNumberIfNeeded(queryBeforeStep, column).avg(), column);
                    outputType = Type.DOUBLE;
                    break;
                }
                case MEDIAN: {
                    valuePickingQuery = queryBeforeStep.subQuery("values");
                    this.filterOutNullOrEmpty(valuePickingQuery, queryBeforeStep, column);
                    if (chain.getDialect().getOperator(QueryUtils.OperatorType.PERCENTILE_APPROX_AGG) != null) {
                        valuePickingQuery.select(this.castToNumberIfNeeded(queryBeforeStep, column).percentileApproxAgg(0.5).cast(schemaColumn.getType(), schemaColumn.getMaxLength()), column);
                        break;
                    }
                    valuePickingQuery.select(this.castToNumberIfNeeded(queryBeforeStep, column).percentileApproxWin(0.5).cast(schemaColumn.getType(), schemaColumn.getMaxLength()), column);
                    valuePickingQuery.selectDistinct();
                    break;
                }
                case MODE: {
                    SelectQueryBuilder countsQuery = queryBeforeStep.subQuery("values");
                    countsQuery.select(queryBeforeStep.col(column));
                    countsQuery.select(this.ebf.count("*"), countColumnName);
                    countsQuery.group(queryBeforeStep.col(column));
                    this.filterOutNullOrEmpty(countsQuery, queryBeforeStep, column);
                    QueryAst.Window windowUnbounded = SelectQueryBuilder.window(new ArrayList<ExpressionBuilder>(), Lists.newArrayList((Object[])new ExpressionBuilder[]{this.ebf.col(countColumnName)}), Lists.newArrayList((Object[])new QueryAst.OrderType[]{QueryAst.OrderType.DESC}));
                    windowUnbounded = SelectQueryBuilder.unboundedWindow(windowUnbounded);
                    SelectQueryBuilder rankQuery = new SelectQueryBuilder();
                    rankQuery.from(countsQuery, "counts");
                    rankQuery.select(queryBeforeStep.col(column));
                    rankQuery.select(this.ebf.rowNumber().over(windowUnbounded), rankColumnName);
                    valuePickingQuery = new SelectQueryBuilder();
                    valuePickingQuery.select(queryBeforeStep.col(column));
                    valuePickingQuery.from(rankQuery, "ranks");
                    valuePickingQuery.where(this.ebf.col(rankColumnName).eq(this.ebf.cst(1)));
                    break;
                }
                default: {
                    throw new Error("unreachable");
                }
            }
            ExpressionBuilder c2 = chain.col(chain.getCurrentMainAlias(), column);
            chain.join(valuePickingQuery, QueryAst.JoinType.LEFT, valuesSubQueryName).on(this.ebf.cst(1).eq(this.ebf.cst(1)));
            ExpressionBuilder coalescedValue = outputType == Type.STRING ? c2.coalesceNullOrEmptyString(queryBeforeStep.col(valuesSubQueryName, column)) : c2.coalesce(queryBeforeStep.col(valuesSubQueryName, column));
            chain.replaceSelect(column, coalescedValue.cast(outputType, schemaColumn.maxLength), column);
            chain.markColumnModified(column);
            return chain;
        }
    }

    public static class Parameter
    extends AppliesToProcessor.AppliesToParams {
        private static final long serialVersionUID = -1L;
        public ComputedValueMode mode = ComputedValueMode.MEAN;
    }

    public static enum ComputedValueMode {
        MEAN,
        MEDIAN,
        MODE;

    }
}

