/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.sql;

import com.dataiku.dip.classpathfix.DKUDoubles;
import com.dataiku.dip.coremodel.Dataset;
import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.datasets.Type;
import com.dataiku.dip.pivot.UnsupportedOperation;
import com.dataiku.dip.spark.SparkJobHelper;
import com.dataiku.dip.sql.DSSTypeSQLMapping;
import com.dataiku.dip.sql.DatePart;
import com.dataiku.dip.sql.DateRounding;
import com.dataiku.dip.sql.GenericSQLDialect;
import com.dataiku.dip.sql.HiveLikeSQLDialect;
import com.dataiku.dip.sql.SQLAggregateAbility;
import com.dataiku.dip.sql.SQLAggregateType;
import com.dataiku.dip.sql.queries.ExpressionBuilder;
import com.dataiku.dip.sql.queries.QueryAst;
import com.dataiku.dip.sql.queries.QueryUtils;
import com.dataiku.dip.utils.NotImplementedException;
import com.google.common.collect.Lists;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

public abstract class AbstractSparkSQLDialect
extends HiveLikeSQLDialect {
    private static Pattern TZ_AWARE_TIMESTAMP_STRING = Pattern.compile("^[0-9]{4}.[0-9]{2}.[0-9]{2}.[0-9]{2}.[0-9]{2}.[0-9]{2}(\\.[0-9]+)?\\s*(Z|[+-][0-9:]+)$");

    @Override
    public String cleanupColumnName(String columnName) {
        return super.cleanupColumnName(columnName);
    }

    @Override
    public int getMaxPossibleVarcharLen() {
        return 65500;
    }

    @Override
    public boolean canSQL99() {
        return true;
    }

    @Override
    public boolean supportsResultSetMetadataOnPreparedStatement(String sql) {
        return true;
    }

    @Override
    public DSSTypeSQLMapping getSQLType(SchemaColumn schemaColumn, Dataset dataset) {
        switch (schemaColumn.getType()) {
            case DATETIMENOTZ: {
                return new DSSTypeSQLMapping(Type.DATETIMENOTZ, 93, "timestamp_ntz", new Integer[]{91});
            }
        }
        return super.getSQLType(schemaColumn, dataset);
    }

    @Override
    public boolean lacksTimezoneInfo(String sqlTypeName, int sqlPrecision) {
        return "timestamp_ntz".equalsIgnoreCase(sqlTypeName);
    }

    @Override
    public String useUTCTimezone() {
        return "SET TIME ZONE 'UTC'";
    }

    @Override
    public void initOperators() {
        super.initOperators();
        this.addOperator(new QueryUtils.Function(this, QueryUtils.OperatorType.AGG_CONCAT, "concat_ws", QueryUtils.Arity.TERNARY){

            @Override
            public String apply(QueryAst.Expr[] args) {
                this.validateMinNumberOfParameters(args, 1);
                String column = this.toSQLNoBrackets(args[0]);
                String separator = null;
                boolean distinct = false;
                if (args.length > 1) {
                    QueryAst.ConstExpr separatorExpr = (QueryAst.ConstExpr)args[1];
                    String string = separator = separatorExpr == null ? null : this.toSQLNoBrackets(separatorExpr);
                }
                if (args.length > 2) {
                    QueryAst.ConstExpr distinctExpr = (QueryAst.ConstExpr)args[2];
                    boolean bl = distinct = distinctExpr != null && (Boolean)distinctExpr.value != false;
                }
                if (separator == null) {
                    return "concat_ws('', " + (distinct ? "collect_set" : "collect_list") + "(CAST(" + column + " AS string)))";
                }
                return "concat_ws(" + separator + ", " + (distinct ? "collect_set" : "collect_list") + "(CAST(" + column + " AS string)))";
            }
        });
        this.addOperator(new QueryUtils.Function(this, QueryUtils.OperatorType.COLLECT_STRING_LIST, "collect_list", QueryUtils.Arity.UNARY){

            @Override
            public String apply(QueryAst.Expr[] args) {
                this.validateNumberOfParameters(args);
                String column = this.toSQLNoBrackets(args[0]);
                return "collect_list(CAST(" + column + " AS string))";
            }
        });
        this.addOperator(new QueryUtils.Function(this, QueryUtils.OperatorType.COLLECT_STRING_SET, "collect_set", QueryUtils.Arity.UNARY){

            @Override
            public String apply(QueryAst.Expr[] args) {
                this.validateNumberOfParameters(args);
                String column = this.toSQLNoBrackets(args[0]);
                return "collect_set(CAST(" + column + " AS string))";
            }
        });
        this.addOperator(new QueryUtils.Function(this, QueryUtils.OperatorType.ARRAY_TO_STRING, "concat_ws", QueryUtils.Arity.BINARY){

            @Override
            public String apply(QueryAst.Expr[] args) {
                this.validateNumberOfParameters(args);
                String array = this.toSQLNoBrackets(args[0]);
                QueryAst.ConstExpr separatorExpr = (QueryAst.ConstExpr)args[1];
                String separator = (String)separatorExpr.value;
                if (separator == null) {
                    return "concat_ws('', " + array + ")";
                }
                return "concat_ws(" + this.toSQLNoBrackets(separatorExpr) + ", " + array + ")";
            }
        });
        this.addOperator(new QueryUtils.Function(this, QueryUtils.OperatorType.FIRST_VALUE, "FIRST_VALUE", QueryUtils.Arity.BINARY){

            @Override
            public String apply(QueryAst.Expr[] args) {
                boolean ignoreNulls;
                this.validateNumberOfParameters(args);
                String column = this.toSQLNoBrackets(args[0]);
                QueryAst.ConstExpr ignoreNullsExpr = (QueryAst.ConstExpr)args[1];
                boolean bl = ignoreNulls = ignoreNullsExpr.value != null && (Boolean)ignoreNullsExpr.value != false;
                if (ignoreNulls) {
                    return "FIRST_VALUE(" + column + ", true)";
                }
                return "FIRST_VALUE(" + column + ")";
            }
        });
        this.addOperator(new QueryUtils.Function(this, QueryUtils.OperatorType.LAST_VALUE, "LAST_VALUE", QueryUtils.Arity.BINARY){

            @Override
            public String apply(QueryAst.Expr[] args) {
                boolean ignoreNulls;
                this.validateNumberOfParameters(args);
                String column = this.toSQLNoBrackets(args[0]);
                QueryAst.ConstExpr ignoreNullsExpr = (QueryAst.ConstExpr)args[1];
                boolean bl = ignoreNulls = ignoreNullsExpr.value != null && (Boolean)ignoreNullsExpr.value != false;
                if (ignoreNulls) {
                    return "LAST_VALUE(" + column + ", true)";
                }
                return "LAST_VALUE(" + column + ")";
            }
        });
        this.addOperator(new QueryUtils.Operator(this, QueryUtils.OperatorType.CONTAINS, "CONTAINS", QueryUtils.Arity.BINARY, GenericSQLDialect.SQLPriority.LIKE.priority){

            @Override
            public String apply(QueryAst.Expr[] args) {
                this.validateNumberOfParameters(args);
                String haystack = this.toSQLWithBracketsIfNeeded(args[0], GenericSQLDialect.SQLPriority.LIKE.priority);
                QueryAst.Expr[] concatArgs = new QueryAst.Expr[]{new QueryAst.ConstExpr("%"), new QueryAst.InlineExpr("REGEXP_REPLACE(" + this.toSQLNoBrackets(args[1]) + ", '([%_\\\\\\\\])', '\\\\\\\\$1')"), new QueryAst.ConstExpr("%")};
                String pattern = AbstractSparkSQLDialect.this.getOperator(QueryUtils.OperatorType.CONCAT).apply(concatArgs);
                return haystack + " LIKE " + pattern;
            }
        });
        this.addOperator(new QueryUtils.Operator(this, QueryUtils.OperatorType.STARTS_WITH, "STARTS_WITH", QueryUtils.Arity.BINARY, GenericSQLDialect.SQLPriority.LIKE.priority){

            @Override
            public String apply(QueryAst.Expr[] args) {
                this.validateNumberOfParameters(args);
                String haystack = this.toSQLWithBracketsIfNeeded(args[0], GenericSQLDialect.SQLPriority.LIKE.priority);
                QueryAst.Expr[] concatArgs = new QueryAst.Expr[]{new QueryAst.InlineExpr("REGEXP_REPLACE(" + this.toSQLNoBrackets(args[1]) + ", '([%_\\\\\\\\])', '\\\\\\\\$1')"), new QueryAst.ConstExpr("%")};
                String pattern = AbstractSparkSQLDialect.this.getOperator(QueryUtils.OperatorType.CONCAT).apply(concatArgs);
                return haystack + " LIKE " + pattern;
            }
        });
        this.addOperator(new QueryUtils.Function(this, QueryUtils.OperatorType.STRING_TO_TIMESTAMPTZ, "STRING_TO_TIMESTAMPTZ", QueryUtils.Arity.UNARY){

            @Override
            public String apply(QueryAst.Expr[] args) {
                this.validateNumberOfParameters(args);
                String ret = args[0] instanceof QueryAst.ConstExpr && ((QueryAst.ConstExpr)args[0]).value instanceof String ? AbstractSparkSQLDialect.this.quoteString((String)((QueryAst.ConstExpr)args[0]).value) : this.toSQLNoBrackets(args[0]);
                return "TO_TIMESTAMP(CONCAT(" + ret + ", '+00'))";
            }
        });
        this.addOperator(new QueryUtils.Function(this, QueryUtils.OperatorType.PERCENTILE_APPROX_AGG, QueryUtils.Arity.BINARY){

            @Override
            public String apply(QueryAst.Expr[] args) {
                this.validateNumberOfParameters(args);
                String column = this.toSQLNoBrackets(args[0]);
                double percentile = this.getParamAs(args[1], Double.class);
                return "percentile_approx(" + column + "," + percentile + ")";
            }
        });
        this.addOperator(new QueryUtils.Operator(this, QueryUtils.OperatorType.DIV, "/", QueryUtils.Arity.BINARY, GenericSQLDialect.SQLPriority.TIMES.priority, false){

            @Override
            public String apply(QueryAst.Expr[] args) {
                this.validateNumberOfParameters(args);
                return AbstractSparkSQLDialect.this.getDivisionClause(this.toSQLWithBracketsIfNeeded(args[0], GenericSQLDialect.SQLPriority.TIMES.priority), this.toSQLWithBracketsIfNeeded(args[1], GenericSQLDialect.SQLPriority.EQ.priority));
            }
        });
        this.addOperator(new QueryUtils.Operator(this, QueryUtils.OperatorType.FLOAT_DIV, "/", QueryUtils.Arity.BINARY, GenericSQLDialect.SQLPriority.TIMES.priority, false){

            @Override
            public String apply(QueryAst.Expr[] args) {
                this.validateNumberOfParameters(args);
                QueryAst.Expr castedArg = new ExpressionBuilder.ExpressionBuilderFactory().expr((QueryAst.Expr)args[0]).cast((Object[])new Object[]{Type.DOUBLE}).expr;
                return AbstractSparkSQLDialect.this.getDivisionClause(this.toSQLWithBracketsIfNeeded(castedArg, GenericSQLDialect.SQLPriority.TIMES.priority), this.toSQLWithBracketsIfNeeded(args[1], GenericSQLDialect.SQLPriority.EQ.priority));
            }
        });
        this.addOperator(new QueryUtils.Function(this, QueryUtils.OperatorType.LEAST, "LEAST", QueryUtils.Arity.NARY){

            @Override
            public boolean checkNumberOfParameters(int nArgs) {
                return nArgs > 0;
            }

            @Override
            public String apply(QueryAst.Expr[] args) {
                this.validateNumberOfParameters(args);
                List funcArgs = Lists.newArrayList((Object[])args).stream().map(this::toSQLNoBrackets).collect(Collectors.toList());
                String least = "LEAST(" + funcArgs.stream().collect(Collectors.joining(", ")) + ")";
                List nullableArgs = Lists.newArrayList((Object[])args).stream().filter(a -> !(a instanceof QueryAst.ConstExpr) || ((QueryAst.ConstExpr)a).value == null).collect(Collectors.toList());
                if (nullableArgs.isEmpty()) {
                    return least;
                }
                return "IF(" + nullableArgs.stream().map(a -> "(" + this.toSQLNoBrackets((QueryAst.Expr)a) + ") IS NULL").collect(Collectors.joining(" OR ")) + ", NULL, " + least + ")";
            }
        });
        this.addOperator(new QueryUtils.Function(this, QueryUtils.OperatorType.GREATEST, "GREATEST", QueryUtils.Arity.NARY){

            @Override
            public boolean checkNumberOfParameters(int nArgs) {
                return nArgs > 0;
            }

            @Override
            public String apply(QueryAst.Expr[] args) {
                this.validateNumberOfParameters(args);
                List funcArgs = Lists.newArrayList((Object[])args).stream().map(this::toSQLNoBrackets).collect(Collectors.toList());
                String greatest = "GREATEST(" + funcArgs.stream().collect(Collectors.joining(", ")) + ")";
                List nullableArgs = Lists.newArrayList((Object[])args).stream().filter(a -> !(a instanceof QueryAst.ConstExpr) || ((QueryAst.ConstExpr)a).value == null).collect(Collectors.toList());
                if (nullableArgs.isEmpty()) {
                    return greatest;
                }
                return "IF(" + nullableArgs.stream().map(a -> "(" + this.toSQLNoBrackets((QueryAst.Expr)a) + ") IS NULL").collect(Collectors.joining(" OR ")) + ", NULL, " + greatest + ")";
            }
        });
        this.addOperator(new QueryUtils.Operator(this, QueryUtils.OperatorType.ISTRUE, null, QueryUtils.Arity.UNARY, GenericSQLDialect.SQLPriority.AND.priority){

            @Override
            public String apply(QueryAst.Expr[] args) {
                String x = this.toSQLWithBracketsIfNeeded(args[0], GenericSQLDialect.SQLPriority.EQ.priority);
                return "(" + x + " = TRUE)";
            }
        });
        this.addOperator(new QueryUtils.Operator(this, QueryUtils.OperatorType.ISFALSE, null, QueryUtils.Arity.UNARY, GenericSQLDialect.SQLPriority.AND.priority){

            @Override
            public String apply(QueryAst.Expr[] args) {
                String x = this.toSQLWithBracketsIfNeeded(args[0], GenericSQLDialect.SQLPriority.EQ.priority);
                return "(" + x + " = FALSE)";
            }
        });
        this.addGenericFunction(QueryUtils.OperatorType.COSH, "COSH", QueryUtils.Arity.UNARY);
        this.addGenericFunction(QueryUtils.OperatorType.SINH, "SINH", QueryUtils.Arity.UNARY);
        this.addGenericFunction(QueryUtils.OperatorType.TANH, "TANH", QueryUtils.Arity.UNARY);
        this.addGenericFunction(QueryUtils.OperatorType.DEGREES, "DEGREES", QueryUtils.Arity.UNARY);
        this.addGenericFunction(QueryUtils.OperatorType.RADIANS, "RADIANS", QueryUtils.Arity.UNARY);
        this.addGenericFunction(QueryUtils.OperatorType.ATAN2, "ATAN2", QueryUtils.Arity.BINARY);
    }

    @Override
    public String getDivisionClause(String numerator, String denominator) {
        Double den = DKUDoubles.tryParse((String)denominator);
        if (den != null) {
            return numerator + " / " + (den == 0.0 ? "NULL" : denominator);
        }
        return numerator + " / (CASE WHEN " + denominator + " = 0 THEN NULL ELSE " + denominator + " END)";
    }

    @Override
    public int getIdentifiersMaxLength() {
        return -1;
    }

    @Override
    public String quoteDate(String str) {
        return this.getOperator(QueryUtils.OperatorType.STRING_TO_TIMESTAMPTZ).apply(new QueryAst.Expr[]{new QueryAst.ConstExpr(str)});
    }

    public String quoteDateAsDeltaPartitionValue(String str) {
        return "TO_UTC_TIMESTAMP(" + this.quoteString(str) + ", 'UTC')";
    }

    @Override
    public String quoteDateOnly(String str) {
        return "TO_DATE(" + this.quoteString(str) + ")";
    }

    @Override
    public String quoteDatetimeNoTz(String str) {
        if (SparkJobHelper.isSparkAtLeast(3, 4)) {
            return "TO_TIMESTAMP_NTZ(" + this.quoteString(str) + ")";
        }
        if (str != null && TZ_AWARE_TIMESTAMP_STRING.matcher(str).matches()) {
            return "TO_TIMESTAMP(" + this.quoteString(str) + ")";
        }
        return "TO_TIMESTAMP(" + this.quoteString(str + "Z") + ")";
    }

    @Override
    public String datePartExpression(String expr, DatePart part) {
        return this.temporalPartExpression(expr, part, true, false);
    }

    @Override
    public String dateonlyPartExpression(String expr, DatePart part) {
        return this.temporalPartExpression(expr, part, false, true);
    }

    @Override
    public String datetimenotzPartExpression(String expr, DatePart part) {
        return this.temporalPartExpression(expr, part, false, false);
    }

    private String temporalPartExpression(String expr, DatePart part, boolean tzAware, boolean isDateOnly) {
        switch (part) {
            case DAY_OF_MONTH: {
                return "DAYOFMONTH(" + expr + ")";
            }
            case DAY_OF_WEEK: {
                return "(1 + " + this.dayOfWeekExpression(expr) + ")";
            }
            case HOUR_OF_DAY: {
                if (tzAware) {
                    return "CAST(MOD(UNIX_SECONDS(" + expr + "), 86400) / 3600 as INT)";
                }
                return "HOUR(" + expr + ")";
            }
            case MINUTE_OF_HOUR: {
                return "MINUTE(" + expr + ")";
            }
            case SECOND_OF_MINUTE: {
                return "SECOND(" + expr + ")";
            }
            case MILLISECOND_OF_SECOND: {
                throw new NotImplementedException("Extracting milliseconds of a timestamp is not supported in SparkSQL");
            }
            case MONTH_OF_YEAR: {
                return "MONTH(" + expr + ")";
            }
            case SECOND_FROM_EPOCH: {
                if (isDateOnly) {
                    return "(CAST(UNIX_DATE(" + expr + ") AS BIGINT) * 86400)";
                }
                return "TO_UNIX_TIMESTAMP(" + expr + ")";
            }
            case MILLIS_FROM_EPOCH: {
                if (isDateOnly) {
                    return "(CAST(UNIX_DATE(" + expr + ") AS BIGINT) * 86400 * 1000)";
                }
                return "(TO_UNIX_TIMESTAMP(" + expr + ") * 1000)";
            }
            case WEEK_OF_YEAR: {
                return "WEEKOFYEAR(" + expr + ")";
            }
            case QUARTER_OF_YEAR: {
                return "(((MONTH(" + expr + ") - 1)/3)+1)";
            }
            case YEAR: {
                return "YEAR(" + expr + ")";
            }
        }
        throw new NotImplementedException(String.format("Date part '%s' is not supported on SparkSQL", part));
    }

    @Override
    public String dateTrunc(String inputDateExpression, DateRounding rounding) {
        return this.temporalTrunc(inputDateExpression, rounding, true, false);
    }

    @Override
    public String dateonlyTrunc(String inputDateExpression, DateRounding rounding) {
        return this.temporalTrunc(inputDateExpression, rounding, false, true);
    }

    @Override
    public String datetimenotzTrunc(String inputDateExpression, DateRounding rounding) {
        return this.temporalTrunc(inputDateExpression, rounding, false, false);
    }

    private String temporalTrunc(String inputDateExpression, DateRounding rounding, boolean tzAware, boolean isDateOnly) {
        String secondFromEpoch = this.temporalPartExpression(inputDateExpression, DatePart.SECOND_FROM_EPOCH, tzAware, isDateOnly);
        switch (rounding) {
            case YEAR: 
            case MONTH: 
            case DAY: 
            case HOUR: 
            case MINUTE: 
            case SECOND: {
                return AbstractSparkSQLDialect.temporalTruncFromSeconds(secondFromEpoch, rounding, tzAware, isDateOnly);
            }
            case WEEK: {
                String weekStart = "TRUNC(" + inputDateExpression + ", 'week')";
                if (isDateOnly) {
                    return weekStart;
                }
                String weekStartSeconds = this.dateonlyPartExpression(weekStart, DatePart.SECOND_FROM_EPOCH);
                return "TO_TIMESTAMP(" + weekStartSeconds + ")";
            }
            case QUARTER: {
                String yearPartStr = "CAST(" + this.temporalPartExpression(inputDateExpression, DatePart.YEAR, tzAware, isDateOnly) + " AS " + this.typeNameForCastAsString() + ")";
                String truncatedMonthNumber = "(CEIL(" + this.temporalPartExpression(inputDateExpression, DatePart.MONTH_OF_YEAR, tzAware, isDateOnly) + "/3)-1)*3+1";
                String monthZeroPrefix = "CASE WHEN (" + truncatedMonthNumber + "<10) THEN '0' ELSE '' END";
                String truncatedMonthStr = "CAST(" + truncatedMonthNumber + " AS " + this.typeNameForCastAsString() + ")";
                String concatenated = "CONCAT(" + yearPartStr + ",'-'," + monthZeroPrefix + "," + truncatedMonthStr + ",'-01 00:00:00 +00:00')";
                return "CAST(" + concatenated + " AS TIMESTAMP)";
            }
        }
        throw new UnsupportedOperation("Date truncation is not implemented for '" + String.valueOf(rounding) + "'");
    }

    private static String temporalTruncFromSeconds(String secondFromEpoch, DateRounding rounding, boolean tzAware, boolean isDateOnly) {
        String format = switch (rounding) {
            case DateRounding.YEAR -> "yyyy";
            case DateRounding.MONTH -> "yyyy-MM";
            case DateRounding.DAY -> "yyyy-MM-dd";
            case DateRounding.HOUR -> "yyyy-MM-dd HH";
            case DateRounding.MINUTE -> "yyyy-MM-dd HH:mm";
            case DateRounding.SECOND -> "yyyy-MM-dd HH:mm:ss";
            default -> throw new UnsupportedOperation("Date truncation is not implemented for '" + String.valueOf(rounding) + "'");
        };
        if (isDateOnly) {
            if (format.length() > 10) {
                format = format.substring(0, 10);
            }
            String fixedPart = "0001-01-01".substring(format.length());
            String asTruncatedFormattedDate = "FROM_UNIXTIME(" + secondFromEpoch + ", '" + format + fixedPart + "')";
            return "CAST(" + asTruncatedFormattedDate + " AS DATE)";
        }
        String fixedPart = "0001-01-01 00:00:00 Z".substring(format.length());
        String asTruncatedFormattedDate = "FROM_UNIXTIME(" + secondFromEpoch + ", '" + format + fixedPart + "')";
        return "CAST(" + asTruncatedFormattedDate + " AS TIMESTAMP)";
    }

    @Override
    public boolean supportsInDatabaseCharts() {
        return true;
    }

    @Override
    public Map<SQLAggregateType, SQLAggregateAbility> getAggregationAbilities() {
        Map<SQLAggregateType, SQLAggregateAbility> abilities = super.getAggregationAbilities();
        abilities.put(SQLAggregateType.CONCAT, new SQLAggregateAbility(true, true, true, true));
        abilities.put(SQLAggregateType.CONCAT_DISTINCT, new SQLAggregateAbility(true, true, true, true));
        abilities.put(SQLAggregateType.FIRST_NOTNULL, new SQLAggregateAbility(true, true, true, true));
        abilities.put(SQLAggregateType.LAST_NOTNULL, new SQLAggregateAbility(true, true, true, true));
        return abilities;
    }
}

