Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright (c) 2025-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.errors

import java.lang.reflect.InvocationTargetException

object ConvUtils {
private val queryExecutionErrorsCompanion =
"org.apache.spark.sql.errors.QueryExecutionErrors$"

def overflowInConvError(): Unit = {
try {
val companion = Class.forName(queryExecutionErrorsCompanion).getField("MODULE$").get(null)
val method = companion.getClass.getMethods.find { method =>
method.getName == "overflowInConvError" && method.getParameterCount == 1
}.getOrElse {
throw new UnsupportedOperationException()
}
throw method.invoke(companion, null.asInstanceOf[AnyRef]).asInstanceOf[Throwable]
} catch {
case _: ClassNotFoundException | _: NoSuchFieldException =>
throw new UnsupportedOperationException()
case e: InvocationTargetException =>
throw e.getCause
}
Comment on lines +33 to +38

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Hardcoded null passed to unknown parameter type

The code finds any overflowInConvError overload with exactly one parameter and invokes it with null. If that parameter is a primitive-boxed type, JVM unboxing will throw a NullPointerException inside invoke, masking the intended error. Either matching the exact parameter type in the find predicate or adding a comment documenting the nullable assumption would reduce the ambiguity.

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright (c) 2025-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.rapids.shims

import java.lang.reflect.InvocationTargetException

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.SparkPlan

object SparkSessionUtils {

def sessionFromPlan(plan: SparkPlan): SparkSession = {
invokeNoArg(plan, "session").asInstanceOf[SparkSession]
}

def leafNodeDefaultParallelism(ss: SparkSession): Int = {
invokeNoArg(ss, "leafNodeDefaultParallelism").asInstanceOf[Int]
}

private def invokeNoArg(target: AnyRef, methodName: String): AnyRef = {
try {
target.getClass.getMethod(methodName).invoke(target)
} catch {
case e: InvocationTargetException =>
throw e.getCause
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) 2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "330"}
{"spark": "330db"}
{"spark": "331"}
{"spark": "332"}
{"spark": "333"}
{"spark": "334"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids.ShimDataWritingCommandRule

import org.apache.spark.sql.execution.command.CreateDataSourceTableAsSelectCommand

object CreateDataSourceTableAsSelectRules {
val dataWriteCmd: ShimDataWritingCommandRule[CreateDataSourceTableAsSelectCommand] =
ShimDataWritingCommandRule[CreateDataSourceTableAsSelectCommand](
"Create table with select command")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright (c) 2024-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "330"}
{"spark": "330db"}
{"spark": "331"}
{"spark": "332"}
{"spark": "332db"}
{"spark": "333"}
{"spark": "340"}
{"spark": "341"}
{"spark": "341db"}
{"spark": "350"}
{"spark": "350db143"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

trait SequenceSizeTooLongErrorBuilder {

def getTooLongSequenceErrorString(sequenceSize: Int, functionName: String): String = {
// For these Spark versions, the sequence length and function name
// do not appear in the exception message.
s"Too long sequence found. Should be <= $MAX_ROUNDED_ARRAY_LENGTH"
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright (c) 2021-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "330"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims.spark330

import com.nvidia.spark.rapids.SparkShimVersion

object SparkShimServiceProvider {
val VERSION = SparkShimVersion(3, 3, 0)
val VERSIONNAMES = Seq(s"$VERSION")
}

class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider {

override def getShimVersion: SparkShimVersion = SparkShimServiceProvider.VERSION

def matchesVersion(version: String): Boolean = {
SparkShimServiceProvider.VERSIONNAMES.contains(version)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright (c) 2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "330"}
{"spark": "330db"}
{"spark": "331"}
{"spark": "332"}
{"spark": "332db"}
{"spark": "333"}
{"spark": "334"}
{"spark": "340"}
{"spark": "341"}
{"spark": "341db"}
{"spark": "342"}
{"spark": "343"}
{"spark": "344"}
{"spark": "350"}
{"spark": "350db143"}
{"spark": "351"}
{"spark": "352"}
{"spark": "353"}
{"spark": "354"}
{"spark": "355"}
{"spark": "356"}
{"spark": "357"}
{"spark": "358"}
{"spark": "400"}
{"spark": "401"}
{"spark": "402"}
{"spark": "411"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids

import org.apache.spark.TaskContext
import org.apache.spark.shuffle.{ShuffleHandle, ShuffleManager, ShuffleReader, ShuffleReadMetricsReporter}

/**
* Shim object to handle version-specific differences in ShuffleManager APIs.
*/
object ShuffleManagerShims {
/**
* Call ShuffleManager.getReader with the appropriate signature for this Spark version.
* This method is overridden in version-specific shims.
*/
def getReader[K, C](
manager: ShuffleManager,
handle: ShuffleHandle,
startMapIndex: Int,
endMapIndex: Int,
startPartition: Int,
endPartition: Int,
context: TaskContext,
metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
manager.getReader(handle, startMapIndex, endMapIndex, startPartition,
endPartition, context, metrics)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright (c) 2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "330"}
{"spark": "330db"}
{"spark": "331"}
{"spark": "332"}
{"spark": "332db"}
{"spark": "333"}
{"spark": "334"}
{"spark": "340"}
{"spark": "341"}
{"spark": "341db"}
{"spark": "342"}
{"spark": "343"}
{"spark": "344"}
{"spark": "350"}
{"spark": "350db143"}
{"spark": "351"}
{"spark": "352"}
{"spark": "353"}
{"spark": "354"}
{"spark": "355"}
{"spark": "356"}
{"spark": "357"}
{"spark": "358"}
{"spark": "400"}
{"spark": "400db173"}
{"spark": "401"}
{"spark": "402"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

import org.apache.hadoop.mapreduce.TaskAttemptContext

import org.apache.spark.internal.io.FileCommitProtocol

/**
* Shim for FileCommitProtocol.newTaskTempFile API.
* In Spark <= 4.0.x, we use the deprecated (ext: String) signature.
* In Spark 4.1.0+, we use the new (spec: FileNameSpec) signature.
*/
object FileCommitProtocolShims {
@scala.annotation.nowarn(
"msg=method newTaskTempFile in class FileCommitProtocol is deprecated"
)
def newTaskTempFile(
committer: FileCommitProtocol,
taskContext: TaskAttemptContext,
dir: Option[String],
ext: String): String = {
committer.newTaskTempFile(taskContext, dir, ext)
}

@scala.annotation.nowarn(
"msg=method newTaskTempFileAbsPath in class FileCommitProtocol is deprecated"
)
def newTaskTempFileAbsPath(
committer: FileCommitProtocol,
taskContext: TaskAttemptContext,
absoluteDir: String,
ext: String): String = {
committer.newTaskTempFileAbsPath(taskContext, absoluteDir, ext)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright (c) 2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "330"}
{"spark": "331"}
{"spark": "332"}
{"spark": "333"}
{"spark": "334"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims

import org.apache.spark.sql.catalyst.trees.Origin

// Apache Spark 3.3.x carries SPARK-39175 with `Origin.context: String`.
object OriginContextShim {
def queryContext(origin: Origin): String = origin.context
def contextSummary(origin: Origin): String = origin.context
}
Loading