diff --git a/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/CudfUnsafeRow.scala b/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/CudfUnsafeRow.scala index ddc8f77b162..d6fd8f120b8 100644 --- a/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/CudfUnsafeRow.scala +++ b/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/CudfUnsafeRow.scala @@ -40,4 +40,9 @@ final class CudfUnsafeRow( } } +// Keep companion line metadata aligned with pre-Spark-4 shims for binary-dedupe. + + + + object CudfUnsafeRow extends CudfUnsafeRowTrait diff --git a/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/DateTimeUtilsShims.scala b/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/DateTimeUtilsShims.scala deleted file mode 100644 index 21254c4b39a..00000000000 --- a/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/DateTimeUtilsShims.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2024-2026, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -/*** spark-rapids-shim-json-lines -{"spark": "400"} -{"spark": "400db173"} -{"spark": "401"} -{"spark": "402"} -{"spark": "411"} -spark-rapids-shim-json-lines ***/ -package com.nvidia.spark.rapids.shims - -import org.apache.spark.sql.catalyst.util.SparkDateTimeUtils - -object DateTimeUtilsShims { - def currentTimestamp: Long = SparkDateTimeUtils.instantToMicros(java.time.Instant.now()) -} \ No newline at end of file diff --git a/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/GpuOrcDataReader.scala b/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/GpuOrcDataReader.scala index 6bd9f9c99ae..1616a6b39b9 100644 --- a/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/GpuOrcDataReader.scala +++ b/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/GpuOrcDataReader.scala @@ -36,6 +36,19 @@ class GpuOrcDataReader( } +// Keep executable line numbers aligned with pre-Spark-4 shims for binary-dedupe. + + + + + + + + + + + + object GpuOrcDataReader { // File cache is being used, so we want read ranges that can be cached separately val shouldMergeDiskRanges: Boolean = false diff --git a/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/LogicalPlanShims.scala b/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/LogicalPlanShims.scala index f31ca918539..0eaa55f0c6c 100644 --- a/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/LogicalPlanShims.scala +++ b/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/LogicalPlanShims.scala @@ -26,6 +26,18 @@ package com.nvidia.spark.rapids.shims import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.datasources.{FileIndex, HadoopFsRelation, LogicalRelationWithTable} +// Keep companion line metadata aligned with pre-Spark-4 shims for binary-dedupe. + + + + + + + + + + + object LogicalPlanShims { def getLocations(plan: LogicalPlan): Seq[FileIndex] = plan.collect { case LogicalRelationWithTable(t: HadoopFsRelation, _) => t.location diff --git a/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/NullIntolerantShim.scala b/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/NullIntolerantShim.scala index 842846d6aae..bd6c0e791c0 100644 --- a/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/NullIntolerantShim.scala +++ b/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/NullIntolerantShim.scala @@ -28,3 +28,24 @@ import org.apache.spark.sql.catalyst.expressions.Expression trait NullIntolerantShim extends Expression { override def nullIntolerant: Boolean = true } + +abstract class GpuLiteralShim extends com.nvidia.spark.rapids.GpuLeafExpression { + def value: Any + def dataType: org.apache.spark.sql.types.DataType + + override protected def jsonFields: List[org.json4s.JsonAST.JField] = { + val jsonValue = (value, dataType) match { + case (null, _) => org.json4s.JsonAST.JNull + case (i: Int, org.apache.spark.sql.types.DateType) => + org.json4s.JsonAST.JString( + org.apache.spark.sql.catalyst.util.DateTimeUtils.toJavaDate(i).toString) + case (l: Long, org.apache.spark.sql.types.TimestampType) => + org.json4s.JsonAST.JString( + org.apache.spark.sql.catalyst.util.DateTimeUtils.toJavaTimestamp(l).toString) + case (other, _) => org.json4s.JsonAST.JString(other.toString) + } + ("value" -> jsonValue) :: + ("dataType" -> org.apache.spark.sql.rapids.execution.TrampolineUtil.jsonValue(dataType) + .asInstanceOf[org.json4s.JsonAST.JValue]) :: Nil + } +} diff --git a/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/Spark400PlusCommonShims.scala b/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/Spark400PlusCommonShims.scala index d3f63a943a1..72a8bbc2f38 100644 --- a/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/Spark400PlusCommonShims.scala +++ b/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/Spark400PlusCommonShims.scala @@ -40,7 +40,7 @@ trait Spark400PlusCommonShims extends Spark350PlusNonDBShims { "And(GreaterThanOrEqual(ref, lower), LessThanOrEqual(ref, upper); StructToJson is " + "replaced by Invoke(Literal(StructsToJsonEvaluator), evaluate, string_type, arguments)", InvokeCheck, - InvokeExprMeta) + (invoke, conf, p, r) => new InvokeExprMeta(invoke, conf, p, r)) .note("The supported types are not deterministic since it's a dynamic expression") ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap super.getExprs ++ shimExprs diff --git a/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/spark400/SparkShimServiceProvider.scala b/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/spark400/SparkShimServiceProvider.scala deleted file mode 100644 index 95525c673f7..00000000000 --- a/sql-plugin/src/main/spark400/scala/com/nvidia/spark/rapids/shims/spark400/SparkShimServiceProvider.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (c) 2024-2026, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*** spark-rapids-shim-json-lines -{"spark": "400"} -{"spark": "400db173"} -spark-rapids-shim-json-lines ***/ -package com.nvidia.spark.rapids.shims.spark400 - -import com.nvidia.spark.rapids.SparkShimVersion - -object SparkShimServiceProvider { - val VERSION = SparkShimVersion(4, 0, 0) - val VERSIONNAMES = Seq(s"$VERSION") -} - -class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider { - - override def getShimVersion: SparkShimVersion = SparkShimServiceProvider.VERSION - - override def matchesVersion(version: String): Boolean = { - SparkShimServiceProvider.VERSIONNAMES.contains(version) - } -} diff --git a/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/InvokeExprMeta.scala b/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/InvokeExprMeta.scala index d66e879667a..4cb2c60811d 100644 --- a/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/InvokeExprMeta.scala +++ b/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/InvokeExprMeta.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.url.ParseUrlEvaluator import org.apache.spark.sql.rapids.{GpuParseUrl, GpuStructsToJson} import org.apache.spark.sql.types._ -case class InvokeExprMeta( +class InvokeExprMeta( invoke: Invoke, override val conf: RapidsConf, p: Option[RapidsMeta[_, _, _]], diff --git a/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/OriginContextShim.scala b/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/OriginContextShim.scala deleted file mode 100644 index 7f0892adf61..00000000000 --- a/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/OriginContextShim.scala +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*** spark-rapids-shim-json-lines -{"spark": "400"} -{"spark": "400db173"} -{"spark": "401"} -{"spark": "402"} -{"spark": "411"} -spark-rapids-shim-json-lines ***/ -package org.apache.spark.sql.rapids.shims - -import org.apache.spark.sql.catalyst.trees.{Origin, SQLQueryContext} - -// Spark 4.0 widened `Origin.context` to `org.apache.spark.QueryContext`, while -// `QueryExecutionErrors` still takes the `SQLQueryContext` subtype — narrow here. -object OriginContextShim { - def queryContext(origin: Origin): SQLQueryContext = origin.context match { - case ctx: SQLQueryContext => ctx - case _ => null - } - def contextSummary(origin: Origin): String = origin.context match { - case null => "" - case ctx => ctx.summary - } -} diff --git a/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/SparkSessionUtils.scala b/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/SparkSessionUtils.scala deleted file mode 100644 index 6de8f1d6165..00000000000 --- a/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/SparkSessionUtils.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (c) 2025-2026, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*** spark-rapids-shim-json-lines -{"spark": "400"} -{"spark": "400db173"} -{"spark": "401"} -{"spark": "402"} -{"spark": "411"} -spark-rapids-shim-json-lines ***/ -package org.apache.spark.sql.rapids.shims - -import org.apache.spark.sql.classic.SparkSession -import org.apache.spark.sql.execution.SparkPlan - -object SparkSessionUtils { - def sessionFromPlan(plan: SparkPlan): SparkSession = { - plan.session - } - - def leafNodeDefaultParallelism(ss: SparkSession): Int = { - ss.leafNodeDefaultParallelism - } -} diff --git a/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/TrampolineConnectShims.scala b/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/TrampolineConnectShims.scala deleted file mode 100644 index 9a753370800..00000000000 --- a/sql-plugin/src/main/spark400/scala/org/apache/spark/sql/rapids/shims/TrampolineConnectShims.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright (c) 2025-2026, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -/*** spark-rapids-shim-json-lines -{"spark": "400"} -{"spark": "400db173"} -{"spark": "401"} -{"spark": "402"} -{"spark": "411"} -spark-rapids-shim-json-lines ***/ -package org.apache.spark.sql.rapids.shims - -import org.apache.avro.NameValidator -import org.apache.avro.Schema - -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -object TrampolineConnectShims { - - type SparkSession = org.apache.spark.sql.classic.SparkSession - type DataFrame = org.apache.spark.sql.classic.DataFrame - type Dataset = org.apache.spark.sql.classic.Dataset[org.apache.spark.sql.Row] - - def cleanupAnyExistingSession(): Unit = { - org.apache.spark.sql.classic.SparkSession.cleanupAnyExistingSession() - } - - def createDataFrame(spark: SparkSession, plan: LogicalPlan): DataFrame = { - org.apache.spark.sql.classic.Dataset.ofRows(spark, plan) - } - - def getBuilder(): org.apache.spark.sql.classic.SparkSession.Builder = { - org.apache.spark.sql.classic.SparkSession.builder() - } - - def hasActiveSession: Boolean = { - org.apache.spark.sql.classic.SparkSession.getActiveSession.isDefined - } - - def getActiveSession: SparkSession = { - org.apache.spark.sql.classic.SparkSession.getActiveSession.getOrElse( - throw new IllegalStateException("No active SparkSession found") - ) - } - - def createSchemaParser(): Schema.Parser = { - // Spark-4.0+ depends on Avro-1.12.0 where validate() is removed and we need to use - // NameValidator interface instead of validate() method. - new Schema.Parser(NameValidator.NO_VALIDATION).setValidateDefaults(false) - } -} diff --git a/sql-plugin/src/main/spark400db173/scala/com/nvidia/spark/rapids/shims/Spark400PlusDBShims.scala b/sql-plugin/src/main/spark400db173/scala/com/nvidia/spark/rapids/shims/Spark400PlusDBShims.scala index 2b443be563c..22c37269c9b 100644 --- a/sql-plugin/src/main/spark400db173/scala/com/nvidia/spark/rapids/shims/Spark400PlusDBShims.scala +++ b/sql-plugin/src/main/spark400db173/scala/com/nvidia/spark/rapids/shims/Spark400PlusDBShims.scala @@ -35,7 +35,7 @@ trait Spark400PlusDBShims extends Spark341PlusDBShims { "And(GreaterThanOrEqual(ref, lower), LessThanOrEqual(ref, upper); StructToJson is " + "replaced by Invoke(Literal(StructsToJsonEvaluator), evaluate, string_type, arguments)", InvokeCheck, - InvokeExprMeta) + (invoke, conf, p, r) => new InvokeExprMeta(invoke, conf, p, r)) .note("The supported types are not deterministic since it's a dynamic expression") ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap super.getExprs ++ shimExprs diff --git a/sql-plugin/src/main/spark400db173/scala/com/nvidia/spark/rapids/shims/TimeAddShims.scala b/sql-plugin/src/main/spark400db173/scala/com/nvidia/spark/rapids/shims/TimeAddShims.scala index 83d9298bfa2..5163eb14680 100644 --- a/sql-plugin/src/main/spark400db173/scala/com/nvidia/spark/rapids/shims/TimeAddShims.scala +++ b/sql-plugin/src/main/spark400db173/scala/com/nvidia/spark/rapids/shims/TimeAddShims.scala @@ -22,12 +22,40 @@ package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids._ -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Expression, TimestampAddInterval} +import org.apache.spark.sql.rapids.shims.GpuTimestampAddInterval +import org.apache.spark.sql.types.{CalendarIntervalType, DayTimeIntervalType} +import org.apache.spark.unsafe.types.CalendarInterval /** - * Empty TimeAddShims for Spark 4.1.0+ and Databricks 17.3. - * TimeAdd was renamed to TimestampAddInterval and is handled by DayTimeIntervalShims. + * TimestampAddInterval support for Spark 4.1.0+ and Databricks 17.3. + * TimeAdd was renamed to TimestampAddInterval in Spark 4.1. */ object TimeAddShims { - val exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Map.empty + val exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq( + GpuOverrides.expr[TimestampAddInterval]( + "Adds interval to timestamp", + ExprChecks.binaryProject(TypeSig.TIMESTAMP, TypeSig.TIMESTAMP, + ("start", TypeSig.TIMESTAMP, TypeSig.TIMESTAMP), + ("interval", TypeSig.DAYTIME + TypeSig.lit(TypeEnum.CALENDAR) + .withPsNote(TypeEnum.CALENDAR, "month intervals are not supported"), + TypeSig.DAYTIME + TypeSig.CALENDAR)), + (timeAdd, conf, p, r) => new BinaryExprMeta[TimestampAddInterval](timeAdd, conf, p, r) { + override def tagExprForGpu(): Unit = { + GpuOverrides.extractLit(timeAdd.interval).foreach { lit => + lit.dataType match { + case CalendarIntervalType => + val intvl = lit.value.asInstanceOf[CalendarInterval] + if (intvl.months != 0) { + willNotWorkOnGpu("interval months isn't supported") + } + case _: DayTimeIntervalType => + } + } + } + + override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression = + GpuTimestampAddInterval(lhs, rhs) + }) + ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap } diff --git a/sql-plugin/src/main/spark400db173/scala/com/nvidia/spark/rapids/shims/spark400db173/SparkShimServiceProvider.scala b/sql-plugin/src/main/spark400db173/scala/com/nvidia/spark/rapids/shims/spark400db173/SparkShimServiceProvider.scala deleted file mode 100644 index 8f793d05165..00000000000 --- a/sql-plugin/src/main/spark400db173/scala/com/nvidia/spark/rapids/shims/spark400db173/SparkShimServiceProvider.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*** spark-rapids-shim-json-lines -{"spark": "400db173"} -spark-rapids-shim-json-lines ***/ -package com.nvidia.spark.rapids.shims.spark400db173 - -import com.nvidia.spark.rapids._ - -import org.apache.spark.SparkEnv - -object SparkShimServiceProvider { - // DB version should conform to "major.minor" and has no patch version. - // Refer to VersionUtils.getVersionForJni - val VERSION = DatabricksShimVersion(4, 0, 0, "17.3") -} - -class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider { - - override def getShimVersion: ShimVersion = SparkShimServiceProvider.VERSION - - def matchesVersion(version: String): Boolean = { - val shimEnabledProp = "spark.rapids.shims.spark400db173" + ".enabled" - val shimEnabled = Option(SparkEnv.get) - .flatMap(_.conf.getOption(shimEnabledProp).map(_.toBoolean)) - .getOrElse(true) - - DatabricksShimServiceProvider.matchesVersion( - dbrVersion = "17.3.x", - shimMatchEnabled = shimEnabled, - disclaimer = "Development of support for Databricks 17.3.x is still in progress: " + - "https://github.com/NVIDIA/spark-rapids/issues/14015" - ) - } -} diff --git a/sql-plugin/src/main/spark400db173/scala/org/apache/spark/sql/rapids/ShuffleManagerShims.scala b/sql-plugin/src/main/spark400db173/scala/org/apache/spark/sql/rapids/ShuffleManagerShims.scala deleted file mode 100644 index 2f4f8b5feda..00000000000 --- a/sql-plugin/src/main/spark400db173/scala/org/apache/spark/sql/rapids/ShuffleManagerShims.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*** spark-rapids-shim-json-lines -{"spark": "400db173"} -spark-rapids-shim-json-lines ***/ -package org.apache.spark.sql.rapids - -import org.apache.spark.TaskContext -import org.apache.spark.shuffle.{ShuffleHandle, ShuffleManager, ShuffleReader, ShuffleReadMetricsReporter} - -object ShuffleManagerShims { - def getReader[K, C]( - manager: ShuffleManager, - handle: ShuffleHandle, - startMapIndex: Int, - endMapIndex: Int, - startPartition: Int, - endPartition: Int, - context: TaskContext, - metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { - manager.getReader(handle, startMapIndex, endMapIndex, startPartition, - endPartition, context, metrics, false) - } -} diff --git a/sql-plugin/src/main/spark400db173/scala/org/apache/spark/sql/rapids/execution/python/shims/WindowInPandasExecTypeShim.scala b/sql-plugin/src/main/spark400db173/scala/org/apache/spark/sql/rapids/execution/python/shims/WindowInPandasExecTypeShim.scala index 8c5ed60a4c0..d8c84ed87aa 100644 --- a/sql-plugin/src/main/spark400db173/scala/org/apache/spark/sql/rapids/execution/python/shims/WindowInPandasExecTypeShim.scala +++ b/sql-plugin/src/main/spark400db173/scala/org/apache/spark/sql/rapids/execution/python/shims/WindowInPandasExecTypeShim.scala @@ -22,6 +22,26 @@ package org.apache.spark.sql.rapids.execution.python.shims import org.apache.spark.sql.execution.python.ArrowWindowPythonExec +// Keep executable line numbers aligned with pre-Spark-4 shims for binary-dedupe. + + + + + + + + + + + + + + + + + + + /** * WindowInPandasExec was renamed to ArrowWindowPythonExec in Spark 4.1.0. * Use the new class name as the type alias. diff --git a/sql-plugin/src/main/spark400db173/scala/org/apache/spark/storage/ShuffleClientShims.scala b/sql-plugin/src/main/spark400db173/scala/org/apache/spark/storage/ShuffleClientShims.scala deleted file mode 100644 index aedfb30f7e9..00000000000 --- a/sql-plugin/src/main/spark400db173/scala/org/apache/spark/storage/ShuffleClientShims.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*** spark-rapids-shim-json-lines -{"spark": "400db173"} -spark-rapids-shim-json-lines ***/ -package org.apache.spark.storage - -import org.apache.spark.network.shuffle.BlockStoreClient -import org.apache.spark.network.shuffle.checksum.Cause - -object ShuffleClientShims { - def diagnoseCorruption( - client: BlockStoreClient, - host: String, - port: Int, - execId: String, - blockId: BlockId, - checksum: Long, - algorithm: String): Cause = { - client.diagnoseCorruption(host, port, execId, blockId.name, checksum, algorithm) - } -} - diff --git a/sql-plugin/src/main/spark401/scala/com/nvidia/spark/rapids/shims/spark401/SparkShimServiceProvider.scala b/sql-plugin/src/main/spark401/scala/com/nvidia/spark/rapids/shims/spark401/SparkShimServiceProvider.scala deleted file mode 100644 index 1779ea4f649..00000000000 --- a/sql-plugin/src/main/spark401/scala/com/nvidia/spark/rapids/shims/spark401/SparkShimServiceProvider.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2025-2026, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*** spark-rapids-shim-json-lines -{"spark": "401"} -spark-rapids-shim-json-lines ***/ -package com.nvidia.spark.rapids.shims.spark401 - -import com.nvidia.spark.rapids.SparkShimVersion - -object SparkShimServiceProvider { - val VERSION = SparkShimVersion(4, 0, 1) - val VERSIONNAMES = Seq(s"$VERSION") -} - -class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider { - - override def getShimVersion: SparkShimVersion = SparkShimServiceProvider.VERSION - - override def matchesVersion(version: String): Boolean = { - SparkShimServiceProvider.VERSIONNAMES.contains(version) - } -} diff --git a/sql-plugin/src/main/spark402/scala/com/nvidia/spark/rapids/shims/spark402/SparkShimServiceProvider.scala b/sql-plugin/src/main/spark402/scala/com/nvidia/spark/rapids/shims/spark402/SparkShimServiceProvider.scala deleted file mode 100644 index 11650cec175..00000000000 --- a/sql-plugin/src/main/spark402/scala/com/nvidia/spark/rapids/shims/spark402/SparkShimServiceProvider.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*** spark-rapids-shim-json-lines -{"spark": "402"} -spark-rapids-shim-json-lines ***/ -package com.nvidia.spark.rapids.shims.spark402 - -import com.nvidia.spark.rapids.SparkShimVersion - -object SparkShimServiceProvider { - val VERSION = SparkShimVersion(4, 0, 2) - val VERSIONNAMES = Seq(s"$VERSION") -} - -class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider { - - override def getShimVersion: SparkShimVersion = SparkShimServiceProvider.VERSION - - override def matchesVersion(version: String): Boolean = { - SparkShimServiceProvider.VERSIONNAMES.contains(version) - } -} diff --git a/sql-plugin/src/main/spark411/scala/com/nvidia/spark/rapids/shims/spark411/SparkShimServiceProvider.scala b/sql-plugin/src/main/spark411/scala/com/nvidia/spark/rapids/shims/spark411/SparkShimServiceProvider.scala deleted file mode 100644 index bd5a848206a..00000000000 --- a/sql-plugin/src/main/spark411/scala/com/nvidia/spark/rapids/shims/spark411/SparkShimServiceProvider.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2025-2026, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*** spark-rapids-shim-json-lines -{"spark": "411"} -spark-rapids-shim-json-lines ***/ -package com.nvidia.spark.rapids.shims.spark411 - -import com.nvidia.spark.rapids.SparkShimVersion - -object SparkShimServiceProvider { - val VERSION = SparkShimVersion(4, 1, 1) - val VERSIONNAMES = Seq(s"$VERSION") -} - -class SparkShimServiceProvider extends com.nvidia.spark.rapids.SparkShimServiceProvider { - - override def getShimVersion: SparkShimVersion = SparkShimServiceProvider.VERSION - - override def matchesVersion(version: String): Boolean = { - SparkShimServiceProvider.VERSIONNAMES.contains(version) - } -} diff --git a/sql-plugin/src/main/spark411/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupedPythonRunnerFactory.scala b/sql-plugin/src/main/spark411/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupedPythonRunnerFactory.scala index e51f5bc5b42..9d04faddd95 100644 --- a/sql-plugin/src/main/spark411/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupedPythonRunnerFactory.scala +++ b/sql-plugin/src/main/spark411/scala/org/apache/spark/sql/rapids/execution/python/shims/GpuGroupedPythonRunnerFactory.scala @@ -34,14 +34,14 @@ import org.apache.spark.sql.vectorized.ColumnarBatch * - Create new Arrow stream for each batch * - Send 0 to indicate end of data */ -case class GpuGroupedPythonRunnerFactory( +class GpuGroupedPythonRunnerFactory( conf: org.apache.spark.sql.internal.SQLConf, chainedFunc: Seq[(ChainedPythonFunctions, Long)], argOffsets: Array[Array[Int]], dedupAttrs: StructType, pythonOutputSchema: StructType, evalType: Int, - argNames: Option[Array[Array[Option[String]]]] = None) { + argNames: Option[Array[Array[Option[String]]]]) extends Serializable { val sessionLocalTimeZone = conf.sessionLocalTimeZone val pythonRunnerConf = ArrowUtilsShim.getPythonRunnerConfMap(conf) diff --git a/sql-plugin/src/main/spark411/scala/org/apache/spark/sql/rapids/shims/FileCommitProtocolShims.scala b/sql-plugin/src/main/spark411/scala/org/apache/spark/sql/rapids/shims/FileCommitProtocolShims.scala deleted file mode 100644 index 9f664bdb705..00000000000 --- a/sql-plugin/src/main/spark411/scala/org/apache/spark/sql/rapids/shims/FileCommitProtocolShims.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*** spark-rapids-shim-json-lines -{"spark": "411"} -spark-rapids-shim-json-lines ***/ -package org.apache.spark.sql.rapids.shims - -import org.apache.hadoop.mapreduce.TaskAttemptContext - -import org.apache.spark.internal.io.{FileCommitProtocol, FileNameSpec} - -/** - * Shim for FileCommitProtocol.newTaskTempFile API in Spark 4.1.0+. - * Uses the new (spec: FileNameSpec) signature instead of deprecated (ext: String). - */ -object FileCommitProtocolShims { - def newTaskTempFile( - committer: FileCommitProtocol, - taskContext: TaskAttemptContext, - dir: Option[String], - ext: String): String = { - // FileNameSpec(prefix, suffix) - we put ext as suffix with empty prefix - committer.newTaskTempFile(taskContext, dir, FileNameSpec("", ext)) - } - - def newTaskTempFileAbsPath( - committer: FileCommitProtocol, - taskContext: TaskAttemptContext, - absoluteDir: String, - ext: String): String = { - // FileNameSpec(prefix, suffix) - we put ext as suffix with empty prefix - committer.newTaskTempFileAbsPath(taskContext, absoluteDir, FileNameSpec("", ext)) - } -} diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/GpuSortRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/GpuSortRetrySuite.scala index 1e9be0c55d9..e08c3d6b178 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/GpuSortRetrySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/GpuSortRetrySuite.scala @@ -179,11 +179,11 @@ class GpuSortRetrySuite extends RmmSparkRetrySuiteBase with MockitoSugar { val eachBatchIter = new GpuSortEachBatchIterator( batchIter(2), gpuSorter, - false, - NoopMetric, - NoopMetric, - NoopMetric, - NoopMetric) + singleBatch = false, + opTime = NoopMetric, + sortTime = NoopMetric, + outputBatches = NoopMetric, + outputRows = NoopMetric) RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 2, RmmSpark.OomInjectionType.GPU.ordinal, 0) while (eachBatchIter.hasNext) { @@ -208,11 +208,11 @@ class GpuSortRetrySuite extends RmmSparkRetrySuiteBase with MockitoSugar { val eachBatchIter = new GpuSortEachBatchIterator( inputIter, gpuSorter, - false, - NoopMetric, - NoopMetric, - NoopMetric, - NoopMetric) + singleBatch = false, + opTime = NoopMetric, + sortTime = NoopMetric, + outputBatches = NoopMetric, + outputRows = NoopMetric) RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1, RmmSpark.OomInjectionType.GPU.ordinal, 0) assertThrows[GpuSplitAndRetryOOM] { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/window/GpuUnboundedToUnboundedAggWindowSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/window/GpuUnboundedToUnboundedAggWindowSuite.scala index fbdda7e2de9..8a61fe51169 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/window/GpuUnboundedToUnboundedAggWindowSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/window/GpuUnboundedToUnboundedAggWindowSuite.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024-2025, NVIDIA CORPORATION. + * Copyright (c) 2024-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,7 +40,7 @@ class GpuUnboundedToUnboundedAggWindowSuite extends RmmSparkRetrySuiteBase { val finalProject = GpuUnboundedToUnboundedAggWindowIterator.computeFinalProject( rideAlongOutput, repeatOutput, repeatOutput ++ rideAlongOutput, Map.empty) - val conf = GpuUnboundedToUnboundedAggStages(Seq.empty, Seq.empty, Seq.empty, + val conf = new GpuUnboundedToUnboundedAggStages(Seq.empty, Seq.empty, Seq.empty, Seq.empty, finalProject) def makeRepeatCb(): SpillableColumnarBatch = { @@ -89,7 +89,7 @@ class GpuUnboundedToUnboundedAggWindowSuite extends RmmSparkRetrySuiteBase { rowsRemaining -= rowsToAdd rideAlongList.add(makeRideAlongCb(rowsToAdd.toInt)) } - val inputIter = Seq(SecondPassAggResult(rideAlongList, makeRepeatCb())).toIterator + val inputIter = Seq(new SecondPassAggResult(rideAlongList, makeRepeatCb())).toIterator val splitIter = new GpuUnboundedToUnboundedAggSliceBySizeIterator(inputIter, conf, targetSizeBytes, NoopMetric) val repeatIter = new GpuUnboundedToUnboundedAggFinalIterator(splitIter, conf,