Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,31 +40,12 @@
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids.{ExprChecks, ExprRule, GpuCast, GpuExpression, GpuOverrides, TypeSig, UnaryExprMeta}
import com.nvidia.spark.rapids.ExprRule

import org.apache.spark.sql.catalyst.expressions.{CheckOverflowInTableInsert, Expression}
import org.apache.spark.sql.rapids.GpuCheckOverflowInTableInsert
import org.apache.spark.sql.catalyst.expressions.Expression

trait Spark331PlusNonDBShims extends Spark330PlusNonDBShims {
override def getExprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = {
val map: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Seq(
// Add expression CheckOverflowInTableInsert starting Spark-3.3.1+
// Accepts all types as input as the child Cast does the type checking and the calculations.
GpuOverrides.expr[CheckOverflowInTableInsert](
"Casting a numeric value as another numeric type in store assignment",
ExprChecks.unaryProjectInputMatchesOutput(
TypeSig.all,
TypeSig.all),
(t, conf, p, r) => new UnaryExprMeta[CheckOverflowInTableInsert](t, conf, p, r) {
override def convertToGpu(child: Expression): GpuExpression = {
child match {
case c: GpuCast => GpuCheckOverflowInTableInsert(c, t.columnName)
case _ =>
throw new IllegalStateException("Expression child is not of Type GpuCast")
}
}
})
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
super.getExprs ++ map
super.getExprs ++ CheckOverflowInTableInsertShims.exprs
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -24,14 +24,15 @@ package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids._

import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, DataWritingCommand, RunnableCommand}
import org.apache.spark.sql.execution.command.{DataWritingCommand, RunnableCommand}

object SparkShimImpl extends Spark331PlusNonDBShims with AnsiCastRuleShims {
override def getDataWriteCmds: Map[Class[_ <: DataWritingCommand],
DataWritingCommandRule[_ <: DataWritingCommand]] = {
Seq(GpuOverrides.dataWriteCmd[CreateDataSourceTableAsSelectCommand](
"Create table with select command",
(a, conf, p, r) => new CreateDataSourceTableAsSelectCommandMeta(a, conf, p, r))
Seq(
GpuOverrides.dataWriteCmdFromShim(
CreateDataSourceTableAsSelectRules.dataWriteCmd,
(a, conf, p, r) => new CreateDataSourceTableAsSelectCommandMeta(a, conf, p, r))
).map(r => (r.getClassFor.asSubclass(classOf[DataWritingCommand]), r)).toMap
}

Expand Down

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ import com.nvidia.spark.rapids._

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, DataWritingCommand, RunnableCommand}
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.command.{DataWritingCommand, RunnableCommand}

trait Spark332PlusDBShims extends Spark330PlusDBShims {
// AnsiCast is removed from Spark3.4.0
Expand All @@ -47,19 +46,8 @@ trait Spark332PlusDBShims extends Spark330PlusDBShims {
super.getExprs ++ shimExprs
}

private val shimExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = Seq(
GpuOverrides.exec[WriteFilesExec](
"v1 write files",
// WriteFilesExec always has patterns:
// InsertIntoHadoopFsRelationCommand(WriteFilesExec) or InsertIntoHiveTable(WriteFilesExec)
// The parent node of `WriteFilesExec` will check the types, here just let type check pass
ExecChecks(TypeSig.all, TypeSig.all),
(write, conf, p, r) => new GpuWriteFilesMeta(write, conf, p, r)
)
).map(r => (r.getClassFor.asSubclass(classOf[SparkPlan]), r)).toMap

override def getExecs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] =
super.getExecs ++ shimExecs
super.getExecs ++ WriteFilesExecRule.execs

override def getDataWriteCmds: Map[Class[_ <: DataWritingCommand],
DataWritingCommandRule[_ <: DataWritingCommand]] = {
Expand All @@ -71,8 +59,8 @@ trait Spark332PlusDBShims extends Spark330PlusDBShims {
override def getRunnableCmds: Map[Class[_ <: RunnableCommand],
RunnableCommandRule[_ <: RunnableCommand]] = {
Seq(
GpuOverrides.runnableCmd[CreateDataSourceTableAsSelectCommand](
"Write to a data source",
GpuOverrides.runnableCmdFromShim(
CreateDataSourceTableAsSelectRules.runnableCmd,
(a, conf, p, r) => new CreateDataSourceTableAsSelectCommandMeta(a, conf, p, r))
).map(r => (r.getClassFor.asSubclass(classOf[RunnableCommand]), r)).toMap
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2026, NVIDIA CORPORATION.
* Copyright (c) 2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -38,25 +38,25 @@
{"spark": "402"}
{"spark": "411"}
spark-rapids-shim-json-lines ***/
package org.apache.spark.sql.rapids.shims
package com.nvidia.spark.rapids.shims

import org.apache.spark.SparkUpgradeException
import com.nvidia.spark.rapids.{ExecChecks, ExecRule, GpuOverrides, TypeSig}

object SparkUpgradeExceptionShims {
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.datasources.GpuWriteFilesMeta

def newSparkUpgradeException(
version: String,
message: String,
cause: Throwable): SparkUpgradeException = {
new SparkUpgradeException(
"INCONSISTENT_BEHAVIOR_CROSS_VERSION",
Map(version -> message),
cause)
}

// Used in tests to compare the class seen in an exception to
// `SparkUpgradeException` which is private in Spark
def getSparkUpgradeExceptionClass: Class[_] = {
classOf[SparkUpgradeException]
object WriteFilesExecRule {
val execs: Map[Class[_ <: SparkPlan], ExecRule[_ <: SparkPlan]] = {
Seq(
GpuOverrides.execFromShim(
WriteFilesExecShims.exec,
// WriteFilesExec always has patterns:
// InsertIntoHadoopFsRelationCommand(WriteFilesExec) or
// InsertIntoHiveTable(WriteFilesExec)
// The parent node of `WriteFilesExec` will check the types, here just let type check pass.
ExecChecks(TypeSig.all, TypeSig.all),
(write, conf, p, r) => new GpuWriteFilesMeta(write, conf, p, r)
)
).map(r => (r.getClassFor.asSubclass(classOf[SparkPlan]), r)).toMap
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,15 @@ import scala.util.control.NonFatal
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path

import org.apache.spark.internal.Logging
object HiveFileUtil {
private val log = org.slf4j.LoggerFactory.getLogger(HiveFileUtil.getClass)

private def logWarning(msg: => String): Unit = {
if (log.isWarnEnabled) {
log.warn(msg)
}
}

object HiveFileUtil extends Logging {

// prior to Spark 3.4.0, this method was accessible via the SaveAsHiveFile trait, but
// was removed in https://github.com/apache/spark/pull/39277
Expand Down

This file was deleted.

This file was deleted.

Loading
Loading