diff --git a/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala b/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala index ecc0823d60..b41c1dc386 100644 --- a/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala +++ b/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.comet.{CometCollectLimitExec, CometColumnarToRowExec import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{ColumnarToRowExec, RowToColumnarExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.QueryStageExec +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.comet.CometConf @@ -112,6 +113,7 @@ case class EliminateRedundantTransitions(session: SparkSession) extends Rule[Spa private def hasCometNativeChild(op: SparkPlan): Boolean = { op match { case c: QueryStageExec => hasCometNativeChild(c.plan) + case c: ReusedExchangeExec => hasCometNativeChild(c.child) case _ => op.exists(_.isInstanceOf[CometPlan]) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala index 0391a1c3b3..95e03ca69c 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometColumnarToRowExec.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.comet.util.{Utils => CometUtils} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{CodegenSupport, ColumnarToRowTransition, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive.BroadcastQueryStageExec +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.vectorized.{ConstantColumnVector, WritableColumnVector} import org.apache.spark.sql.types._ @@ -172,6 +173,7 @@ case class CometColumnarToRowExec(child: SparkPlan) op match { case b: CometBroadcastExchangeExec => Some(b) case b: BroadcastQueryStageExec => findCometBroadcastExchange(b.plan) + case b: ReusedExchangeExec => findCometBroadcastExchange(b.child) case _ => op.children.collectFirst(Function.unlift(findCometBroadcastExchange)) } }