Skip to content

Commit

Permalink
chore: Add safety check to CometBuffer (#1050)
Browse files Browse the repository at this point in the history
* chore: Add safety check to CometBuffer

* Add CometColumnarToRowExec

* fix

* fix

* more

* Update plan stability results

* fix

* fix

* fix

* Revert "fix"

This reverts commit 9bad173.

* Revert "Revert "fix""

This reverts commit d527ad1.

* fix BucketedReadWithoutHiveSupportSuite

* fix SparkPlanSuite
  • Loading branch information
viirya authored Jan 3, 2025
1 parent 2e0f00a commit 4333dce
Show file tree
Hide file tree
Showing 830 changed files with 4,101 additions and 3,761 deletions.
22 changes: 0 additions & 22 deletions common/src/main/java/org/apache/comet/parquet/ColumnReader.java
Original file line number Diff line number Diff line change
Expand Up @@ -172,28 +172,6 @@ public void close() {

/** Returns a decoded {@link CometDecodedVector Comet vector}. */
public CometDecodedVector loadVector() {
// Only re-use Comet vector iff:
// 1. if we're not using dictionary encoding, since with dictionary encoding, the native
// side may fallback to plain encoding and the underlying memory address for the vector
// will change as result.
// 2. if the column type is of fixed width, in other words, string/binary are not supported
// since the native side may resize the vector and therefore change memory address.
// 3. if the last loaded vector contains null values: if values of last vector are all not
// null, Arrow C data API will skip loading the native validity buffer, therefore we
// should not re-use the vector in that case.
// 4. if the last loaded vector doesn't contain any null value, but the current vector also
// are all not null, which means we can also re-use the loaded vector.
// 5. if the new number of value is the same or smaller
if ((hadNull || currentNumNulls == 0)
&& currentVector != null
&& dictionary == null
&& currentVector.isFixedLength()
&& currentVector.numValues() >= currentNumValues) {
currentVector.setNumNulls(currentNumNulls);
currentVector.setNumValues(currentNumValues);
return currentVector;
}

LOG.debug("Reloading vector");

// Close the previous vector first to release struct memory allocated to import Arrow array &
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ public ConstantColumnReader(

public ConstantColumnReader(
DataType type, ColumnDescriptor descriptor, Object value, boolean useDecimal128) {
super(type, descriptor, useDecimal128);
super(type, descriptor, useDecimal128, true);
this.value = value;
}

ConstantColumnReader(
DataType type, ColumnDescriptor descriptor, int batchSize, boolean useDecimal128) {
super(type, descriptor, useDecimal128);
super(type, descriptor, useDecimal128, true);
this.batchSize = batchSize;
initNative();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,14 @@ public class MetadataColumnReader extends AbstractColumnReader {
private ArrowArray array = null;
private ArrowSchema schema = null;

public MetadataColumnReader(DataType type, ColumnDescriptor descriptor, boolean useDecimal128) {
private boolean isConstant;

public MetadataColumnReader(
DataType type, ColumnDescriptor descriptor, boolean useDecimal128, boolean isConstant) {
// TODO: should we handle legacy dates & timestamps for metadata columns?
super(type, descriptor, useDecimal128, false);

this.isConstant = isConstant;
}

@Override
Expand All @@ -62,7 +67,7 @@ public void readBatch(int total) {

Native.currentBatch(nativeHandle, arrayAddr, schemaAddr);
FieldVector fieldVector = Data.importVector(allocator, array, schema, null);
vector = new CometPlainVector(fieldVector, useDecimal128);
vector = new CometPlainVector(fieldVector, useDecimal128, false, isConstant);
}

vector.setNumValues(total);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public class RowIndexColumnReader extends MetadataColumnReader {
private long offset;

public RowIndexColumnReader(StructField field, int batchSize, long[] indices) {
super(field.dataType(), TypeUtil.convertToParquet(field), false);
super(field.dataType(), TypeUtil.convertToParquet(field), false, false);
this.indices = indices;
setBatchSize(batchSize);
}
Expand Down
16 changes: 16 additions & 0 deletions common/src/main/java/org/apache/comet/vector/CometPlainVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,18 @@ public class CometPlainVector extends CometDecodedVector {
private byte booleanByteCache;
private int booleanByteCacheIndex = -1;

private boolean isReused;

public CometPlainVector(ValueVector vector, boolean useDecimal128) {
this(vector, useDecimal128, false);
}

public CometPlainVector(ValueVector vector, boolean useDecimal128, boolean isUuid) {
this(vector, useDecimal128, isUuid, false);
}

public CometPlainVector(
ValueVector vector, boolean useDecimal128, boolean isUuid, boolean isReused) {
super(vector, vector.getField(), useDecimal128, isUuid);
// NullType doesn't have data buffer.
if (vector instanceof NullVector) {
Expand All @@ -52,6 +59,15 @@ public CometPlainVector(ValueVector vector, boolean useDecimal128, boolean isUui
}

isBaseFixedWidthVector = valueVector instanceof BaseFixedWidthVector;
this.isReused = isReused;
}

public boolean isReused() {
return isReused;
}

public void setReused(boolean isReused) {
this.isReused = isReused;
}

@Override
Expand Down
38 changes: 32 additions & 6 deletions dev/diffs/3.4.3.diff
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
diff --git a/pom.xml b/pom.xml
index d3544881af1..bf0e2b53c70 100644
index d3544881af1..26ab186c65d 100644
--- a/pom.xml
+++ b/pom.xml
@@ -148,6 +148,8 @@
Expand Down Expand Up @@ -38,7 +38,7 @@ index d3544881af1..bf0e2b53c70 100644
</dependencyManagement>

diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index b386d135da1..854aec17c2d 100644
index b386d135da1..46449e3f3f1 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -77,6 +77,10 @@
Expand Down Expand Up @@ -1284,6 +1284,27 @@ index 47679ed7865..9ffbaecb98e 100644
}.length == hashAggCount)
assert(collectWithSubqueries(plan) { case s: SortAggregateExec => s }.length == sortAggCount)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala
index b14f4a405f6..88815fd078f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.Deduplicate
+import org.apache.spark.sql.comet.CometColumnarToRowExec
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
@@ -131,7 +132,7 @@ class SparkPlanSuite extends QueryTest with SharedSparkSession {
spark.range(1).write.parquet(path.getAbsolutePath)
val df = spark.read.parquet(path.getAbsolutePath)
val columnarToRowExec =
- df.queryExecution.executedPlan.collectFirst { case p: ColumnarToRowExec => p }.get
+ df.queryExecution.executedPlan.collectFirst { case p: CometColumnarToRowExec => p }.get
try {
spark.range(1).foreach { _ =>
columnarToRowExec.canonicalized
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index ac710c32296..baae214c6ee 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
Expand Down Expand Up @@ -2281,7 +2302,7 @@ index d083cac48ff..3c11bcde807 100644
import testImplicits._

diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index 266bb343526..a426d8396be 100644
index 266bb343526..c3e3d155813 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -24,10 +24,11 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec
Expand Down Expand Up @@ -2331,7 +2352,7 @@ index 266bb343526..a426d8396be 100644

val bucketColumnType = bucketedDataFrame.schema.apply(bucketColumnIndex).dataType
val rowsWithInvalidBuckets = fileScan.execute().filter(row => {
@@ -451,28 +461,44 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
@@ -451,28 +461,49 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
val joinOperator = if (joined.sqlContext.conf.adaptiveExecutionEnabled) {
val executedPlan =
joined.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
Expand All @@ -2357,6 +2378,11 @@ index 266bb343526..a426d8396be 100644
+ case s: SortMergeJoinExec => s
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
+ case CometColumnarToRowExec(child) =>
+ child.asInstanceOf[CometSortMergeJoinExec].originalPlan match {
+ case s: SortMergeJoinExec => s
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
}
Expand Down Expand Up @@ -2384,7 +2410,7 @@ index 266bb343526..a426d8396be 100644
s"expected sort in the right child to be $sortRight but found\n${joinOperator.right}")

// check the output partitioning
@@ -835,11 +861,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
@@ -835,11 +866,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")

val scanDF = spark.table("bucketed_table").select("j")
Expand All @@ -2398,7 +2424,7 @@ index 266bb343526..a426d8396be 100644
checkAnswer(aggDF, df1.groupBy("j").agg(max("k")))
}
}
@@ -1026,15 +1052,23 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
@@ -1026,15 +1057,23 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
expectedNumShuffles: Int,
expectedCoalescedNumBuckets: Option[Int]): Unit = {
val plan = sql(query).queryExecution.executedPlan
Expand Down
40 changes: 33 additions & 7 deletions dev/diffs/3.5.1.diff
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
diff --git a/pom.xml b/pom.xml
index 0f504dbee85..f6019da888a 100644
index 0f504dbee85..430ec217e59 100644
--- a/pom.xml
+++ b/pom.xml
@@ -152,6 +152,8 @@
Expand Down Expand Up @@ -38,7 +38,7 @@ index 0f504dbee85..f6019da888a 100644
</dependencyManagement>

diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index c46ab7b8fce..d8b99c2c115 100644
index c46ab7b8fce..13357e8c7a6 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -77,6 +77,10 @@
Expand Down Expand Up @@ -1309,8 +1309,29 @@ index 47679ed7865..9ffbaecb98e 100644
}.length == hashAggCount)
assert(collectWithSubqueries(plan) { case s: SortAggregateExec => s }.length == sortAggCount)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala
index b14f4a405f6..88815fd078f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.Deduplicate
+import org.apache.spark.sql.comet.CometColumnarToRowExec
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
@@ -131,7 +132,7 @@ class SparkPlanSuite extends QueryTest with SharedSparkSession {
spark.range(1).write.parquet(path.getAbsolutePath)
val df = spark.read.parquet(path.getAbsolutePath)
val columnarToRowExec =
- df.queryExecution.executedPlan.collectFirst { case p: ColumnarToRowExec => p }.get
+ df.queryExecution.executedPlan.collectFirst { case p: CometColumnarToRowExec => p }.get
try {
spark.range(1).foreach { _ =>
columnarToRowExec.canonicalized
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 5a413c77754..c52f4b3818c 100644
index 5a413c77754..a6f97dccb67 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
Expand Down Expand Up @@ -2270,7 +2291,7 @@ index d083cac48ff..3c11bcde807 100644
import testImplicits._

diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index 746f289c393..1a2f1f7e3fd 100644
index 746f289c393..0c99d028163 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -25,10 +25,11 @@ import org.apache.spark.sql.catalyst.expressions
Expand Down Expand Up @@ -2320,7 +2341,7 @@ index 746f289c393..1a2f1f7e3fd 100644

val bucketColumnType = bucketedDataFrame.schema.apply(bucketColumnIndex).dataType
val rowsWithInvalidBuckets = fileScan.execute().filter(row => {
@@ -452,28 +462,44 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
@@ -452,28 +462,49 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
val joinOperator = if (joined.sqlContext.conf.adaptiveExecutionEnabled) {
val executedPlan =
joined.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
Expand All @@ -2346,6 +2367,11 @@ index 746f289c393..1a2f1f7e3fd 100644
+ case s: SortMergeJoinExec => s
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
+ case CometColumnarToRowExec(child) =>
+ child.asInstanceOf[CometSortMergeJoinExec].originalPlan match {
+ case s: SortMergeJoinExec => s
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
}
Expand Down Expand Up @@ -2373,7 +2399,7 @@ index 746f289c393..1a2f1f7e3fd 100644
s"expected sort in the right child to be $sortRight but found\n${joinOperator.right}")

// check the output partitioning
@@ -836,11 +862,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
@@ -836,11 +867,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")

val scanDF = spark.table("bucketed_table").select("j")
Expand All @@ -2387,7 +2413,7 @@ index 746f289c393..1a2f1f7e3fd 100644
checkAnswer(aggDF, df1.groupBy("j").agg(max("k")))
}
}
@@ -1029,15 +1055,21 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
@@ -1029,15 +1060,21 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
Seq(true, false).foreach { aqeEnabled =>
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled.toString) {
val plan = sql(query).queryExecution.executedPlan
Expand Down
34 changes: 30 additions & 4 deletions dev/diffs/4.0.0-preview1.diff
Original file line number Diff line number Diff line change
Expand Up @@ -1438,6 +1438,27 @@ index 47679ed7865..9ffbaecb98e 100644
}.length == hashAggCount)
assert(collectWithSubqueries(plan) { case s: SortAggregateExec => s }.length == sortAggCount)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala
index 966f4e74712..a715193d96d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.Deduplicate
+import org.apache.spark.sql.comet.CometColumnarToRowExec
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
@@ -134,7 +135,7 @@ class SparkPlanSuite extends QueryTest with SharedSparkSession {
spark.range(1).write.parquet(path.getAbsolutePath)
val df = spark.read.parquet(path.getAbsolutePath)
val columnarToRowExec =
- df.queryExecution.executedPlan.collectFirst { case p: ColumnarToRowExec => p }.get
+ df.queryExecution.executedPlan.collectFirst { case p: CometColumnarToRowExec => p }.get
try {
spark.range(1).foreach { _ =>
columnarToRowExec.canonicalized
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 3aaf61ffba4..4130ece2283 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
Expand Down Expand Up @@ -2562,7 +2583,7 @@ index 6ff07449c0c..9f95cff99e5 100644
import testImplicits._

diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index 3573bafe482..a21767840a2 100644
index 3573bafe482..11d387110ea 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -25,10 +25,11 @@ import org.apache.spark.sql.catalyst.expressions
Expand Down Expand Up @@ -2612,7 +2633,7 @@ index 3573bafe482..a21767840a2 100644

val bucketColumnType = bucketedDataFrame.schema.apply(bucketColumnIndex).dataType
val rowsWithInvalidBuckets = fileScan.execute().filter(row => {
@@ -452,28 +462,44 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
@@ -452,28 +462,49 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
val joinOperator = if (joined.sparkSession.sessionState.conf.adaptiveExecutionEnabled) {
val executedPlan =
joined.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
Expand All @@ -2638,6 +2659,11 @@ index 3573bafe482..a21767840a2 100644
+ case s: SortMergeJoinExec => s
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
+ case CometColumnarToRowExec(child) =>
+ child.asInstanceOf[CometSortMergeJoinExec].originalPlan match {
+ case s: SortMergeJoinExec => s
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
+ case o => fail(s"expected SortMergeJoinExec, but found\n$o")
+ }
}
Expand Down Expand Up @@ -2665,7 +2691,7 @@ index 3573bafe482..a21767840a2 100644
s"expected sort in the right child to be $sortRight but found\n${joinOperator.right}")

// check the output partitioning
@@ -836,11 +862,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
@@ -836,11 +867,11 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
df1.write.format("parquet").bucketBy(8, "i").saveAsTable("bucketed_table")

val scanDF = spark.table("bucketed_table").select("j")
Expand All @@ -2679,7 +2705,7 @@ index 3573bafe482..a21767840a2 100644
checkAnswer(aggDF, df1.groupBy("j").agg(max("k")))
}
}
@@ -1029,15 +1055,21 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
@@ -1029,15 +1060,21 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
Seq(true, false).foreach { aqeEnabled =>
withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqeEnabled.toString) {
val plan = sql(query).queryExecution.executedPlan
Expand Down
Loading

0 comments on commit 4333dce

Please sign in to comment.