diff --git a/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/AggregateModeInfo.java b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/AggregateModeInfo.java new file mode 100644 index 00000000000..80834a02f5c --- /dev/null +++ b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/AggregateModeInfo.java @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids; + +import java.io.Serializable; +import java.util.Objects; + +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode; +import org.apache.spark.sql.catalyst.expressions.aggregate.Complete$; +import org.apache.spark.sql.catalyst.expressions.aggregate.Final$; +import org.apache.spark.sql.catalyst.expressions.aggregate.Partial$; +import org.apache.spark.sql.catalyst.expressions.aggregate.PartialMerge$; + +import scala.collection.Seq; + +/** + * Information on the aggregation modes being used. + */ +public class AggregateModeInfo implements Serializable { + private static final long serialVersionUID = 1L; + + private final Seq uniqueModes; + private final boolean hasPartialMode; + private final boolean hasPartialMergeMode; + private final boolean hasFinalMode; + private final boolean hasCompleteMode; + + public AggregateModeInfo( + Seq uniqueModes, + boolean hasPartialMode, + boolean hasPartialMergeMode, + boolean hasFinalMode, + boolean hasCompleteMode) { + this.uniqueModes = uniqueModes; + this.hasPartialMode = hasPartialMode; + this.hasPartialMergeMode = hasPartialMergeMode; + this.hasFinalMode = hasFinalMode; + this.hasCompleteMode = hasCompleteMode; + } + + public static AggregateModeInfo from(Seq uniqueModes) { + return new AggregateModeInfo( + uniqueModes, + uniqueModes.contains(Partial$.MODULE$), + uniqueModes.contains(PartialMerge$.MODULE$), + uniqueModes.contains(Final$.MODULE$), + uniqueModes.contains(Complete$.MODULE$)); + } + + public Seq uniqueModes() { + return uniqueModes; + } + + public boolean hasPartialMode() { + return hasPartialMode; + } + + public boolean hasPartialMergeMode() { + return hasPartialMergeMode; + } + + public boolean hasFinalMode() { + return hasFinalMode; + } + + public boolean hasCompleteMode() { + return hasCompleteMode; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof AggregateModeInfo)) { + return false; + } + AggregateModeInfo that = (AggregateModeInfo) other; + return hasPartialMode == that.hasPartialMode + && hasPartialMergeMode == that.hasPartialMergeMode + && hasFinalMode == that.hasFinalMode + && hasCompleteMode == that.hasCompleteMode + && Objects.equals(uniqueModes, that.uniqueModes); + } + + @Override + public int hashCode() { + return Objects.hash( + uniqueModes, hasPartialMode, hasPartialMergeMode, hasFinalMode, hasCompleteMode); + } + + @Override + public String toString() { + return "AggregateModeInfo(" + uniqueModes + "," + hasPartialMode + "," + + hasPartialMergeMode + "," + hasFinalMode + "," + hasCompleteMode + ")"; + } +} diff --git a/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/io/async/AsyncMetrics.java b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/io/async/AsyncMetrics.java new file mode 100644 index 00000000000..bea87901510 --- /dev/null +++ b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/io/async/AsyncMetrics.java @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.io.async; + +import java.io.Serializable; +import java.util.Objects; + +/** + * Scheduling and execution timings for an async task. + */ +public class AsyncMetrics implements Serializable { + private static final long serialVersionUID = 1L; + + private final long scheduleTimeMs; + private final long executionTimeMs; + + public AsyncMetrics(long scheduleTimeMs, long executionTimeMs) { + this.scheduleTimeMs = scheduleTimeMs; + this.executionTimeMs = executionTimeMs; + } + + public long scheduleTimeMs() { + return scheduleTimeMs; + } + + public long executionTimeMs() { + return executionTimeMs; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof AsyncMetrics)) { + return false; + } + AsyncMetrics that = (AsyncMetrics) other; + return scheduleTimeMs == that.scheduleTimeMs + && executionTimeMs == that.executionTimeMs; + } + + @Override + public int hashCode() { + return Objects.hash(scheduleTimeMs, executionTimeMs); + } + + @Override + public String toString() { + return "AsyncMetrics(" + scheduleTimeMs + "," + executionTimeMs + ")"; + } +} diff --git a/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/io/async/ThrottlingExecutorStats.java b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/io/async/ThrottlingExecutorStats.java new file mode 100644 index 00000000000..40b86a6ad3e --- /dev/null +++ b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/io/async/ThrottlingExecutorStats.java @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.io.async; + +import java.io.Serializable; +import java.util.Objects; + +/** + * Mutable throttling counters updated by ThrottlingExecutor. + */ +public class ThrottlingExecutorStats implements Serializable { + private static final long serialVersionUID = 1L; + + public int numTasksScheduled; + public long accumulatedThrottleTimeNs; + public long minThrottleTimeNs; + public long maxThrottleTimeNs; + + public ThrottlingExecutorStats( + int numTasksScheduled, + long accumulatedThrottleTimeNs, + long minThrottleTimeNs, + long maxThrottleTimeNs) { + this.numTasksScheduled = numTasksScheduled; + this.accumulatedThrottleTimeNs = accumulatedThrottleTimeNs; + this.minThrottleTimeNs = minThrottleTimeNs; + this.maxThrottleTimeNs = maxThrottleTimeNs; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof ThrottlingExecutorStats)) { + return false; + } + ThrottlingExecutorStats that = (ThrottlingExecutorStats) other; + return numTasksScheduled == that.numTasksScheduled + && accumulatedThrottleTimeNs == that.accumulatedThrottleTimeNs + && minThrottleTimeNs == that.minThrottleTimeNs + && maxThrottleTimeNs == that.maxThrottleTimeNs; + } + + @Override + public int hashCode() { + return Objects.hash( + numTasksScheduled, accumulatedThrottleTimeNs, minThrottleTimeNs, maxThrottleTimeNs); + } + + @Override + public String toString() { + return "ThrottlingExecutorStats(" + numTasksScheduled + "," + + accumulatedThrottleTimeNs + "," + minThrottleTimeNs + "," + + maxThrottleTimeNs + ")"; + } +} diff --git a/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/shuffle/BlockRange.java b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/shuffle/BlockRange.java new file mode 100644 index 00000000000..1dcc97b5165 --- /dev/null +++ b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/shuffle/BlockRange.java @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2020-2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shuffle; + +import java.util.Objects; + +/** Byte range for a block. */ +public final class BlockRange { + private final T block; + private final long rangeStart; + private final long rangeEnd; + + public BlockRange(T block, long rangeStart, long rangeEnd) { + if (rangeStart >= rangeEnd) { + throw new IllegalArgumentException( + "requirement failed: Instantiated a BlockRange with invalid boundaries: " + + rangeStart + " to " + rangeEnd); + } + this.block = block; + this.rangeStart = rangeStart; + this.rangeEnd = rangeEnd; + } + + public T block() { + return block; + } + + public long rangeStart() { + return rangeStart; + } + + public long rangeEnd() { + return rangeEnd; + } + + public long rangeSize() { + return rangeEnd - rangeStart; + } + + public boolean isComplete() { + return rangeEnd == block.size(); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof BlockRange)) { + return false; + } + BlockRange other = (BlockRange) obj; + return rangeStart == other.rangeStart && + rangeEnd == other.rangeEnd && + Objects.equals(block, other.block); + } + + @Override + public int hashCode() { + return Objects.hash(block, rangeStart, rangeEnd); + } + + @Override + public String toString() { + return "BlockRange(" + block + "," + rangeStart + "," + rangeEnd + ")"; + } +} diff --git a/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/shuffle/BlockWithSize.java b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/shuffle/BlockWithSize.java new file mode 100644 index 00000000000..312f8198fef --- /dev/null +++ b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/shuffle/BlockWithSize.java @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2020-2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shuffle; + +/** Block-like value that can report its size in bytes. */ +public interface BlockWithSize { + long size(); +} diff --git a/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/shuffle/TransactionStats.java b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/shuffle/TransactionStats.java new file mode 100644 index 00000000000..75f4a0cef6e --- /dev/null +++ b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/shuffle/TransactionStats.java @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2020-2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.shuffle; + +import java.util.Objects; + +/** Statistics for a shuffle transaction. */ +public final class TransactionStats { + private final double txTimeMs; + private final long sendSize; + private final long receiveSize; + private final double sendThroughput; + private final double recvThroughput; + + public TransactionStats(double txTimeMs, long sendSize, long receiveSize, + double sendThroughput, double recvThroughput) { + this.txTimeMs = txTimeMs; + this.sendSize = sendSize; + this.receiveSize = receiveSize; + this.sendThroughput = sendThroughput; + this.recvThroughput = recvThroughput; + } + + public double txTimeMs() { + return txTimeMs; + } + + public long sendSize() { + return sendSize; + } + + public long receiveSize() { + return receiveSize; + } + + public double sendThroughput() { + return sendThroughput; + } + + public double recvThroughput() { + return recvThroughput; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof TransactionStats)) { + return false; + } + TransactionStats other = (TransactionStats) obj; + return Double.compare(txTimeMs, other.txTimeMs) == 0 && + sendSize == other.sendSize && + receiveSize == other.receiveSize && + Double.compare(sendThroughput, other.sendThroughput) == 0 && + Double.compare(recvThroughput, other.recvThroughput) == 0; + } + + @Override + public int hashCode() { + return Objects.hash(txTimeMs, sendSize, receiveSize, sendThroughput, recvThroughput); + } + + @Override + public String toString() { + return "TransactionStats(" + txTimeMs + "," + sendSize + "," + receiveSize + "," + + sendThroughput + "," + recvThroughput + ")"; + } +} diff --git a/sql-plugin-columnar/src/main/java/org/apache/spark/sql/rapids/execution/JoinCardinalityStats.java b/sql-plugin-columnar/src/main/java/org/apache/spark/sql/rapids/execution/JoinCardinalityStats.java new file mode 100644 index 00000000000..8cd8a0c4cfc --- /dev/null +++ b/sql-plugin-columnar/src/main/java/org/apache/spark/sql/rapids/execution/JoinCardinalityStats.java @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.execution; + +import java.io.Serializable; +import java.util.Objects; + +import org.apache.spark.sql.types.DataType; + +import scala.collection.Seq; + +/** Statistics for join cardinality logging to help diagnose performance issues. */ +public final class JoinCardinalityStats implements Serializable { + private static final long serialVersionUID = 0L; + + private final long leftRowCount; + private final long rightRowCount; + private final long leftDistinctCount; + private final long rightDistinctCount; + private final Seq leftNullCounts; + private final Seq rightNullCounts; + private final Seq leftKeyTypes; + private final Seq rightKeyTypes; + + public JoinCardinalityStats( + long leftRowCount, + long rightRowCount, + long leftDistinctCount, + long rightDistinctCount, + Seq leftNullCounts, + Seq rightNullCounts, + Seq leftKeyTypes, + Seq rightKeyTypes) { + this.leftRowCount = leftRowCount; + this.rightRowCount = rightRowCount; + this.leftDistinctCount = leftDistinctCount; + this.rightDistinctCount = rightDistinctCount; + this.leftNullCounts = leftNullCounts; + this.rightNullCounts = rightNullCounts; + this.leftKeyTypes = leftKeyTypes; + this.rightKeyTypes = rightKeyTypes; + } + + public long leftRowCount() { + return leftRowCount; + } + + public long rightRowCount() { + return rightRowCount; + } + + public long leftDistinctCount() { + return leftDistinctCount; + } + + public long rightDistinctCount() { + return rightDistinctCount; + } + + public Seq leftNullCounts() { + return leftNullCounts; + } + + public Seq rightNullCounts() { + return rightNullCounts; + } + + public Seq leftKeyTypes() { + return leftKeyTypes; + } + + public Seq rightKeyTypes() { + return rightKeyTypes; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof JoinCardinalityStats)) { + return false; + } + JoinCardinalityStats that = (JoinCardinalityStats) other; + return leftRowCount == that.leftRowCount + && rightRowCount == that.rightRowCount + && leftDistinctCount == that.leftDistinctCount + && rightDistinctCount == that.rightDistinctCount + && Objects.equals(leftNullCounts, that.leftNullCounts) + && Objects.equals(rightNullCounts, that.rightNullCounts) + && Objects.equals(leftKeyTypes, that.leftKeyTypes) + && Objects.equals(rightKeyTypes, that.rightKeyTypes); + } + + @Override + public int hashCode() { + return Objects.hash( + leftRowCount, + rightRowCount, + leftDistinctCount, + rightDistinctCount, + leftNullCounts, + rightNullCounts, + leftKeyTypes, + rightKeyTypes); + } + + @Override + public String toString() { + return "JoinCardinalityStats(" + leftRowCount + "," + rightRowCount + "," + + leftDistinctCount + "," + rightDistinctCount + "," + leftNullCounts + "," + + rightNullCounts + "," + leftKeyTypes + "," + rightKeyTypes + ")"; + } +} diff --git a/sql-plugin-columnar/src/main/java/org/apache/spark/sql/rapids/execution/JoinOptions.java b/sql-plugin-columnar/src/main/java/org/apache/spark/sql/rapids/execution/JoinOptions.java new file mode 100644 index 00000000000..be0487ecde7 --- /dev/null +++ b/sql-plugin-columnar/src/main/java/org/apache/spark/sql/rapids/execution/JoinOptions.java @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.rapids.execution; + +import java.io.Serializable; +import java.util.Objects; + +import scala.Enumeration.Value; + +/** Options to control join behavior. */ +public final class JoinOptions implements Serializable { + private static final long serialVersionUID = 0L; + + private final Value strategy; + private final Value buildSideSelection; + private final long targetSize; + private final boolean logCardinalityEnabled; + private final double sizeEstimateThreshold; + + public JoinOptions( + Value strategy, + Value buildSideSelection, + long targetSize, + boolean logCardinalityEnabled, + double sizeEstimateThreshold) { + this.strategy = strategy; + this.buildSideSelection = buildSideSelection; + this.targetSize = targetSize; + this.logCardinalityEnabled = logCardinalityEnabled; + this.sizeEstimateThreshold = sizeEstimateThreshold; + } + + public Value strategy() { + return strategy; + } + + public Value buildSideSelection() { + return buildSideSelection; + } + + public long targetSize() { + return targetSize; + } + + public boolean logCardinalityEnabled() { + return logCardinalityEnabled; + } + + public double sizeEstimateThreshold() { + return sizeEstimateThreshold; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof JoinOptions)) { + return false; + } + JoinOptions that = (JoinOptions) other; + return targetSize == that.targetSize + && logCardinalityEnabled == that.logCardinalityEnabled + && Double.compare(that.sizeEstimateThreshold, sizeEstimateThreshold) == 0 + && Objects.equals(strategy, that.strategy) + && Objects.equals(buildSideSelection, that.buildSideSelection); + } + + @Override + public int hashCode() { + return Objects.hash( + strategy, buildSideSelection, targetSize, logCardinalityEnabled, sizeEstimateThreshold); + } + + @Override + public String toString() { + return "JoinOptions(" + strategy + "," + buildSideSelection + "," + targetSize + "," + + logCardinalityEnabled + "," + sizeEstimateThreshold + ")"; + } +} diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/WindowedBlockIteratorSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/WindowedBlockIteratorSuite.scala index 47d32633062..e8a952b3ef5 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/WindowedBlockIteratorSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/shuffle/WindowedBlockIteratorSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,10 +28,10 @@ class WindowedBlockIteratorSuite extends RapidsShuffleTestHelper { } test ("1-byte+ ranges are allowed, but 0-byte or negative ranges are not") { - assertResult(1)(BlockRange(null, 123, 124).rangeSize()) - assertResult(2)(BlockRange(null, 123, 125).rangeSize()) - assertThrows[IllegalArgumentException](BlockRange(null, 123, 123)) - assertThrows[IllegalArgumentException](BlockRange(null, 123, 122)) + assertResult(1)(new BlockRange(null, 123, 124).rangeSize()) + assertResult(2)(new BlockRange(null, 123, 125).rangeSize()) + assertThrows[IllegalArgumentException](new BlockRange(null, 123, 123)) + assertThrows[IllegalArgumentException](new BlockRange(null, 123, 122)) } test ("0-byte blocks are not allowed") {