Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2025, NVIDIA CORPORATION.
* Copyright (c) 2022-2026, NVIDIA CORPORATION.
*
* This file was derived from CheckDeltaInvariant.scala in the
* Delta Lake project at https://github.com/delta-io/delta.
Expand Down Expand Up @@ -132,8 +132,8 @@ object GpuCheckDeltaInvariant extends Logging {
ExprChecks.projectOnly(
TypeSig.all,
TypeSig.all,
paramCheck = Seq(ParamCheck("input", TypeSig.all, TypeSig.all)),
repeatingParamCheck = Some(RepeatingParamCheck("extra", TypeSig.all, TypeSig.all))
paramCheck = Seq(new ParamCheck("input", TypeSig.all, TypeSig.all)),
repeatingParamCheck = Some(new RepeatingParamCheck("extra", TypeSig.all, TypeSig.all))
),
(c, conf, p, r) => new GpuCheckDeltaInvariantMeta(c, conf, p, r))

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -110,7 +110,7 @@ object AstUtil {
val gpuExpr = expr.convertToGpu()

// Check if we've already processed this expression (for deduplication)
processed.get(GpuExpressionEquals(gpuExpr)) match {
processed.get(new GpuExpressionEquals(gpuExpr)) match {
case Some(replacement) =>
replacement
case None =>
Expand All @@ -135,7 +135,7 @@ object AstUtil {
// Create an AttributeReference explicitly to avoid issues with unresolved aliases
val attributeRef = AttributeReference(alias.name, gpuExpr.dataType,
gpuExpr.nullable, alias.metadata)(alias.exprId, alias.qualifier)
processed.put(GpuExpressionEquals(gpuExpr), attributeRef)
processed.put(new GpuExpressionEquals(gpuExpr), attributeRef)
attributeRef
}
} else {
Expand Down
599 changes: 397 additions & 202 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
* Copyright (c) 2025-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -26,7 +26,8 @@ import org.apache.spark.sql.catalyst.rules.Rule
* phase by `SparkSessionExtensions.injectPostHocResolutionRule`. As its name suggests, it will
* be applied after the logical plan has been resolved.
*/
case class GpuPostHocResolutionOverrides(spark: SparkSession) extends Rule[LogicalPlan] {
class GpuPostHocResolutionOverrides(val spark: SparkSession)
extends Rule[LogicalPlan] with Serializable {

@transient private val rapidsConf = new RapidsConf(spark.sessionState.conf)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
* Copyright (c) 2021-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -23,7 +23,6 @@ import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.shims.ShimExpression

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UserDefinedExpression}
import org.apache.spark.sql.rapids.execution.TrampolineUtil
Expand Down Expand Up @@ -91,7 +90,13 @@ object GpuUserDefinedFunction {
* and do the processing on CPU.
*/
trait GpuRowBasedUserDefinedFunction extends GpuExpression
with ShimExpression with UserDefinedExpression with Serializable with Logging {
with ShimExpression with UserDefinedExpression with Serializable {

@transient private lazy val log = org.slf4j.LoggerFactory.getLogger(
classOf[GpuRowBasedUserDefinedFunction])

private def logDebug(msg: => String): Unit = if (log.isDebugEnabled) log.debug(msg)

/** name of the UDF function */
val name: String

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
* Copyright (c) 2025-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -83,14 +83,14 @@ object HashExprChecks {

val murmur3ProjectChecks: ExprChecks = ExprChecks.projectOnly(
TypeSig.INT, TypeSig.INT,
repeatingParamCheck = Some(RepeatingParamCheck(
repeatingParamCheck = Some(new RepeatingParamCheck(
"input",
murmur3InputTypes,
TypeSig.all)))

val xxhash64ProjectChecks: ExprChecks = ExprChecks.projectOnly(
TypeSig.LONG, TypeSig.LONG,
repeatingParamCheck = Some(RepeatingParamCheck(
repeatingParamCheck = Some(new RepeatingParamCheck(
"input",
XxHash64Shims.supportedTypes,
TypeSig.all)))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
* Copyright (c) 2020-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,7 +16,7 @@

package com.nvidia.spark.rapids

import java.nio.{ByteBuffer, ByteOrder}
import java.nio.{Buffer, ByteBuffer, ByteOrder}

import scala.collection.mutable.ArrayBuffer

Expand All @@ -25,7 +25,6 @@ import com.google.flatbuffers.FlatBufferBuilder
import com.nvidia.spark.rapids.Arm.withResource
import com.nvidia.spark.rapids.format._

import org.apache.spark.internal.Logging
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.storage.ShuffleBlockBatchId
Expand Down Expand Up @@ -117,9 +116,9 @@ object MetaUtils {
packedMeta: ByteBuffer,
numRows: Long): TableMeta = {
val vectorBuffer = fbb.createUnintializedVector(1, packedMeta.remaining(), 1)
packedMeta.mark()
packedMeta.asInstanceOf[Buffer].mark()
vectorBuffer.put(packedMeta)
packedMeta.reset()
packedMeta.asInstanceOf[Buffer].reset()
val packedMetaOffset = fbb.endVector()

TableMeta.startTableMeta(fbb)
Expand Down Expand Up @@ -262,7 +261,7 @@ class DirectByteBufferFactory extends FlatBufferBuilder.ByteBufferFactory {
}
}

object ShuffleMetadata extends Logging{
object ShuffleMetadata {

val bbFactory = new DirectByteBufferFactory

Expand Down
Loading