Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,44 @@ jobs:
- name: Run Tests
run: HIVE_AVAILABLE=true mvn -Dtest=TestHiveTLP test

spark:
name: DBMS Tests (Spark)
runs-on: ubuntu-latest

services:
spark:
image: apache/spark:3.5.1
ports:
- 10000:10000

command: >-
/opt/spark/bin/spark-submit
--class org.apache.spark.sql.hive.thriftserver.HiveThriftServer2
--name "Thrift JDBC/ODBC Server"
--master local[*]
--driver-memory 4g
--conf spark.hive.server2.thrift.port=10000
--conf spark.sql.warehouse.dir=/tmp/spark-warehouse
spark-internal

steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0

- name: Set up JDK 11
uses: actions/setup-java@v3
with:
distribution: 'temurin'
java-version: '11'
cache: 'maven'

- name: Build SQLancer
run: mvn -B package -DskipTests=true

- name: Run Tests
run: SPARK_AVAILABLE=true mvn -Dtest=TestSparkTLP test

hsqldb:
name: DBMS Tests (HSQLB)
runs-on: ubuntu-latest
Expand Down
9 changes: 7 additions & 2 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId> slf4j-simple</artifactId>
<artifactId>slf4j-simple</artifactId>
<version>2.0.6</version>
</dependency>
<dependency>
Expand Down Expand Up @@ -381,7 +381,7 @@
<dependency>
<groupId>org.apache.hive</groupId>
<artifactId>hive-jdbc</artifactId>
<version>4.0.1</version>
<version>3.1.2</version>
</dependency>
<dependency>
<groupId>org.apache.hive</groupId>
Expand All @@ -393,6 +393,11 @@
<artifactId>hive-cli</artifactId>
<version>4.0.1</version>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-common</artifactId>
<version>3.2.4</version>
</dependency>
</dependencies>
<reporting>
<plugins>
Expand Down
2 changes: 2 additions & 0 deletions src/sqlancer/Main.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import sqlancer.tidb.TiDBProvider;
import sqlancer.yugabyte.ycql.YCQLProvider;
import sqlancer.yugabyte.ysql.YSQLProvider;
import sqlancer.spark.SparkProvider;

public final class Main {

Expand Down Expand Up @@ -756,6 +757,7 @@ private static void checkForIssue799(List<DatabaseProvider<?, ?, ?>> providers)
providers.add(new DuckDBProvider());
providers.add(new H2Provider());
providers.add(new HiveProvider());
providers.add(new SparkProvider());
providers.add(new HSQLDBProvider());
providers.add(new MariaDBProvider());
providers.add(new MaterializeProvider());
Expand Down
67 changes: 67 additions & 0 deletions src/sqlancer/spark/SparkErrors.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package sqlancer.spark;

import java.util.ArrayList;
import java.util.List;

import sqlancer.common.query.ExpectedErrors;

public final class SparkErrors {

private SparkErrors() {
}

public static List<String> getExpressionErrors() {
ArrayList<String> errors = new ArrayList<>();

errors.add("cannot resolve");
errors.add("AnalysisException");
errors.add("data type mismatch");
errors.add("undefined function");
errors.add("mismatched input");
errors.add("due to data type mismatch");

// --- Invalid Literals
errors.add("The value of the typed literal");

errors.add("DATATYPE_MISMATCH");
errors.add("cannot be cast to");

errors.add("Overflow");
errors.add("Divide by zero"); // Common if spark.sql.ansi.enabled is true
errors.add("division by zero");

// --- Group By / Aggregation errors ---
errors.add("grouping expressions");
errors.add("expression is neither present in the group by");
errors.add("is not a valid grouping expression");
errors.add("is not contained in either an aggregate function or the GROUP BY clause");
errors.add("PARSE_SYNTAX_ERROR");
errors.add("Syntax error");

return errors;
}

public static void addExpressionErrors(ExpectedErrors errors) {
errors.addAll(getExpressionErrors());
}

public static List<String> getInsertErrors() {
ArrayList<String> errors = new ArrayList<>();

errors.add("not enough data columns");
errors.add("cannot write to");
errors.add("incompatible types");
errors.add("too many data columns");
errors.add("cannot be cast to");
errors.add("Error running query");
errors.add("The value of the typed literal");
errors.add("Cannot safely cast"); // Found in logs: Decimal -> Date
errors.add("AnalysisException"); // Spark throws this for almost all insert failures

return errors;
}

public static void addInsertErrors(ExpectedErrors errors) {
errors.addAll(getInsertErrors());
}
}
11 changes: 11 additions & 0 deletions src/sqlancer/spark/SparkGlobalState.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package sqlancer.spark;

import sqlancer.SQLGlobalState;

public class SparkGlobalState extends SQLGlobalState<SparkOptions, SparkSchema> {

@Override
protected SparkSchema readSchema() throws Exception {
return SparkSchema.fromConnection(getConnection(), getDatabaseName());
}
}
43 changes: 43 additions & 0 deletions src/sqlancer/spark/SparkOptions.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package sqlancer.spark;

import java.sql.SQLException;
import java.util.Arrays;
import java.util.List;

import com.beust.jcommander.Parameter;
import com.beust.jcommander.Parameters;

import sqlancer.DBMSSpecificOptions;
import sqlancer.OracleFactory;
import sqlancer.common.oracle.TLPWhereOracle;
import sqlancer.common.oracle.TestOracle;
import sqlancer.common.query.ExpectedErrors;
import sqlancer.spark.gen.SparkExpressionGenerator;

@Parameters(separators = "=", commandDescription = "Spark SQL (default port: " + SparkOptions.DEFAULT_PORT
+ ", default host: " + SparkOptions.DEFAULT_HOST + ")")
public class SparkOptions implements DBMSSpecificOptions<SparkOptions.SparkOracleFactory> {
public static final String DEFAULT_HOST = "localhost";
public static final int DEFAULT_PORT = 10000;

@Parameter(names = "--oracle")
public List<SparkOracleFactory> oracle = Arrays.asList(SparkOracleFactory.TLPWhere);

public enum SparkOracleFactory implements OracleFactory<SparkGlobalState> {
TLPWhere {
@Override
public TestOracle<SparkGlobalState> create(SparkGlobalState globalState) throws SQLException {
SparkExpressionGenerator gen = new SparkExpressionGenerator(globalState);
ExpectedErrors expectedErrors = ExpectedErrors.newErrors().with(SparkErrors.getExpressionErrors())
.build();

return new TLPWhereOracle<>(globalState, gen, expectedErrors);
}
};
}

@Override
public List<SparkOracleFactory> getTestOracleFactory() {
return oracle;
}
}
122 changes: 122 additions & 0 deletions src/sqlancer/spark/SparkProvider.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package sqlancer.spark;

import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.sql.Statement;

import com.google.auto.service.AutoService;

import sqlancer.AbstractAction;
import sqlancer.DatabaseProvider;
import sqlancer.IgnoreMeException;
import sqlancer.MainOptions;
import sqlancer.Randomly;
import sqlancer.SQLConnection;
import sqlancer.SQLProviderAdapter;
import sqlancer.StatementExecutor;
import sqlancer.common.query.SQLQueryAdapter;
import sqlancer.common.query.SQLQueryProvider;
import sqlancer.spark.gen.SparkInsertGenerator;
import sqlancer.spark.gen.SparkTableGenerator;

@AutoService(DatabaseProvider.class)
public class SparkProvider extends SQLProviderAdapter<SparkGlobalState, SparkOptions> {

public SparkProvider() {
super(SparkGlobalState.class, SparkOptions.class);
}

public enum Action implements AbstractAction<SparkGlobalState> {
INSERT(SparkInsertGenerator::getQuery); // You will need to create this class

private final SQLQueryProvider<SparkGlobalState> sqlQueryProvider;

Action(SQLQueryProvider<SparkGlobalState> sqlQueryProvider) {
this.sqlQueryProvider = sqlQueryProvider;
}

@Override
public SQLQueryAdapter getQuery(SparkGlobalState state) throws Exception {
return sqlQueryProvider.getQuery(state);
}
}

private static int mapActions(SparkGlobalState globalState, Action a) {
Randomly r = globalState.getRandomly();
switch (a) {
case INSERT:
return r.getInteger(0, globalState.getOptions().getMaxNumberInserts());
default:
throw new AssertionError(a);
}
}

@Override
public void generateDatabase(SparkGlobalState globalState) throws Exception {
for (int i = 0; i < Randomly.fromOptions(1, 2); i++) {
boolean success;
do {
String tableName = globalState.getSchema().getFreeTableName();
SQLQueryAdapter qt = SparkTableGenerator.generate(globalState, tableName);
success = globalState.executeStatement(qt);
} while (!success);
}

if (globalState.getSchema().getDatabaseTables().isEmpty()) {
throw new IgnoreMeException();
}

StatementExecutor<SparkGlobalState, Action> se = new StatementExecutor<>(globalState, Action.values(),
SparkProvider::mapActions, (q) -> {
if (globalState.getSchema().getDatabaseTables().isEmpty()) {
throw new IgnoreMeException();
}
});
se.executeStatements();
}

@Override
public SQLConnection createDatabase(SparkGlobalState globalState) throws SQLException {
String username = globalState.getOptions().getUserName();
String password = globalState.getOptions().getPassword();
String host = globalState.getOptions().getHost();
int port = globalState.getOptions().getPort();

if (host == null) {
host = SparkOptions.DEFAULT_HOST;
}
if (port == MainOptions.NO_SET_PORT) {
port = SparkOptions.DEFAULT_PORT;
}

String databaseName = globalState.getDatabaseName();

// Spark uses the Hive driver for JDBC usually
String url = String.format("jdbc:hive2://%s:%d/%s", host, port, "default");

// Connect to default to create the fuzzing DB
Connection con = DriverManager.getConnection(url, username, password);
try (Statement s = con.createStatement()) {
s.execute("DROP DATABASE IF EXISTS " + databaseName + " CASCADE");
}
try (Statement s = con.createStatement()) {
s.execute("CREATE DATABASE " + databaseName);
}
con.close();

// Connect to the specific fuzzing DB
con = DriverManager.getConnection(String.format("jdbc:hive2://%s:%d/%s", host, port, databaseName), username,
password);
try (Statement s = con.createStatement()) {
// This allows casting things like BOOLEAN to DATE/TIMESTAMP, which the generator loves to do.
s.execute("SET spark.sql.ansi.enabled=false");
}
return new SQLConnection(con);
}

@Override
public String getDBMSName() {
return "spark";
}
}
Loading