From 05bb28ce7a6f8e1152293fc8d080a94f9e4dc480 Mon Sep 17 00:00:00 2001 From: Aman Date: Sun, 7 Dec 2025 12:33:03 +0530 Subject: [PATCH] Add Spark support for TLP Oracle --- .github/workflows/main.yml | 38 ++ pom.xml | 9 +- src/sqlancer/Main.java | 2 + src/sqlancer/spark/SparkErrors.java | 67 ++++ src/sqlancer/spark/SparkGlobalState.java | 11 + src/sqlancer/spark/SparkOptions.java | 43 +++ src/sqlancer/spark/SparkProvider.java | 122 +++++++ src/sqlancer/spark/SparkSchema.java | 114 ++++++ src/sqlancer/spark/SparkToStringVisitor.java | 120 +++++++ .../spark/ast/SparkBetweenOperation.java | 10 + .../spark/ast/SparkBinaryOperation.java | 11 + .../spark/ast/SparkCaseOperation.java | 13 + .../spark/ast/SparkCastOperation.java | 25 ++ .../spark/ast/SparkColumnReference.java | 11 + src/sqlancer/spark/ast/SparkConstant.java | 194 ++++++++++ src/sqlancer/spark/ast/SparkExpression.java | 7 + src/sqlancer/spark/ast/SparkFunction.java | 13 + src/sqlancer/spark/ast/SparkInOperation.java | 12 + src/sqlancer/spark/ast/SparkJoin.java | 46 +++ src/sqlancer/spark/ast/SparkOrderingTerm.java | 10 + src/sqlancer/spark/ast/SparkSelect.java | 42 +++ .../spark/ast/SparkTableReference.java | 13 + .../spark/ast/SparkUnaryPostfixOperation.java | 13 + .../spark/ast/SparkUnaryPrefixOperation.java | 12 + .../spark/gen/SparkExpressionGenerator.java | 336 ++++++++++++++++++ .../spark/gen/SparkInsertGenerator.java | 47 +++ .../spark/gen/SparkTableGenerator.java | 100 ++++++ test/sqlancer/dbms/TestConfig.java | 1 + test/sqlancer/dbms/TestSparkTLP.java | 20 ++ 29 files changed, 1460 insertions(+), 2 deletions(-) create mode 100644 src/sqlancer/spark/SparkErrors.java create mode 100644 src/sqlancer/spark/SparkGlobalState.java create mode 100644 src/sqlancer/spark/SparkOptions.java create mode 100644 src/sqlancer/spark/SparkProvider.java create mode 100644 src/sqlancer/spark/SparkSchema.java create mode 100644 src/sqlancer/spark/SparkToStringVisitor.java create mode 100644 src/sqlancer/spark/ast/SparkBetweenOperation.java create mode 100644 src/sqlancer/spark/ast/SparkBinaryOperation.java create mode 100644 src/sqlancer/spark/ast/SparkCaseOperation.java create mode 100644 src/sqlancer/spark/ast/SparkCastOperation.java create mode 100644 src/sqlancer/spark/ast/SparkColumnReference.java create mode 100644 src/sqlancer/spark/ast/SparkConstant.java create mode 100644 src/sqlancer/spark/ast/SparkExpression.java create mode 100644 src/sqlancer/spark/ast/SparkFunction.java create mode 100644 src/sqlancer/spark/ast/SparkInOperation.java create mode 100644 src/sqlancer/spark/ast/SparkJoin.java create mode 100644 src/sqlancer/spark/ast/SparkOrderingTerm.java create mode 100644 src/sqlancer/spark/ast/SparkSelect.java create mode 100644 src/sqlancer/spark/ast/SparkTableReference.java create mode 100644 src/sqlancer/spark/ast/SparkUnaryPostfixOperation.java create mode 100644 src/sqlancer/spark/ast/SparkUnaryPrefixOperation.java create mode 100644 src/sqlancer/spark/gen/SparkExpressionGenerator.java create mode 100644 src/sqlancer/spark/gen/SparkInsertGenerator.java create mode 100644 src/sqlancer/spark/gen/SparkTableGenerator.java create mode 100644 test/sqlancer/dbms/TestSparkTLP.java diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 5c53192aa..41bd92d3c 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -333,6 +333,44 @@ jobs: - name: Run Tests run: HIVE_AVAILABLE=true mvn -Dtest=TestHiveTLP test + spark: + name: DBMS Tests (Spark) + runs-on: ubuntu-latest + + services: + spark: + image: apache/spark:3.5.1 + ports: + - 10000:10000 + + command: >- + /opt/spark/bin/spark-submit + --class org.apache.spark.sql.hive.thriftserver.HiveThriftServer2 + --name "Thrift JDBC/ODBC Server" + --master local[*] + --driver-memory 4g + --conf spark.hive.server2.thrift.port=10000 + --conf spark.sql.warehouse.dir=/tmp/spark-warehouse + spark-internal + + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Set up JDK 11 + uses: actions/setup-java@v3 + with: + distribution: 'temurin' + java-version: '11' + cache: 'maven' + + - name: Build SQLancer + run: mvn -B package -DskipTests=true + + - name: Run Tests + run: SPARK_AVAILABLE=true mvn -Dtest=TestSparkTLP test + hsqldb: name: DBMS Tests (HSQLB) runs-on: ubuntu-latest diff --git a/pom.xml b/pom.xml index 2037b71ce..c7a38c9aa 100644 --- a/pom.xml +++ b/pom.xml @@ -329,7 +329,7 @@ org.slf4j - slf4j-simple + slf4j-simple 2.0.6 @@ -381,7 +381,7 @@ org.apache.hive hive-jdbc - 4.0.1 + 3.1.2 org.apache.hive @@ -393,6 +393,11 @@ hive-cli 4.0.1 + + org.apache.hadoop + hadoop-common + 3.2.4 + diff --git a/src/sqlancer/Main.java b/src/sqlancer/Main.java index 1f2642f95..f778bd7da 100644 --- a/src/sqlancer/Main.java +++ b/src/sqlancer/Main.java @@ -48,6 +48,7 @@ import sqlancer.tidb.TiDBProvider; import sqlancer.yugabyte.ycql.YCQLProvider; import sqlancer.yugabyte.ysql.YSQLProvider; +import sqlancer.spark.SparkProvider; public final class Main { @@ -756,6 +757,7 @@ private static void checkForIssue799(List> providers) providers.add(new DuckDBProvider()); providers.add(new H2Provider()); providers.add(new HiveProvider()); + providers.add(new SparkProvider()); providers.add(new HSQLDBProvider()); providers.add(new MariaDBProvider()); providers.add(new MaterializeProvider()); diff --git a/src/sqlancer/spark/SparkErrors.java b/src/sqlancer/spark/SparkErrors.java new file mode 100644 index 000000000..97c8056a3 --- /dev/null +++ b/src/sqlancer/spark/SparkErrors.java @@ -0,0 +1,67 @@ +package sqlancer.spark; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.common.query.ExpectedErrors; + +public final class SparkErrors { + + private SparkErrors() { + } + + public static List getExpressionErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("cannot resolve"); + errors.add("AnalysisException"); + errors.add("data type mismatch"); + errors.add("undefined function"); + errors.add("mismatched input"); + errors.add("due to data type mismatch"); + + // --- Invalid Literals + errors.add("The value of the typed literal"); + + errors.add("DATATYPE_MISMATCH"); + errors.add("cannot be cast to"); + + errors.add("Overflow"); + errors.add("Divide by zero"); // Common if spark.sql.ansi.enabled is true + errors.add("division by zero"); + + // --- Group By / Aggregation errors --- + errors.add("grouping expressions"); + errors.add("expression is neither present in the group by"); + errors.add("is not a valid grouping expression"); + errors.add("is not contained in either an aggregate function or the GROUP BY clause"); + errors.add("PARSE_SYNTAX_ERROR"); + errors.add("Syntax error"); + + return errors; + } + + public static void addExpressionErrors(ExpectedErrors errors) { + errors.addAll(getExpressionErrors()); + } + + public static List getInsertErrors() { + ArrayList errors = new ArrayList<>(); + + errors.add("not enough data columns"); + errors.add("cannot write to"); + errors.add("incompatible types"); + errors.add("too many data columns"); + errors.add("cannot be cast to"); + errors.add("Error running query"); + errors.add("The value of the typed literal"); + errors.add("Cannot safely cast"); // Found in logs: Decimal -> Date + errors.add("AnalysisException"); // Spark throws this for almost all insert failures + + return errors; + } + + public static void addInsertErrors(ExpectedErrors errors) { + errors.addAll(getInsertErrors()); + } +} \ No newline at end of file diff --git a/src/sqlancer/spark/SparkGlobalState.java b/src/sqlancer/spark/SparkGlobalState.java new file mode 100644 index 000000000..e79826332 --- /dev/null +++ b/src/sqlancer/spark/SparkGlobalState.java @@ -0,0 +1,11 @@ +package sqlancer.spark; + +import sqlancer.SQLGlobalState; + +public class SparkGlobalState extends SQLGlobalState { + + @Override + protected SparkSchema readSchema() throws Exception { + return SparkSchema.fromConnection(getConnection(), getDatabaseName()); + } +} \ No newline at end of file diff --git a/src/sqlancer/spark/SparkOptions.java b/src/sqlancer/spark/SparkOptions.java new file mode 100644 index 000000000..c9422a910 --- /dev/null +++ b/src/sqlancer/spark/SparkOptions.java @@ -0,0 +1,43 @@ +package sqlancer.spark; + +import java.sql.SQLException; +import java.util.Arrays; +import java.util.List; + +import com.beust.jcommander.Parameter; +import com.beust.jcommander.Parameters; + +import sqlancer.DBMSSpecificOptions; +import sqlancer.OracleFactory; +import sqlancer.common.oracle.TLPWhereOracle; +import sqlancer.common.oracle.TestOracle; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.spark.gen.SparkExpressionGenerator; + +@Parameters(separators = "=", commandDescription = "Spark SQL (default port: " + SparkOptions.DEFAULT_PORT + + ", default host: " + SparkOptions.DEFAULT_HOST + ")") +public class SparkOptions implements DBMSSpecificOptions { + public static final String DEFAULT_HOST = "localhost"; + public static final int DEFAULT_PORT = 10000; + + @Parameter(names = "--oracle") + public List oracle = Arrays.asList(SparkOracleFactory.TLPWhere); + + public enum SparkOracleFactory implements OracleFactory { + TLPWhere { + @Override + public TestOracle create(SparkGlobalState globalState) throws SQLException { + SparkExpressionGenerator gen = new SparkExpressionGenerator(globalState); + ExpectedErrors expectedErrors = ExpectedErrors.newErrors().with(SparkErrors.getExpressionErrors()) + .build(); + + return new TLPWhereOracle<>(globalState, gen, expectedErrors); + } + }; + } + + @Override + public List getTestOracleFactory() { + return oracle; + } +} \ No newline at end of file diff --git a/src/sqlancer/spark/SparkProvider.java b/src/sqlancer/spark/SparkProvider.java new file mode 100644 index 000000000..f53ca10a8 --- /dev/null +++ b/src/sqlancer/spark/SparkProvider.java @@ -0,0 +1,122 @@ +package sqlancer.spark; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.sql.Statement; + +import com.google.auto.service.AutoService; + +import sqlancer.AbstractAction; +import sqlancer.DatabaseProvider; +import sqlancer.IgnoreMeException; +import sqlancer.MainOptions; +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.SQLProviderAdapter; +import sqlancer.StatementExecutor; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.common.query.SQLQueryProvider; +import sqlancer.spark.gen.SparkInsertGenerator; +import sqlancer.spark.gen.SparkTableGenerator; + +@AutoService(DatabaseProvider.class) +public class SparkProvider extends SQLProviderAdapter { + + public SparkProvider() { + super(SparkGlobalState.class, SparkOptions.class); + } + + public enum Action implements AbstractAction { + INSERT(SparkInsertGenerator::getQuery); // You will need to create this class + + private final SQLQueryProvider sqlQueryProvider; + + Action(SQLQueryProvider sqlQueryProvider) { + this.sqlQueryProvider = sqlQueryProvider; + } + + @Override + public SQLQueryAdapter getQuery(SparkGlobalState state) throws Exception { + return sqlQueryProvider.getQuery(state); + } + } + + private static int mapActions(SparkGlobalState globalState, Action a) { + Randomly r = globalState.getRandomly(); + switch (a) { + case INSERT: + return r.getInteger(0, globalState.getOptions().getMaxNumberInserts()); + default: + throw new AssertionError(a); + } + } + + @Override + public void generateDatabase(SparkGlobalState globalState) throws Exception { + for (int i = 0; i < Randomly.fromOptions(1, 2); i++) { + boolean success; + do { + String tableName = globalState.getSchema().getFreeTableName(); + SQLQueryAdapter qt = SparkTableGenerator.generate(globalState, tableName); + success = globalState.executeStatement(qt); + } while (!success); + } + + if (globalState.getSchema().getDatabaseTables().isEmpty()) { + throw new IgnoreMeException(); + } + + StatementExecutor se = new StatementExecutor<>(globalState, Action.values(), + SparkProvider::mapActions, (q) -> { + if (globalState.getSchema().getDatabaseTables().isEmpty()) { + throw new IgnoreMeException(); + } + }); + se.executeStatements(); + } + + @Override + public SQLConnection createDatabase(SparkGlobalState globalState) throws SQLException { + String username = globalState.getOptions().getUserName(); + String password = globalState.getOptions().getPassword(); + String host = globalState.getOptions().getHost(); + int port = globalState.getOptions().getPort(); + + if (host == null) { + host = SparkOptions.DEFAULT_HOST; + } + if (port == MainOptions.NO_SET_PORT) { + port = SparkOptions.DEFAULT_PORT; + } + + String databaseName = globalState.getDatabaseName(); + + // Spark uses the Hive driver for JDBC usually + String url = String.format("jdbc:hive2://%s:%d/%s", host, port, "default"); + + // Connect to default to create the fuzzing DB + Connection con = DriverManager.getConnection(url, username, password); + try (Statement s = con.createStatement()) { + s.execute("DROP DATABASE IF EXISTS " + databaseName + " CASCADE"); + } + try (Statement s = con.createStatement()) { + s.execute("CREATE DATABASE " + databaseName); + } + con.close(); + + // Connect to the specific fuzzing DB + con = DriverManager.getConnection(String.format("jdbc:hive2://%s:%d/%s", host, port, databaseName), username, + password); + try (Statement s = con.createStatement()) { + // This allows casting things like BOOLEAN to DATE/TIMESTAMP, which the generator loves to do. + s.execute("SET spark.sql.ansi.enabled=false"); + } + return new SQLConnection(con); + } + + @Override + public String getDBMSName() { + return "spark"; + } +} \ No newline at end of file diff --git a/src/sqlancer/spark/SparkSchema.java b/src/sqlancer/spark/SparkSchema.java new file mode 100644 index 000000000..849652b19 --- /dev/null +++ b/src/sqlancer/spark/SparkSchema.java @@ -0,0 +1,114 @@ +package sqlancer.spark; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.SQLConnection; +import sqlancer.common.schema.AbstractRelationalTable; +import sqlancer.common.schema.AbstractSchema; +import sqlancer.common.schema.AbstractTableColumn; +import sqlancer.common.schema.AbstractTables; +import sqlancer.common.schema.TableIndex; +import sqlancer.spark.SparkSchema.SparkTable; + +public class SparkSchema extends AbstractSchema { + + public enum SparkDataType { + STRING, INTEGER, DOUBLE, BOOLEAN, TIMESTAMP, DATE; + + public static SparkDataType getRandomType() { + return Randomly.fromList(Arrays.asList(values())); + } + } + + public static class SparkColumn extends AbstractTableColumn { + public SparkColumn(String name, SparkTable table, SparkDataType type) { + super(name, table, type); + } + } + + public static class SparkTables extends AbstractTables { + public SparkTables(List tables) { + super(tables); + } + } + + public static class SparkTable extends AbstractRelationalTable { + public SparkTable(String name, List columns, boolean isView) { + super(name, columns, Collections.emptyList(), isView); + } + } + + public SparkSchema(List databaseTables) { + super(databaseTables); + } + + public static SparkSchema fromConnection(SQLConnection con, String databaseName) throws SQLException { + List databaseTables = new ArrayList<>(); + List tableNames = getTableNames(con); + for (String tableName : tableNames) { + List databaseColumns = getTableColumns(con, tableName); + boolean isView = tableName.toLowerCase().startsWith("v"); + SparkTable t = new SparkTable(tableName, databaseColumns, isView); + for (SparkColumn c : databaseColumns) { + c.setTable(t); + } + databaseTables.add(t); + } + return new SparkSchema(databaseTables); + } + + private static List getTableNames(SQLConnection con) throws SQLException { + List tableNames = new ArrayList<>(); + try (Statement s = con.createStatement()) { + ResultSet tableRs = s.executeQuery("SHOW TABLES"); + while (tableRs.next()) { + // Spark SHOW TABLES output: database, tableName, isTemporary + String tableName = tableRs.getString("tableName"); + tableNames.add(tableName); + } + } + return tableNames; + } + + private static List getTableColumns(SQLConnection con, String tableName) throws SQLException { + List columns = new ArrayList<>(); + try (Statement s = con.createStatement()) { + try (ResultSet rs = s.executeQuery(String.format("DESCRIBE %s", tableName))) { + while (rs.next()) { + String columnName = rs.getString("col_name"); + String dataType = rs.getString("data_type"); + // Filter out Spark partition info or comments usually at bottom of describe + if (columnName.startsWith("#") || columnName.isEmpty()) + continue; + + columns.add(new SparkColumn(columnName, null, getColumnType(dataType))); + } + } + } + return columns; + } + + private static SparkDataType getColumnType(String typeString) { + String upper = typeString.toUpperCase(); + if (upper.startsWith("STRING") || upper.startsWith("VARCHAR") || upper.startsWith("CHAR")) + return SparkDataType.STRING; + if (upper.startsWith("INT") || upper.startsWith("BIGINT") || upper.startsWith("SMALLINT")) + return SparkDataType.INTEGER; + if (upper.startsWith("DOUBLE") || upper.startsWith("FLOAT") || upper.startsWith("DECIMAL")) + return SparkDataType.DOUBLE; + if (upper.startsWith("BOOLEAN")) + return SparkDataType.BOOLEAN; + if (upper.startsWith("TIMESTAMP")) + return SparkDataType.TIMESTAMP; + if (upper.startsWith("DATE")) + return SparkDataType.DATE; + return SparkDataType.STRING; // Fallback + } +} \ No newline at end of file diff --git a/src/sqlancer/spark/SparkToStringVisitor.java b/src/sqlancer/spark/SparkToStringVisitor.java new file mode 100644 index 000000000..91f47e32c --- /dev/null +++ b/src/sqlancer/spark/SparkToStringVisitor.java @@ -0,0 +1,120 @@ +package sqlancer.spark; + +import sqlancer.common.ast.newast.NewToStringVisitor; +import sqlancer.common.ast.newast.TableReferenceNode; +import sqlancer.spark.ast.SparkCastOperation; +import sqlancer.spark.ast.SparkConstant; +import sqlancer.spark.ast.SparkExpression; +import sqlancer.spark.ast.SparkJoin; +import sqlancer.spark.ast.SparkSelect; + +public class SparkToStringVisitor extends NewToStringVisitor { + + @Override + public void visitSpecific(SparkExpression expr) { + if (expr instanceof SparkConstant) { + visit((SparkConstant) expr); + } else if (expr instanceof SparkSelect) { + visit((SparkSelect) expr); + } else if (expr instanceof SparkJoin) { + visit((SparkJoin) expr); + } else if (expr instanceof SparkCastOperation) { + visit((SparkCastOperation) expr); + } else { + throw new AssertionError(expr.getClass()); + } + } + + private void visit(SparkConstant constant) { + sb.append(constant.toString()); + } + + private void visit(SparkSelect select) { + sb.append("SELECT "); + if (select.isDistinct()) { + sb.append("DISTINCT "); + } + visit(select.getFetchColumns()); + sb.append(" FROM "); + visit(select.getFromList()); + if (!select.getFromList().isEmpty() && !select.getJoinList().isEmpty()) { + sb.append(", "); + } + if (!select.getJoinList().isEmpty()) { + visit(select.getJoinList()); + } + if (select.getWhereClause() != null) { + sb.append(" WHERE "); + visit(select.getWhereClause()); + } + if (!select.getGroupByExpressions().isEmpty()) { + sb.append(" GROUP BY "); + visit(select.getGroupByExpressions()); + } + if (select.getHavingClause() != null) { + sb.append(" HAVING "); + visit(select.getHavingClause()); + } + if (!select.getOrderByClauses().isEmpty()) { + sb.append(" ORDER BY "); + visit(select.getOrderByClauses()); + } + if (select.getLimitClause() != null) { + sb.append(" LIMIT "); + visit(select.getLimitClause()); + } + // Spark supports OFFSET, though strictly usually with LIMIT or in newer versions + if (select.getOffsetClause() != null) { + sb.append(" OFFSET "); + visit(select.getOffsetClause()); + } + } + + private void visit(SparkJoin join) { + switch (join.getJoinType()) { + case INNER: + sb.append(" INNER JOIN "); + break; + case LEFT_OUTER: + sb.append(" LEFT JOIN "); + break; + case RIGHT_OUTER: + sb.append(" RIGHT JOIN "); + break; + case FULL_OUTER: + sb.append(" FULL JOIN "); + break; + case LEFT_SEMI: + sb.append(" LEFT SEMI JOIN "); + break; + // Spark also supports LEFT ANTI, which Hive might lack in some older versions + case LEFT_ANTI: + sb.append(" LEFT ANTI JOIN "); + break; + case CROSS: + sb.append(" CROSS JOIN "); + break; + default: + throw new UnsupportedOperationException("Join type not supported in Spark visitor: " + join.getJoinType()); + } + visit((TableReferenceNode) join.getRightTable()); + if (join.getOnClause() != null) { + sb.append(" ON "); + visit(join.getOnClause()); + } + } + + private void visit(SparkCastOperation cast) { + sb.append("CAST("); + visit(cast.getExpression()); + sb.append(" AS "); + sb.append(cast.getType()); + sb.append(")"); + } + + public static String asString(SparkExpression expr) { + SparkToStringVisitor visitor = new SparkToStringVisitor(); + visitor.visit(expr); + return visitor.get(); + } +} \ No newline at end of file diff --git a/src/sqlancer/spark/ast/SparkBetweenOperation.java b/src/sqlancer/spark/ast/SparkBetweenOperation.java new file mode 100644 index 000000000..f229c1c7c --- /dev/null +++ b/src/sqlancer/spark/ast/SparkBetweenOperation.java @@ -0,0 +1,10 @@ +package sqlancer.spark.ast; + +import sqlancer.common.ast.newast.NewBetweenOperatorNode; + +public class SparkBetweenOperation extends NewBetweenOperatorNode implements SparkExpression { + + public SparkBetweenOperation(SparkExpression left, SparkExpression middle, SparkExpression right, boolean isTrue) { + super(left, middle, right, isTrue); + } +} \ No newline at end of file diff --git a/src/sqlancer/spark/ast/SparkBinaryOperation.java b/src/sqlancer/spark/ast/SparkBinaryOperation.java new file mode 100644 index 000000000..04af0ec4c --- /dev/null +++ b/src/sqlancer/spark/ast/SparkBinaryOperation.java @@ -0,0 +1,11 @@ +package sqlancer.spark.ast; + +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.common.ast.newast.NewBinaryOperatorNode; + +public class SparkBinaryOperation extends NewBinaryOperatorNode implements SparkExpression { + + public SparkBinaryOperation(SparkExpression left, SparkExpression right, Operator op) { + super(left, right, op); + } +} \ No newline at end of file diff --git a/src/sqlancer/spark/ast/SparkCaseOperation.java b/src/sqlancer/spark/ast/SparkCaseOperation.java new file mode 100644 index 000000000..fb1ee0cd8 --- /dev/null +++ b/src/sqlancer/spark/ast/SparkCaseOperation.java @@ -0,0 +1,13 @@ +package sqlancer.spark.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.NewCaseOperatorNode; + +public class SparkCaseOperation extends NewCaseOperatorNode implements SparkExpression { + + public SparkCaseOperation(SparkExpression switchCondition, List conditions, + List expressions, SparkExpression elseExpr) { + super(switchCondition, conditions, expressions, elseExpr); + } +} \ No newline at end of file diff --git a/src/sqlancer/spark/ast/SparkCastOperation.java b/src/sqlancer/spark/ast/SparkCastOperation.java new file mode 100644 index 000000000..3bc5eb30d --- /dev/null +++ b/src/sqlancer/spark/ast/SparkCastOperation.java @@ -0,0 +1,25 @@ +package sqlancer.spark.ast; + +import sqlancer.spark.SparkSchema.SparkDataType; + +public class SparkCastOperation implements SparkExpression { + + private final SparkExpression expression; + private final SparkDataType type; + + public SparkCastOperation(SparkExpression expression, SparkDataType type) { + if (expression == null) { + throw new AssertionError(); + } + this.expression = expression; + this.type = type; + } + + public SparkExpression getExpression() { + return expression; + } + + public SparkDataType getType() { + return type; + } +} \ No newline at end of file diff --git a/src/sqlancer/spark/ast/SparkColumnReference.java b/src/sqlancer/spark/ast/SparkColumnReference.java new file mode 100644 index 000000000..75e92d267 --- /dev/null +++ b/src/sqlancer/spark/ast/SparkColumnReference.java @@ -0,0 +1,11 @@ +package sqlancer.spark.ast; + +import sqlancer.common.ast.newast.ColumnReferenceNode; +import sqlancer.spark.SparkSchema.SparkColumn; + +public class SparkColumnReference extends ColumnReferenceNode implements SparkExpression { + + public SparkColumnReference(SparkColumn column) { + super(column); + } +} \ No newline at end of file diff --git a/src/sqlancer/spark/ast/SparkConstant.java b/src/sqlancer/spark/ast/SparkConstant.java new file mode 100644 index 000000000..9f73af59f --- /dev/null +++ b/src/sqlancer/spark/ast/SparkConstant.java @@ -0,0 +1,194 @@ +package sqlancer.spark.ast; + +import java.math.BigDecimal; +import java.sql.Timestamp; +import java.text.SimpleDateFormat; + +public abstract class SparkConstant implements SparkExpression { + + public boolean isNull() { + return false; + } + + public static class SparkNullConstant extends SparkConstant { + + @Override + public boolean isNull() { + return true; + } + + @Override + public String toString() { + return "NULL"; + } + } + + public static class SparkIntConstant extends SparkConstant { + + private final long value; + + public SparkIntConstant(long value) { + this.value = value; + } + + public long getValue() { + return value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + } + + public static class SparkDoubleConstant extends SparkConstant { + + private final double value; + + public SparkDoubleConstant(double value) { + this.value = value; + } + + public double getValue() { + return value; + } + + @Override + public String toString() { + if (value == Double.POSITIVE_INFINITY) { + return "CAST('Infinity' AS DOUBLE)"; + } else if (value == Double.NEGATIVE_INFINITY) { + return "CAST('-Infinity' AS DOUBLE)"; + } else if (Double.isNaN(value)) { + return "CAST('NaN' AS DOUBLE)"; + } + return String.valueOf(value); + } + } + + public static class SparkDecimalConstant extends SparkConstant { + + private final BigDecimal value; + + public SparkDecimalConstant(BigDecimal value) { + this.value = value; + } + + public BigDecimal getValue() { + return value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + } + + public static class SparkTimestampConstant extends SparkConstant { + + private final String textRepr; + + public SparkTimestampConstant(long value) { + Timestamp timestamp = new Timestamp(value); + SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); // Spark prefers full timestamp + this.textRepr = dateFormat.format(timestamp); + } + + public String getValue() { + return textRepr; + } + + @Override + public String toString() { + return String.format("TIMESTAMP '%s'", textRepr); + } + } + + public static class SparkDateConstant extends SparkConstant { + + private final String textRepr; + + public SparkDateConstant(long value) { + Timestamp timestamp = new Timestamp(value); + SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd"); + this.textRepr = dateFormat.format(timestamp); + } + + public String getValue() { + return textRepr; + } + + @Override + public String toString() { + return String.format("DATE '%s'", textRepr); + } + } + + public static class SparkStringConstant extends SparkConstant { + + private final String value; + + public SparkStringConstant(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + @Override + public String toString() { + return "'" + value.replace("'", "''").replace("\\", "\\\\") + "'"; + } + } + + public static class SparkBooleanConstant extends SparkConstant { + + private final boolean value; + + public SparkBooleanConstant(boolean value) { + this.value = value; + } + + public boolean getValue() { + return value; + } + + @Override + public String toString() { + return String.valueOf(value); + } + } + + public static SparkConstant createNullConstant() { + return new SparkNullConstant(); + } + + public static SparkConstant createIntConstant(long value) { + return new SparkIntConstant(value); + } + + public static SparkConstant createDoubleConstant(double value) { + return new SparkDoubleConstant(value); + } + + public static SparkConstant createDecimalConstant(BigDecimal value) { + return new SparkDecimalConstant(value); + } + + public static SparkConstant createTimestampConstant(long value) { + return new SparkTimestampConstant(value); + } + + public static SparkConstant createDateConstant(long value) { + return new SparkDateConstant(value); + } + + public static SparkConstant createStringConstant(String value) { + return new SparkStringConstant(value); + } + + public static SparkConstant createBooleanConstant(boolean value) { + return new SparkBooleanConstant(value); + } +} \ No newline at end of file diff --git a/src/sqlancer/spark/ast/SparkExpression.java b/src/sqlancer/spark/ast/SparkExpression.java new file mode 100644 index 000000000..a130096e3 --- /dev/null +++ b/src/sqlancer/spark/ast/SparkExpression.java @@ -0,0 +1,7 @@ +package sqlancer.spark.ast; + +import sqlancer.common.ast.newast.Expression; +import sqlancer.spark.SparkSchema.SparkColumn; + +public interface SparkExpression extends Expression { +} \ No newline at end of file diff --git a/src/sqlancer/spark/ast/SparkFunction.java b/src/sqlancer/spark/ast/SparkFunction.java new file mode 100644 index 000000000..d5740ee36 --- /dev/null +++ b/src/sqlancer/spark/ast/SparkFunction.java @@ -0,0 +1,13 @@ +package sqlancer.spark.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.NewFunctionNode; + +public class SparkFunction extends NewFunctionNode implements SparkExpression { + + public SparkFunction(List args, F func) { + super(args, func); + } + +} diff --git a/src/sqlancer/spark/ast/SparkInOperation.java b/src/sqlancer/spark/ast/SparkInOperation.java new file mode 100644 index 000000000..37a80e3ff --- /dev/null +++ b/src/sqlancer/spark/ast/SparkInOperation.java @@ -0,0 +1,12 @@ +package sqlancer.spark.ast; + +import java.util.List; + +import sqlancer.common.ast.newast.NewInOperatorNode; + +public class SparkInOperation extends NewInOperatorNode implements SparkExpression { + + public SparkInOperation(SparkExpression left, List right, boolean isNegated) { + super(left, right, isNegated); + } +} \ No newline at end of file diff --git a/src/sqlancer/spark/ast/SparkJoin.java b/src/sqlancer/spark/ast/SparkJoin.java new file mode 100644 index 000000000..44da7fba4 --- /dev/null +++ b/src/sqlancer/spark/ast/SparkJoin.java @@ -0,0 +1,46 @@ +package sqlancer.spark.ast; + +import sqlancer.common.ast.newast.Join; +import sqlancer.spark.SparkSchema.SparkColumn; +import sqlancer.spark.SparkSchema.SparkTable; + +public class SparkJoin implements SparkExpression, Join { + + private final SparkTableReference leftTable; + private final SparkTableReference rightTable; + private final JoinType joinType; + private SparkExpression onClause; + + public enum JoinType { + INNER, LEFT_OUTER, RIGHT_OUTER, FULL_OUTER, LEFT_SEMI, LEFT_ANTI, CROSS; + } + + public SparkJoin(SparkTableReference leftTable, SparkTableReference rightTable, JoinType joinType, + SparkExpression onClause) { + this.leftTable = leftTable; + this.rightTable = rightTable; + this.joinType = joinType; + this.onClause = onClause; + } + + public SparkTableReference getLeftTable() { + return leftTable; + } + + public SparkTableReference getRightTable() { + return rightTable; + } + + public JoinType getJoinType() { + return joinType; + } + + public SparkExpression getOnClause() { + return onClause; + } + + @Override + public void setOnClause(SparkExpression onClause) { + this.onClause = onClause; + } +} \ No newline at end of file diff --git a/src/sqlancer/spark/ast/SparkOrderingTerm.java b/src/sqlancer/spark/ast/SparkOrderingTerm.java new file mode 100644 index 000000000..824801c00 --- /dev/null +++ b/src/sqlancer/spark/ast/SparkOrderingTerm.java @@ -0,0 +1,10 @@ +package sqlancer.spark.ast; + +import sqlancer.common.ast.newast.NewOrderingTerm; + +public class SparkOrderingTerm extends NewOrderingTerm implements SparkExpression { + + public SparkOrderingTerm(SparkExpression expr, Ordering ordering) { + super(expr, ordering); + } +} \ No newline at end of file diff --git a/src/sqlancer/spark/ast/SparkSelect.java b/src/sqlancer/spark/ast/SparkSelect.java new file mode 100644 index 000000000..0986ce0a6 --- /dev/null +++ b/src/sqlancer/spark/ast/SparkSelect.java @@ -0,0 +1,42 @@ +package sqlancer.spark.ast; + +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.common.ast.SelectBase; +import sqlancer.common.ast.newast.Select; +import sqlancer.spark.SparkSchema.SparkColumn; +import sqlancer.spark.SparkSchema.SparkTable; +import sqlancer.spark.SparkToStringVisitor; + +public class SparkSelect extends SelectBase + implements Select, SparkExpression { + + private boolean isDistinct; + + public void setDistinct(boolean isDistinct) { + this.isDistinct = isDistinct; + } + + public boolean isDistinct() { + return isDistinct; + } + + @Override + public void setJoinClauses(List joinStatements) { + List expressions = joinStatements.stream().map(e -> (SparkExpression) e) + .collect(Collectors.toList()); + setJoinList(expressions); + } + + @Override + public List getJoinClauses() { + return getJoinList().stream().map(e -> (SparkJoin) e).collect(Collectors.toList()); + } + + @Override + public String asString() { + return SparkToStringVisitor.asString(this); + } + +} \ No newline at end of file diff --git a/src/sqlancer/spark/ast/SparkTableReference.java b/src/sqlancer/spark/ast/SparkTableReference.java new file mode 100644 index 000000000..92a59ad3d --- /dev/null +++ b/src/sqlancer/spark/ast/SparkTableReference.java @@ -0,0 +1,13 @@ +package sqlancer.spark.ast; + +import sqlancer.common.ast.newast.TableReferenceNode; +import sqlancer.spark.SparkSchema; + +public class SparkTableReference extends TableReferenceNode + implements SparkExpression { + + public SparkTableReference(SparkSchema.SparkTable table) { + super(table); + } + +} \ No newline at end of file diff --git a/src/sqlancer/spark/ast/SparkUnaryPostfixOperation.java b/src/sqlancer/spark/ast/SparkUnaryPostfixOperation.java new file mode 100644 index 000000000..f1082a655 --- /dev/null +++ b/src/sqlancer/spark/ast/SparkUnaryPostfixOperation.java @@ -0,0 +1,13 @@ +package sqlancer.spark.ast; + +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.common.ast.newast.NewUnaryPostfixOperatorNode; + +public class SparkUnaryPostfixOperation extends NewUnaryPostfixOperatorNode + implements SparkExpression { + + public SparkUnaryPostfixOperation(SparkExpression expr, Operator op) { + super(expr, op); + } + +} \ No newline at end of file diff --git a/src/sqlancer/spark/ast/SparkUnaryPrefixOperation.java b/src/sqlancer/spark/ast/SparkUnaryPrefixOperation.java new file mode 100644 index 000000000..d1bd94ab4 --- /dev/null +++ b/src/sqlancer/spark/ast/SparkUnaryPrefixOperation.java @@ -0,0 +1,12 @@ +package sqlancer.spark.ast; + +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.common.ast.newast.NewUnaryPrefixOperatorNode; + +public class SparkUnaryPrefixOperation extends NewUnaryPrefixOperatorNode implements SparkExpression { + + public SparkUnaryPrefixOperation(SparkExpression expr, Operator op) { + super(expr, op); + } + +} \ No newline at end of file diff --git a/src/sqlancer/spark/gen/SparkExpressionGenerator.java b/src/sqlancer/spark/gen/SparkExpressionGenerator.java new file mode 100644 index 000000000..faf8a07f0 --- /dev/null +++ b/src/sqlancer/spark/gen/SparkExpressionGenerator.java @@ -0,0 +1,336 @@ +package sqlancer.spark.gen; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import sqlancer.Randomly; +import sqlancer.common.ast.BinaryOperatorNode.Operator; +import sqlancer.common.ast.newast.NewOrderingTerm.Ordering; +import sqlancer.common.gen.TLPWhereGenerator; +import sqlancer.common.gen.UntypedExpressionGenerator; +import sqlancer.common.schema.AbstractTables; +import sqlancer.spark.SparkGlobalState; +import sqlancer.spark.SparkSchema.SparkColumn; +import sqlancer.spark.SparkSchema.SparkDataType; +import sqlancer.spark.SparkSchema.SparkTable; +import sqlancer.spark.ast.SparkBetweenOperation; +import sqlancer.spark.ast.SparkBinaryOperation; +import sqlancer.spark.ast.SparkCaseOperation; +import sqlancer.spark.ast.SparkCastOperation; +import sqlancer.spark.ast.SparkColumnReference; +import sqlancer.spark.ast.SparkConstant; +import sqlancer.spark.ast.SparkExpression; +import sqlancer.spark.ast.SparkFunction; +import sqlancer.spark.ast.SparkInOperation; +import sqlancer.spark.ast.SparkJoin; +import sqlancer.spark.ast.SparkOrderingTerm; +import sqlancer.spark.ast.SparkSelect; +import sqlancer.spark.ast.SparkTableReference; +import sqlancer.spark.ast.SparkUnaryPostfixOperation; +import sqlancer.spark.ast.SparkUnaryPrefixOperation; + +public class SparkExpressionGenerator extends UntypedExpressionGenerator + implements TLPWhereGenerator { + + private final SparkGlobalState globalState; + private List tables; + + private enum Expression { + UNARY_PREFIX, UNARY_POSTFIX, BINARY_COMPARISON, BINARY_LOGICAL, BINARY_ARITHMETIC, CAST, FUNC, BETWEEN, IN, + CASE; + } + + public SparkExpressionGenerator(SparkGlobalState globalState) { + this.globalState = globalState; + } + + @Override + public SparkExpression negatePredicate(SparkExpression predicate) { + return new SparkUnaryPrefixOperation(predicate, SparkUnaryPrefixOperator.NOT); + } + + @Override + public SparkExpression isNull(SparkExpression expr) { + return new SparkUnaryPostfixOperation(expr, SparkUnaryPostfixOperator.IS_NULL); + } + + @Override + protected SparkExpression generateExpression(int depth) { + return generateExpressionInternal(depth); + } + + private SparkExpression generateExpressionInternal(int depth) throws AssertionError { + if (depth >= globalState.getOptions().getMaxExpressionDepth() + || Randomly.getBooleanWithRatherLowProbability()) { + return generateLeafNode(); + } + if (allowAggregates && Randomly.getBooleanWithRatherLowProbability()) { + allowAggregates = false; // aggregate function calls cannot be nested + SparkAggregateFunction aggregate = SparkAggregateFunction.getRandom(); + return new SparkFunction<>(generateExpressions(aggregate.getNrArgs(), depth + 1), aggregate); + } + + List possibleOptions = new ArrayList<>(Arrays.asList(Expression.values())); + Expression expr = Randomly.fromList(possibleOptions); + + switch (expr) { + case UNARY_PREFIX: + return new SparkUnaryPrefixOperation(generateExpression(depth + 1), SparkUnaryPrefixOperator.getRandom()); + case UNARY_POSTFIX: + return new SparkUnaryPostfixOperation(generateExpression(depth + 1), SparkUnaryPostfixOperator.getRandom()); + case BINARY_COMPARISON: + Operator op = SparkBinaryComparisonOperator.getRandom(); + return new SparkBinaryOperation(generateExpression(depth + 1), generateExpression(depth + 1), op); + case BINARY_LOGICAL: + op = SparkBinaryLogicalOperator.getRandom(); + return new SparkBinaryOperation(generateExpression(depth + 1), generateExpression(depth + 1), op); + case BINARY_ARITHMETIC: + return new SparkBinaryOperation(generateExpression(depth + 1), generateExpression(depth + 1), + SparkBinaryArithmeticOperator.getRandom()); + case CAST: + return new SparkCastOperation(generateExpression(depth + 1), SparkDataType.getRandomType()); + case FUNC: + SparkFunc func = SparkFunc.getRandom(); + return new SparkFunction<>(generateExpressions(func.getNrArgs()), func); + case BETWEEN: + return new SparkBetweenOperation(generateExpression(depth + 1), generateExpression(depth + 1), + generateExpression(depth + 1), Randomly.getBoolean()); + case IN: + return new SparkInOperation(generateExpression(depth + 1), + generateExpressions(Randomly.smallNumber() + 1, depth + 1), Randomly.getBoolean()); + case CASE: + int nr = Randomly.smallNumber() + 1; + return new SparkCaseOperation(generateExpression(depth + 1), generateExpressions(nr, depth + 1), + generateExpressions(nr, depth + 1), generateExpression(depth + 1)); + default: + throw new AssertionError(expr); + } + } + + @Override + public SparkExpression generateConstant() { + if (Randomly.getBooleanWithRatherLowProbability()) { + return SparkConstant.createNullConstant(); + } + SparkDataType[] values = SparkDataType.values(); + SparkDataType constantType = Randomly.fromOptions(values); + switch (constantType) { + case STRING: + return SparkConstant.createStringConstant(globalState.getRandomly().getString()); + case INTEGER: + return SparkConstant.createIntConstant(globalState.getRandomly().getInteger()); + case DOUBLE: + return SparkConstant.createDoubleConstant(globalState.getRandomly().getDouble()); + case BOOLEAN: + return SparkConstant.createBooleanConstant(Randomly.getBoolean()); + case TIMESTAMP: + return SparkConstant.createTimestampConstant(globalState.getRandomly().getInteger()); + case DATE: + return SparkConstant.createDateConstant(globalState.getRandomly().getInteger()); + default: + throw new AssertionError(constantType); + } + } + + @Override + protected SparkExpression generateColumn() { + SparkColumn column = Randomly.fromList(columns); + return new SparkColumnReference(column); + } + + @Override + public List generateOrderBys() { + List expr = super.generateOrderBys(); + List newExpr = new ArrayList<>(expr.size()); + for (SparkExpression curExpr : expr) { + if (Randomly.getBoolean()) { + curExpr = new SparkOrderingTerm(curExpr, Ordering.getRandom()); + } + newExpr.add(curExpr); + } + return newExpr; + } + + @Override + public SparkExpressionGenerator setTablesAndColumns(AbstractTables tables) { + this.columns = tables.getColumns(); + this.tables = tables.getTables(); + return this; + } + + @Override + public SparkExpression generateBooleanExpression() { + return generateExpression(); + } + + @Override + public SparkSelect generateSelect() { + return new SparkSelect(); + } + + @Override + public List getTableRefs() { + return tables.stream().map(t -> new SparkTableReference(t)).collect(Collectors.toList()); + } + + @Override + public List generateFetchColumns(boolean allowAggregates) { + if (Randomly.getBoolean()) { + return List.of(new SparkColumnReference(new SparkColumn("*", null, null))); + } + return Randomly.nonEmptySubset(columns).stream().map(c -> new SparkColumnReference(c)) + .collect(Collectors.toList()); + } + + @Override + public List getRandomJoinClauses() { + return List.of(); + } + + public enum SparkUnaryPrefixOperator implements Operator { + NOT("NOT"), PLUS("+"), MINUS("-"), BITWISE_NOT("~"); + + private String textRepr; + + SparkUnaryPrefixOperator(String textRepr) { + this.textRepr = textRepr; + } + + public static SparkUnaryPrefixOperator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepr; + } + } + + public enum SparkUnaryPostfixOperator implements Operator { + IS_NULL("IS NULL"), IS_NOT_NULL("IS NOT NULL"); + + private String textRepr; + + SparkUnaryPostfixOperator(String textRepr) { + this.textRepr = textRepr; + } + + public static SparkUnaryPostfixOperator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepr; + } + } + + public enum SparkBinaryComparisonOperator implements Operator { + EQUALS("="), GREATER(">"), GREATER_EQUALS(">="), SMALLER("<"), SMALLER_EQUALS("<="), NOT_EQUALS("!="), + LIKE("LIKE"), NOT_LIKE("NOT LIKE"), RLIKE("RLIKE"); + + private String textRepr; + + SparkBinaryComparisonOperator(String textRepr) { + this.textRepr = textRepr; + } + + public static SparkBinaryComparisonOperator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepr; + } + } + + public enum SparkBinaryLogicalOperator implements Operator { + AND("AND"), OR("OR"); + + private String textRepr; + + SparkBinaryLogicalOperator(String textRepr) { + this.textRepr = textRepr; + } + + public static SparkBinaryLogicalOperator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepr; + } + } + + public enum SparkBinaryArithmeticOperator implements Operator { + // Spark supports || for concat, and bitwise operators &, |, ^ + CONCAT("||"), ADD("+"), SUB("-"), MULT("*"), DIV("/"), MOD("%"), BITWISE_AND("&"), BITWISE_OR("|"), + BITWISE_XOR("^"); + + private String textRepr; + + SparkBinaryArithmeticOperator(String textRepr) { + this.textRepr = textRepr; + } + + public static SparkBinaryArithmeticOperator getRandom() { + return Randomly.fromOptions(values()); + } + + @Override + public String getTextRepresentation() { + return textRepr; + } + } + + public enum SparkAggregateFunction { + COUNT(1), SUM(1), AVG(1), MIN(1), MAX(1), VARIANCE(1), VAR_SAMP(1), STDDEV_POP(1), STDDEV_SAMP(1), COVAR_POP(2), + COVAR_SAMP(2), CORR(2); + + private int nrArgs; + + SparkAggregateFunction(int nrArgs) { + this.nrArgs = nrArgs; + } + + public static SparkAggregateFunction getRandom() { + return Randomly.fromOptions(values()); + } + + public int getNrArgs() { + return nrArgs; + } + } + + public enum SparkFunc { + ROUND(2), FLOOR(1), ABS(1), CEIL(1); + + private int nrArgs; + private boolean isVariadic; + + SparkFunc(int nrArgs) { + this(nrArgs, false); + } + + SparkFunc(int nrArgs, boolean isVariadic) { + this.nrArgs = nrArgs; + this.isVariadic = isVariadic; + } + + public static SparkFunc getRandom() { + return Randomly.fromOptions(values()); + } + + public int getNrArgs() { + if (isVariadic) { + return Randomly.smallNumber() + nrArgs; + } else { + return nrArgs; + } + } + } +} \ No newline at end of file diff --git a/src/sqlancer/spark/gen/SparkInsertGenerator.java b/src/sqlancer/spark/gen/SparkInsertGenerator.java new file mode 100644 index 000000000..29232fdb2 --- /dev/null +++ b/src/sqlancer/spark/gen/SparkInsertGenerator.java @@ -0,0 +1,47 @@ +package sqlancer.spark.gen; + +import java.util.List; + +import sqlancer.common.gen.AbstractInsertGenerator; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.spark.SparkErrors; +import sqlancer.spark.SparkGlobalState; +import sqlancer.spark.SparkSchema.SparkColumn; +import sqlancer.spark.SparkSchema.SparkTable; +import sqlancer.spark.SparkToStringVisitor; + +public class SparkInsertGenerator extends AbstractInsertGenerator { + + private final SparkGlobalState globalState; + private final ExpectedErrors errors = new ExpectedErrors(); + private final SparkExpressionGenerator gen; + + public SparkInsertGenerator(SparkGlobalState globalState) { + this.globalState = globalState; + this.gen = new SparkExpressionGenerator(globalState); + } + + public static SQLQueryAdapter getQuery(SparkGlobalState globalState) { + return new SparkInsertGenerator(globalState).generate(); + } + + @Override + protected void insertValue(SparkColumn column) { + sb.append(SparkToStringVisitor.asString(gen.generateConstant())); + } + + private SQLQueryAdapter generate() { + sb.append("INSERT INTO "); + SparkTable table = globalState.getSchema().getRandomTable(t -> !t.isView()); + sb.append(table.getName()); + + sb.append(" VALUES "); + + List columns = table.getColumns(); + insertColumns(columns); + + SparkErrors.addInsertErrors(errors); + return new SQLQueryAdapter(sb.toString(), errors, false, false); + } +} \ No newline at end of file diff --git a/src/sqlancer/spark/gen/SparkTableGenerator.java b/src/sqlancer/spark/gen/SparkTableGenerator.java new file mode 100644 index 000000000..68cafdafb --- /dev/null +++ b/src/sqlancer/spark/gen/SparkTableGenerator.java @@ -0,0 +1,100 @@ +package sqlancer.spark.gen; + +import java.util.ArrayList; +import java.util.List; + +import sqlancer.Randomly; +import sqlancer.common.DBMSCommon; +import sqlancer.common.query.ExpectedErrors; +import sqlancer.common.query.SQLQueryAdapter; +import sqlancer.spark.SparkErrors; +import sqlancer.spark.SparkGlobalState; +import sqlancer.spark.SparkSchema; +import sqlancer.spark.SparkSchema.SparkColumn; +import sqlancer.spark.SparkSchema.SparkDataType; +import sqlancer.spark.SparkSchema.SparkTable; +import sqlancer.spark.SparkToStringVisitor; + +public class SparkTableGenerator { + + private enum ColumnConstraints { + NOT_NULL, DEFAULT + // PRIMARY KEY and UNIQUE are often not supported in standard Spark file sources (Parquet/ORC) + // without specific catalogs (like Delta/Iceberg), so we limit to constraints Spark SQL widely accepts. + } + + private final SparkGlobalState globalState; + private final String tableName; + private final StringBuilder sb = new StringBuilder(); + private final SparkExpressionGenerator gen; + private final SparkTable table; + private final List columnsToBeAdded = new ArrayList<>(); + + public SparkTableGenerator(SparkGlobalState globalState, String tableName) { + this.tableName = tableName; + this.globalState = globalState; + this.table = new SparkTable(tableName, columnsToBeAdded, false); + this.gen = new SparkExpressionGenerator(globalState).setColumns(columnsToBeAdded); + } + + public static SQLQueryAdapter generate(SparkGlobalState globalState, String tableName) { + SparkTableGenerator generator = new SparkTableGenerator(globalState, tableName); + return generator.create(); + } + + private SQLQueryAdapter create() { + ExpectedErrors errors = new ExpectedErrors(); + + sb.append("CREATE TABLE "); + sb.append(globalState.getDatabaseName()); + sb.append("."); + sb.append(tableName); + sb.append(" ("); + for (int i = 0; i < Randomly.smallNumber() + 1; i++) { + if (i != 0) { + sb.append(", "); + } + appendColumn(i); + } + sb.append(")"); + sb.append(" USING PARQUET"); + + // TODO: implement PARTITION BY clause + // TODO: implement CLUSTERED BY clauses + // TODO: implement ROW FORMAT and STORED AS clauses + // TODO: randomly add some predefined TABLEPROPERTIES + + SparkErrors.addExpressionErrors(errors); + return new SQLQueryAdapter(sb.toString(), errors, true, false); + } + + private void appendColumn(int columnId) { + String columnName = DBMSCommon.createColumnName(columnId); + sb.append(columnName); + sb.append(" "); + SparkDataType randType = SparkSchema.SparkDataType.getRandomType(); + sb.append(randType); + columnsToBeAdded.add(new SparkColumn(columnName, table, randType)); + appendColumnConstraint(); + } + + private void appendColumnConstraint() { + if (Randomly.getBoolean()) { + return; + } + + ColumnConstraints constraint = Randomly.fromOptions(ColumnConstraints.values()); + switch (constraint) { + case NOT_NULL: + sb.append(" NOT NULL"); + break; + case DEFAULT: + sb.append(" DEFAULT "); + sb.append(SparkToStringVisitor.asString(gen.generateConstant())); + sb.append(" "); + break; + default: + throw new AssertionError(constraint); + } + } +} \ No newline at end of file diff --git a/test/sqlancer/dbms/TestConfig.java b/test/sqlancer/dbms/TestConfig.java index f5aeefa12..f6be45648 100644 --- a/test/sqlancer/dbms/TestConfig.java +++ b/test/sqlancer/dbms/TestConfig.java @@ -11,6 +11,7 @@ public class TestConfig { public static final String DATAFUSION_ENV = "DATAFUSION_AVAILABLE"; public static final String DORIS_ENV = "DORIS_AVAILABLE"; public static final String HIVE_ENV = "HIVE_AVAILABLE"; + public static final String SPARK_ENV = "SPARK_AVAILABLE"; public static final String MARIADB_ENV = "MARIADB_AVAILABLE"; public static final String MATERIALIZE_ENV = "MATERIALIZE_AVAILABLE"; public static final String MYSQL_ENV = "MYSQL_AVAILABLE"; diff --git a/test/sqlancer/dbms/TestSparkTLP.java b/test/sqlancer/dbms/TestSparkTLP.java new file mode 100644 index 000000000..83302ceff --- /dev/null +++ b/test/sqlancer/dbms/TestSparkTLP.java @@ -0,0 +1,20 @@ +package sqlancer.dbms; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +import org.junit.jupiter.api.Test; + +import sqlancer.Main; + +public class TestSparkTLP { + + @Test + public void testSparkTLPWhere() { + assumeTrue(TestConfig.isEnvironmentTrue(TestConfig.SPARK_ENV)); + assertEquals(0, + Main.executeMain(new String[] { "--canonicalize-sql-strings", "false", "--random-seed", "0", + "--timeout-seconds", TestConfig.SECONDS, "--num-threads", "1", "--num-queries", + TestConfig.NUM_QUERIES, "spark", "--oracle", "TLPWhere" })); + } +} \ No newline at end of file