Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import java.util.regex.Pattern

import com.nvidia.spark.rapids.Arm.withResource

import org.apache.spark.internal.Logging

/**
* Memory allocation kind for retry coverage tracking.
Expand Down Expand Up @@ -62,7 +61,17 @@ object AllocationKind extends Enumeration {
*
* See: https://github.com/NVIDIA/spark-rapids/issues/13672
*/
object AllocationRetryCoverageTracker extends Logging {
object AllocationRetryCoverageTracker {
private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$"))

private def logWarning(msg: => String): Unit = {
log.warn(msg)
}

private def logError(msg: => String, throwable: Throwable): Unit = {
log.error(msg, throwable)
}

import AllocationKind._

// Environment variable to enable retry coverage tracking (debug-only).
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2025-2025, NVIDIA CORPORATION.
* Copyright (c) 2025-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,7 +19,7 @@ package com.nvidia.spark.rapids
import org.apache.spark.sql.catalyst.expressions.{ArrayDistinct, Expression}
import org.apache.spark.sql.rapids.GpuArrayDistinct

case class GpuArrayDistinctMeta(
class GpuArrayDistinctMeta(
expr: ArrayDistinct,
override val conf: RapidsConf,
parentMetaOpt: Option[RapidsMeta[_, _, _]],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@

package com.nvidia.spark.rapids

import java.time.LocalDate
import java.time.{Instant, LocalDate}

import scala.collection.mutable.ListBuffer

import ai.rapids.cudf.{DType, Scalar}
import com.nvidia.spark.rapids.VersionUtils.isSpark320OrLater
import com.nvidia.spark.rapids.shims.DateTimeUtilsShims

import org.apache.spark.sql.catalyst.util.DateTimeUtils.localDateToDays
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -53,6 +52,11 @@ object DateUtils {

val ONE_SECOND_MICROSECONDS = 1000000

private def currentTimestampMicros: Long = {
val instant = Instant.now()
instant.getEpochSecond * ONE_SECOND_MICROSECONDS + instant.getNano / 1000
Comment on lines 53 to +57

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Plain arithmetic instead of overflow-safe multiply

Spark's own instantToMicros uses Math.multiplyExact and Math.addExact to guard against overflow. The replacement here uses plain * and +. For timestamps within a practical range the result is identical, but it silently returns a wrong (wrapped) value rather than throwing if a pathological timestamp is ever passed, while the shim it replaces would have thrown. Worth noting since this replaces a shim that was previously calling through to Spark's overflow-checked path.

}

val ONE_DAY_SECONDS = 86400L

val ONE_DAY_MICROSECONDS = 86400000000L
Expand Down Expand Up @@ -80,7 +84,7 @@ object DateUtils {
Map.empty
} else {
val today = currentDate()
val now = DateTimeUtilsShims.currentTimestamp
val now = currentTimestampMicros
Map(
EPOCH -> 0,
NOW -> now / 1000000L,
Expand All @@ -94,7 +98,7 @@ object DateUtils {
Map.empty
} else {
val today = currentDate()
val now = DateTimeUtilsShims.currentTimestamp
val now = currentTimestampMicros
Map(
EPOCH -> 0,
NOW -> now,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2025, NVIDIA CORPORATION.
* Copyright (c) 2020-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -18,14 +18,15 @@ package com.nvidia.spark.rapids

import ai.rapids.cudf.{ColumnVector, DType, Scalar}
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.shims.ShimPredicate

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, Predicate}
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DoubleType, FloatType}

case class GpuInSet(
child: Expression,
list: Seq[Any]) extends GpuUnaryExpression with Predicate {
list: Seq[Any]) extends GpuUnaryExpression with ShimPredicate {
require(list != null, "list should not be null")

@transient private[this] lazy val hasNull: Boolean = list.contains(null)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
* Copyright (c) 2021-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -177,7 +177,7 @@ object GpuMapUtils {

}

case class GpuMapFromArraysMeta(expr: MapFromArrays,
class GpuMapFromArraysMeta(expr: MapFromArrays,
override val conf: RapidsConf,
override val parent: Option[RapidsMeta[_, _, _]],
rule: DataFromReplacementRule)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import scala.collection.immutable.TreeMap
import com.nvidia.spark.rapids.metrics.GpuBubbleTimerManager

import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
Expand Down Expand Up @@ -86,7 +85,7 @@ class GpuMetricFactory(metricsConf: MetricsLevel, context: SparkContext) {
createInternal(level, SQLMetrics.createTimingMetric(context, name))
}

object GpuMetric extends Logging {
object GpuMetric {
// Metric names.
val BUFFER_TIME = "bufferTime"
val BUFFER_TIME_BUBBLE = "bufferTimeBubble"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
* Copyright (c) 2025-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -25,7 +25,6 @@ import scala.util.{Failure, Success, Try}
import com.nvidia.spark.rapids.Arm.withResource

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging

trait MemoryChecker {
def getAvailableMemoryBytes(rapidsConf: RapidsConf): Option[Long]
Expand All @@ -38,7 +37,19 @@ trait MemoryChecker {
* on which it checks corresponding files, env variables, etc. for memory usage
* and limits.
*/
object MemoryCheckerImpl extends MemoryChecker with Logging {
object MemoryCheckerImpl extends MemoryChecker {
private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$"))

private def logInfo(msg: => String): Unit = {
if (log.isInfoEnabled) {
log.info(msg)
}
}

private def logWarning(msg: => String): Unit = {
log.warn(msg)
}

def main(args: Array[String]): Unit = {
val conf = new RapidsConf(new SparkConf())
println(s"Available memory: ${getAvailableMemoryBytes(conf)} bytes")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
* Copyright (c) 2025-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,18 +22,17 @@ import scala.collection.mutable

import ai.rapids.cudf.{NvtxColor, NvtxRange}

import org.apache.spark.internal.Logging

object RangeDebugger extends Logging {
object RangeDebugger {
private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$"))
val threadLocalStack = new ThreadLocal[mutable.ArrayStack[NvtxId]] {
override def initialValue(): mutable.ArrayStack[NvtxId] = mutable.ArrayStack[NvtxId]()
}

private def dumpOrderErrorMessage(popped: Option[NvtxId], elem: NvtxId): Unit = {
logError(s"OUT OF ORDER POP of $elem")
logError(s"TOP OF STACK IS ${popped.getOrElse("<nil>")}")
log.error(s"OUT OF ORDER POP of $elem")
log.error(s"TOP OF STACK IS ${popped.getOrElse("<nil>")}")
val stackTrace = Thread.currentThread.getStackTrace
stackTrace.foreach(elem => logError(elem.toString))
stackTrace.foreach(elem => log.error(elem.toString))
}

def push(elem: NvtxId): Unit = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
* Copyright (c) 2019-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -76,11 +76,11 @@ object NvtxIdWithMetrics {
}
}

class MetricRange(val metrics: Seq[GpuMetric], val excludeMetric: Seq[GpuMetric] = Seq.empty)
class MetricRange(val metrics: Seq[GpuMetric], val excludeMetric: Seq[GpuMetric])
extends AutoCloseable {

// add a convenient constructor
def this(metrics: GpuMetric*) = this(metrics.toSeq)
def this(metrics: GpuMetric*) = this(metrics.toSeq, Seq.empty)

val needTracks = metrics.map(_.tryActivateTimer(excludeMetric))
private val start = System.nanoTime()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
* Copyright (c) 2024-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,11 +33,11 @@ class PrioritySemaphore[T](val maxPermits: Long, val maxConcurrentGpuTasksLimit:
private var occupiedSlots: Long = 0
private var currentConcurrentGpuTasksNum: Long = 0

private case class ThreadInfo(priority: T,
condition: Condition,
computeNumPermits: () => Long,
wasOnGpuBefore: () => Boolean,
taskId: Long) {
private class ThreadInfo(val priority: T,
val condition: Condition,
val computeNumPermits: () => Long,
val wasOnGpuBefore: () => Boolean,
val taskId: Long) {
var signaled: Boolean = false
var permitsUsed: Long = 0
}
Expand All @@ -60,7 +60,7 @@ class PrioritySemaphore[T](val maxPermits: Long, val maxConcurrentGpuTasksLimit:
if (waitingQueue.size() > 0 &&
priorityComp.compare(
waitingQueue.peek(),
ThreadInfo(priority, null, () => numPermits, wasOnGpuBefore, taskAttemptId)
new ThreadInfo(priority, null, () => numPermits, wasOnGpuBefore, taskAttemptId)
) < 0) {
false
} else if (!canAcquire(numPermits)) {
Expand All @@ -81,7 +81,8 @@ class PrioritySemaphore[T](val maxPermits: Long, val maxConcurrentGpuTasksLimit:
val numPermitsNow = computePermits()
if (!tryAcquire(numPermitsNow, priority, wasOnGpuBefore, taskAttemptId)) {
val condition = lock.newCondition()
val info = ThreadInfo(priority, condition, computePermits, wasOnGpuBefore, taskAttemptId)
val info = new ThreadInfo(
priority, condition, computePermits, wasOnGpuBefore, taskAttemptId)
try {
waitingQueue.add(info)
// only count tasks that had held semaphore before,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import org.apache.hadoop.fs.Path

import org.apache.spark.TaskContext
import org.apache.spark.api.plugin.PluginContext
import org.apache.spark.internal.Logging
import org.apache.spark.util.SerializableConfiguration

/**
Expand All @@ -55,7 +54,9 @@ import org.apache.spark.util.SerializableConfiguration
*
*/

object AsyncProfilerOnExecutor extends Logging {
object AsyncProfilerOnExecutor {

private val log = org.slf4j.LoggerFactory.getLogger(AsyncProfilerOnExecutor.getClass)

private var asyncProfilerPrefix: Option[String] = None
private var asyncProfiler: Option[AsyncProfiler] = None
Expand Down Expand Up @@ -347,7 +348,7 @@ object AsyncProfilerOnExecutor extends Logging {
val outPath = new Path(asyncProfilerPrefix.get,
if (jfrCompressionEnabled) baseFileName + ".gz" else baseFileName)

val hadoopConf = pluginCtx.ask(ProfileInitMsg(executorId, outPath.toString))
val hadoopConf = pluginCtx.ask(new ProfileInitMsg(executorId, outPath.toString))
.asInstanceOf[SerializableConfiguration].value
val fs = outPath.getFileSystem(hadoopConf)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import com.nvidia.spark.rapids.shims.ShimExpression
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.catalyst.expressions.{Add, And, ArrayAggregate, Attribute, AttributeReference, AttributeSeq, CaseWhen, Cast, Expression, ExprId, Greatest, If, LambdaFunction, Least, Literal, Multiply, NamedExpression, NamedLambdaVariable, Or}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.rapids.GpuMapDedupPolicy
import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, Metadata, NumericType, ShortType, StructField, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch

Expand Down Expand Up @@ -527,8 +528,7 @@ case class GpuTransformKeys(
override def prettyName: String = "transform_keys"

// Spark 4.1+ returns an enum value instead of String, so use toString first
private def exceptionOnDupKeys =
SQLConf.get.getConf(SQLConf.MAP_KEY_DEDUP_POLICY).toString.toUpperCase == "EXCEPTION"
private def exceptionOnDupKeys = GpuMapDedupPolicy.isException

override lazy val hasSideEffects: Boolean =
function.nullable || exceptionOnDupKeys || super.hasSideEffects
Expand Down Expand Up @@ -1140,13 +1140,13 @@ case object AnyOp extends AggOp {
* @param accVarExprId the accumulator NamedLambdaVariable's exprId
* @param elemVar the element NamedLambdaVariable (used to build the g lambda)
*/
case class ArrayAggregateDecomposition(
op: AggOp,
g: Expression,
accVarExprId: ExprId,
elemVar: NamedLambdaVariable)
class ArrayAggregateDecomposition(
val op: AggOp,
val g: Expression,
val accVarExprId: ExprId,
val elemVar: NamedLambdaVariable) extends Serializable

private case class ExtractedG(g: Expression, hasBareAccBranch: Boolean)
private class ExtractedG(val g: Expression, val hasBareAccBranch: Boolean) extends Serializable


/**
Expand Down Expand Up @@ -1223,7 +1223,7 @@ object ArrayAggregateDecomposer {
"that no-contribution branch into an identity value")
}

Right(ArrayAggregateDecomposition(op, g, accId, elemVar))
Right(new ArrayAggregateDecomposition(op, g, accId, elemVar))
}

/**
Expand All @@ -1246,8 +1246,8 @@ object ArrayAggregateDecomposer {
accId: ExprId,
op: AggOp): Option[ExtractedG] = {
op.matchBinary(unwrapDecimalPatternWrappers(e)).flatMap { case (l, r) =>
if (isAccRef(l, accId) && !containsAccRef(r, accId)) Some(ExtractedG(r, false))
else if (isAccRef(r, accId) && !containsAccRef(l, accId)) Some(ExtractedG(l, false))
if (isAccRef(l, accId) && !containsAccRef(r, accId)) Some(new ExtractedG(r, false))
else if (isAccRef(r, accId) && !containsAccRef(l, accId)) Some(new ExtractedG(l, false))
else None
}
}
Expand All @@ -1264,7 +1264,7 @@ object ArrayAggregateDecomposer {
op: AggOp,
accType: DataType): Option[ExtractedG] = {
if (isAccRef(branch, accId)) {
Some(ExtractedG(op.identityLiteral(accType), true))
Some(new ExtractedG(op.identityLiteral(accType), true))
} else {
extractG(branch, accId, op, accType)
}
Expand All @@ -1279,7 +1279,7 @@ object ArrayAggregateDecomposer {
for {
tG <- extractBranch(t, accId, op, accType)
fG <- extractBranch(f, accId, op, accType)
} yield ExtractedG(If(cond, tG.g, fG.g), tG.hasBareAccBranch || fG.hasBareAccBranch)
} yield new ExtractedG(If(cond, tG.g, fG.g), tG.hasBareAccBranch || fG.hasBareAccBranch)

case CaseWhen(branches, Some(elseValue))
if branches.forall { case (c, _) => !containsAccRef(c, accId) } =>
Expand All @@ -1292,7 +1292,7 @@ object ArrayAggregateDecomposer {
val gBranches = branchDecs.map { case (c, dec) => (c, dec.get.g) }
val hasBareAccBranch = branchDecs.exists(_._2.exists(_.hasBareAccBranch)) ||
elseDec.exists(_.hasBareAccBranch)
Some(ExtractedG(CaseWhen(gBranches, Some(elseDec.get.g)), hasBareAccBranch))
Some(new ExtractedG(CaseWhen(gBranches, Some(elseDec.get.g)), hasBareAccBranch))
}

case _ => None
Expand Down
Loading
Loading