/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.scoring.exports;

import com.dataiku.dip.analysis.model.prediction.ResolvedPredictionPreprocessingParams;
import com.dataiku.dip.analysis.model.preprocessing.FeaturePreprocessingParams;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.Schema;
import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.datasets.Type;
import com.dataiku.dip.scoring.exports.Columns;
import com.dataiku.dip.sql.SQLDialect;
import com.dataiku.dip.sql.queries.ExpressionBuilder;
import com.dataiku.dip.sql.queries.ExpressionUtils;
import com.dataiku.dip.sql.queries.SelectQueryBuilder;
import com.dataiku.dip.utils.DKUtils;
import com.dataiku.dss.shadelib.org.joda.time.DateTime;
import com.dataiku.scoring.pipelines.Binarizer;
import com.dataiku.scoring.pipelines.CategoricalEncoder;
import com.dataiku.scoring.pipelines.DatetimeCyclicalEncoder;
import com.dataiku.scoring.pipelines.DropRow;
import com.dataiku.scoring.pipelines.Dummifier;
import com.dataiku.scoring.pipelines.FeatureSelection;
import com.dataiku.scoring.pipelines.Flagger;
import com.dataiku.scoring.pipelines.ImputeWithValue;
import com.dataiku.scoring.pipelines.Interactions;
import com.dataiku.scoring.pipelines.PreprocessingPipeline;
import com.dataiku.scoring.pipelines.Processor;
import com.dataiku.scoring.pipelines.Rescaler;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import org.apache.log4j.Logger;

public class SQLPreprocessing {
    private static Logger logger = Logger.getLogger((String)"dku.sql.preprocessing");
    private static final ExpressionBuilder.ExpressionBuilderFactory EBF = new ExpressionBuilder.ExpressionBuilderFactory();
    static final String BACKUP_PREFIX = "__dku_in_";
    public static final DateTime EPOCH_FOR_DOCTOR = DKUtils.getISODateFormatter().withZoneUTC().parseDateTime("1900-01-01T00:00:00.000Z");

    private static Columns backupColumns(Collection<String> columnsIn, Collection<String> toBackup) {
        Columns columns = new Columns(columnsIn);
        for (String s : toBackup) {
            columns.addExpression(EBF.col(s), BACKUP_PREFIX + s);
        }
        return columns;
    }

    private static SelectWithSchema coerce(SelectWithSchema query, ResolvedPredictionPreprocessingParams params, Schema inputSchema) {
        ArrayList toBeCastedStringToNumeric = Lists.newArrayList();
        for (Map.Entry feature : params.per_feature.entrySet()) {
            SchemaColumn schemaColumn = inputSchema.getColumn((String)feature.getKey());
            if (schemaColumn == null || ((FeaturePreprocessingParams)feature.getValue()).type != FeaturePreprocessingParams.FeatureType.NUMERIC || schemaColumn.getType() != Type.STRING) continue;
            toBeCastedStringToNumeric.add(schemaColumn);
        }
        Columns columns = query.columns;
        for (SchemaColumn column : toBeCastedStringToNumeric) {
            ExpressionBuilder col = EBF.col(column.getName());
            col = EBF.caseWhen(col.isNullOrEmptyString(), EBF.nullValue(column.getType(), column.getMaxLength()), col);
            col = col.castToFloat();
            columns.addExpression(col, column.getName());
        }
        SelectQueryBuilder sqb = columns.selectFrom(query.query);
        columns.flatten();
        return new SelectWithSchema(sqb, columns);
    }

    static SelectWithSchema normalize(SelectWithSchema query, ResolvedPredictionPreprocessingParams params, Dataset inputDataset, SQLDialect dialect) {
        ArrayList toNormalize = Lists.newArrayList();
        HashSet toCast = Sets.newHashSet();
        for (Map.Entry entry : params.per_feature.entrySet()) {
            SchemaColumn schemaColumn = inputDataset.getSchema().getColumn((String)entry.getKey());
            if (schemaColumn == null && ((FeaturePreprocessingParams)entry.getValue()).role == FeaturePreprocessingParams.Role.INPUT) {
                throw new IllegalArgumentException("Feature '" + (String)entry.getKey() + "' doesn't exist in the input");
            }
            if (((FeaturePreprocessingParams)entry.getValue()).role != FeaturePreprocessingParams.Role.INPUT || ((FeaturePreprocessingParams)entry.getValue()).type != FeaturePreprocessingParams.FeatureType.NUMERIC || schemaColumn == null || !schemaColumn.getType().isTemporal()) continue;
            toNormalize.add(schemaColumn);
            if (schemaColumn.getType() == Type.DATE) continue;
            toCast.add(schemaColumn);
        }
        if (toNormalize.isEmpty()) {
            return query;
        }
        Columns columns = query.columns;
        for (SchemaColumn column : toNormalize) {
            ExpressionBuilder col = ExpressionUtils.getAdjustedColumn(EBF.col(column.getName()), column, inputDataset.getParams(), inputDataset.isManaged(), dialect);
            if (toCast.contains(column)) {
                col.expr.outputType.dssType = column.getType();
                col = col.castToDate();
            }
            columns.addExpression(col.minusDate(EBF.cst(EPOCH_FOR_DOCTOR), "SECOND"), column.getName());
        }
        SelectQueryBuilder selectQueryBuilder = columns.selectFrom(query.query);
        columns.flatten();
        return new SelectWithSchema(selectQueryBuilder, columns);
    }

    public static SelectWithSchema preprocessingQuery(SQLDialect dialect, Dataset inputDataset, PreprocessingPipeline pipeline, ResolvedPredictionPreprocessingParams rppp, SelectQueryBuilder from, List<String> featureColumns, List<String> columnsToKeep, Double unrecordedValue) {
        SelectQueryBuilder sqb = from;
        Columns columns = SQLPreprocessing.backupColumns(featureColumns, columnsToKeep);
        SelectWithSchema normalized = SQLPreprocessing.normalize(new SelectWithSchema(sqb, columns), rppp, inputDataset, dialect);
        SelectWithSchema coerced = SQLPreprocessing.coerce(normalized, rppp, inputDataset.getSchema());
        sqb = coerced.query;
        columns = coerced.columns;
        logger.info((Object)("Sparse features unrecorded value: " + unrecordedValue));
        for (Processor step : pipeline.getStages()) {
            if (step instanceof Rescaler) {
                columns = SQLPreprocessing.rescalerQuery((Rescaler)step, columns);
            } else if (step instanceof Dummifier) {
                columns = SQLPreprocessing.dummifierQuery(dialect, (Dummifier)step, columns, unrecordedValue);
            } else if (step instanceof ImputeWithValue) {
                columns = SQLPreprocessing.imputeQuery((ImputeWithValue)step, columns);
            } else if (step instanceof CategoricalEncoder) {
                columns = SQLPreprocessing.categoricalEncoderQuery((CategoricalEncoder)step, columns);
            } else if (step instanceof Flagger) {
                columns = SQLPreprocessing.flaggerQuery((Flagger)step, columns, unrecordedValue);
            } else if (step instanceof Binarizer) {
                columns = SQLPreprocessing.binarizerQuery((Binarizer)step, columns, unrecordedValue);
            } else if (step instanceof DatetimeCyclicalEncoder) {
                columns = SQLPreprocessing.datetimeCyclicalQuery((DatetimeCyclicalEncoder)step, columns);
            } else if (step instanceof Interactions) {
                columns = SQLPreprocessing.interactionsQuery((Interactions)step, columns, unrecordedValue);
            } else if (step instanceof DropRow) {
                SelectWithSchema sws = SQLPreprocessing.dropRows(new SelectWithSchema(sqb, columns), (DropRow)step);
                columns = sws.columns;
                sqb = sws.query;
            } else {
                if (step instanceof FeatureSelection.Drop) continue;
                throw new UnsupportedOperationException("Step cannot be converted to SQL : " + String.valueOf(step.getClass()));
            }
            sqb = columns.selectFrom(sqb);
            columns.flatten();
        }
        return new SelectWithSchema(sqb, columns);
    }

    private static SelectWithSchema dropRows(SelectWithSchema query, DropRow dr) {
        String[] dropped = dr.getColumns();
        if (dropped.length == 0) {
            return query;
        }
        SelectQueryBuilder sqb = query.columns.selectFrom(query.query);
        ExpressionBuilder where = EBF.col(dropped[0]).isnotnull();
        for (int i = 1; i < dropped.length; ++i) {
            where = where.and(EBF.col(dropped[i]).isnotnull());
        }
        sqb.where(where);
        query.columns.flatten();
        return new SelectWithSchema(sqb, query.columns);
    }

    private static Columns rescalerQuery(Rescaler step, Columns schema) {
        String[] columns = step.getColumns();
        double[] shifts = step.getShifts();
        double[] invScales = step.getInv_scales();
        for (int i = 0; i < columns.length; ++i) {
            ExpressionBuilder eb = EBF.col(columns[i]).minus(shifts[i]).time(invScales[i]);
            schema.addExpression(eb, columns[i]);
        }
        return schema;
    }

    private static Columns interactionsQuery(Interactions interactions, Columns schema, Double missingValue) {
        for (Interactions.Interaction i : interactions.interactions) {
            if (i instanceof Interactions.NumericalNumericalInteraction) {
                schema = SQLPreprocessing.numNumQuery((Interactions.NumericalNumericalInteraction)i, schema);
                continue;
            }
            if (i instanceof Interactions.NumericalCategoricalInteraction) {
                schema = SQLPreprocessing.numCatQuery((Interactions.NumericalCategoricalInteraction)i, schema, missingValue);
                continue;
            }
            schema = SQLPreprocessing.catCatQuery((Interactions.CategoricalCategoricalInteraction)i, schema, missingValue);
        }
        return schema;
    }

    private static Columns numNumQuery(Interactions.NumericalNumericalInteraction interaction, Columns schema) {
        ExpressionBuilder col = EBF.col(interaction.column_1).time(EBF.col(interaction.column_2));
        if (interaction.rescale) {
            col = col.minus(EBF.cst(interaction.shift)).time(EBF.cst(interaction.invScale));
        }
        schema.addExpression(col, interaction.outName);
        return schema;
    }

    private static Columns numCatQuery(Interactions.NumericalCategoricalInteraction interaction, Columns schema, Double missingValue) {
        for (String s : interaction.values) {
            ExpressionBuilder col = EBF.col(interaction.cat).coalesce("N/A");
            ExpressionBuilder dum = EBF.caseWhen(new Object[0]).caseWhen(col.eq(s), EBF.col(interaction.num), missingValue);
            schema.addExpression(dum, interaction.outName(s));
        }
        return schema;
    }

    private static Columns catCatQuery(Interactions.CategoricalCategoricalInteraction interaction, Columns schema, Double missingValue) {
        for (Interactions.CategoricalCategoricalInteraction.StringPair p : interaction.values) {
            ExpressionBuilder col1 = EBF.col(interaction.column_1).coalesce("N/A");
            ExpressionBuilder col2 = EBF.col(interaction.column_2).coalesce("N/A");
            ExpressionBuilder dum = EBF.caseWhen(new Object[0]).caseWhen(col1.eq(p.a).and(col2.eq(p.b)), 1.0, missingValue);
            schema.addExpression(dum, interaction.outName(p));
        }
        return schema;
    }

    private static Columns dummifierQuery(SQLDialect dialect, Dummifier step, Columns schema, Double unrecordedValue) {
        String[] columns = step.getColumns();
        for (int i = 0; i < columns.length; ++i) {
            for (String s : (Set)step.getLevels().get(i)) {
                ExpressionBuilder col = EBF.col(columns[i]);
                ExpressionBuilder dum = EBF.caseWhen(new Object[0]).caseWhen(col.eq(s), 1.0, unrecordedValue);
                schema.addExpression(dum, Dummifier.dummifyName((String)columns[i], (String)s));
            }
            ExpressionBuilder col = EBF.col(columns[i]);
            ExpressionBuilder naDummy = EBF.caseWhen(new Object[0]).caseWhen(col.isnull(), 1.0, unrecordedValue);
            schema.addExpression(naDummy, Dummifier.dummifyName((String)columns[i], (String)"N/A"));
            if (!step.getWithOthers()[i]) continue;
            ExpressionBuilder otherDummies = EBF.caseWhen(col.castToString(dialect.getDefaultVarcharLen()).inList((Collection)step.getLevels().get(i)).or(col.isnull()), unrecordedValue, 1.0);
            schema.addExpression(otherDummies, Dummifier.dummifyName((String)columns[i], (String)"__Others__"));
        }
        return schema;
    }

    private static Columns imputeQuery(ImputeWithValue step, Columns columns) {
        for (Map.Entry e : step.getColumnMapping().entrySet()) {
            ExpressionBuilder eb = EBF.col((String)e.getKey()).coalesce(e.getValue());
            columns.addExpression(eb, (String)e.getKey());
        }
        return columns;
    }

    private static Columns categoricalEncoderQuery(CategoricalEncoder step, Columns schema) {
        String[] columns = step.getColumns();
        List values = step.getEncodings();
        double[][] defaults = step.getDefaults();
        String[][] outputNames = step.getOutputNames();
        for (int i = 0; i < columns.length; ++i) {
            int j;
            ExpressionBuilder[] eb = new ExpressionBuilder[outputNames[i].length];
            for (j = 0; j < eb.length; ++j) {
                eb[j] = EBF.caseWhen(new Object[0]);
            }
            for (Map.Entry e : ((Map)values.get(i)).entrySet()) {
                double[] vals = (double[])e.getValue();
                if (Objects.equals(e.getKey(), "_default_")) continue;
                for (int j2 = 0; j2 < vals.length; ++j2) {
                    eb[j2] = eb[j2].caseWhen(EBF.col(columns[i]).eq(e.getKey()), vals[j2]);
                }
            }
            for (j = 0; j < eb.length; ++j) {
                eb[j] = eb[j].caseWhen(defaults[i][j]);
                schema.addExpression(eb[j], outputNames[i][j]);
            }
        }
        return schema;
    }

    private static Columns flaggerQuery(Flagger step, Columns columns, Double missingValue) {
        String[] cols = step.getColumns();
        String[] out = step.getOutputNames();
        for (int i = 0; i < cols.length; ++i) {
            columns.addExpression(EBF.caseWhen(new Object[0]).caseWhen(EBF.col(cols[i]).isnull(), missingValue, 1.0), out[i]);
        }
        return columns;
    }

    private static Columns binarizerQuery(Binarizer step, Columns columns, Double missingValue) {
        String[] cols = step.getColumns();
        String[] out = step.getOutputColumns();
        double[] thresholds = step.getThresholds();
        for (int i = 0; i < cols.length; ++i) {
            columns.addExpression(EBF.caseWhen(new Object[0]).caseWhen(EBF.col(cols[i]).gt(thresholds[i]), 1.0, missingValue), out[i]);
        }
        return columns;
    }

    private static Columns datetimeCyclicalQuery(DatetimeCyclicalEncoder step, Columns columns) {
        for (Map.Entry entry : step.mapping.entrySet()) {
            String column = (String)entry.getKey();
            for (DatetimeCyclicalEncoder.Period period : (Set)entry.getValue()) {
                ExpressionBuilder date = EBF.cst(EPOCH_FOR_DOCTOR).dateAdd(EBF.col(column), "SECOND");
                ExpressionBuilder expr = date.minusDate(date.dateTrunc(period.name()), "SECOND").time(2, Math.PI).div(period.durationInSeconds);
                String prefix = "datetime_cyclical:" + column + ":" + period.name().toLowerCase() + ":";
                columns.addExpression(expr.sin(), prefix + "sin");
                columns.addExpression(expr.cos(), prefix + "cos");
            }
        }
        return columns;
    }

    static class SelectWithSchema {
        final SelectQueryBuilder query;
        final Columns columns;

        SelectWithSchema(SelectQueryBuilder query, Columns columns) {
            this.query = query;
            this.columns = columns;
        }
    }
}

