From 6180d977ed4d53f92e3a02b504018246e33c89ae Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 19 Jan 2026 12:04:33 -0700 Subject: [PATCH 1/2] feat: add experimental cost-based optimizer (CBO) for Comet vs Spark execution This PR introduces an experimental lightweight cost-based optimizer that estimates whether a Comet query plan will be faster than a Spark plan, falling back to Spark when Comet execution is estimated to be slower. **Key Features:** - Heuristic-based cost model with configurable weights for different operator types (scan, filter, project, aggregate, join, sort) - Configurable speedup factors for each Comet operator type - Transition penalty for columnar<->row conversions - Cardinality estimation using Spark's logical plan statistics - CBO analysis included in EXPLAIN output when enabled **Configuration:** - `spark.comet.cbo.enabled` (default: false) - Enable/disable CBO - `spark.comet.cbo.speedupThreshold` (default: 1.0) - Minimum speedup required - `spark.comet.cbo.explain.enabled` (default: false) - Log CBO decisions **Important Notes:** - This is an EXPERIMENTAL feature, disabled by default - CBO only affects operator conversion (filter, project, aggregate, etc.), not scan conversion which is handled by CometScanRule - Default parameters are initial estimates and should be tuned with benchmarks Co-Authored-By: Claude Opus 4.5 --- .../scala/org/apache/comet/CometConf.scala | 118 +++++++ .../apache/comet/ExtendedExplainInfo.scala | 11 +- .../comet/rules/CometCostEstimator.scala | 289 ++++++++++++++++++ .../apache/comet/rules/CometExecRule.scala | 31 +- .../apache/comet/rules/CometCBOSuite.scala | 251 +++++++++++++++ 5 files changed, 696 insertions(+), 4 deletions(-) create mode 100644 spark/src/main/scala/org/apache/comet/rules/CometCostEstimator.scala create mode 100644 spark/src/test/scala/org/apache/comet/rules/CometCBOSuite.scala diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 89dbb6468d..cccebef5f3 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -65,6 +65,7 @@ object CometConf extends ShimCometConf { private val CATEGORY_SHUFFLE = "shuffle" private val CATEGORY_TUNING = "tuning" private val CATEGORY_TESTING = "testing" + private val CATEGORY_CBO = "cbo" def register(conf: ConfigEntry[_]): Unit = { assert(conf.category.nonEmpty, s"${conf.key} does not have a category defined") @@ -772,6 +773,123 @@ object CometConf extends ShimCometConf { .booleanConf .createWithEnvVarOrDefault("ENABLE_COMET_STRICT_TESTING", false) + // CBO Configuration Options + + val COMET_CBO_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.cbo.enabled") + .category(CATEGORY_CBO) + .doc( + "Enable cost-based optimizer to decide Comet vs Spark execution. " + + "When enabled, Comet estimates whether native execution will be faster " + + "and falls back to Spark if not. Note: This only affects operator conversion " + + "(filter, project, aggregate, etc.), not scan conversion which is handled separately.") + .booleanConf + .createWithDefault(false) + + val COMET_CBO_EXPLAIN_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.cbo.explain.enabled") + .category(CATEGORY_CBO) + .doc("Log CBO decision details for debugging.") + .booleanConf + .createWithDefault(false) + + val COMET_CBO_SPEEDUP_THRESHOLD: ConfigEntry[Double] = conf("spark.comet.cbo.speedupThreshold") + .category(CATEGORY_CBO) + .doc("Minimum estimated speedup ratio required to use Comet. " + + "Values less than 1.0 allow Comet even when estimated slightly slower.") + .doubleConf + .checkValue(_ > 0, "Threshold must be positive") + .createWithDefault(1.0) + + val COMET_CBO_DEFAULT_ROW_COUNT: ConfigEntry[Long] = conf("spark.comet.cbo.defaultRowCount") + .category(CATEGORY_CBO) + .internal() + .doc("Default row count estimate when statistics unavailable.") + .longConf + .createWithDefault(1000000L) + + val COMET_CBO_TRANSITION_COST: ConfigEntry[Double] = conf("spark.comet.cbo.cost.transition") + .category(CATEGORY_CBO) + .internal() + .doc("Cost penalty per row for columnar<->row transitions.") + .doubleConf + .createWithDefault(0.001) + + val COMET_CBO_SCAN_WEIGHT: ConfigEntry[Double] = conf("spark.comet.cbo.weight.scan") + .category(CATEGORY_CBO) + .internal() + .doc("Weight for scan operators in cost calculation.") + .doubleConf + .createWithDefault(1.0) + + val COMET_CBO_FILTER_WEIGHT: ConfigEntry[Double] = conf("spark.comet.cbo.weight.filter") + .category(CATEGORY_CBO) + .internal() + .doc("Weight for filter operators in cost calculation.") + .doubleConf + .createWithDefault(0.1) + + val COMET_CBO_PROJECT_WEIGHT: ConfigEntry[Double] = conf("spark.comet.cbo.weight.project") + .category(CATEGORY_CBO) + .internal() + .doc("Weight for project operators in cost calculation.") + .doubleConf + .createWithDefault(0.1) + + val COMET_CBO_AGGREGATE_WEIGHT: ConfigEntry[Double] = conf("spark.comet.cbo.weight.aggregate") + .category(CATEGORY_CBO) + .internal() + .doc("Weight for aggregate operators in cost calculation.") + .doubleConf + .createWithDefault(2.0) + + val COMET_CBO_JOIN_WEIGHT: ConfigEntry[Double] = conf("spark.comet.cbo.weight.join") + .category(CATEGORY_CBO) + .internal() + .doc("Weight for join operators in cost calculation.") + .doubleConf + .createWithDefault(5.0) + + val COMET_CBO_SORT_WEIGHT: ConfigEntry[Double] = conf("spark.comet.cbo.weight.sort") + .category(CATEGORY_CBO) + .internal() + .doc("Weight for sort operators in cost calculation.") + .doubleConf + .createWithDefault(1.5) + + val COMET_CBO_SCAN_SPEEDUP: ConfigEntry[Double] = conf("spark.comet.cbo.speedup.scan") + .category(CATEGORY_CBO) + .internal() + .doc("Expected speedup factor for Comet scan operators vs Spark.") + .doubleConf + .createWithDefault(2.0) + + val COMET_CBO_FILTER_SPEEDUP: ConfigEntry[Double] = conf("spark.comet.cbo.speedup.filter") + .category(CATEGORY_CBO) + .internal() + .doc("Expected speedup factor for Comet filter operators vs Spark.") + .doubleConf + .createWithDefault(3.0) + + val COMET_CBO_AGGREGATE_SPEEDUP: ConfigEntry[Double] = conf("spark.comet.cbo.speedup.aggregate") + .category(CATEGORY_CBO) + .internal() + .doc("Expected speedup factor for Comet aggregate operators vs Spark.") + .doubleConf + .createWithDefault(2.5) + + val COMET_CBO_JOIN_SPEEDUP: ConfigEntry[Double] = conf("spark.comet.cbo.speedup.join") + .category(CATEGORY_CBO) + .internal() + .doc("Expected speedup factor for Comet join operators vs Spark.") + .doubleConf + .createWithDefault(2.0) + + val COMET_CBO_SORT_SPEEDUP: ConfigEntry[Double] = conf("spark.comet.cbo.speedup.sort") + .category(CATEGORY_CBO) + .internal() + .doc("Expected speedup factor for Comet sort operators vs Spark.") + .doubleConf + .createWithDefault(2.0) + /** Create a config to enable a specific operator */ private def createExecEnabledConfig( exec: String, diff --git a/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala b/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala index 56ae64ed68..37d5e1fbd5 100644 --- a/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala +++ b/spark/src/main/scala/org/apache/comet/ExtendedExplainInfo.scala @@ -29,13 +29,14 @@ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffl import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.comet.CometExplainInfo.getActualPlan +import org.apache.comet.rules.CometCBOInfo class ExtendedExplainInfo extends ExtendedExplainGenerator { override def title: String = "Comet" def generateExtendedInfo(plan: SparkPlan): String = { - CometConf.COMET_EXTENDED_EXPLAIN_FORMAT.get() match { + val baseInfo = CometConf.COMET_EXTENDED_EXPLAIN_FORMAT.get() match { case CometConf.COMET_EXTENDED_EXPLAIN_FORMAT_VERBOSE => // Generates the extended info in a verbose manner, printing each node along with the // extended information in a tree display. @@ -47,6 +48,14 @@ class ExtendedExplainInfo extends ExtendedExplainGenerator { // Generates the extended info as a list of fallback reasons getFallbackReasons(plan).mkString("\n").trim } + + // Add CBO info if available + val cboInfo = getActualPlan(plan) + .getTagValue(CometCBOInfo.TAG) + .map(analysis => s"\n${analysis.toExplainString}") + .getOrElse("") + + baseInfo + cboInfo } def getFallbackReasons(plan: SparkPlan): Seq[String] = { diff --git a/spark/src/main/scala/org/apache/comet/rules/CometCostEstimator.scala b/spark/src/main/scala/org/apache/comet/rules/CometCostEstimator.scala new file mode 100644 index 0000000000..0c9beb1145 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/rules/CometCostEstimator.scala @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.rules + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.comet._ +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive._ +import org.apache.spark.sql.execution.aggregate._ +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec +import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.internal.SQLConf + +import org.apache.comet.CometConf + +/** + * Cost analysis result containing all metrics for the CBO decision. + */ +case class CostAnalysis( + cometOperatorCount: Int, + sparkOperatorCount: Int, + transitionCount: Int, + estimatedRowCount: Option[Long], + estimatedSizeBytes: Option[Long], + sparkCost: Double, + cometCost: Double, + estimatedSpeedup: Double, + shouldUseComet: Boolean) { + + def toExplainString: String = { + f"""CBO Analysis: + | Decision: ${if (shouldUseComet) "Use Comet" else "Fall back to Spark"} + | Estimated Speedup: $estimatedSpeedup%.2fx + | Comet Operators: $cometOperatorCount + | Spark Operators: $sparkOperatorCount + | Transitions: $transitionCount + | Estimated Rows: ${estimatedRowCount.map(_.toString).getOrElse("unknown")} + | Spark Cost: $sparkCost%.2f + | Comet Cost: $cometCost%.2f""".stripMargin + } +} + +/** + * Statistics collected from plan traversal. + */ +case class PlanStatistics( + cometOps: Int = 0, + sparkOps: Int = 0, + transitions: Int = 0, + cometScans: Int = 0, + cometFilters: Int = 0, + cometProjects: Int = 0, + cometAggregates: Int = 0, + cometJoins: Int = 0, + cometSorts: Int = 0, + sparkScans: Int = 0, + sparkFilters: Int = 0, + sparkProjects: Int = 0, + sparkAggregates: Int = 0, + sparkJoins: Int = 0, + sparkSorts: Int = 0) + +/** + * Tag for attaching CBO info to plan nodes for EXPLAIN output. + */ +object CometCBOInfo { + val TAG: TreeNodeTag[CostAnalysis] = new TreeNodeTag[CostAnalysis]("CometCBOInfo") +} + +/** + * Cost estimator for comparing Comet vs Spark execution plans. + * + * The estimator uses a heuristic-based cost model with configurable weights for different + * operator types and transition penalties. It estimates whether running a query with Comet will + * be faster than running it with Spark. + */ +object CometCostEstimator extends Logging { + + /** + * Analyze a Comet plan and determine if it should be used over Spark. + */ + def analyze(cometPlan: SparkPlan, conf: SQLConf): CostAnalysis = { + val stats = collectStats(cometPlan) + val rowCount = extractRowCount(cometPlan) + val sizeBytes = extractSizeBytes(cometPlan) + + val sparkCost = calculateSparkCost(stats, rowCount, conf) + val cometCost = calculateCometCost(stats, rowCount, conf) + + val speedup = if (cometCost > 0) sparkCost / cometCost else Double.MaxValue + val threshold = CometConf.COMET_CBO_SPEEDUP_THRESHOLD.get(conf) + + CostAnalysis( + cometOperatorCount = stats.cometOps, + sparkOperatorCount = stats.sparkOps, + transitionCount = stats.transitions, + estimatedRowCount = rowCount, + estimatedSizeBytes = sizeBytes, + sparkCost = sparkCost, + cometCost = cometCost, + estimatedSpeedup = speedup, + shouldUseComet = speedup >= threshold) + } + + private def collectStats(plan: SparkPlan): PlanStatistics = { + var stats = PlanStatistics() + + plan.foreach { + // Transitions - these are expensive + case _: CometColumnarToRowExec | _: CometSparkToColumnarExec | _: ColumnarToRowExec | + _: RowToColumnarExec => + stats = stats.copy(transitions = stats.transitions + 1) + + // Comet scans + case _: CometScanExec | _: CometBatchScanExec | _: CometNativeScanExec | + _: CometIcebergNativeScanExec | _: CometLocalTableScanExec => + stats = stats.copy(cometOps = stats.cometOps + 1, cometScans = stats.cometScans + 1) + + // Comet filters + case _: CometFilterExec => + stats = stats.copy(cometOps = stats.cometOps + 1, cometFilters = stats.cometFilters + 1) + + // Comet projects + case _: CometProjectExec => + stats = stats.copy(cometOps = stats.cometOps + 1, cometProjects = stats.cometProjects + 1) + + // Comet aggregates + case _: CometHashAggregateExec => + stats = + stats.copy(cometOps = stats.cometOps + 1, cometAggregates = stats.cometAggregates + 1) + + // Comet joins + case _: CometBroadcastHashJoinExec | _: CometHashJoinExec | _: CometSortMergeJoinExec => + stats = stats.copy(cometOps = stats.cometOps + 1, cometJoins = stats.cometJoins + 1) + + // Comet sorts + case _: CometSortExec => + stats = stats.copy(cometOps = stats.cometOps + 1, cometSorts = stats.cometSorts + 1) + + // Other Comet operators (shuffle, window, etc.) + case _: CometShuffleExchangeExec | _: CometBroadcastExchangeExec | _: CometWindowExec | + _: CometCoalesceExec | _: CometCollectLimitExec | _: CometTakeOrderedAndProjectExec | + _: CometUnionExec | _: CometExpandExec | _: CometExplodeExec | _: CometLocalLimitExec | + _: CometGlobalLimitExec => + stats = stats.copy(cometOps = stats.cometOps + 1) + + // Spark scans + case _: FileSourceScanExec | _: BatchScanExec => + stats = stats.copy(sparkOps = stats.sparkOps + 1, sparkScans = stats.sparkScans + 1) + + // Spark filters + case _: FilterExec => + stats = stats.copy(sparkOps = stats.sparkOps + 1, sparkFilters = stats.sparkFilters + 1) + + // Spark projects + case _: ProjectExec => + stats = stats.copy(sparkOps = stats.sparkOps + 1, sparkProjects = stats.sparkProjects + 1) + + // Spark aggregates + case _: HashAggregateExec | _: ObjectHashAggregateExec => + stats = + stats.copy(sparkOps = stats.sparkOps + 1, sparkAggregates = stats.sparkAggregates + 1) + + // Spark joins + case _: BroadcastHashJoinExec | _: ShuffledHashJoinExec | _: SortMergeJoinExec => + stats = stats.copy(sparkOps = stats.sparkOps + 1, sparkJoins = stats.sparkJoins + 1) + + // Spark sorts + case _: SortExec => + stats = stats.copy(sparkOps = stats.sparkOps + 1, sparkSorts = stats.sparkSorts + 1) + + // Ignore wrapper/internal nodes + case _: AdaptiveSparkPlanExec | _: InputAdapter | _: QueryStageExec | + _: WholeStageCodegenExec | _: ReusedExchangeExec | _: AQEShuffleReadExec => + // Don't count these + + case _ => + // Other operators not specifically categorized + } + stats + } + + private def extractRowCount(plan: SparkPlan): Option[Long] = { + // Try logical plan stats first + plan.logicalLink.flatMap(_.stats.rowCount.map(_.toLong)).orElse { + // Fallback: estimate from size (assume ~100 bytes per row) + extractSizeBytes(plan).map(_ / 100) + } + } + + private def extractSizeBytes(plan: SparkPlan): Option[Long] = { + plan.logicalLink.map(_.stats.sizeInBytes.toLong).orElse { + // Fallback: look for scan-level size info + plan.collectLeaves().collectFirst { + case scan: CometScanExec => scan.relation.sizeInBytes + case scan: FileSourceScanExec => scan.relation.sizeInBytes + } + } + } + + private def calculateSparkCost( + stats: PlanStatistics, + rowCount: Option[Long], + conf: SQLConf): Double = { + val rows = rowCount.getOrElse(CometConf.COMET_CBO_DEFAULT_ROW_COUNT.get(conf)).toDouble + + // Base cost for each operator type (calculate as if all operators ran in Spark) + val scanCost = + (stats.sparkScans + stats.cometScans) * CometConf.COMET_CBO_SCAN_WEIGHT.get(conf) * rows + val filterCost = + (stats.sparkFilters + stats.cometFilters) * CometConf.COMET_CBO_FILTER_WEIGHT.get( + conf) * rows + val projectCost = + (stats.sparkProjects + stats.cometProjects) * CometConf.COMET_CBO_PROJECT_WEIGHT + .get(conf) * rows + val aggCost = + (stats.sparkAggregates + stats.cometAggregates) * CometConf.COMET_CBO_AGGREGATE_WEIGHT + .get(conf) * rows + val joinCost = + (stats.sparkJoins + stats.cometJoins) * CometConf.COMET_CBO_JOIN_WEIGHT.get(conf) * rows + val sortCost = (stats.sparkSorts + stats.cometSorts) * CometConf.COMET_CBO_SORT_WEIGHT.get( + conf) * rows * Math.log(rows + 1) + + scanCost + filterCost + projectCost + aggCost + joinCost + sortCost + } + + private def calculateCometCost( + stats: PlanStatistics, + rowCount: Option[Long], + conf: SQLConf): Double = { + val rows = rowCount.getOrElse(CometConf.COMET_CBO_DEFAULT_ROW_COUNT.get(conf)).toDouble + + // Comet operators cost less (divided by speedup factor) + val cometScanCost = stats.cometScans * CometConf.COMET_CBO_SCAN_WEIGHT.get( + conf) * rows / CometConf.COMET_CBO_SCAN_SPEEDUP.get(conf) + val cometFilterCost = stats.cometFilters * CometConf.COMET_CBO_FILTER_WEIGHT.get( + conf) * rows / CometConf.COMET_CBO_FILTER_SPEEDUP.get(conf) + val cometProjectCost = stats.cometProjects * CometConf.COMET_CBO_PROJECT_WEIGHT.get( + conf) * rows / CometConf.COMET_CBO_FILTER_SPEEDUP.get(conf) + val cometAggCost = stats.cometAggregates * CometConf.COMET_CBO_AGGREGATE_WEIGHT.get( + conf) * rows / CometConf.COMET_CBO_AGGREGATE_SPEEDUP.get(conf) + val cometJoinCost = stats.cometJoins * CometConf.COMET_CBO_JOIN_WEIGHT.get( + conf) * rows / CometConf.COMET_CBO_JOIN_SPEEDUP.get(conf) + val cometSortCost = stats.cometSorts * CometConf.COMET_CBO_SORT_WEIGHT.get(conf) * rows * Math + .log(rows + 1) / CometConf.COMET_CBO_SORT_SPEEDUP.get(conf) + + val cometOpCost = cometScanCost + cometFilterCost + cometProjectCost + + cometAggCost + cometJoinCost + cometSortCost + + // Spark operators that couldn't be converted still have full cost + val sparkScanCost = stats.sparkScans * CometConf.COMET_CBO_SCAN_WEIGHT.get(conf) * rows + val sparkFilterCost = stats.sparkFilters * CometConf.COMET_CBO_FILTER_WEIGHT.get(conf) * rows + val sparkProjectCost = + stats.sparkProjects * CometConf.COMET_CBO_PROJECT_WEIGHT.get(conf) * rows + val sparkAggCost = + stats.sparkAggregates * CometConf.COMET_CBO_AGGREGATE_WEIGHT.get(conf) * rows + val sparkJoinCost = stats.sparkJoins * CometConf.COMET_CBO_JOIN_WEIGHT.get(conf) * rows + val sparkSortCost = + stats.sparkSorts * CometConf.COMET_CBO_SORT_WEIGHT.get(conf) * rows * Math.log(rows + 1) + + val sparkOpCost = sparkScanCost + sparkFilterCost + sparkProjectCost + + sparkAggCost + sparkJoinCost + sparkSortCost + + // Transition penalty + val transitionCost = stats.transitions * CometConf.COMET_CBO_TRANSITION_COST.get(conf) * rows + + cometOpCost + sparkOpCost + transitionCost + } +} diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index bb4ce879d7..a2b5d901f2 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -387,6 +387,26 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { var newPlan = transform(planWithJoinRewritten) + // CBO decision point: analyze the transformed plan and decide whether to use Comet + // Store the analysis to attach to the final plan later (since transforms create new nodes) + var cboAnalysis: Option[CostAnalysis] = None + if (CometConf.COMET_CBO_ENABLED.get(conf)) { + val costAnalysis = CometCostEstimator.analyze(newPlan, conf) + cboAnalysis = Some(costAnalysis) + + if (CometConf.COMET_CBO_EXPLAIN_ENABLED.get(conf)) { + logInfo(s"Comet CBO Analysis:\n${costAnalysis.toExplainString}") + } + + if (!costAnalysis.shouldUseComet) { + logInfo( + s"Comet CBO: Falling back to Spark " + + f"(speedup=${costAnalysis.estimatedSpeedup}%.2f, threshold=" + + f"${CometConf.COMET_CBO_SPEEDUP_THRESHOLD.get(conf)}%.2f)") + return planWithJoinRewritten // Return original Spark plan + } + } + // if the plan cannot be run fully natively then explain why (when appropriate // config is enabled) if (CometConf.COMET_EXPLAIN_FALLBACK_ENABLED.get()) { @@ -442,9 +462,9 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { // Convert native execution block by linking consecutive native operators. var firstNativeOp = true - newPlan.transformDown { + val finalPlan = newPlan.transformDown { case op: CometNativeExec => - val newPlan = if (firstNativeOp) { + val transformedOp = if (firstNativeOp) { firstNativeOp = false op.convertBlock() } else { @@ -468,11 +488,16 @@ case class CometExecRule(session: SparkSession) extends Rule[SparkPlan] { firstNativeOp = true } - newPlan + transformedOp case op => firstNativeOp = true op } + + // Attach CBO info to the final plan for EXPLAIN output + cboAnalysis.foreach(analysis => finalPlan.setTagValue(CometCBOInfo.TAG, analysis)) + + finalPlan } } diff --git a/spark/src/test/scala/org/apache/comet/rules/CometCBOSuite.scala b/spark/src/test/scala/org/apache/comet/rules/CometCBOSuite.scala new file mode 100644 index 0000000000..9560643d70 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/rules/CometCBOSuite.scala @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.rules + +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.comet._ +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper + +import org.apache.comet.{CometConf, ExtendedExplainInfo} + +/** + * Test suite for the Comet Cost-Based Optimizer (CBO). + * + * Note: CBO only affects operator conversion (filter, project, aggregate, etc.), not scan + * conversion which is handled by CometScanRule. + */ +class CometCBOSuite extends CometTestBase with AdaptiveSparkPlanHelper { + + /** Helper to check if a plan contains any Comet exec operators (excluding scans) */ + private def containsCometExecOperators(plan: SparkPlan): Boolean = { + stripAQEPlan(plan).find { + case _: CometFilterExec | _: CometProjectExec | _: CometHashAggregateExec | + _: CometSortExec | _: CometBroadcastHashJoinExec | _: CometHashJoinExec | + _: CometSortMergeJoinExec => + true + case _ => false + }.isDefined + } + + /** Helper to check if a plan contains any Comet operators (including scans) */ + private def containsCometPlan(plan: SparkPlan): Boolean = { + stripAQEPlan(plan).find(_.isInstanceOf[CometPlan]).isDefined + } + + test("CBO prefers Comet for fully native plans") { + withParquetTable((0 until 1000).map(i => (i, s"val_$i")), "t") { + withSQLConf( + CometConf.COMET_CBO_ENABLED.key -> "true", + CometConf.COMET_CBO_SPEEDUP_THRESHOLD.key -> "1.0") { + val df = spark.sql("SELECT * FROM t WHERE _1 > 500") + val plan = df.queryExecution.executedPlan + assert(containsCometExecOperators(plan), "Expected Comet exec operators in the plan") + } + } + } + + test("CBO respects speedup threshold - high threshold disables Comet exec operators") { + withParquetTable((0 until 100).map(i => (i, s"val_$i")), "t") { + // With very high threshold, CBO should fall back to Spark for exec operators + // Note: Scans may still be Comet since they're converted by CometScanRule + withSQLConf( + CometConf.COMET_CBO_ENABLED.key -> "true", + CometConf.COMET_CBO_SPEEDUP_THRESHOLD.key -> "100.0") { + val df = spark.sql("SELECT * FROM t WHERE _1 > 50") + val plan = df.queryExecution.executedPlan + assert( + !containsCometExecOperators(plan), + "Expected no Comet exec operators with high threshold") + } + } + } + + test("CBO respects speedup threshold - low threshold enables Comet") { + withParquetTable((0 until 100).map(i => (i, s"val_$i")), "t") { + // With low threshold, always use Comet + withSQLConf( + CometConf.COMET_CBO_ENABLED.key -> "true", + CometConf.COMET_CBO_SPEEDUP_THRESHOLD.key -> "0.1") { + val df = spark.sql("SELECT * FROM t WHERE _1 > 50") + val plan = df.queryExecution.executedPlan + assert( + containsCometExecOperators(plan), + "Expected Comet exec operators with low threshold") + } + } + } + + test("CBO disabled returns Comet plan unconditionally") { + withParquetTable((0 until 100).map(i => (i, s"val_$i")), "t") { + // CBO is disabled by default, so Comet operators should be used + withSQLConf(CometConf.COMET_CBO_ENABLED.key -> "false") { + val df = spark.sql("SELECT * FROM t WHERE _1 > 50") + val plan = df.queryExecution.executedPlan + // Should use Comet regardless of cost when CBO is disabled + assert(containsCometPlan(plan), "Expected Comet operators when CBO is disabled") + } + } + } + + test("CBO analysis is computed when enabled") { + withParquetTable((0 until 100).map(i => (i, s"val_$i")), "t") { + withSQLConf( + CometConf.COMET_CBO_ENABLED.key -> "true", + CometConf.COMET_CBO_EXPLAIN_ENABLED.key -> "true", + CometConf.COMET_CBO_SPEEDUP_THRESHOLD.key -> "1.0") { + val df = spark.sql("SELECT * FROM t WHERE _1 > 50") + // Just verify the query runs successfully with CBO enabled + val result = df.collect() + assert(result.length > 0, "Query should return results") + } + } + } + + test("results are correct regardless of CBO decision") { + withParquetTable((0 until 100).map(i => (i, s"val_$i")), "t") { + // Get expected results with Comet disabled + var expected: Array[org.apache.spark.sql.Row] = Array.empty + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + expected = spark.sql("SELECT * FROM t WHERE _1 > 50").collect() + } + + // Test with CBO enabled + withSQLConf(CometConf.COMET_CBO_ENABLED.key -> "true") { + val actual = spark.sql("SELECT * FROM t WHERE _1 > 50").collect() + assert(actual.toSet == expected.toSet, "Results should match regardless of CBO decision") + } + } + } + + test("CostAnalysis correctly computes estimated speedup") { + // Create a simple analysis and verify computation + val analysis = CostAnalysis( + cometOperatorCount = 5, + sparkOperatorCount = 0, + transitionCount = 0, + estimatedRowCount = Some(1000L), + estimatedSizeBytes = Some(100000L), + sparkCost = 1000.0, + cometCost = 500.0, + estimatedSpeedup = 2.0, + shouldUseComet = true) + + assert(analysis.estimatedSpeedup == 2.0) + assert(analysis.shouldUseComet) + assert(analysis.cometOperatorCount == 5) + assert(analysis.sparkOperatorCount == 0) + } + + test("CostAnalysis toExplainString format") { + val analysis = CostAnalysis( + cometOperatorCount = 3, + sparkOperatorCount = 1, + transitionCount = 1, + estimatedRowCount = Some(1000L), + estimatedSizeBytes = Some(100000L), + sparkCost = 1000.0, + cometCost = 600.0, + estimatedSpeedup = 1.67, + shouldUseComet = true) + + val explainString = analysis.toExplainString + + assert(explainString.contains("CBO Analysis")) + assert(explainString.contains("Decision: Use Comet")) + assert(explainString.contains("Comet Operators: 3")) + assert(explainString.contains("Spark Operators: 1")) + assert(explainString.contains("Transitions: 1")) + assert(explainString.contains("Estimated Rows: 1000")) + } + + test("CostAnalysis reflects fall back decision") { + val analysis = CostAnalysis( + cometOperatorCount = 1, + sparkOperatorCount = 5, + transitionCount = 3, + estimatedRowCount = Some(1000L), + estimatedSizeBytes = Some(100000L), + sparkCost = 500.0, + cometCost = 800.0, + estimatedSpeedup = 0.625, + shouldUseComet = false) + + val explainString = analysis.toExplainString + + assert(explainString.contains("Decision: Fall back to Spark")) + assert(!analysis.shouldUseComet) + } + + test("CBO handles aggregation queries") { + withParquetTable((0 until 1000).map(i => (i % 10, s"val_$i")), "t") { + withSQLConf( + CometConf.COMET_CBO_ENABLED.key -> "true", + CometConf.COMET_CBO_SPEEDUP_THRESHOLD.key -> "1.0") { + val df = spark.sql("SELECT _1, COUNT(*) as cnt FROM t GROUP BY _1") + checkSparkAnswer(df) + } + } + } + + test("CBO handles join queries") { + withParquetTable((0 until 100).map(i => (i, s"val_$i")), "t1") { + withParquetTable((0 until 100).map(i => (i, s"other_$i")), "t2") { + withSQLConf( + CometConf.COMET_CBO_ENABLED.key -> "true", + CometConf.COMET_CBO_SPEEDUP_THRESHOLD.key -> "1.0") { + val df = spark.sql("SELECT t1._1, t2._2 FROM t1 JOIN t2 ON t1._1 = t2._1") + checkSparkAnswer(df) + } + } + } + } + + test("CBO handles sort queries") { + withParquetTable((0 until 100).map(i => (i, s"val_$i")), "t") { + withSQLConf( + CometConf.COMET_CBO_ENABLED.key -> "true", + CometConf.COMET_CBO_SPEEDUP_THRESHOLD.key -> "1.0") { + val df = spark.sql("SELECT * FROM t ORDER BY _1 DESC") + checkSparkAnswer(df) + } + } + } + + test("PlanStatistics correctly counts operators") { + // Test the PlanStatistics case class + val stats = PlanStatistics( + cometOps = 5, + sparkOps = 2, + transitions = 1, + cometScans = 1, + cometFilters = 2, + cometProjects = 2, + sparkScans = 1, + sparkFilters = 1) + + assert(stats.cometOps == 5) + assert(stats.sparkOps == 2) + assert(stats.transitions == 1) + assert(stats.cometScans == 1) + assert(stats.cometFilters == 2) + assert(stats.cometProjects == 2) + } +} From f12a21f79730224586a827ae4e8aa3cdd1a8f074 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 19 Jan 2026 12:13:35 -0700 Subject: [PATCH 2/2] feat: add expression-based costing to CBO Enhance the CBO cost model to consider the cost of individual expressions in projections and filters, rather than using a fixed cost per operator. Key changes: - Add DEFAULT_EXPR_COSTS map with cost multipliers for common expressions (e.g., AttributeReference=0.1 since Comet just clones arrays) - Add dynamic config override via spark.comet.cbo.exprCost. - Update cost calculation to sum expression costs in filters/projects - Add CometCBOSuite to CI workflows (Linux and macOS) - Add 6 new tests for expression-based costing Expression cost multipliers: - < 1.0 means Comet is faster for this expression - > 1.0 means Spark is faster for this expression - 1.0 means they are equivalent Example config override: spark.conf.set("spark.comet.cbo.exprCost.MyCustomExpr", "1.5") Co-Authored-By: Claude Opus 4.5 --- .github/workflows/pr_build_linux.yml | 1 + .github/workflows/pr_build_macos.yml | 1 + .../scala/org/apache/comet/CometConf.scala | 36 ++++ .../comet/rules/CometCostEstimator.scala | 166 ++++++++++++++++-- .../apache/comet/rules/CometCBOSuite.scala | 64 +++++++ 5 files changed, 251 insertions(+), 17 deletions(-) diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index 8e4dc5124b..b1bd8599db 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -150,6 +150,7 @@ jobs: org.apache.spark.CometPluginsUnifiedModeOverrideSuite org.apache.comet.rules.CometScanRuleSuite org.apache.comet.rules.CometExecRuleSuite + org.apache.comet.rules.CometCBOSuite org.apache.spark.sql.CometTPCDSQuerySuite org.apache.spark.sql.CometTPCDSQueryTestSuite org.apache.spark.sql.CometTPCHQuerySuite diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index f94071dbc7..74c72a9164 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -113,6 +113,7 @@ jobs: org.apache.spark.CometPluginsUnifiedModeOverrideSuite org.apache.comet.rules.CometScanRuleSuite org.apache.comet.rules.CometExecRuleSuite + org.apache.comet.rules.CometCBOSuite org.apache.spark.sql.CometTPCDSQuerySuite org.apache.spark.sql.CometTPCDSQueryTestSuite org.apache.spark.sql.CometTPCHQuerySuite diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index cccebef5f3..ce0bdab82a 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -940,6 +940,42 @@ object CometConf extends ShimCometConf { def getBooleanConf(name: String, defaultValue: Boolean, conf: SQLConf): Boolean = { conf.getConfString(name, defaultValue.toString).toLowerCase(Locale.ROOT) == "true" } + + // CBO expression cost configuration helpers + + /** Config key prefix for CBO expression costs */ + val COMET_CBO_EXPR_COST_PREFIX = "spark.comet.cbo.exprCost" + + /** + * Get the config key for an expression's cost multiplier. Example: + * spark.comet.cbo.exprCost.AttributeReference + */ + def getExprCostConfigKey(exprName: String): String = { + s"$COMET_CBO_EXPR_COST_PREFIX.$exprName" + } + + /** + * Get the config key for an expression class's cost multiplier. + */ + def getExprCostConfigKey(exprClass: Class[_]): String = { + getExprCostConfigKey(exprClass.getSimpleName) + } + + /** + * Get the cost multiplier for an expression from config, with fallback to default. A cost + * multiplier < 1.0 means Comet is faster, > 1.0 means Spark is faster. + */ + def getExprCost(exprName: String, defaultCost: Double, conf: SQLConf): Double = { + getDoubleConf(getExprCostConfigKey(exprName), defaultCost, conf) + } + + def getDoubleConf(name: String, defaultValue: Double, conf: SQLConf): Double = { + try { + conf.getConfString(name, defaultValue.toString).toDouble + } catch { + case _: NumberFormatException => defaultValue + } + } } object ConfigHelpers { diff --git a/spark/src/main/scala/org/apache/comet/rules/CometCostEstimator.scala b/spark/src/main/scala/org/apache/comet/rules/CometCostEstimator.scala index 0c9beb1145..d11259ead5 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometCostEstimator.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometCostEstimator.scala @@ -20,6 +20,7 @@ package org.apache.comet.rules import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec @@ -78,7 +79,10 @@ case class PlanStatistics( sparkProjects: Int = 0, sparkAggregates: Int = 0, sparkJoins: Int = 0, - sparkSorts: Int = 0) + sparkSorts: Int = 0, + // Expression costs for Comet and Spark projects/filters + cometExprCost: Double = 0.0, + sparkExprCost: Double = 0.0) /** * Tag for attaching CBO info to plan nodes for EXPLAIN output. @@ -96,6 +100,105 @@ object CometCBOInfo { */ object CometCostEstimator extends Logging { + /** + * Default expression cost multipliers for Comet vs Spark. A value < 1.0 means Comet is faster + * for this expression type. A value > 1.0 means Spark is faster for this expression type. A + * value of 1.0 means they are equivalent. + * + * These can be overridden via config: spark.comet.cbo.exprCost. + * + * For example, AttributeReference is very fast in Comet (just array cloning) while complex + * string operations might be slower. + */ + val DEFAULT_EXPR_COSTS: Map[String, Double] = Map( + // Very fast in Comet - just array/reference operations + "AttributeReference" -> 0.1, // Array cloning vs row writing + "BoundReference" -> 0.1, + "Literal" -> 0.5, + "Alias" -> 0.3, + // Arithmetic - generally fast in Comet with vectorization + "Add" -> 0.4, + "Subtract" -> 0.4, + "Multiply" -> 0.4, + "Divide" -> 0.5, + "Remainder" -> 0.5, + "UnaryMinus" -> 0.3, + "Abs" -> 0.4, + // Comparisons - fast with SIMD + "EqualTo" -> 0.4, + "EqualNullSafe" -> 0.5, + "LessThan" -> 0.4, + "LessThanOrEqual" -> 0.4, + "GreaterThan" -> 0.4, + "GreaterThanOrEqual" -> 0.4, + // Logical - fast + "And" -> 0.3, + "Or" -> 0.3, + "Not" -> 0.3, + // Null handling - fast + "IsNull" -> 0.3, + "IsNotNull" -> 0.3, + "Coalesce" -> 0.5, + "If" -> 0.5, + "CaseWhen" -> 0.6, + // Cast - depends on types but generally comparable + "Cast" -> 0.8, + // String operations - some are slower in Comet + "Upper" -> 0.9, + "Lower" -> 0.9, + "Substring" -> 0.8, + "StringTrim" -> 0.9, + "StringTrimLeft" -> 0.9, + "StringTrimRight" -> 0.9, + "Concat" -> 1.0, + "Length" -> 0.6, + "Like" -> 1.0, + "Contains" -> 0.9, + "StartsWith" -> 0.8, + "EndsWith" -> 0.8, + // Date/Time - comparable + "Year" -> 0.7, + "Month" -> 0.7, + "DayOfMonth" -> 0.7, + "Hour" -> 0.7, + "Minute" -> 0.7, + "Second" -> 0.7, + // Aggregation expressions + "Sum" -> 0.5, + "Count" -> 0.4, + "Min" -> 0.5, + "Max" -> 0.5, + "Average" -> 0.5) + + /** Default cost for expressions not in the map */ + val DEFAULT_UNKNOWN_EXPR_COST: Double = 0.7 + + /** + * Get the cost multiplier for an expression, checking config override first. + */ + def getExprCost(exprName: String, conf: SQLConf): Double = { + val defaultCost = DEFAULT_EXPR_COSTS.getOrElse(exprName, DEFAULT_UNKNOWN_EXPR_COST) + CometConf.getExprCost(exprName, defaultCost, conf) + } + + /** + * Calculate the total expression cost for a sequence of expressions. + */ + def calculateExpressionCost(expressions: Seq[Expression], conf: SQLConf): Double = { + expressions.map(calculateSingleExprCost(_, conf)).sum + } + + /** + * Calculate the cost for a single expression tree. + */ + private def calculateSingleExprCost(expr: Expression, conf: SQLConf): Double = { + // Base cost for this expression + val baseCost = getExprCost(expr.getClass.getSimpleName, conf) + // Recursively add cost of child expressions + val childCost = expr.children.map(calculateSingleExprCost(_, conf)).sum + baseCost + childCost + } + /** * Analyze a Comet plan and determine if it should be used over Spark. */ @@ -123,6 +226,10 @@ object CometCostEstimator extends Logging { } private def collectStats(plan: SparkPlan): PlanStatistics = { + collectStatsWithConf(plan, SQLConf.get) + } + + private def collectStatsWithConf(plan: SparkPlan, conf: SQLConf): PlanStatistics = { var stats = PlanStatistics() plan.foreach { @@ -136,13 +243,21 @@ object CometCostEstimator extends Logging { _: CometIcebergNativeScanExec | _: CometLocalTableScanExec => stats = stats.copy(cometOps = stats.cometOps + 1, cometScans = stats.cometScans + 1) - // Comet filters - case _: CometFilterExec => - stats = stats.copy(cometOps = stats.cometOps + 1, cometFilters = stats.cometFilters + 1) - - // Comet projects - case _: CometProjectExec => - stats = stats.copy(cometOps = stats.cometOps + 1, cometProjects = stats.cometProjects + 1) + // Comet filters - also calculate expression cost + case f: CometFilterExec => + val exprCost = calculateExpressionCost(Seq(f.condition), conf) + stats = stats.copy( + cometOps = stats.cometOps + 1, + cometFilters = stats.cometFilters + 1, + cometExprCost = stats.cometExprCost + exprCost) + + // Comet projects - calculate expression cost for all project expressions + case p: CometProjectExec => + val exprCost = calculateExpressionCost(p.projectList, conf) + stats = stats.copy( + cometOps = stats.cometOps + 1, + cometProjects = stats.cometProjects + 1, + cometExprCost = stats.cometExprCost + exprCost) // Comet aggregates case _: CometHashAggregateExec => @@ -168,13 +283,21 @@ object CometCostEstimator extends Logging { case _: FileSourceScanExec | _: BatchScanExec => stats = stats.copy(sparkOps = stats.sparkOps + 1, sparkScans = stats.sparkScans + 1) - // Spark filters - case _: FilterExec => - stats = stats.copy(sparkOps = stats.sparkOps + 1, sparkFilters = stats.sparkFilters + 1) - - // Spark projects - case _: ProjectExec => - stats = stats.copy(sparkOps = stats.sparkOps + 1, sparkProjects = stats.sparkProjects + 1) + // Spark filters - also calculate expression cost + case f: FilterExec => + val exprCost = calculateExpressionCost(Seq(f.condition), conf) + stats = stats.copy( + sparkOps = stats.sparkOps + 1, + sparkFilters = stats.sparkFilters + 1, + sparkExprCost = stats.sparkExprCost + exprCost) + + // Spark projects - calculate expression cost for all project expressions + case p: ProjectExec => + val exprCost = calculateExpressionCost(p.projectList, conf) + stats = stats.copy( + sparkOps = stats.sparkOps + 1, + sparkProjects = stats.sparkProjects + 1, + sparkExprCost = stats.sparkExprCost + exprCost) // Spark aggregates case _: HashAggregateExec | _: ObjectHashAggregateExec => @@ -241,7 +364,11 @@ object CometCostEstimator extends Logging { val sortCost = (stats.sparkSorts + stats.cometSorts) * CometConf.COMET_CBO_SORT_WEIGHT.get( conf) * rows * Math.log(rows + 1) - scanCost + filterCost + projectCost + aggCost + joinCost + sortCost + // Expression cost: in Spark, expression cost multiplier is 1.0 (baseline) + // Total expression cost is the sum of all expression costs from filters and projects + val totalExprCost = (stats.sparkExprCost + stats.cometExprCost) * rows + + scanCost + filterCost + projectCost + aggCost + joinCost + sortCost + totalExprCost } private def calculateCometCost( @@ -281,9 +408,14 @@ object CometCostEstimator extends Logging { val sparkOpCost = sparkScanCost + sparkFilterCost + sparkProjectCost + sparkAggCost + sparkJoinCost + sparkSortCost + // Expression cost: Comet expression costs already have the multiplier applied + // (values < 1.0 mean faster in Comet). Spark expressions run at baseline cost (1.0). + val cometExprCost = stats.cometExprCost * rows + val sparkExprCost = stats.sparkExprCost * rows + // Transition penalty val transitionCost = stats.transitions * CometConf.COMET_CBO_TRANSITION_COST.get(conf) * rows - cometOpCost + sparkOpCost + transitionCost + cometOpCost + sparkOpCost + cometExprCost + sparkExprCost + transitionCost } } diff --git a/spark/src/test/scala/org/apache/comet/rules/CometCBOSuite.scala b/spark/src/test/scala/org/apache/comet/rules/CometCBOSuite.scala index 9560643d70..05a98de69b 100644 --- a/spark/src/test/scala/org/apache/comet/rules/CometCBOSuite.scala +++ b/spark/src/test/scala/org/apache/comet/rules/CometCBOSuite.scala @@ -248,4 +248,68 @@ class CometCBOSuite extends CometTestBase with AdaptiveSparkPlanHelper { assert(stats.cometFilters == 2) assert(stats.cometProjects == 2) } + + // Expression-based costing tests + + test("default expression costs favor Comet for simple expressions") { + // AttributeReference should be very fast in Comet (0.1) + assert(CometCostEstimator.DEFAULT_EXPR_COSTS("AttributeReference") < 0.5) + // Arithmetic should be fast in Comet + assert(CometCostEstimator.DEFAULT_EXPR_COSTS("Add") < 1.0) + assert(CometCostEstimator.DEFAULT_EXPR_COSTS("Multiply") < 1.0) + // Comparisons should be fast with SIMD + assert(CometCostEstimator.DEFAULT_EXPR_COSTS("LessThan") < 1.0) + } + + test("expression cost can be overridden via config") { + withSQLConf(CometConf.COMET_CBO_EXPR_COST_PREFIX + ".TestExpr" -> "2.5") { + val cost = CometConf.getExprCost("TestExpr", 1.0, spark.sessionState.conf) + assert(cost == 2.5, "Config override should take precedence") + } + } + + test("expression cost uses default when config not set") { + val cost = CometConf.getExprCost( + "NonExistentExpr", + CometCostEstimator.DEFAULT_UNKNOWN_EXPR_COST, + spark.sessionState.conf) + assert(cost == CometCostEstimator.DEFAULT_UNKNOWN_EXPR_COST) + } + + test("CBO considers expression costs in projections") { + withParquetTable((0 until 100).map(i => (i, i * 2, s"val_$i")), "t") { + withSQLConf( + CometConf.COMET_CBO_ENABLED.key -> "true", + CometConf.COMET_CBO_SPEEDUP_THRESHOLD.key -> "1.0") { + // Query with simple projections (should favor Comet) + val df = spark.sql("SELECT _1, _2, _1 + _2 as sum FROM t") + checkSparkAnswer(df) + } + } + } + + test("CBO considers expression costs in filters") { + withParquetTable((0 until 100).map(i => (i, s"val_$i")), "t") { + withSQLConf( + CometConf.COMET_CBO_ENABLED.key -> "true", + CometConf.COMET_CBO_SPEEDUP_THRESHOLD.key -> "1.0") { + // Query with comparison filter (should favor Comet) + val df = spark.sql("SELECT * FROM t WHERE _1 > 50 AND _1 < 80") + checkSparkAnswer(df) + } + } + } + + test("PlanStatistics tracks expression costs") { + val stats = PlanStatistics( + cometOps = 2, + sparkOps = 1, + cometFilters = 1, + cometProjects = 1, + cometExprCost = 0.5, + sparkExprCost = 1.0) + + assert(stats.cometExprCost == 0.5) + assert(stats.sparkExprCost == 1.0) + } }