Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapabil
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.execution.datasources.{FileFormat, PartitionedFile}
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.sources.{And, Filter}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

Expand Down Expand Up @@ -196,16 +197,32 @@ class ServerSidePlannedTable(

/**
* ScanBuilder that uses ServerSidePlanningClient to plan the scan.
* Implements SupportsPushDownFilters to enable WHERE clause pushdown to the server.
*/
class ServerSidePlannedScanBuilder(
spark: SparkSession,
databaseName: String,
tableName: String,
tableSchema: StructType,
planningClient: ServerSidePlanningClient) extends ScanBuilder {
planningClient: ServerSidePlanningClient)
extends ScanBuilder with SupportsPushDownFilters {

// Filters that have been pushed down and will be sent to the server
private var _pushedFilters: Array[Filter] = Array.empty

override def pushFilters(filters: Array[Filter]): Array[Filter] = {
// Store filters to send to catalog, but return all as residuals.
// Since we don't know what the catalog can handle yet, we conservatively claim we handle
// none. Even if the catalog applies some filters, Spark will redundantly re-apply them.
_pushedFilters = filters
filters // Return all as residuals
}

override def pushedFilters(): Array[Filter] = _pushedFilters

override def build(): Scan = {
new ServerSidePlannedScan(spark, databaseName, tableName, tableSchema, planningClient)
new ServerSidePlannedScan(
spark, databaseName, tableName, tableSchema, planningClient, _pushedFilters)
}
}

Expand All @@ -217,16 +234,30 @@ class ServerSidePlannedScan(
databaseName: String,
tableName: String,
tableSchema: StructType,
planningClient: ServerSidePlanningClient) extends Scan with Batch {
planningClient: ServerSidePlanningClient,
pushedFilters: Array[Filter]) extends Scan with Batch {

override def readSchema(): StructType = tableSchema

override def toBatch: Batch = this

// Call the server-side planning API once and store the result
private val scanPlan = planningClient.planScan(databaseName, tableName)
// Convert pushed filters to a single Spark Filter for the API call.
// If no filters, pass None. If filters exist, combine them into a single filter.
private val combinedFilter: Option[Filter] = {
if (pushedFilters.isEmpty) {
None
} else if (pushedFilters.length == 1) {
Some(pushedFilters.head)
} else {
// Combine multiple filters with And
Some(pushedFilters.reduce((left, right) => And(left, right)))
}
}

override def planInputPartitions(): Array[InputPartition] = {
// Call the server-side planning API to get the scan plan
val scanPlan = planningClient.planScan(databaseName, tableName, combinedFilter)

// Convert each file to an InputPartition
scanPlan.files.map { file =>
ServerSidePlannedFileInputPartition(file.filePath, file.fileSizeInBytes, file.fileFormat)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.spark.sql.delta.serverSidePlanning

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.sources.Filter

/**
* Simple data class representing a file to scan.
Expand Down Expand Up @@ -46,9 +47,13 @@ trait ServerSidePlanningClient {
*
* @param databaseName The database or schema name
* @param table The table name
* @param filter Optional filter expression to push down to server (Spark Filter format)
* @return ScanPlan containing files to read
*/
def planScan(databaseName: String, table: String): ScanPlan
def planScan(
databaseName: String,
table: String,
filter: Option[Filter] = None): ScanPlan
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.delta.serverSidePlanning
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.delta.test.DeltaSQLCommandTest
import org.apache.spark.sql.sources.{And, EqualTo, Filter, GreaterThan, LessThan}

/**
* Tests for server-side planning with a mock client.
Expand Down Expand Up @@ -50,8 +51,30 @@ class ServerSidePlannedTableSuite extends QueryTest with DeltaSQLCommandTest {
* This prevents test pollution from leaked configuration.
*/
private def withServerSidePlanningEnabled(f: => Unit): Unit = {
withServerSidePlanningFactory(new TestServerSidePlanningClientFactory())(f)
}

/**
* Helper method to run tests with pushdown capturing enabled.
* TestServerSidePlanningClient captures pushdowns (filter, projection) passed to planScan().
*/
private def withPushdownCapturingEnabled(f: => Unit): Unit = {
withServerSidePlanningFactory(new TestServerSidePlanningClientFactory()) {
try {
f
} finally {
TestServerSidePlanningClient.clearCaptured()
}
}
}

/**
* Common helper for setting up server-side planning with a specific factory.
*/
private def withServerSidePlanningFactory(factory: ServerSidePlanningClientFactory)
(f: => Unit): Unit = {
val originalConfig = spark.conf.getOption(DeltaSQLConf.ENABLE_SERVER_SIDE_PLANNING.key)
ServerSidePlanningClientFactory.setFactory(new TestServerSidePlanningClientFactory())
ServerSidePlanningClientFactory.setFactory(factory)
spark.conf.set(DeltaSQLConf.ENABLE_SERVER_SIDE_PLANNING.key, "true")
try {
f
Expand All @@ -66,6 +89,15 @@ class ServerSidePlannedTableSuite extends QueryTest with DeltaSQLCommandTest {
}
}

/**
* Extract all leaf filters from a filter tree.
* Spark may wrap filters with And and IsNotNull checks, so this flattens the tree.
*/
private def collectLeafFilters(filter: Filter): Seq[Filter] = filter match {
case And(left, right) => collectLeafFilters(left) ++ collectLeafFilters(right)
case other => Seq(other)
}

test("full query through DeltaCatalog with server-side planning") {
// This test verifies server-side planning works end-to-end by checking:
// (1) DeltaCatalog returns ServerSidePlannedTable (not normal table)
Expand Down Expand Up @@ -177,7 +209,7 @@ class ServerSidePlannedTableSuite extends QueryTest with DeltaSQLCommandTest {
}
}

test("fromTable returns metadata with empty defaults for non-UC catalogs") {
test("ServerSidePlanningMetadata.fromTable returns empty defaults for non-UC catalogs") {
import org.apache.spark.sql.connector.catalog.Identifier

// Create a simple identifier for testing
Expand All @@ -197,4 +229,59 @@ class ServerSidePlannedTableSuite extends QueryTest with DeltaSQLCommandTest {
assert(metadata.authToken.isEmpty)
assert(metadata.tableProperties.isEmpty)
}

test("simple EqualTo filter pushed to planning client") {
withPushdownCapturingEnabled {
sql("SELECT id, name, value FROM test_db.shared_test WHERE id = 2").collect()

val capturedFilter = TestServerSidePlanningClient.getCapturedFilter
assert(capturedFilter.isDefined, "Filter should be pushed down")

// Extract leaf filters and find the EqualTo filter
val leafFilters = collectLeafFilters(capturedFilter.get)
val eqFilter = leafFilters.collectFirst {
case eq: EqualTo if eq.attribute == "id" => eq
}
assert(eqFilter.isDefined, "Expected EqualTo filter on 'id'")
assert(eqFilter.get.value == 2, s"Expected EqualTo value 2, got ${eqFilter.get.value}")
}
}

test("compound And filter pushed to planning client") {
withPushdownCapturingEnabled {
sql("SELECT id, name, value FROM test_db.shared_test WHERE id > 1 AND value < 30").collect()

val capturedFilter = TestServerSidePlanningClient.getCapturedFilter
assert(capturedFilter.isDefined, "Filter should be pushed down")

val filter = capturedFilter.get
assert(filter.isInstanceOf[And], s"Expected And filter, got ${filter.getClass.getSimpleName}")

// Extract all leaf filters from the And tree (Spark may add IsNotNull checks)
val leafFilters = collectLeafFilters(filter)

// Verify GreaterThan(id, 1) is present
val gtFilter = leafFilters.collectFirst {
case gt: GreaterThan if gt.attribute == "id" => gt
}
assert(gtFilter.isDefined, "Expected GreaterThan filter on 'id'")
assert(gtFilter.get.value == 1, s"Expected GreaterThan value 1, got ${gtFilter.get.value}")

// Verify LessThan(value, 30) is present
val ltFilter = leafFilters.collectFirst {
case lt: LessThan if lt.attribute == "value" => lt
}
assert(ltFilter.isDefined, "Expected LessThan filter on 'value'")
assert(ltFilter.get.value == 30, s"Expected LessThan value 30, got ${ltFilter.get.value}")
}
}

test("no filter pushed when no WHERE clause") {
withPushdownCapturingEnabled {
sql("SELECT id, name, value FROM test_db.shared_test").collect()

val capturedFilter = TestServerSidePlanningClient.getCapturedFilter
assert(capturedFilter.isEmpty, "No filter should be pushed when there's no WHERE clause")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,24 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.delta.sources.DeltaSQLConf
import org.apache.spark.sql.functions.input_file_name
import org.apache.spark.sql.sources.Filter

/**
* Implementation of ServerSidePlanningClient that uses Spark SQL with input_file_name()
* to discover the list of files in a table. This allows end-to-end testing without
* a real server that can do server-side planning.
*
* Also captures filter/projection parameters for test verification via companion object.
*/
class TestServerSidePlanningClient(spark: SparkSession) extends ServerSidePlanningClient {

override def planScan(databaseName: String, table: String): ScanPlan = {
override def planScan(
databaseName: String,
table: String,
filter: Option[Filter] = None): ScanPlan = {
// Capture filter for test verification
TestServerSidePlanningClient.capturedFilter = filter

val fullTableName = s"$databaseName.$table"

// Temporarily disable server-side planning to avoid infinite recursion
Expand Down Expand Up @@ -79,6 +88,17 @@ class TestServerSidePlanningClient(spark: SparkSession) extends ServerSidePlanni
private def getFileFormat(path: Path): String = "parquet"
}

/**
* Companion object for TestServerSidePlanningClient.
* Stores captured pushdown parameters (filter, projection) for test verification.
*/
object TestServerSidePlanningClient {
private var capturedFilter: Option[Filter] = None

def getCapturedFilter: Option[Filter] = capturedFilter
def clearCaptured(): Unit = { capturedFilter = None }
}

/**
* Factory for creating TestServerSidePlanningClient instances.
*/
Expand Down
Loading