Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 24 additions & 23 deletions sql-plugin/src/main/scala/com/nvidia/spark/FunctionsImpl.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,87 +35,87 @@ class FunctionsImpl extends Functions {
* nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
*/
override def df_udf(f: Function0[Column]): UserDefinedFunction =
sp_udf(DFUDF0(f), LongType)
sp_udf(new DFUDF0(f), LongType)

/**
* Defines a Scala closure of Columns as user-defined function (UDF).
* By default the returned UDF is deterministic. To change it to
* nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
*/
override def df_udf(f: Function1[Column, Column]): UserDefinedFunction =
sp_udf(DFUDF1(f), LongType)
sp_udf(new DFUDF1(f), LongType)

/**
* Defines a Scala closure of Columns as user-defined function (UDF).
* By default the returned UDF is deterministic. To change it to
* nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
*/
override def df_udf(f: Function2[Column, Column, Column]): UserDefinedFunction =
sp_udf(DFUDF2(f), LongType)
sp_udf(new DFUDF2(f), LongType)

/**
* Defines a Scala closure of Columns as user-defined function (UDF).
* By default the returned UDF is deterministic. To change it to
* nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
*/
override def df_udf(f: Function3[Column, Column, Column, Column]): UserDefinedFunction =
sp_udf(DFUDF3(f), LongType)
sp_udf(new DFUDF3(f), LongType)

/**
* Defines a Scala closure of Columns as user-defined function (UDF).
* By default the returned UDF is deterministic. To change it to
* nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
*/
override def df_udf(f: Function4[Column, Column, Column, Column, Column]): UserDefinedFunction =
sp_udf(DFUDF4(f), LongType)
sp_udf(new DFUDF4(f), LongType)

/**
* Defines a Scala closure of Columns as user-defined function (UDF).
* By default the returned UDF is deterministic. To change it to
* nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
*/
override def df_udf(f: Function5[Column, Column, Column, Column, Column,
Column]): UserDefinedFunction = sp_udf(DFUDF5(f), LongType)
Column]): UserDefinedFunction = sp_udf(new DFUDF5(f), LongType)

/**
* Defines a Scala closure of Columns as user-defined function (UDF).
* By default the returned UDF is deterministic. To change it to
* nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
*/
override def df_udf(f: Function6[Column, Column, Column, Column, Column, Column,
Column]): UserDefinedFunction = sp_udf(DFUDF6(f), LongType)
Column]): UserDefinedFunction = sp_udf(new DFUDF6(f), LongType)

/**
* Defines a Scala closure of Columns as user-defined function (UDF).
* By default the returned UDF is deterministic. To change it to
* nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
*/
override def df_udf(f: Function7[Column, Column, Column, Column, Column, Column,
Column, Column]): UserDefinedFunction = sp_udf(DFUDF7(f), LongType)
Column, Column]): UserDefinedFunction = sp_udf(new DFUDF7(f), LongType)

/**
* Defines a Scala closure of Columns as user-defined function (UDF).
* By default the returned UDF is deterministic. To change it to
* nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
*/
override def df_udf(f: Function8[Column, Column, Column, Column, Column, Column,
Column, Column, Column]): UserDefinedFunction = sp_udf(DFUDF8(f), LongType)
Column, Column, Column]): UserDefinedFunction = sp_udf(new DFUDF8(f), LongType)

/**
* Defines a Scala closure of Columns as user-defined function (UDF).
* By default the returned UDF is deterministic. To change it to
* nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
*/
override def df_udf(f: Function9[Column, Column, Column, Column, Column, Column,
Column, Column, Column, Column]): UserDefinedFunction = sp_udf(DFUDF9(f), LongType)
Column, Column, Column, Column]): UserDefinedFunction = sp_udf(new DFUDF9(f), LongType)

/**
* Defines a Scala closure of Columns as user-defined function (UDF).
* By default the returned UDF is deterministic. To change it to
* nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
*/
override def df_udf(f: Function10[Column, Column, Column, Column, Column, Column,
Column, Column, Column, Column, Column]): UserDefinedFunction = sp_udf(DFUDF10(f), LongType)
Column, Column, Column, Column, Column]): UserDefinedFunction = sp_udf(new DFUDF10(f), LongType)


//////////////////////////////////////////////////////////////////////////////////////////////
Expand All @@ -128,85 +128,86 @@ class FunctionsImpl extends Functions {
* API `UserDefinedFunction.asNondeterministic()`.
*/
override def df_udf(f: UDF0[Column]): UserDefinedFunction =
sp_udf(JDFUDF0(f), LongType)
sp_udf(new JDFUDF0(f), LongType)

/**
* Defines a Java UDF instance of Columns as user-defined function (UDF).
* By default the returned UDF is deterministic. To change it to nondeterministic, call the
* API `UserDefinedFunction.asNondeterministic()`.
*/
override def df_udf(f: UDF1[Column, Column]): UserDefinedFunction =
sp_udf(JDFUDF1(f), LongType)
sp_udf(new JDFUDF1(f), LongType)

/**
* Defines a Java UDF instance of Columns as user-defined function (UDF).
* By default the returned UDF is deterministic. To change it to nondeterministic, call the
* API `UserDefinedFunction.asNondeterministic()`.
*/
override def df_udf(f: UDF2[Column, Column, Column]): UserDefinedFunction =
sp_udf(JDFUDF2(f), LongType)
sp_udf(new JDFUDF2(f), LongType)

/**
* Defines a Java UDF instance of Columns as user-defined function (UDF).
* By default the returned UDF is deterministic. To change it to nondeterministic, call the
* API `UserDefinedFunction.asNondeterministic()`.
*/
override def df_udf(f: UDF3[Column, Column, Column, Column]): UserDefinedFunction =
sp_udf(JDFUDF3(f), LongType)
sp_udf(new JDFUDF3(f), LongType)

/**
* Defines a Java UDF instance of Columns as user-defined function (UDF).
* By default the returned UDF is deterministic. To change it to nondeterministic, call the
* API `UserDefinedFunction.asNondeterministic()`.
*/
override def df_udf(f: UDF4[Column, Column, Column, Column, Column]): UserDefinedFunction =
sp_udf(JDFUDF4(f), LongType)
sp_udf(new JDFUDF4(f), LongType)

/**
* Defines a Java UDF instance of Columns as user-defined function (UDF).
* By default the returned UDF is deterministic. To change it to nondeterministic, call the
* API `UserDefinedFunction.asNondeterministic()`.
*/
override def df_udf(f: UDF5[Column, Column, Column, Column, Column,
Column]): UserDefinedFunction = sp_udf(JDFUDF5(f), LongType)
Column]): UserDefinedFunction = sp_udf(new JDFUDF5(f), LongType)

/**
* Defines a Java UDF instance of Columns as user-defined function (UDF).
* By default the returned UDF is deterministic. To change it to nondeterministic, call the
* API `UserDefinedFunction.asNondeterministic()`.
*/
override def df_udf(f: UDF6[Column, Column, Column, Column, Column, Column,
Column]): UserDefinedFunction = sp_udf(JDFUDF6(f), LongType)
Column]): UserDefinedFunction = sp_udf(new JDFUDF6(f), LongType)

/**
* Defines a Java UDF instance of Columns as user-defined function (UDF).
* By default the returned UDF is deterministic. To change it to nondeterministic, call the
* API `UserDefinedFunction.asNondeterministic()`.
*/
override def df_udf(f: UDF7[Column, Column, Column, Column, Column, Column,
Column, Column]): UserDefinedFunction = sp_udf(JDFUDF7(f), LongType)
Column, Column]): UserDefinedFunction = sp_udf(new JDFUDF7(f), LongType)

/**
* Defines a Java UDF instance of Columns as user-defined function (UDF).
* By default the returned UDF is deterministic. To change it to nondeterministic, call the
* API `UserDefinedFunction.asNondeterministic()`.
*/
override def df_udf(f: UDF8[Column, Column, Column, Column, Column, Column,
Column, Column, Column]): UserDefinedFunction = sp_udf(JDFUDF8(f), LongType)
Column, Column, Column]): UserDefinedFunction = sp_udf(new JDFUDF8(f), LongType)

/**
* Defines a Java UDF instance of Columns as user-defined function (UDF).
* By default the returned UDF is deterministic. To change it to nondeterministic, call the
* API `UserDefinedFunction.asNondeterministic()`.
*/
override def df_udf(f: UDF9[Column, Column, Column, Column, Column, Column,
Column, Column, Column, Column]): UserDefinedFunction = sp_udf(JDFUDF9(f), LongType)
Column, Column, Column, Column]): UserDefinedFunction = sp_udf(new JDFUDF9(f), LongType)

/**
* Defines a Java UDF instance of Columns as user-defined function (UDF).
* By default the returned UDF is deterministic. To change it to nondeterministic, call the
* API `UserDefinedFunction.asNondeterministic()`.
*/
override def df_udf(f: UDF10[Column, Column, Column, Column, Column, Column,
Column, Column, Column, Column, Column]): UserDefinedFunction = sp_udf(JDFUDF10(f), LongType)
Column, Column, Column, Column, Column]): UserDefinedFunction =
sp_udf(new JDFUDF10(f), LongType)
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import com.nvidia.spark.rapids.RmmRapidsRetryIterator.{splitTargetSizeInHalfGpu,
import com.nvidia.spark.rapids.ScalableTaskCompletion.onTaskCompletion

import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.{InnerLike, JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.rapids.execution.GatherMapsResult
Expand Down Expand Up @@ -148,7 +147,7 @@ abstract class AbstractGpuJoinIterator(
// less from the gatherer, but because the gatherer tracks how much is used, the
// next call to this function will start in the right place.
val estimatedDataSize = (gather.numRowsLeft * gather.realCheapPerRowSizeEstimate).toLong
val targetSizeWrapper = AutoCloseableTargetSize(targetSize, minTargetSize,
val targetSizeWrapper = new AutoCloseableTargetSize(targetSize, minTargetSize,
estimatedDataSize)
gather.checkpoint()
withRetry(targetSizeWrapper, splitTargetSizeInHalfGpu) { attempt =>
Expand Down Expand Up @@ -199,7 +198,7 @@ abstract class SplittableJoinIterator(
targetSize,
sizeEstimateThreshold,
opTime = opTime,
joinTime = joinTime) with Logging {
joinTime = joinTime) with RapidsLocalLog {
// For some join types even if there is no stream data we might output something
private var isInitialJoin = true
// If the join explodes this holds batches from the stream side split into smaller pieces.
Expand Down Expand Up @@ -364,7 +363,7 @@ abstract class SplittableJoinIterator(
case None if joinType == RightOuter && rightData.numCols > 0 =>
// Distinct right outer joins only produce a single gather map since right table rows
// are not rearranged by the join.
MultiJoinGather(leftGatherer, new JoinGathererSameTable(rightData))
new MultiJoinGather(leftGatherer, new JoinGathererSameTable(rightData))
case None =>
// When there isn't a `rightMap` we are in either LeftSemi or LeftAnti joins.
// In these cases, the map and the table are both the left side, and everything in the map
Expand All @@ -383,7 +382,7 @@ abstract class SplittableJoinIterator(
}
val lazyRightMap = LazySpillableGatherMap(right, "right_map")
val rightGatherer = JoinGatherer(lazyRightMap, rightData, rightOutOfBoundsPolicy)
MultiJoinGather(leftGatherer, rightGatherer)
new MultiJoinGather(leftGatherer, rightGatherer)
}
if (gatherer.isDone) {
// Nothing matched...
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
* Copyright (c) 2023-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,28 +28,6 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch

/**
* Wrapper class that specifies how many rows to replicate
* the partition value.
*/
case class PartitionRowData(rowValue: InternalRow, rowNum: Int)

object PartitionRowData {
def from(rowValues: Array[InternalRow], rowNums: Array[Int]): Array[PartitionRowData] = {
rowValues.zip(rowNums).map {
case (rowValue, rowNum) => PartitionRowData(rowValue, rowNum)
}
}

def from(rowValues: Array[InternalRow], rowNums: Array[Long]): Array[PartitionRowData] = {
rowValues.zip(rowNums).map {
case (rowValue, rowNum) =>
require(rowNum <= Integer.MAX_VALUE, s"Row number $rowNum exceeds max value of an integer.")
PartitionRowData(rowValue, rowNum.toInt)
}
}
}

/**
* Class to wrap columnar batch and partition rows data and utility functions to merge them.
*
Expand All @@ -59,10 +37,10 @@ object PartitionRowData {
* rows to replicate the partition value.
* @param partitionSchema Schema of the partitioned data.
*/
case class BatchWithPartitionData(
inputBatch: SpillableColumnarBatch,
partitionedRowsData: Array[PartitionRowData],
partitionSchema: StructType) extends AutoCloseable {
class BatchWithPartitionData(
val inputBatch: SpillableColumnarBatch,
val partitionedRowsData: Array[PartitionRowData],
val partitionSchema: StructType) extends AutoCloseable {

/**
* Merges the partitioned data with the input ColumnarBatch.
Expand Down Expand Up @@ -98,7 +76,9 @@ case class BatchWithPartitionData(
val dataType = field.dataType
// Create an array to hold the individual columns for each partition.
val singlePartCols = partitionedRowsData.safeMap {
case PartitionRowData(valueRow, rowNum) =>
partitionRowData =>
val valueRow = partitionRowData.rowValue
val rowNum = partitionRowData.rowNum
val singleValue = valueRow.get(colIndex, dataType)
withResource(GpuScalar.from(singleValue, dataType)) { singleScalar =>
// Create a column vector from the GPU scalar, associated with the row number.
Expand Down Expand Up @@ -272,14 +252,14 @@ object BatchWithPartitionDataUtils {
// Splitting occurs if for any column, maximum rows we can fit is less than rows in partition.
splitOccurred = maxRows < rowsInPartition
if (splitOccurred) {
currentBatch.append(PartitionRowData(valuesInPartition, maxRows))
currentBatch.append(new PartitionRowData(valuesInPartition, maxRows))
resultBatches.append(currentBatch.toArray)
currentBatch.clear()
java.util.Arrays.fill(sizeOfBatch, 0)
rowsInPartition -= maxRows
} else {
// If there was no split, all rows can fit in current batch.
currentBatch.append(PartitionRowData(valuesInPartition, rowsInPartition))
currentBatch.append(new PartitionRowData(valuesInPartition, rowsInPartition))
val partitionSizes = calculatePartitionSizes(rowsInPartition, valuesInPartition, partSchema)
sizeOfBatch.indices.foreach(i => sizeOfBatch(i) += partitionSizes(i))
}
Expand Down Expand Up @@ -364,7 +344,7 @@ object BatchWithPartitionDataUtils {
// Combine the split GPU ColumnVectors with partition ColumnVectors.
splitColumnarBatches.zip(listOfPartitionedRowsData).map {
case (spillableBatch, partitionedRowsData) =>
BatchWithPartitionData(spillableBatch, partitionedRowsData, partitionSchema)
new BatchWithPartitionData(spillableBatch, partitionedRowsData, partitionSchema)
}
}
}
Expand Down Expand Up @@ -397,9 +377,7 @@ object BatchWithPartitionDataUtils {
listOfPartitionedRowsData: Array[Array[PartitionRowData]]): Seq[Int] = {
// Calculate the row counts for each batch
val rowCountsForEachBatch = listOfPartitionedRowsData.map(partitionData =>
partitionData.map {
case PartitionRowData(_, rowNum) => rowNum
}.sum
partitionData.map(_.rowNum).sum
)
// Calculate split indices using cumulative sum
rowCountsForEachBatch.scanLeft(0)(_ + _).drop(1).dropRight(1)
Expand Down Expand Up @@ -479,13 +457,13 @@ object BatchWithPartitionDataUtils {
if (remainingRows > 0) {
// Add rows to the left partition, up to the remaining rows available
val rowsToAddToLeft = Math.min(partitionRow.rowNum, remainingRows)
leftHalf += partitionRow.copy(rowNum = rowsToAddToLeft)
leftHalf += new PartitionRowData(partitionRow.rowValue, rowsToAddToLeft)
rowsAddedToLeft += rowsToAddToLeft
remainingRows -= rowsToAddToLeft
if (remainingRows <= 0) {
// Add remaining rows to the right partition
val rowsToAddToRight = partitionRow.rowNum - rowsToAddToLeft
rightHalf += partitionRow.copy(rowNum = rowsToAddToRight)
rightHalf += new PartitionRowData(partitionRow.rowValue, rowsToAddToRight)
rowsAddedToRight += rowsToAddToRight
}
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
* Copyright (c) 2025-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,7 +15,6 @@
*/
package com.nvidia.spark.rapids

import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.exchange.Exchange
import org.apache.spark.sql.rapids.GpuFileSourceScanExec
Expand All @@ -28,7 +27,7 @@ import org.apache.spark.sql.rapids.GpuFileSourceScanExec
*
* NOTE: This is postShimPlanRule which should be applied after GpuOverrides.
*/
object BucketJoinTwoSidesPrefetch extends Rule[SparkPlan] {
object BucketJoinTwoSidesPrefetch {

// Traverse through the plan tree and enable IO prefetch for all GpuFileSourceScanExec
// which are directly connected to this join node without any shuffle.
Expand All @@ -44,7 +43,7 @@ object BucketJoinTwoSidesPrefetch extends Rule[SparkPlan] {
}
}

override def apply(plan: SparkPlan): SparkPlan = {
def apply(plan: SparkPlan): SparkPlan = {
// Enable IO prefetch by a mutable operation on target nodes instead of re-generating
// the plan tree. By doing so, it saves a lot of trouble.
if (RapidsConf.BUCKET_JOIN_IO_PREFETCH.get(plan.conf)) {
Expand Down
Loading
Loading