diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/serverSidePlanning/ServerSidePlannedTable.scala b/spark/src/main/scala/org/apache/spark/sql/delta/serverSidePlanning/ServerSidePlannedTable.scala index bfd3c6516c2..ea0dfe2e489 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/serverSidePlanning/ServerSidePlannedTable.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/serverSidePlanning/ServerSidePlannedTable.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapabil import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.execution.datasources.{FileFormat, PartitionedFile} import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.sources.{And, Filter} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -196,16 +197,32 @@ class ServerSidePlannedTable( /** * ScanBuilder that uses ServerSidePlanningClient to plan the scan. + * Implements SupportsPushDownFilters to enable WHERE clause pushdown to the server. */ class ServerSidePlannedScanBuilder( spark: SparkSession, databaseName: String, tableName: String, tableSchema: StructType, - planningClient: ServerSidePlanningClient) extends ScanBuilder { + planningClient: ServerSidePlanningClient) + extends ScanBuilder with SupportsPushDownFilters { + + // Filters that have been pushed down and will be sent to the server + private var _pushedFilters: Array[Filter] = Array.empty + + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + // Store filters to send to catalog, but return all as residuals. + // Since we don't know what the catalog can handle yet, we conservatively claim we handle + // none. Even if the catalog applies some filters, Spark will redundantly re-apply them. + _pushedFilters = filters + filters // Return all as residuals + } + + override def pushedFilters(): Array[Filter] = _pushedFilters override def build(): Scan = { - new ServerSidePlannedScan(spark, databaseName, tableName, tableSchema, planningClient) + new ServerSidePlannedScan( + spark, databaseName, tableName, tableSchema, planningClient, _pushedFilters) } } @@ -217,16 +234,30 @@ class ServerSidePlannedScan( databaseName: String, tableName: String, tableSchema: StructType, - planningClient: ServerSidePlanningClient) extends Scan with Batch { + planningClient: ServerSidePlanningClient, + pushedFilters: Array[Filter]) extends Scan with Batch { override def readSchema(): StructType = tableSchema override def toBatch: Batch = this - // Call the server-side planning API once and store the result - private val scanPlan = planningClient.planScan(databaseName, tableName) + // Convert pushed filters to a single Spark Filter for the API call. + // If no filters, pass None. If filters exist, combine them into a single filter. + private val combinedFilter: Option[Filter] = { + if (pushedFilters.isEmpty) { + None + } else if (pushedFilters.length == 1) { + Some(pushedFilters.head) + } else { + // Combine multiple filters with And + Some(pushedFilters.reduce((left, right) => And(left, right))) + } + } override def planInputPartitions(): Array[InputPartition] = { + // Call the server-side planning API to get the scan plan + val scanPlan = planningClient.planScan(databaseName, tableName, combinedFilter) + // Convert each file to an InputPartition scanPlan.files.map { file => ServerSidePlannedFileInputPartition(file.filePath, file.fileSizeInBytes, file.fileFormat) diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/serverSidePlanning/ServerSidePlanningClient.scala b/spark/src/main/scala/org/apache/spark/sql/delta/serverSidePlanning/ServerSidePlanningClient.scala index dc3fc1579c7..cf97393c5d6 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/serverSidePlanning/ServerSidePlanningClient.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/serverSidePlanning/ServerSidePlanningClient.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.delta.serverSidePlanning import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.sources.Filter /** * Simple data class representing a file to scan. @@ -46,9 +47,13 @@ trait ServerSidePlanningClient { * * @param databaseName The database or schema name * @param table The table name + * @param filter Optional filter expression to push down to server (Spark Filter format) * @return ScanPlan containing files to read */ - def planScan(databaseName: String, table: String): ScanPlan + def planScan( + databaseName: String, + table: String, + filter: Option[Filter] = None): ScanPlan } /** diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/serverSidePlanning/ServerSidePlannedTableSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/serverSidePlanning/ServerSidePlannedTableSuite.scala index fe8fb428cf7..7a693a237e7 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/serverSidePlanning/ServerSidePlannedTableSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/serverSidePlanning/ServerSidePlannedTableSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.delta.serverSidePlanning import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.delta.sources.DeltaSQLConf import org.apache.spark.sql.delta.test.DeltaSQLCommandTest +import org.apache.spark.sql.sources.{And, EqualTo, Filter, GreaterThan, LessThan} /** * Tests for server-side planning with a mock client. @@ -50,8 +51,30 @@ class ServerSidePlannedTableSuite extends QueryTest with DeltaSQLCommandTest { * This prevents test pollution from leaked configuration. */ private def withServerSidePlanningEnabled(f: => Unit): Unit = { + withServerSidePlanningFactory(new TestServerSidePlanningClientFactory())(f) + } + + /** + * Helper method to run tests with pushdown capturing enabled. + * TestServerSidePlanningClient captures pushdowns (filter, projection) passed to planScan(). + */ + private def withPushdownCapturingEnabled(f: => Unit): Unit = { + withServerSidePlanningFactory(new TestServerSidePlanningClientFactory()) { + try { + f + } finally { + TestServerSidePlanningClient.clearCaptured() + } + } + } + + /** + * Common helper for setting up server-side planning with a specific factory. + */ + private def withServerSidePlanningFactory(factory: ServerSidePlanningClientFactory) + (f: => Unit): Unit = { val originalConfig = spark.conf.getOption(DeltaSQLConf.ENABLE_SERVER_SIDE_PLANNING.key) - ServerSidePlanningClientFactory.setFactory(new TestServerSidePlanningClientFactory()) + ServerSidePlanningClientFactory.setFactory(factory) spark.conf.set(DeltaSQLConf.ENABLE_SERVER_SIDE_PLANNING.key, "true") try { f @@ -66,6 +89,15 @@ class ServerSidePlannedTableSuite extends QueryTest with DeltaSQLCommandTest { } } + /** + * Extract all leaf filters from a filter tree. + * Spark may wrap filters with And and IsNotNull checks, so this flattens the tree. + */ + private def collectLeafFilters(filter: Filter): Seq[Filter] = filter match { + case And(left, right) => collectLeafFilters(left) ++ collectLeafFilters(right) + case other => Seq(other) + } + test("full query through DeltaCatalog with server-side planning") { // This test verifies server-side planning works end-to-end by checking: // (1) DeltaCatalog returns ServerSidePlannedTable (not normal table) @@ -177,7 +209,7 @@ class ServerSidePlannedTableSuite extends QueryTest with DeltaSQLCommandTest { } } - test("fromTable returns metadata with empty defaults for non-UC catalogs") { + test("ServerSidePlanningMetadata.fromTable returns empty defaults for non-UC catalogs") { import org.apache.spark.sql.connector.catalog.Identifier // Create a simple identifier for testing @@ -197,4 +229,59 @@ class ServerSidePlannedTableSuite extends QueryTest with DeltaSQLCommandTest { assert(metadata.authToken.isEmpty) assert(metadata.tableProperties.isEmpty) } + + test("simple EqualTo filter pushed to planning client") { + withPushdownCapturingEnabled { + sql("SELECT id, name, value FROM test_db.shared_test WHERE id = 2").collect() + + val capturedFilter = TestServerSidePlanningClient.getCapturedFilter + assert(capturedFilter.isDefined, "Filter should be pushed down") + + // Extract leaf filters and find the EqualTo filter + val leafFilters = collectLeafFilters(capturedFilter.get) + val eqFilter = leafFilters.collectFirst { + case eq: EqualTo if eq.attribute == "id" => eq + } + assert(eqFilter.isDefined, "Expected EqualTo filter on 'id'") + assert(eqFilter.get.value == 2, s"Expected EqualTo value 2, got ${eqFilter.get.value}") + } + } + + test("compound And filter pushed to planning client") { + withPushdownCapturingEnabled { + sql("SELECT id, name, value FROM test_db.shared_test WHERE id > 1 AND value < 30").collect() + + val capturedFilter = TestServerSidePlanningClient.getCapturedFilter + assert(capturedFilter.isDefined, "Filter should be pushed down") + + val filter = capturedFilter.get + assert(filter.isInstanceOf[And], s"Expected And filter, got ${filter.getClass.getSimpleName}") + + // Extract all leaf filters from the And tree (Spark may add IsNotNull checks) + val leafFilters = collectLeafFilters(filter) + + // Verify GreaterThan(id, 1) is present + val gtFilter = leafFilters.collectFirst { + case gt: GreaterThan if gt.attribute == "id" => gt + } + assert(gtFilter.isDefined, "Expected GreaterThan filter on 'id'") + assert(gtFilter.get.value == 1, s"Expected GreaterThan value 1, got ${gtFilter.get.value}") + + // Verify LessThan(value, 30) is present + val ltFilter = leafFilters.collectFirst { + case lt: LessThan if lt.attribute == "value" => lt + } + assert(ltFilter.isDefined, "Expected LessThan filter on 'value'") + assert(ltFilter.get.value == 30, s"Expected LessThan value 30, got ${ltFilter.get.value}") + } + } + + test("no filter pushed when no WHERE clause") { + withPushdownCapturingEnabled { + sql("SELECT id, name, value FROM test_db.shared_test").collect() + + val capturedFilter = TestServerSidePlanningClient.getCapturedFilter + assert(capturedFilter.isEmpty, "No filter should be pushed when there's no WHERE clause") + } + } } diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/serverSidePlanning/TestServerSidePlanningClient.scala b/spark/src/test/scala/org/apache/spark/sql/delta/serverSidePlanning/TestServerSidePlanningClient.scala index bf861ac7132..d18c4b42a45 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/serverSidePlanning/TestServerSidePlanningClient.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/serverSidePlanning/TestServerSidePlanningClient.scala @@ -20,15 +20,24 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession import org.apache.spark.sql.delta.sources.DeltaSQLConf import org.apache.spark.sql.functions.input_file_name +import org.apache.spark.sql.sources.Filter /** * Implementation of ServerSidePlanningClient that uses Spark SQL with input_file_name() * to discover the list of files in a table. This allows end-to-end testing without * a real server that can do server-side planning. + * + * Also captures filter/projection parameters for test verification via companion object. */ class TestServerSidePlanningClient(spark: SparkSession) extends ServerSidePlanningClient { - override def planScan(databaseName: String, table: String): ScanPlan = { + override def planScan( + databaseName: String, + table: String, + filter: Option[Filter] = None): ScanPlan = { + // Capture filter for test verification + TestServerSidePlanningClient.capturedFilter = filter + val fullTableName = s"$databaseName.$table" // Temporarily disable server-side planning to avoid infinite recursion @@ -79,6 +88,17 @@ class TestServerSidePlanningClient(spark: SparkSession) extends ServerSidePlanni private def getFileFormat(path: Path): String = "parquet" } +/** + * Companion object for TestServerSidePlanningClient. + * Stores captured pushdown parameters (filter, projection) for test verification. + */ +object TestServerSidePlanningClient { + private var capturedFilter: Option[Filter] = None + + def getCapturedFilter: Option[Filter] = capturedFilter + def clearCaptured(): Unit = { capturedFilter = None } +} + /** * Factory for creating TestServerSidePlanningClient instances. */