Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 63 additions & 54 deletions spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -411,61 +411,70 @@ case class CometExecRule(session: SparkSession)
* BroadcastQueryStageExec.
*/
private def convertSubqueryBroadcasts(plan: SparkPlan): SparkPlan = {
// CometIcebergNativeScanExec.runtimeFilters is a top-level constructor field visible to
// productIterator, so transformExpressionsUp rewrites it directly. The wrapped @transient
// originalPlan still holds the pre-rewrite runtimeFilters; we don't sync it here because
// CometIcebergNativeScanExec.serializedPartitionData rebuilds originalPlan from the
// top-level runtimeFilters at serialization time (single source of truth).
plan.transformExpressionsUp { case inSub: InSubqueryExec =>
inSub.plan match {
case sub: SubqueryBroadcastExec =>
sub.child match {
case b: BroadcastExchangeExec =>
// The BroadcastExchangeExec child is CometNativeColumnarToRowExec wrapping
// a Comet plan. Strip the row transition to get the columnar Comet plan.
val cometChild = b.child match {
case c2r: CometNativeColumnarToRowExec => c2r.child
case other => other
}
if (cometChild.isInstanceOf[CometNativeExec]) {
logInfo(
"Converting SubqueryBroadcastExec to " +
"CometSubqueryBroadcastExec for DPP exchange reuse")
val cometBroadcast = CometBroadcastExchangeExec(b, b.output, b.mode, cometChild)
val cometSub = CometSubqueryBroadcastExec(
sub.name,
getSubqueryBroadcastExecIndices(sub),
sub.buildKeys,
cometBroadcast)
inSub.withNewPlan(cometSub)
} else {
inSub
}
case _ => inSub
}
case sab: SubqueryAdaptiveBroadcastExec if isSpark35Plus =>
// Wrap SABs to prevent Spark's PlanAdaptiveDynamicPruningFilters from
// converting them to Literal.TrueLiteral. Spark's rule pattern-matches for
// BroadcastHashJoinExec, which Comet replaced with CometBroadcastHashJoinExec.
// Without wrapping, DPP is disabled for both Comet native scans and non-Comet
// scans (e.g., V2 BatchScan). CometPlanAdaptiveDynamicPruningFilters
// (queryStageOptimizerRule, 3.5+) unwraps and converts them later.
//
// On Spark 3.4, injectQueryStageOptimizerRule is unavailable. The isSpark35Plus
// guard leaves SABs unwrapped; CometSpark34AqeDppFallbackRule then tags the
// matching BHJ's build broadcast so Spark's rule can match it natively.
assert(
sab.buildKeys.nonEmpty,
s"SubqueryAdaptiveBroadcastExec '${sab.name}' has empty buildKeys")
logInfo(
s"Wrapping SubqueryAdaptiveBroadcastExec '${sab.name}' in " +
"CometSubqueryAdaptiveBroadcastExec to preserve AQE DPP")
val indices = getSubqueryBroadcastIndices(sab)
val wrapped = CometSubqueryAdaptiveBroadcastExec(
sab.name,
indices,
sab.onlyInBroadcast,
sab.buildPlan,
sab.buildKeys,
sab.child)
inSub.withNewPlan(wrapped)
case _ => inSub
}
rewriteInSubqueryPlan(inSub)
}
}

private def rewriteInSubqueryPlan(inSub: InSubqueryExec): Expression = {
inSub.plan match {
case sub: SubqueryBroadcastExec =>
sub.child match {
case b: BroadcastExchangeExec =>
// The BroadcastExchangeExec child is CometNativeColumnarToRowExec wrapping
// a Comet plan. Strip the row transition to get the columnar Comet plan.
val cometChild = b.child match {
case c2r: CometNativeColumnarToRowExec => c2r.child
case other => other
}
if (cometChild.isInstanceOf[CometNativeExec]) {
logInfo(
"Converting SubqueryBroadcastExec to " +
"CometSubqueryBroadcastExec for DPP exchange reuse")
val cometBroadcast = CometBroadcastExchangeExec(b, b.output, b.mode, cometChild)
val cometSub = CometSubqueryBroadcastExec(
sub.name,
getSubqueryBroadcastExecIndices(sub),
sub.buildKeys,
cometBroadcast)
inSub.withNewPlan(cometSub)
} else {
inSub
}
case _ => inSub
}
case sab: SubqueryAdaptiveBroadcastExec if isSpark35Plus =>
// Wrap SABs to prevent Spark's PlanAdaptiveDynamicPruningFilters from
// converting them to Literal.TrueLiteral. Spark's rule pattern-matches for
// BroadcastHashJoinExec, which Comet replaced with CometBroadcastHashJoinExec.
// Without wrapping, DPP is disabled for both Comet native scans and non-Comet
// scans (e.g., V2 BatchScan). CometPlanAdaptiveDynamicPruningFilters
// (queryStageOptimizerRule, 3.5+) unwraps and converts them later.
//
// On Spark 3.4, injectQueryStageOptimizerRule is unavailable. The isSpark35Plus
// guard leaves SABs unwrapped; CometSpark34AqeDppFallbackRule then tags the
// matching BHJ's build broadcast so Spark's rule can match it natively.
assert(
sab.buildKeys.nonEmpty,
s"SubqueryAdaptiveBroadcastExec '${sab.name}' has empty buildKeys")
logInfo(
s"Wrapping SubqueryAdaptiveBroadcastExec '${sab.name}' in " +
"CometSubqueryAdaptiveBroadcastExec to preserve AQE DPP")
val indices = getSubqueryBroadcastIndices(sab)
val wrapped = CometSubqueryAdaptiveBroadcastExec(
sab.name,
indices,
sab.onlyInBroadcast,
sab.buildPlan,
sab.buildKeys,
sab.child)
inSub.withNewPlan(wrapped)
case _ => inSub
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, BindReferences, Dynamic
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec, CometNativeScanExec, CometSubqueryAdaptiveBroadcastExec, CometSubqueryBroadcastExec}
import org.apache.spark.sql.comet._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper}
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
Expand All @@ -34,7 +34,7 @@ import org.apache.comet.shims.{ShimPrepareExecutedPlan, ShimSubqueryBroadcast}

/**
* Converts CometSubqueryAdaptiveBroadcastExec (wrapped AQE DPP) to CometSubqueryBroadcastExec
* inside CometNativeScanExec's partitionFilters.
* inside CometNativeScanExec's partitionFilters and CometIcebergNativeScanExec's runtimeFilters.
*
* CometExecRule wraps SubqueryAdaptiveBroadcastExec in CometSubqueryAdaptiveBroadcastExec during
* queryStagePreparationRules to prevent Spark's PlanAdaptiveDynamicPruningFilters from replacing
Expand All @@ -49,6 +49,11 @@ import org.apache.comet.shims.{ShimPrepareExecutedPlan, ShimSubqueryBroadcast}
* CometScanExec.partitionFilters are separate InSubqueryExec instances. Both must be converted
* because CometScanExec.dynamicallySelectedPartitions evaluates its own partitionFilters.
*
* For CometIcebergNativeScanExec, runtimeFilters is a top-level constructor field and
* originalPlan.runtimeFilters mirrors it (sharing the same InSubqueryExec instances). The Iceberg
* case rewrites both in lockstep so the wrapper's expressions tree and the inner BatchScanExec's
* runtime filters stay aligned.
*
* @see
* PlanAdaptiveDynamicPruningFilters (Spark's equivalent for BroadcastHashJoinExec)
* @see
Expand All @@ -74,12 +79,32 @@ case object CometPlanAdaptiveDynamicPruningFilters
case nativeScan: CometNativeScanExec if nativeScan.partitionFilters.exists(hasCometSAB) =>
logDebug("Converting AQE DPP for CometNativeScanExec")
convertNativeScanDPP(nativeScan, plan)
case p: SparkPlan if !p.isInstanceOf[CometNativeScanExec] && hasWrappedSAB(p) =>
case icebergScan: CometIcebergNativeScanExec
if icebergScan.runtimeFilters.exists(hasCometSAB) =>
logDebug("Converting AQE DPP for CometIcebergNativeScanExec")
convertIcebergScanDPP(icebergScan, plan)
case p: SparkPlan
if !p.isInstanceOf[CometNativeScanExec]
&& !p.isInstanceOf[CometIcebergNativeScanExec]
&& hasWrappedSAB(p) =>
logDebug(s"Converting AQE DPP for non-Comet node: ${p.nodeName}")
convertNonCometNodeDPP(p, plan)
}
}

private def convertIcebergScanDPP(
icebergScan: CometIcebergNativeScanExec,
stagePlan: SparkPlan): CometIcebergNativeScanExec = {
val newFilters = icebergScan.runtimeFilters.map(f => convertFilter(f, stagePlan))
if (newFilters == icebergScan.runtimeFilters) return icebergScan
// Top-level runtimeFilters is the single source of truth.
// CometIcebergNativeScanExec.serializedPartitionData rebuilds originalPlan from the top-level
// field at serialization time, so we don't need to sync originalPlan.runtimeFilters here.
val newScan = icebergScan.copy(runtimeFilters = newFilters)
icebergScan.logicalLink.foreach(newScan.setLogicalLink)
newScan
}

private def convertNativeScanDPP(
nativeScan: CometNativeScanExec,
stagePlan: SparkPlan): CometNativeScanExec = {
Expand Down Expand Up @@ -156,6 +181,7 @@ case object CometPlanAdaptiveDynamicPruningFilters
case _ => None
}
}

inSub.plan match {
// ReusedSubqueryExec extends BaseSubqueryExec, so unwrap it before dispatching
// to `BaseSubqueryExec`. The order is load-bearing: if the general case runs
Expand All @@ -179,7 +205,7 @@ case object CometPlanAdaptiveDynamicPruningFilters
* (correct results, scans all partitions).
*
* 3. No reusable broadcast + onlyInBroadcast=false: Aggregate SubqueryExec on the build side
* (DPP via separate execution, matching Spark's PlanAdaptiveDynamicPruningFilters lines 68-79).
* (DPP via separate execution, matching Spark's PlanAdaptiveDynamicPruningFilters case 3).
*/
private def convertSAB(
inSub: InSubqueryExec,
Expand Down Expand Up @@ -215,10 +241,10 @@ case object CometPlanAdaptiveDynamicPruningFilters
val canReuse = conf.exchangeReuseEnabled && matchingJoin.isDefined

if (canReuse) {
// Case 1: broadcast reuse. Matches Spark's PlanAdaptiveDynamicPruningFilters
// lines 44-64: construct a NEW exchange wrapping adaptivePlan.executedPlan,
// then wrap in a new ASPE. AQE's stageCache ensures the broadcast runs once
// via ReusedExchangeExec (same canonical form as the join's exchange).
// Case 1: broadcast reuse. Mirrors Spark's PlanAdaptiveDynamicPruningFilters case 1:
// construct a fresh exchange wrapping the build subtree, then wrap in a new ASPE.
// AQE's stageCache ensures the broadcast runs once via ReusedExchangeExec (same
// canonical form as the join's exchange).
val (broadcastChild, isComet) = matchingJoin.get
val buildSidePlan = adaptivePlan.executedPlan
logDebug(
Expand All @@ -227,7 +253,7 @@ case object CometPlanAdaptiveDynamicPruningFilters
s"${broadcastChild.getClass.getSimpleName}")

// Construct the exchange from buildSidePlan (not from the existing exchange),
// matching Spark's PlanAdaptiveDynamicPruningFilters lines 44-48. The existing
// mirroring Spark's PlanAdaptiveDynamicPruningFilters case 1. The existing
// exchange may belong to a different plan context (e.g., the main query) with
// different attribute IDs than the current SAB's build side (e.g., a scalar
// subquery). Using the existing exchange's output/mode would cause schema
Expand Down Expand Up @@ -274,7 +300,7 @@ case object CometPlanAdaptiveDynamicPruningFilters
// Case 3: no reusable broadcast, but the optimizer says DPP is worthwhile
// even without broadcast reuse. Create an aggregate SubqueryExec on the build
// side to get distinct partition key values for pruning.
// Matches Spark's PlanAdaptiveDynamicPruningFilters lines 68-79.
// Matches Spark's PlanAdaptiveDynamicPruningFilters case 3.
val aliases =
sab.indices.map(idx => Alias(sab.buildKeys(idx), sab.buildKeys(idx).toString)())
val aggregate = Aggregate(aliases, aliases, sab.buildPlan)
Expand Down Expand Up @@ -425,6 +451,7 @@ case object CometPlanAdaptiveDynamicPruningFilters
* CometNativeScanExec.partitionFilters has CometSubqueryAdaptiveBroadcastExec (wrapped by
* CometExecRule). The inner CometScanExec.partitionFilters may have the original
* SubqueryAdaptiveBroadcastExec (unwrapped, because CometScanExec is
*
* @transient).
*/
private def hasCometSAB(e: Expression): Boolean =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -985,8 +985,22 @@ object CometIcebergNativeScan extends CometOperatorSerde[CometBatchScanExec] wit

perPartitionBuilders += partitionBuilder.build()
}
case _ =>
throw new IllegalStateException("Expected DataSourceRDD from BatchScanExec")
case other if other.getClass.getName == "org.apache.spark.rdd.ParallelCollectionRDD" =>
// Spark's BatchScanExec.inputRDD returns sparkContext.parallelize(empty, 1) when
// DPP filtering removes all input partitions. That ParallelCollectionRDD is the only
// non-DataSourceRDD shape its inputRDD produces, so reaching this branch means "DPP
// pruned everything"; emit no per-partition data and let native execution return empty.
// Re-querying scan.toBatch.planInputPartitions() to verify is unreliable because
// Iceberg's Scan state after filter() doesn't always reflect post-DPP partitions on
// a re-call (V2 scan state is one-shot for the materialized inputRDD). Matched by class
// name because ParallelCollectionRDD is private[spark].
logDebug(
"BatchScanExec.inputRDD is ParallelCollectionRDD (DPP pruned all partitions); " +
"skipping per-partition serialization")
case other =>
throw new IllegalStateException(
"Expected DataSourceRDD or ParallelCollectionRDD from BatchScanExec, " +
s"got ${other.getClass.getName}")
}

// Log deduplication summary
Expand Down
Loading
Loading