Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions docs/additional-functionality/rapids-udfs.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ The GPU support for Pandas UDF is an experimental feature, and may change at any
---

GPU support for Pandas UDF is built on Apache Spark's [Pandas UDF(user defined
function)](https://archive.apache.org/dist/spark/docs/3.2.0/api/python/user_guide/sql/arrow_pandas.html#pandas-udfs-a-k-a-vectorized-udfs),
function)](https://spark.apache.org/docs/3.5.7/api/python/user_guide/sql/arrow_pandas.html#pandas-udfs-a-k-a-vectorized-udfs),
and has two features:

- **GPU Assignment(Scheduling) in Python Process**: Let the Python process share the same GPU with
Expand Down Expand Up @@ -201,12 +201,12 @@ Accelerator has a 1-1 mapping support for each of them.

|Spark Execution Plan|Data Transfer Accelerated|Use Case|
|----------------------|----------|--------|
|ArrowEvalPythonExec|yes|[Series to Series](https://archive.apache.org/dist/spark/docs/3.2.0/api/python/user_guide/sql/arrow_pandas.html#series-to-series), [Iterator of Series to Iterator of Series](https://archive.apache.org/dist/spark/docs/3.2.0/api/python/user_guide/sql/arrow_pandas.html#iterator-of-series-to-iterator-of-series) and [Iterator of Multiple Series to Iterator of Series](https://archive.apache.org/dist/spark/docs/3.2.0/api/python/user_guide/sql/arrow_pandas.html#iterator-of-multiple-series-to-iterator-of-series)|
|MapInPandasExec|yes|[Map](https://archive.apache.org/dist/spark/docs/3.2.0/api/python/user_guide/sql/arrow_pandas.html#map)|
|WindowInPandasExec|yes|[Window](https://archive.apache.org/dist/spark/docs/3.2.0/api/python/user_guide/sql/arrow_pandas.html#series-to-scalar)|
|FlatMapGroupsInPandasExec|yes|[Grouped Map](https://archive.apache.org/dist/spark/docs/3.2.0/api/python/user_guide/sql/arrow_pandas.html#grouped-map)|
|AggregateInPandasExec|yes|[Aggregate](https://archive.apache.org/dist/spark/docs/3.2.0/api/python/user_guide/sql/arrow_pandas.html#series-to-scalar)|
|FlatMapCoGroupsInPandasExec|yes|[Co-grouped Map](https://archive.apache.org/dist/spark/docs/3.2.0/api/python/user_guide/sql/arrow_pandas.html#co-grouped-map)|
|ArrowEvalPythonExec|yes|[Series to Series](https://spark.apache.org/docs/3.5.7/api/python/user_guide/sql/arrow_pandas.html#series-to-series), [Iterator of Series to Iterator of Series](https://spark.apache.org/docs/3.5.7/api/python/user_guide/sql/arrow_pandas.html#iterator-of-series-to-iterator-of-series) and [Iterator of Multiple Series to Iterator of Series](https://spark.apache.org/docs/3.5.7/api/python/user_guide/sql/arrow_pandas.html#iterator-of-multiple-series-to-iterator-of-series)|
|MapInPandasExec|yes|[Map](https://spark.apache.org/docs/3.5.7/api/python/user_guide/sql/arrow_pandas.html#map)|
|WindowInPandasExec|yes|[Window](https://spark.apache.org/docs/3.5.7/api/python/user_guide/sql/arrow_pandas.html#series-to-scalar)|
|FlatMapGroupsInPandasExec|yes|[Grouped Map](https://spark.apache.org/docs/3.5.7/api/python/user_guide/sql/arrow_pandas.html#grouped-map)|
|AggregateInPandasExec|yes|[Aggregate](https://spark.apache.org/docs/3.5.7/api/python/user_guide/sql/arrow_pandas.html#series-to-scalar)|
|FlatMapCoGroupsInPandasExec|yes|[Co-grouped Map](https://spark.apache.org/docs/3.5.7/api/python/user_guide/sql/arrow_pandas.html#co-grouped-map)|


### Other Configuration
Expand Down
2 changes: 1 addition & 1 deletion docs/dev/adaptive-query.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ optimizer rules:

```scala
extensions.injectColumnar(_ => ColumnarOverrideRules())
extensions.injectQueryStagePrepRule(_ => GpuQueryStagePrepOverrides())
extensions.injectQueryStagePrepRule(_ => new GpuQueryStagePrepOverrides)
```

The `ColumnarOverrideRules` are used whether AQE is enabled or not, and the
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -21,14 +21,13 @@ import java.nio.charset.Charset
import com.nvidia.spark.rapids.shims.ShimExpression
import com.nvidia.spark.udf.CatalystExpressionBuilder.simplify
import javassist.bytecode.{CodeIterator, Opcode}
import org.slf4j.LoggerFactory

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._


private[udf] object Repr {

abstract class CompilerInternal(name: String) extends ShimExpression {
Expand Down Expand Up @@ -154,7 +153,7 @@ private[udf] object Repr {
if (elemType == t) {
Seq(args.head)
} else {
Seq(Cast(args.head, t))
Seq(new Cast(args.head, t, None))
}
}
}
Expand Down Expand Up @@ -216,7 +215,7 @@ private[udf] object Repr {
* @param opcode
* @param operand
*/
case class Instruction(opcode: Int, operand: Int, instructionStr: String) extends Logging {
case class Instruction(opcode: Int, operand: Int, instructionStr: String) {
def makeState(lambdaReflection: LambdaReflection, basicBlock: BB, state: State): State = {
val st = opcode match {
case Opcode.ALOAD_0 | Opcode.DLOAD_0 | Opcode.FLOAD_0 |
Expand Down Expand Up @@ -321,7 +320,10 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
})
case _ => throw new SparkException("Unsupported instruction: " + instructionStr)
}
logDebug(s"[Instruction] ${instructionStr} got new state: ${st} from state: ${state}")
if (Instruction.log.isDebugEnabled) {
Instruction.log.debug(
s"[Instruction] ${instructionStr} got new state: ${st} from state: ${state}")
}
st
}

Expand Down Expand Up @@ -441,7 +443,7 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
state: State,
dataType: DataType): State = {
val State(locals, top :: rest, cond, expr) = state
State(locals, Cast(top, dataType) :: rest, cond, expr)
State(locals, new Cast(top, dataType, None) :: rest, cond, expr)
}

private def checkcast(lambdaReflection: LambdaReflection, state: State): State = {
Expand Down Expand Up @@ -774,13 +776,13 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
EndsWith(args.head, args.last)
case "equals" =>
checkArgs(methodName, List(StringType, StringType), args)
Cast(EqualNullSafe(args.head, args.last), IntegerType)
new Cast(EqualNullSafe(args.head, args.last), IntegerType, None)
case "equalsIgnoreCase" =>
checkArgs(methodName, List(StringType, StringType), args)
Cast(EqualNullSafe(Upper(args.head), Upper(args.last)), IntegerType)
new Cast(EqualNullSafe(Upper(args.head), Upper(args.last)), IntegerType, None)
case "isEmpty" =>
checkArgs(methodName, List(StringType), args)
Cast(EqualTo(Length(args.head), Literal(0)), IntegerType)
new Cast(EqualTo(Length(args.head), Literal(0)), IntegerType, None)
case "length" =>
checkArgs(methodName, List(StringType), args)
Length(args.head)
Expand Down Expand Up @@ -836,7 +838,7 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
s"String.${methodName}: " +
s"${args.head.dataType}")
}
Cast(args.head, StringType)
new Cast(args.head, StringType, None)
case "indexOf" =>
if (args.length == 2) {
if (args(1).dataType == StringType) {
Expand Down Expand Up @@ -884,10 +886,10 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
case "getBytes" =>
if (args.length == 1) {
checkArgs(methodName, List(StringType), args)
Encode(args.head, Literal(Charset.defaultCharset.toString))
new Encode(args.head, Literal(Charset.defaultCharset.toString))
} else if (args.length == 2) {
checkArgs(methodName, List(StringType, StringType), args)
Encode(args.head, args.last)
new Encode(args.head, args.last)
} else {
throw new SparkException(
s"String.${methodName} operation expects 1 or 2 argument(s), " +
Expand Down Expand Up @@ -953,6 +955,8 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
* Ultimately, every opcode will have to be covered here.
*/
object Instruction {
private val log = LoggerFactory.getLogger(classOf[Instruction])

def apply(codeIterator: CodeIterator, offset: Int, instructionStr: String): Instruction = {
val opcode: Int = codeIterator.byteAt(offset)
val operand: Int = opcode match {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -21,14 +21,13 @@ import java.nio.charset.Charset
import com.nvidia.spark.rapids.shims.ShimExpression
import com.nvidia.spark.udf.CatalystExpressionBuilder.simplify
import javassist.bytecode.{CodeIterator, Opcode}
import org.slf4j.LoggerFactory

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._


private[udf] object Repr {

abstract class CompilerInternal(name: String) extends ShimExpression {
Expand Down Expand Up @@ -149,7 +148,7 @@ private[udf] object Repr {
if (elemType == t) {
Seq(args.head)
} else {
Seq(Cast(args.head, t))
Seq(new Cast(args.head, t, None))
}
}
}
Expand Down Expand Up @@ -211,7 +210,7 @@ private[udf] object Repr {
* @param opcode
* @param operand
*/
case class Instruction(opcode: Int, operand: Int, instructionStr: String) extends Logging {
case class Instruction(opcode: Int, operand: Int, instructionStr: String) {
def makeState(lambdaReflection: LambdaReflection, basicBlock: BB, state: State): State = {
val st = opcode match {
case Opcode.ALOAD_0 | Opcode.DLOAD_0 | Opcode.FLOAD_0 |
Expand Down Expand Up @@ -322,7 +321,10 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
})
case _ => throw new SparkException("Unsupported instruction: " + instructionStr)
}
logDebug(s"[Instruction] ${instructionStr} got new state: ${st} from state: ${state}")
if (Instruction.log.isDebugEnabled) {
Instruction.log.debug(
s"[Instruction] ${instructionStr} got new state: ${st} from state: ${state}")
}
st
}

Expand Down Expand Up @@ -442,7 +444,7 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
state: State,
dataType: DataType): State = {
val State(locals, top :: rest, cond, expr) = state
State(locals, Cast(top, dataType) :: rest, cond, expr)
State(locals, new Cast(top, dataType, None) :: rest, cond, expr)
}

private def checkcast(lambdaReflection: LambdaReflection, state: State): State = {
Expand Down Expand Up @@ -800,13 +802,13 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
EndsWith(args.head, args.last)
case "equals" =>
checkArgs(methodName, List(StringType, StringType), args)
Cast(EqualNullSafe(args.head, args.last), IntegerType)
new Cast(EqualNullSafe(args.head, args.last), IntegerType, None)
case "equalsIgnoreCase" =>
checkArgs(methodName, List(StringType, StringType), args)
Cast(EqualNullSafe(Upper(args.head), Upper(args.last)), IntegerType)
new Cast(EqualNullSafe(Upper(args.head), Upper(args.last)), IntegerType, None)
case "isEmpty" =>
checkArgs(methodName, List(StringType), args)
Cast(EqualTo(Length(args.head), Literal(0)), IntegerType)
new Cast(EqualTo(Length(args.head), Literal(0)), IntegerType, None)
case "length" =>
checkArgs(methodName, List(StringType), args)
Length(args.head)
Expand Down Expand Up @@ -862,7 +864,7 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
s"String.${methodName}: " +
s"${args.head.dataType}")
}
Cast(args.head, StringType)
new Cast(args.head, StringType, None)
case "indexOf" =>
if (args.length == 2) {
if (args(1).dataType == StringType) {
Expand Down Expand Up @@ -910,10 +912,10 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
case "getBytes" =>
if (args.length == 1) {
checkArgs(methodName, List(StringType), args)
Encode(args.head, Literal(Charset.defaultCharset.toString))
new Encode(args.head, Literal(Charset.defaultCharset.toString))
} else if (args.length == 2) {
checkArgs(methodName, List(StringType, StringType), args)
Encode(args.head, args.last)
new Encode(args.head, args.last)
} else {
throw new SparkException(
s"String.${methodName} operation expects 1 or 2 argument(s), " +
Expand Down Expand Up @@ -979,6 +981,8 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend
* Ultimately, every opcode will have to be covered here.
*/
object Instruction {
private val log = LoggerFactory.getLogger(classOf[Instruction])

def apply(codeIterator: CodeIterator, offset: Int, instructionStr: String): Instruction = {
val opcode: Int = codeIterator.byteAt(offset)
val operand: Int = opcode match {
Expand Down
32 changes: 23 additions & 9 deletions udf-compiler/src/main/scala/com/nvidia/spark/udf/CFG.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
* Copyright (c) 2019-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,9 +22,9 @@ import scala.collection.immutable.{HashMap, SortedMap, SortedSet}
import CatalystExpressionBuilder.simplify
import javassist.bytecode.{CodeIterator, ConstPool, InstructionPrinter, Opcode}
import javassist.bytecode.analysis.Util
import org.slf4j.LoggerFactory

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._

/**
Expand All @@ -43,7 +43,7 @@ import org.apache.spark.sql.catalyst.expressions._
*
* @param instructionTable
*/
case class BB(instructionTable: SortedMap[Int, Instruction]) extends Logging {
case class BB(instructionTable: SortedMap[Int, Instruction]) {
def offset: Int = instructionTable.head._1

def last: (Int, Instruction) = instructionTable.last
Expand All @@ -54,18 +54,24 @@ case class BB(instructionTable: SortedMap[Int, Instruction]) extends Logging {

def propagateState(cfg: CFG, states: Map[BB, State]): Map[BB, State] = {
val state@State(_, _, cond, expr) = states(this)
logDebug(s"[BB.propagateState] propagating condition: ${cond} from state ${state} " +
s"onto states: ${states}")
if (BB.log.isDebugEnabled) {
BB.log.debug(s"[BB.propagateState] propagating condition: ${cond} from state ${state} " +
s"onto states: ${states}")
}
lastInstruction.opcode match {
case Opcode.IF_ICMPEQ | Opcode.IF_ICMPNE | Opcode.IF_ICMPLT |
Opcode.IF_ICMPGE | Opcode.IF_ICMPGT | Opcode.IF_ICMPLE |
Opcode.IFLT | Opcode.IFLE | Opcode.IFGT | Opcode.IFGE |
Opcode.IFEQ | Opcode.IFNE | Opcode.IFNULL | Opcode.IFNONNULL => {
logTrace(s"[BB.propagateState] lastInstruction: ${lastInstruction.instructionStr}")
if (BB.log.isTraceEnabled) {
BB.log.trace(s"[BB.propagateState] lastInstruction: ${lastInstruction.instructionStr}")
}

// An if statement has both a false and a true successor
val (0, falseSucc) :: (1, trueSucc) :: Nil = cfg.successor(this)
logTrace(s"[BB.propagateState] falseSucc ${falseSucc} trueSuccc ${trueSucc}")
if (BB.log.isTraceEnabled) {
BB.log.trace(s"[BB.propagateState] falseSucc ${falseSucc} trueSuccc ${trueSucc}")
}

// cond is the entry condition into the condition block, and expr is the
// actual condition for IF* (see Instruction.ifOp).
Expand All @@ -80,15 +86,19 @@ case class BB(instructionTable: SortedMap[Int, Instruction]) extends Logging {
val falseState = state.copy(cond = simplify(And(cond, Not(expr.get))))
val trueState = state.copy(cond = simplify(And(cond, expr.get)))

logDebug(s"[BB.propagateState] States before: ${states}")
if (BB.log.isDebugEnabled) {
BB.log.debug(s"[BB.propagateState] States before: ${states}")
}

// Each successor may already have the state populated if it has
// multiple predecessors.
// Update the states by merging the new state with the existing state.
val newStates = (states
+ (falseSucc -> falseState.merge(states.get(falseSucc)))
+ (trueSucc -> trueState.merge(states.get(trueSucc))))
logDebug(s"[BB.propagateState] States after: ${newStates}")
if (BB.log.isDebugEnabled) {
BB.log.debug(s"[BB.propagateState] States after: ${newStates}")
}
newStates
}
case Opcode.TABLESWITCH | Opcode.LOOKUPSWITCH =>
Expand Down Expand Up @@ -120,6 +130,10 @@ case class BB(instructionTable: SortedMap[Int, Instruction]) extends Logging {
}
}

object BB {
private val log = LoggerFactory.getLogger(classOf[BB])
}

/**
* The Control Flow Graph object.
*
Expand Down
Loading
Loading