Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: Enable Comet by default except some tests in SparkSessionExtensionSuite #1201

Merged
merged 2 commits into from
Jan 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 55 additions & 15 deletions dev/diffs/3.4.3.diff
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
diff --git a/pom.xml b/pom.xml
index d3544881af1..bf0e2b53c70 100644
index d3544881af1..26ab186c65d 100644
--- a/pom.xml
+++ b/pom.xml
@@ -148,6 +148,8 @@
Expand Down Expand Up @@ -38,7 +38,7 @@ index d3544881af1..bf0e2b53c70 100644
</dependencyManagement>

diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index b386d135da1..854aec17c2d 100644
index b386d135da1..46449e3f3f1 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -77,6 +77,10 @@
Expand All @@ -53,7 +53,7 @@ index b386d135da1..854aec17c2d 100644
<!--
This spark-tags test-dep is needed even though it isn't used in this module, otherwise testing-cmds that exclude
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index c595b50950b..6b60213e775 100644
index c595b50950b..3abb6cb9441 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -102,7 +102,7 @@ class SparkSession private(
Expand All @@ -79,7 +79,7 @@ index c595b50950b..6b60213e775 100644
}

+ private def loadCometExtension(sparkContext: SparkContext): Seq[String] = {
+ if (sparkContext.getConf.getBoolean("spark.comet.enabled", false)) {
+ if (sparkContext.getConf.getBoolean("spark.comet.enabled", isCometEnabled)) {
+ Seq("org.apache.comet.CometSparkSessionExtensions")
+ } else {
+ Seq.empty
Expand All @@ -100,6 +100,19 @@ index c595b50950b..6b60213e775 100644
try {
val extensionConfClass = Utils.classForName(extensionConfClassName)
val extensionConf = extensionConfClass.getConstructor().newInstance()
@@ -1323,4 +1333,12 @@ object SparkSession extends Logging {
}
}
}
+
+ /**
+ * Whether Comet extension is enabled
+ */
+ def isCometEnabled: Boolean = {
+ val v = System.getenv("ENABLE_COMET")
+ v == null || v.toBoolean
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
index db587dd9868..aac7295a53d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
Expand Down Expand Up @@ -957,6 +970,37 @@ index 525d97e4998..8a3e7457618 100644
AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") {
sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC").collect()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 48ad10992c5..51d1ee65422 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -221,6 +221,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper {
withSession(extensions) { session =>
session.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED, true)
session.conf.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1")
+ // https://github.com/apache/datafusion-comet/issues/1197
+ session.conf.set("spark.comet.enabled", false)
assert(session.sessionState.columnarRules.contains(
MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())))
import session.sqlContext.implicits._
@@ -279,6 +281,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper {
}
withSession(extensions) { session =>
session.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED, enableAQE)
+ // https://github.com/apache/datafusion-comet/issues/1197
+ session.conf.set("spark.comet.enabled", false)
assert(session.sessionState.columnarRules.contains(
MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())))
import session.sqlContext.implicits._
@@ -317,6 +321,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper {
val session = SparkSession.builder()
.master("local[1]")
.config(COLUMN_BATCH_SIZE.key, 2)
+ // https://github.com/apache/datafusion-comet/issues/1197
+ .config("spark.comet.enabled", false)
.withExtensions { extensions =>
extensions.injectColumnar(session =>
MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())) }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 75eabcb96f2..36e3318ad7e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Expand Down Expand Up @@ -2720,7 +2764,7 @@ index abe606ad9c1..2d930b64cca 100644
val tblTargetName = "tbl_target"
val tblSourceQualified = s"default.$tblSourceName"
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index dd55fcfe42c..aa9b0be8e68 100644
index dd55fcfe42c..2702f87c1f1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTest
Expand All @@ -2744,17 +2788,14 @@ index dd55fcfe42c..aa9b0be8e68 100644
}
}

@@ -242,6 +247,41 @@ private[sql] trait SQLTestUtilsBase
@@ -242,6 +247,38 @@ private[sql] trait SQLTestUtilsBase
protected override def _sqlContext: SQLContext = self.spark.sqlContext
}

+ /**
+ * Whether Comet extension is enabled
+ */
+ protected def isCometEnabled: Boolean = {
+ val v = System.getenv("ENABLE_COMET")
+ v != null && v.toBoolean
+ }
+ protected def isCometEnabled: Boolean = SparkSession.isCometEnabled
+
+ /**
+ * Whether to enable ansi mode This is only effective when
Expand Down Expand Up @@ -2786,7 +2827,7 @@ index dd55fcfe42c..aa9b0be8e68 100644
protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
SparkSession.setActiveSession(spark)
super.withSQLConf(pairs: _*)(f)
@@ -434,6 +474,8 @@ private[sql] trait SQLTestUtilsBase
@@ -434,6 +471,8 @@ private[sql] trait SQLTestUtilsBase
val schema = df.schema
val withoutFilters = df.queryExecution.executedPlan.transform {
case FilterExec(_, child) => child
Expand Down Expand Up @@ -2884,10 +2925,10 @@ index 1966e1e64fd..cde97a0aafe 100644
spark.sql(
"""
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
index 07361cfdce9..6673c141c9a 100644
index 07361cfdce9..e40c59a4207 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -55,25 +55,53 @@ object TestHive
@@ -55,25 +55,52 @@ object TestHive
new SparkContext(
System.getProperty("spark.sql.test.master", "local[1]"),
"TestSQLContext",
Expand Down Expand Up @@ -2929,8 +2970,7 @@ index 07361cfdce9..6673c141c9a 100644
+ // ConstantPropagation etc.
+ .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName)
+
+ val v = System.getenv("ENABLE_COMET")
+ if (v != null && v.toBoolean) {
+ if (SparkSession.isCometEnabled) {
+ conf
+ .set("spark.sql.extensions", "org.apache.comet.CometSparkSessionExtensions")
+ .set("spark.comet.enabled", "true")
Expand Down
80 changes: 60 additions & 20 deletions dev/diffs/3.5.1.diff
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
diff --git a/pom.xml b/pom.xml
index 0f504dbee85..f6019da888a 100644
index 0f504dbee85..430ec217e59 100644
--- a/pom.xml
+++ b/pom.xml
@@ -152,6 +152,8 @@
Expand Down Expand Up @@ -38,7 +38,7 @@ index 0f504dbee85..f6019da888a 100644
</dependencyManagement>

diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index c46ab7b8fce..d8b99c2c115 100644
index c46ab7b8fce..13357e8c7a6 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -77,6 +77,10 @@
Expand All @@ -53,19 +53,19 @@ index c46ab7b8fce..d8b99c2c115 100644
<!--
This spark-tags test-dep is needed even though it isn't used in this module, otherwise testing-cmds that exclude
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 27ae10b3d59..064cbc252ea 100644
index 27ae10b3d59..78e69902dfd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -1353,6 +1353,14 @@ object SparkSession extends Logging {
}
}

+ private def loadCometExtension(sparkContext: SparkContext): Seq[String] = {
+ if (sparkContext.getConf.getBoolean("spark.comet.enabled", false)) {
+ Seq("org.apache.comet.CometSparkSessionExtensions")
+ } else {
+ Seq.empty
+ }
+ if (sparkContext.getConf.getBoolean("spark.comet.enabled", isCometEnabled)) {
+ Seq("org.apache.comet.CometSparkSessionExtensions")
+ } else {
+ Seq.empty
+ }
+ }
+
/**
Expand All @@ -79,6 +79,19 @@ index 27ae10b3d59..064cbc252ea 100644
extensionConfClassNames.foreach { extensionConfClassName =>
try {
val extensionConfClass = Utils.classForName(extensionConfClassName)
@@ -1396,4 +1405,12 @@ object SparkSession extends Logging {
}
}
}
+
+ /**
+ * Whether Comet extension is enabled
+ */
+ def isCometEnabled: Boolean = {
+ val v = System.getenv("ENABLE_COMET")
+ v == null || v.toBoolean
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
index db587dd9868..aac7295a53d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
Expand Down Expand Up @@ -959,6 +972,37 @@ index cfeccbdf648..803d8734cc4 100644
AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "external sort") {
sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC").collect()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 8b4ac474f87..3f79f20822f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -223,6 +223,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt
withSession(extensions) { session =>
session.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED, true)
session.conf.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1")
+ // https://github.com/apache/datafusion-comet/issues/1197
+ session.conf.set("spark.comet.enabled", false)
assert(session.sessionState.columnarRules.contains(
MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())))
import session.sqlContext.implicits._
@@ -281,6 +283,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt
}
withSession(extensions) { session =>
session.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED, enableAQE)
+ // https://github.com/apache/datafusion-comet/issues/1197
+ session.conf.set("spark.comet.enabled", false)
assert(session.sessionState.columnarRules.contains(
MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())))
import session.sqlContext.implicits._
@@ -319,6 +323,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite with SQLHelper with Adapt
val session = SparkSession.builder()
.master("local[1]")
.config(COLUMN_BATCH_SIZE.key, 2)
+ // https://github.com/apache/datafusion-comet/issues/1197
+ .config("spark.comet.enabled", false)
.withExtensions { extensions =>
extensions.injectColumnar(session =>
MyColumnarRule(PreRuleReplaceAddWithBrokenVersion(), MyPostRule())) }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index fbc256b3396..0821999c7c2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
Expand Down Expand Up @@ -1310,7 +1354,7 @@ index 47679ed7865..9ffbaecb98e 100644
assert(collectWithSubqueries(plan) { case s: SortAggregateExec => s }.length == sortAggCount)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 5a413c77754..c52f4b3818c 100644
index 5a413c77754..a6f97dccb67 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
Expand Down Expand Up @@ -2705,7 +2749,7 @@ index abe606ad9c1..2d930b64cca 100644
val tblTargetName = "tbl_target"
val tblSourceQualified = s"default.$tblSourceName"
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index dd55fcfe42c..aa9b0be8e68 100644
index dd55fcfe42c..2702f87c1f1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTest
Expand All @@ -2729,17 +2773,14 @@ index dd55fcfe42c..aa9b0be8e68 100644
}
}

@@ -242,6 +247,41 @@ private[sql] trait SQLTestUtilsBase
@@ -242,6 +247,38 @@ private[sql] trait SQLTestUtilsBase
protected override def _sqlContext: SQLContext = self.spark.sqlContext
}

+ /**
+ * Whether Comet extension is enabled
+ */
+ protected def isCometEnabled: Boolean = {
+ val v = System.getenv("ENABLE_COMET")
+ v != null && v.toBoolean
+ }
+ protected def isCometEnabled: Boolean = SparkSession.isCometEnabled
+
+ /**
+ * Whether to enable ansi mode This is only effective when
Expand Down Expand Up @@ -2771,7 +2812,7 @@ index dd55fcfe42c..aa9b0be8e68 100644
protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
SparkSession.setActiveSession(spark)
super.withSQLConf(pairs: _*)(f)
@@ -434,6 +474,8 @@ private[sql] trait SQLTestUtilsBase
@@ -434,6 +471,8 @@ private[sql] trait SQLTestUtilsBase
val schema = df.schema
val withoutFilters = df.queryExecution.executedPlan.transform {
case FilterExec(_, child) => child
Expand Down Expand Up @@ -2869,10 +2910,10 @@ index dc8b184fcee..dd69a989d40 100644
spark.sql(
"""
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
index 9284b35fb3e..e8984be5ebc 100644
index 9284b35fb3e..2a0269bdc16 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -53,25 +53,53 @@ object TestHive
@@ -53,25 +53,52 @@ object TestHive
new SparkContext(
System.getProperty("spark.sql.test.master", "local[1]"),
"TestSQLContext",
Expand Down Expand Up @@ -2914,8 +2955,7 @@ index 9284b35fb3e..e8984be5ebc 100644
+ // ConstantPropagation etc.
+ .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName)
+
+ val v = System.getenv("ENABLE_COMET")
+ if (v != null && v.toBoolean) {
+ if (SparkSession.isCometEnabled) {
+ conf
+ .set("spark.sql.extensions", "org.apache.comet.CometSparkSessionExtensions")
+ .set("spark.comet.enabled", "true")
Expand Down
Loading
Loading