diff --git a/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java b/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java index f1ec148c622..de14eee54e6 100644 --- a/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java +++ b/async-query-core/src/main/java/org/opensearch/sql/spark/utils/SQLQueryUtils.java @@ -43,14 +43,24 @@ public class SQLQueryUtils { private static final Logger logger = LogManager.getLogger(SQLQueryUtils.class); public static List extractFullyQualifiedTableNames(String sqlQuery) { - SqlBaseParser sqlBaseParser = - new SqlBaseParser( - new CommonTokenStream(new SqlBaseLexer(new CaseInsensitiveCharStream(sqlQuery)))); - sqlBaseParser.addErrorListener(new SyntaxAnalysisErrorListener()); + return extractFullyQualifiedTableNamesWithMetadata(sqlQuery).getFullyQualifiedTableNames(); + } + + public static TableExtractionResult extractFullyQualifiedTableNamesWithMetadata(String sqlQuery) { + SqlBaseParser sqlBaseParser = getBaseParser(sqlQuery); StatementContext statement = sqlBaseParser.statement(); - SparkSqlTableNameVisitor sparkSqlTableNameVisitor = new SparkSqlTableNameVisitor(); - statement.accept(sparkSqlTableNameVisitor); - return sparkSqlTableNameVisitor.getFullyQualifiedTableNames(); + SparkSqlTableNameVisitor visitor = new SparkSqlTableNameVisitor(); + statement.accept(visitor); + + // Remove duplicate table names + List uniqueFullyQualifiedTableNames = new LinkedList<>(); + for (FullyQualifiedTableName fullyQualifiedTableName : visitor.getFullyQualifiedTableNames()) { + if (!uniqueFullyQualifiedTableNames.contains(fullyQualifiedTableName)) { + uniqueFullyQualifiedTableNames.add(fullyQualifiedTableName); + } + } + + return new TableExtractionResult(uniqueFullyQualifiedTableNames, visitor.isCreateTable()); } public static IndexQueryDetails extractIndexDetails(String sqlQuery) { @@ -92,6 +102,8 @@ public static class SparkSqlTableNameVisitor extends SqlBaseParserBaseVisitor fullyQualifiedTableNames = new LinkedList<>(); + @Getter private boolean isCreateTable = false; + public SparkSqlTableNameVisitor() {} @Override @@ -130,6 +142,12 @@ public Void visitCreateTableHeader(SqlBaseParser.CreateTableHeaderContext ctx) { } return super.visitCreateTableHeader(ctx); } + + @Override + public Void visitCreateTable(SqlBaseParser.CreateTableContext ctx) { + isCreateTable = true; + return super.visitCreateTable(ctx); + } } public static class FlintSQLIndexDetailsVisitor extends FlintSparkSqlExtensionsBaseVisitor { @@ -380,4 +398,15 @@ public String removeUnwantedQuotes(String input) { return input.replaceAll("^\"|\"$", ""); } } + + public static class TableExtractionResult { + @Getter private final List fullyQualifiedTableNames; + @Getter private final boolean isCreateTableQuery; + + public TableExtractionResult( + List fullyQualifiedTableNames, boolean isCreateTableQuery) { + this.fullyQualifiedTableNames = fullyQualifiedTableNames; + this.isCreateTableQuery = isCreateTableQuery; + } + } } diff --git a/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java b/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java index f860c6a3bc9..1a251b32fdc 100644 --- a/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java +++ b/async-query-core/src/test/java/org/opensearch/sql/spark/utils/SQLQueryUtilsTest.java @@ -25,6 +25,7 @@ import org.opensearch.sql.spark.dispatcher.model.IndexQueryActionType; import org.opensearch.sql.spark.dispatcher.model.IndexQueryDetails; import org.opensearch.sql.spark.flint.FlintIndexType; +import org.opensearch.sql.spark.utils.SQLQueryUtils.TableExtractionResult; @ExtendWith(MockitoExtension.class) public class SQLQueryUtilsTest { @@ -444,6 +445,69 @@ void testRecoverIndex() { assertEquals(IndexQueryActionType.RECOVER, indexDetails.getIndexQueryActionType()); } + @Test + void testExtractFullyQualifiedTableNamesWithMetadata() { + // Test CREATE TABLE queries + String createTableQuery = + "CREATE EXTERNAL TABLE\n" + + "myS3.default.alb_logs\n" + + "[ PARTITIONED BY (col_name [, … ] ) ]\n" + + "[ ROW FORMAT DELIMITED row_format ]\n" + + "STORED AS file_format\n" + + "LOCATION { 's3://bucket/folder/' }"; + + TableExtractionResult result = + SQLQueryUtils.extractFullyQualifiedTableNamesWithMetadata(createTableQuery); + assertTrue(result.isCreateTableQuery()); + assertEquals(1, result.getFullyQualifiedTableNames().size()); + assertFullyQualifiedTableName( + "myS3", "default", "alb_logs", result.getFullyQualifiedTableNames().get(0)); + + String createTableQuery2 = + "CREATE TABLE myS3.default.new_table (id INT, name STRING) USING PARQUET"; + result = SQLQueryUtils.extractFullyQualifiedTableNamesWithMetadata(createTableQuery2); + assertTrue(result.isCreateTableQuery()); + assertEquals(1, result.getFullyQualifiedTableNames().size()); + assertFullyQualifiedTableName( + "myS3", "default", "new_table", result.getFullyQualifiedTableNames().get(0)); + + // Test SELECT queries + String selectQuery = "SELECT * FROM myS3.default.alb_logs"; + result = SQLQueryUtils.extractFullyQualifiedTableNamesWithMetadata(selectQuery); + assertFalse(result.isCreateTableQuery()); + assertEquals(1, result.getFullyQualifiedTableNames().size()); + assertFullyQualifiedTableName( + "myS3", "default", "alb_logs", result.getFullyQualifiedTableNames().get(0)); + + // Test DROP TABLE queries + String dropTableQuery = "DROP TABLE myS3.default.alb_logs"; + result = SQLQueryUtils.extractFullyQualifiedTableNamesWithMetadata(dropTableQuery); + assertFalse(result.isCreateTableQuery()); + assertEquals(1, result.getFullyQualifiedTableNames().size()); + assertFullyQualifiedTableName( + "myS3", "default", "alb_logs", result.getFullyQualifiedTableNames().get(0)); + + // Test DESCRIBE TABLE queries + String describeTableQuery = "DESCRIBE TABLE myS3.default.alb_logs"; + result = SQLQueryUtils.extractFullyQualifiedTableNamesWithMetadata(describeTableQuery); + assertFalse(result.isCreateTableQuery()); + assertEquals(1, result.getFullyQualifiedTableNames().size()); + assertFullyQualifiedTableName( + "myS3", "default", "alb_logs", result.getFullyQualifiedTableNames().get(0)); + + // Test JOIN queries + String joinQuery = + "SELECT * FROM myS3.default.alb_logs JOIN myS3.default.http_logs ON alb_logs.id =" + + " http_logs.id"; + result = SQLQueryUtils.extractFullyQualifiedTableNamesWithMetadata(joinQuery); + assertFalse(result.isCreateTableQuery()); + assertEquals(2, result.getFullyQualifiedTableNames().size()); + assertFullyQualifiedTableName( + "myS3", "default", "alb_logs", result.getFullyQualifiedTableNames().get(0)); + assertFullyQualifiedTableName( + "myS3", "default", "http_logs", result.getFullyQualifiedTableNames().get(1)); + } + @Getter protected static class IndexQuery { private String query;