diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/serverSidePlanning/ServerSidePlannedTable.scala b/spark/src/main/scala/org/apache/spark/sql/delta/serverSidePlanning/ServerSidePlannedTable.scala index ea0dfe2e489..cfa6a3c87e6 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/serverSidePlanning/ServerSidePlannedTable.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/serverSidePlanning/ServerSidePlannedTable.scala @@ -198,6 +198,7 @@ class ServerSidePlannedTable( /** * ScanBuilder that uses ServerSidePlanningClient to plan the scan. * Implements SupportsPushDownFilters to enable WHERE clause pushdown to the server. + * Implements SupportsPushDownRequiredColumns to enable column pruning pushdown to the server. */ class ServerSidePlannedScanBuilder( spark: SparkSession, @@ -205,11 +206,14 @@ class ServerSidePlannedScanBuilder( tableName: String, tableSchema: StructType, planningClient: ServerSidePlanningClient) - extends ScanBuilder with SupportsPushDownFilters { + extends ScanBuilder with SupportsPushDownFilters with SupportsPushDownRequiredColumns { // Filters that have been pushed down and will be sent to the server private var _pushedFilters: Array[Filter] = Array.empty + // Required schema (columns) that have been pushed down + private var _requiredSchema: StructType = tableSchema + override def pushFilters(filters: Array[Filter]): Array[Filter] = { // Store filters to send to catalog, but return all as residuals. // Since we don't know what the catalog can handle yet, we conservatively claim we handle @@ -220,9 +224,13 @@ class ServerSidePlannedScanBuilder( override def pushedFilters(): Array[Filter] = _pushedFilters + override def pruneColumns(requiredSchema: StructType): Unit = { + _requiredSchema = requiredSchema + } + override def build(): Scan = { new ServerSidePlannedScan( - spark, databaseName, tableName, tableSchema, planningClient, _pushedFilters) + spark, databaseName, tableName, tableSchema, planningClient, _pushedFilters, _requiredSchema) } } @@ -235,7 +243,8 @@ class ServerSidePlannedScan( tableName: String, tableSchema: StructType, planningClient: ServerSidePlanningClient, - pushedFilters: Array[Filter]) extends Scan with Batch { + pushedFilters: Array[Filter], + requiredSchema: StructType) extends Scan with Batch { override def readSchema(): StructType = tableSchema @@ -254,9 +263,18 @@ class ServerSidePlannedScan( } } + // Only pass projection if columns are actually pruned (not SELECT *) + private val projection: Option[StructType] = { + if (requiredSchema.fieldNames.toSet == tableSchema.fieldNames.toSet) { + None + } else { + Some(requiredSchema) + } + } + override def planInputPartitions(): Array[InputPartition] = { // Call the server-side planning API to get the scan plan - val scanPlan = planningClient.planScan(databaseName, tableName, combinedFilter) + val scanPlan = planningClient.planScan(databaseName, tableName, combinedFilter, projection) // Convert each file to an InputPartition scanPlan.files.map { file => diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/serverSidePlanning/ServerSidePlanningClient.scala b/spark/src/main/scala/org/apache/spark/sql/delta/serverSidePlanning/ServerSidePlanningClient.scala index cf97393c5d6..69a9caf1e2b 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/serverSidePlanning/ServerSidePlanningClient.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/serverSidePlanning/ServerSidePlanningClient.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.delta.serverSidePlanning import org.apache.spark.sql.SparkSession import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType /** * Simple data class representing a file to scan. @@ -37,9 +38,10 @@ case class ScanPlan( ) /** - * Interface for planning table scans via server-side planning (e.g., Iceberg REST catalog). - * This interface is intentionally simple and has no dependencies - * on Iceberg libraries, allowing it to live in delta-spark module. + * Interface for planning table scans via server-side planning. + * This interface uses Spark's standard `org.apache.spark.sql.sources.Filter` as the universal + * representation for filter pushdown. This keeps the interface catalog-agnostic while allowing + * each server-side planning catalog implementation to convert filters to their own native format. */ trait ServerSidePlanningClient { /** @@ -48,12 +50,14 @@ trait ServerSidePlanningClient { * @param databaseName The database or schema name * @param table The table name * @param filter Optional filter expression to push down to server (Spark Filter format) + * @param projection Optional projection (required columns) to push down to server * @return ScanPlan containing files to read */ def planScan( databaseName: String, table: String, - filter: Option[Filter] = None): ScanPlan + filter: Option[Filter] = None, + projection: Option[StructType] = None): ScanPlan } /** diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/serverSidePlanning/ServerSidePlannedTableSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/serverSidePlanning/ServerSidePlannedTableSuite.scala index 7a693a237e7..02d68cec64e 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/serverSidePlanning/ServerSidePlannedTableSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/serverSidePlanning/ServerSidePlannedTableSuite.scala @@ -284,4 +284,51 @@ class ServerSidePlannedTableSuite extends QueryTest with DeltaSQLCommandTest { assert(capturedFilter.isEmpty, "No filter should be pushed when there's no WHERE clause") } } + + test("projection pushed when selecting specific columns") { + withPushdownCapturingEnabled { + sql("SELECT id, name FROM test_db.shared_test").collect() + + val capturedProjection = TestServerSidePlanningClient.getCapturedProjection + assert(capturedProjection.isDefined, "Projection should be pushed down") + assert(capturedProjection.get.fieldNames.toSet == Set("id", "name"), + s"Expected {id, name}, got {${capturedProjection.get.fieldNames.mkString(", ")}}") + } + } + + test("no projection pushed when selecting all columns") { + withPushdownCapturingEnabled { + sql("SELECT * FROM test_db.shared_test").collect() + + val capturedProjection = TestServerSidePlanningClient.getCapturedProjection + assert(capturedProjection.isEmpty, + "No projection should be pushed when selecting all columns") + } + } + + test("projection and filter pushed together") { + withPushdownCapturingEnabled { + sql("SELECT id FROM test_db.shared_test WHERE value > 10").collect() + + // Verify projection was pushed with exactly the expected columns + // Spark needs 'id' for SELECT and 'value' for WHERE clause + val capturedProjection = TestServerSidePlanningClient.getCapturedProjection + assert(capturedProjection.isDefined, "Projection should be pushed down") + val projectedFields = capturedProjection.get.fieldNames.toSet + assert(projectedFields == Set("id", "value"), + s"Expected projection with exactly {id, value}, got {${projectedFields.mkString(", ")}}") + + // Verify filter was also pushed + val capturedFilter = TestServerSidePlanningClient.getCapturedFilter + assert(capturedFilter.isDefined, "Filter should be pushed down") + + // Verify GreaterThan(value, 10) is in the filter + val leafFilters = collectLeafFilters(capturedFilter.get) + val gtFilter = leafFilters.collectFirst { + case gt: GreaterThan if gt.attribute == "value" => gt + } + assert(gtFilter.isDefined, "Expected GreaterThan filter on 'value'") + assert(gtFilter.get.value == 10, s"Expected GreaterThan value 10, got ${gtFilter.get.value}") + } + } } diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/serverSidePlanning/TestServerSidePlanningClient.scala b/spark/src/test/scala/org/apache/spark/sql/delta/serverSidePlanning/TestServerSidePlanningClient.scala index d18c4b42a45..fa35c6963c9 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/serverSidePlanning/TestServerSidePlanningClient.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/serverSidePlanning/TestServerSidePlanningClient.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.delta.sources.DeltaSQLConf import org.apache.spark.sql.functions.input_file_name import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.types.StructType /** * Implementation of ServerSidePlanningClient that uses Spark SQL with input_file_name() @@ -34,9 +35,11 @@ class TestServerSidePlanningClient(spark: SparkSession) extends ServerSidePlanni override def planScan( databaseName: String, table: String, - filter: Option[Filter] = None): ScanPlan = { - // Capture filter for test verification + filter: Option[Filter] = None, + projection: Option[StructType] = None): ScanPlan = { + // Capture filter and projection for test verification TestServerSidePlanningClient.capturedFilter = filter + TestServerSidePlanningClient.capturedProjection = projection val fullTableName = s"$databaseName.$table" @@ -94,9 +97,14 @@ class TestServerSidePlanningClient(spark: SparkSession) extends ServerSidePlanni */ object TestServerSidePlanningClient { private var capturedFilter: Option[Filter] = None + private var capturedProjection: Option[StructType] = None def getCapturedFilter: Option[Filter] = capturedFilter - def clearCaptured(): Unit = { capturedFilter = None } + def getCapturedProjection: Option[StructType] = capturedProjection + def clearCaptured(): Unit = { + capturedFilter = None + capturedProjection = None + } } /**