Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -198,18 +198,22 @@ class ServerSidePlannedTable(
/**
* ScanBuilder that uses ServerSidePlanningClient to plan the scan.
* Implements SupportsPushDownFilters to enable WHERE clause pushdown to the server.
* Implements SupportsPushDownRequiredColumns to enable column pruning pushdown to the server.
*/
class ServerSidePlannedScanBuilder(
spark: SparkSession,
databaseName: String,
tableName: String,
tableSchema: StructType,
planningClient: ServerSidePlanningClient)
extends ScanBuilder with SupportsPushDownFilters {
extends ScanBuilder with SupportsPushDownFilters with SupportsPushDownRequiredColumns {

// Filters that have been pushed down and will be sent to the server
private var _pushedFilters: Array[Filter] = Array.empty

// Required schema (columns) that have been pushed down
private var _requiredSchema: StructType = tableSchema

override def pushFilters(filters: Array[Filter]): Array[Filter] = {
// Store filters to send to catalog, but return all as residuals.
// Since we don't know what the catalog can handle yet, we conservatively claim we handle
Expand All @@ -220,9 +224,13 @@ class ServerSidePlannedScanBuilder(

override def pushedFilters(): Array[Filter] = _pushedFilters

override def pruneColumns(requiredSchema: StructType): Unit = {
_requiredSchema = requiredSchema
}

override def build(): Scan = {
new ServerSidePlannedScan(
spark, databaseName, tableName, tableSchema, planningClient, _pushedFilters)
spark, databaseName, tableName, tableSchema, planningClient, _pushedFilters, _requiredSchema)
}
}

Expand All @@ -235,7 +243,8 @@ class ServerSidePlannedScan(
tableName: String,
tableSchema: StructType,
planningClient: ServerSidePlanningClient,
pushedFilters: Array[Filter]) extends Scan with Batch {
pushedFilters: Array[Filter],
requiredSchema: StructType) extends Scan with Batch {

override def readSchema(): StructType = tableSchema

Expand All @@ -254,9 +263,18 @@ class ServerSidePlannedScan(
}
}

// Only pass projection if columns are actually pruned (not SELECT *)
private val projection: Option[StructType] = {
if (requiredSchema.fieldNames.toSet == tableSchema.fieldNames.toSet) {
None
} else {
Some(requiredSchema)
}
}

override def planInputPartitions(): Array[InputPartition] = {
// Call the server-side planning API to get the scan plan
val scanPlan = planningClient.planScan(databaseName, tableName, combinedFilter)
val scanPlan = planningClient.planScan(databaseName, tableName, combinedFilter, projection)

// Convert each file to an InputPartition
scanPlan.files.map { file =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.delta.serverSidePlanning

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType

/**
* Simple data class representing a file to scan.
Expand All @@ -37,9 +38,10 @@ case class ScanPlan(
)

/**
* Interface for planning table scans via server-side planning (e.g., Iceberg REST catalog).
* This interface is intentionally simple and has no dependencies
* on Iceberg libraries, allowing it to live in delta-spark module.
* Interface for planning table scans via server-side planning.
* This interface uses Spark's standard `org.apache.spark.sql.sources.Filter` as the universal
* representation for filter pushdown. This keeps the interface catalog-agnostic while allowing
* each server-side planning catalog implementation to convert filters to their own native format.
*/
trait ServerSidePlanningClient {
/**
Expand All @@ -48,12 +50,14 @@ trait ServerSidePlanningClient {
* @param databaseName The database or schema name
* @param table The table name
* @param filter Optional filter expression to push down to server (Spark Filter format)
* @param projection Optional projection (required columns) to push down to server
* @return ScanPlan containing files to read
*/
def planScan(
databaseName: String,
table: String,
filter: Option[Filter] = None): ScanPlan
filter: Option[Filter] = None,
projection: Option[StructType] = None): ScanPlan
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,4 +284,51 @@ class ServerSidePlannedTableSuite extends QueryTest with DeltaSQLCommandTest {
assert(capturedFilter.isEmpty, "No filter should be pushed when there's no WHERE clause")
}
}

test("projection pushed when selecting specific columns") {
withPushdownCapturingEnabled {
sql("SELECT id, name FROM test_db.shared_test").collect()

val capturedProjection = TestServerSidePlanningClient.getCapturedProjection
assert(capturedProjection.isDefined, "Projection should be pushed down")
assert(capturedProjection.get.fieldNames.toSet == Set("id", "name"),
s"Expected {id, name}, got {${capturedProjection.get.fieldNames.mkString(", ")}}")
}
}

test("no projection pushed when selecting all columns") {
withPushdownCapturingEnabled {
sql("SELECT * FROM test_db.shared_test").collect()

val capturedProjection = TestServerSidePlanningClient.getCapturedProjection
assert(capturedProjection.isEmpty,
"No projection should be pushed when selecting all columns")
}
}

test("projection and filter pushed together") {
withPushdownCapturingEnabled {
sql("SELECT id FROM test_db.shared_test WHERE value > 10").collect()

// Verify projection was pushed with exactly the expected columns
// Spark needs 'id' for SELECT and 'value' for WHERE clause
val capturedProjection = TestServerSidePlanningClient.getCapturedProjection
assert(capturedProjection.isDefined, "Projection should be pushed down")
val projectedFields = capturedProjection.get.fieldNames.toSet
assert(projectedFields == Set("id", "value"),
s"Expected projection with exactly {id, value}, got {${projectedFields.mkString(", ")}}")

// Verify filter was also pushed
val capturedFilter = TestServerSidePlanningClient.getCapturedFilter
assert(capturedFilter.isDefined, "Filter should be pushed down")

// Verify GreaterThan(value, 10) is in the filter
val leafFilters = collectLeafFilters(capturedFilter.get)
val gtFilter = leafFilters.collectFirst {
case gt: GreaterThan if gt.attribute == "value" => gt
}
assert(gtFilter.isDefined, "Expected GreaterThan filter on 'value'")
assert(gtFilter.get.value == 10, s"Expected GreaterThan value 10, got ${gtFilter.get.value}")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.functions.input_file_name
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType

/**
* Implementation of ServerSidePlanningClient that uses Spark SQL with input_file_name()
Expand All @@ -34,9 +35,11 @@ class TestServerSidePlanningClient(spark: SparkSession) extends ServerSidePlanni
override def planScan(
databaseName: String,
table: String,
filter: Option[Filter] = None): ScanPlan = {
// Capture filter for test verification
filter: Option[Filter] = None,
projection: Option[StructType] = None): ScanPlan = {
// Capture filter and projection for test verification
TestServerSidePlanningClient.capturedFilter = filter
TestServerSidePlanningClient.capturedProjection = projection

val fullTableName = s"$databaseName.$table"

Expand Down Expand Up @@ -94,9 +97,14 @@ class TestServerSidePlanningClient(spark: SparkSession) extends ServerSidePlanni
*/
object TestServerSidePlanningClient {
private var capturedFilter: Option[Filter] = None
private var capturedProjection: Option[StructType] = None

def getCapturedFilter: Option[Filter] = capturedFilter
def clearCaptured(): Unit = { capturedFilter = None }
def getCapturedProjection: Option[StructType] = capturedProjection
def clearCaptured(): Unit = {
capturedFilter = None
capturedProjection = None
}
}

/**
Expand Down
Loading