From f1ce3e848ea970d18cd43e43745de90cb1a7d0f3 Mon Sep 17 00:00:00 2001 From: Gera Shegalov Date: Sun, 7 Jun 2026 06:52:16 -0700 Subject: [PATCH] Update UDF compiler for shared helper layout Signed-off-by: Gera Shegalov --- docs/additional-functionality/rapids-udfs.md | 14 +++--- docs/dev/adaptive-query.md | 2 +- .../com/nvidia/spark/udf/Instruction.scala | 30 +++++++------ .../com/nvidia/spark/udf/Instruction.scala | 30 +++++++------ .../main/scala/com/nvidia/spark/udf/CFG.scala | 32 ++++++++++---- .../spark/udf/CatalystExpressionBuilder.scala | 43 +++++++++++++------ .../com/nvidia/spark/udf/GpuScalaUDF.scala | 16 ++++--- .../nvidia/spark/udf/LogicalPlanRules.scala | 5 +-- .../scala/com/nvidia/spark/udf/State.scala | 6 +-- 9 files changed, 110 insertions(+), 68 deletions(-) diff --git a/docs/additional-functionality/rapids-udfs.md b/docs/additional-functionality/rapids-udfs.md index d498a841ef1..e4144460f0e 100644 --- a/docs/additional-functionality/rapids-udfs.md +++ b/docs/additional-functionality/rapids-udfs.md @@ -152,7 +152,7 @@ The GPU support for Pandas UDF is an experimental feature, and may change at any --- GPU support for Pandas UDF is built on Apache Spark's [Pandas UDF(user defined -function)](https://archive.apache.org/dist/spark/docs/3.2.0/api/python/user_guide/sql/arrow_pandas.html#pandas-udfs-a-k-a-vectorized-udfs), +function)](https://spark.apache.org/docs/3.5.7/api/python/user_guide/sql/arrow_pandas.html#pandas-udfs-a-k-a-vectorized-udfs), and has two features: - **GPU Assignment(Scheduling) in Python Process**: Let the Python process share the same GPU with @@ -201,12 +201,12 @@ Accelerator has a 1-1 mapping support for each of them. |Spark Execution Plan|Data Transfer Accelerated|Use Case| |----------------------|----------|--------| - |ArrowEvalPythonExec|yes|[Series to Series](https://archive.apache.org/dist/spark/docs/3.2.0/api/python/user_guide/sql/arrow_pandas.html#series-to-series), [Iterator of Series to Iterator of Series](https://archive.apache.org/dist/spark/docs/3.2.0/api/python/user_guide/sql/arrow_pandas.html#iterator-of-series-to-iterator-of-series) and [Iterator of Multiple Series to Iterator of Series](https://archive.apache.org/dist/spark/docs/3.2.0/api/python/user_guide/sql/arrow_pandas.html#iterator-of-multiple-series-to-iterator-of-series)| - |MapInPandasExec|yes|[Map](https://archive.apache.org/dist/spark/docs/3.2.0/api/python/user_guide/sql/arrow_pandas.html#map)| - |WindowInPandasExec|yes|[Window](https://archive.apache.org/dist/spark/docs/3.2.0/api/python/user_guide/sql/arrow_pandas.html#series-to-scalar)| - |FlatMapGroupsInPandasExec|yes|[Grouped Map](https://archive.apache.org/dist/spark/docs/3.2.0/api/python/user_guide/sql/arrow_pandas.html#grouped-map)| - |AggregateInPandasExec|yes|[Aggregate](https://archive.apache.org/dist/spark/docs/3.2.0/api/python/user_guide/sql/arrow_pandas.html#series-to-scalar)| - |FlatMapCoGroupsInPandasExec|yes|[Co-grouped Map](https://archive.apache.org/dist/spark/docs/3.2.0/api/python/user_guide/sql/arrow_pandas.html#co-grouped-map)| + |ArrowEvalPythonExec|yes|[Series to Series](https://spark.apache.org/docs/3.5.7/api/python/user_guide/sql/arrow_pandas.html#series-to-series), [Iterator of Series to Iterator of Series](https://spark.apache.org/docs/3.5.7/api/python/user_guide/sql/arrow_pandas.html#iterator-of-series-to-iterator-of-series) and [Iterator of Multiple Series to Iterator of Series](https://spark.apache.org/docs/3.5.7/api/python/user_guide/sql/arrow_pandas.html#iterator-of-multiple-series-to-iterator-of-series)| + |MapInPandasExec|yes|[Map](https://spark.apache.org/docs/3.5.7/api/python/user_guide/sql/arrow_pandas.html#map)| + |WindowInPandasExec|yes|[Window](https://spark.apache.org/docs/3.5.7/api/python/user_guide/sql/arrow_pandas.html#series-to-scalar)| + |FlatMapGroupsInPandasExec|yes|[Grouped Map](https://spark.apache.org/docs/3.5.7/api/python/user_guide/sql/arrow_pandas.html#grouped-map)| + |AggregateInPandasExec|yes|[Aggregate](https://spark.apache.org/docs/3.5.7/api/python/user_guide/sql/arrow_pandas.html#series-to-scalar)| + |FlatMapCoGroupsInPandasExec|yes|[Co-grouped Map](https://spark.apache.org/docs/3.5.7/api/python/user_guide/sql/arrow_pandas.html#co-grouped-map)| ### Other Configuration diff --git a/docs/dev/adaptive-query.md b/docs/dev/adaptive-query.md index c3e5568bfb4..cf9c8c126e4 100644 --- a/docs/dev/adaptive-query.md +++ b/docs/dev/adaptive-query.md @@ -51,7 +51,7 @@ optimizer rules: ```scala extensions.injectColumnar(_ => ColumnarOverrideRules()) -extensions.injectQueryStagePrepRule(_ => GpuQueryStagePrepOverrides()) +extensions.injectQueryStagePrepRule(_ => new GpuQueryStagePrepOverrides) ``` The `ColumnarOverrideRules` are used whether AQE is enabled or not, and the diff --git a/udf-compiler/src/main/scala-2.12/com/nvidia/spark/udf/Instruction.scala b/udf-compiler/src/main/scala-2.12/com/nvidia/spark/udf/Instruction.scala index 7adaff26d99..8ce57b84f33 100644 --- a/udf-compiler/src/main/scala-2.12/com/nvidia/spark/udf/Instruction.scala +++ b/udf-compiler/src/main/scala-2.12/com/nvidia/spark/udf/Instruction.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,14 +21,13 @@ import java.nio.charset.Charset import com.nvidia.spark.rapids.shims.ShimExpression import com.nvidia.spark.udf.CatalystExpressionBuilder.simplify import javassist.bytecode.{CodeIterator, Opcode} +import org.slf4j.LoggerFactory import org.apache.spark.SparkException -import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ - private[udf] object Repr { abstract class CompilerInternal(name: String) extends ShimExpression { @@ -154,7 +153,7 @@ private[udf] object Repr { if (elemType == t) { Seq(args.head) } else { - Seq(Cast(args.head, t)) + Seq(new Cast(args.head, t, None)) } } } @@ -216,7 +215,7 @@ private[udf] object Repr { * @param opcode * @param operand */ -case class Instruction(opcode: Int, operand: Int, instructionStr: String) extends Logging { +case class Instruction(opcode: Int, operand: Int, instructionStr: String) { def makeState(lambdaReflection: LambdaReflection, basicBlock: BB, state: State): State = { val st = opcode match { case Opcode.ALOAD_0 | Opcode.DLOAD_0 | Opcode.FLOAD_0 | @@ -321,7 +320,10 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend }) case _ => throw new SparkException("Unsupported instruction: " + instructionStr) } - logDebug(s"[Instruction] ${instructionStr} got new state: ${st} from state: ${state}") + if (Instruction.log.isDebugEnabled) { + Instruction.log.debug( + s"[Instruction] ${instructionStr} got new state: ${st} from state: ${state}") + } st } @@ -441,7 +443,7 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend state: State, dataType: DataType): State = { val State(locals, top :: rest, cond, expr) = state - State(locals, Cast(top, dataType) :: rest, cond, expr) + State(locals, new Cast(top, dataType, None) :: rest, cond, expr) } private def checkcast(lambdaReflection: LambdaReflection, state: State): State = { @@ -774,13 +776,13 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend EndsWith(args.head, args.last) case "equals" => checkArgs(methodName, List(StringType, StringType), args) - Cast(EqualNullSafe(args.head, args.last), IntegerType) + new Cast(EqualNullSafe(args.head, args.last), IntegerType, None) case "equalsIgnoreCase" => checkArgs(methodName, List(StringType, StringType), args) - Cast(EqualNullSafe(Upper(args.head), Upper(args.last)), IntegerType) + new Cast(EqualNullSafe(Upper(args.head), Upper(args.last)), IntegerType, None) case "isEmpty" => checkArgs(methodName, List(StringType), args) - Cast(EqualTo(Length(args.head), Literal(0)), IntegerType) + new Cast(EqualTo(Length(args.head), Literal(0)), IntegerType, None) case "length" => checkArgs(methodName, List(StringType), args) Length(args.head) @@ -836,7 +838,7 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend s"String.${methodName}: " + s"${args.head.dataType}") } - Cast(args.head, StringType) + new Cast(args.head, StringType, None) case "indexOf" => if (args.length == 2) { if (args(1).dataType == StringType) { @@ -884,10 +886,10 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend case "getBytes" => if (args.length == 1) { checkArgs(methodName, List(StringType), args) - Encode(args.head, Literal(Charset.defaultCharset.toString)) + new Encode(args.head, Literal(Charset.defaultCharset.toString)) } else if (args.length == 2) { checkArgs(methodName, List(StringType, StringType), args) - Encode(args.head, args.last) + new Encode(args.head, args.last) } else { throw new SparkException( s"String.${methodName} operation expects 1 or 2 argument(s), " + @@ -953,6 +955,8 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend * Ultimately, every opcode will have to be covered here. */ object Instruction { + private val log = LoggerFactory.getLogger(classOf[Instruction]) + def apply(codeIterator: CodeIterator, offset: Int, instructionStr: String): Instruction = { val opcode: Int = codeIterator.byteAt(offset) val operand: Int = opcode match { diff --git a/udf-compiler/src/main/scala-2.13/com/nvidia/spark/udf/Instruction.scala b/udf-compiler/src/main/scala-2.13/com/nvidia/spark/udf/Instruction.scala index 36fe79da384..fbfe59b6bf7 100644 --- a/udf-compiler/src/main/scala-2.13/com/nvidia/spark/udf/Instruction.scala +++ b/udf-compiler/src/main/scala-2.13/com/nvidia/spark/udf/Instruction.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,14 +21,13 @@ import java.nio.charset.Charset import com.nvidia.spark.rapids.shims.ShimExpression import com.nvidia.spark.udf.CatalystExpressionBuilder.simplify import javassist.bytecode.{CodeIterator, Opcode} +import org.slf4j.LoggerFactory import org.apache.spark.SparkException -import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.analysis.TypeCoercion import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ - private[udf] object Repr { abstract class CompilerInternal(name: String) extends ShimExpression { @@ -149,7 +148,7 @@ private[udf] object Repr { if (elemType == t) { Seq(args.head) } else { - Seq(Cast(args.head, t)) + Seq(new Cast(args.head, t, None)) } } } @@ -211,7 +210,7 @@ private[udf] object Repr { * @param opcode * @param operand */ -case class Instruction(opcode: Int, operand: Int, instructionStr: String) extends Logging { +case class Instruction(opcode: Int, operand: Int, instructionStr: String) { def makeState(lambdaReflection: LambdaReflection, basicBlock: BB, state: State): State = { val st = opcode match { case Opcode.ALOAD_0 | Opcode.DLOAD_0 | Opcode.FLOAD_0 | @@ -322,7 +321,10 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend }) case _ => throw new SparkException("Unsupported instruction: " + instructionStr) } - logDebug(s"[Instruction] ${instructionStr} got new state: ${st} from state: ${state}") + if (Instruction.log.isDebugEnabled) { + Instruction.log.debug( + s"[Instruction] ${instructionStr} got new state: ${st} from state: ${state}") + } st } @@ -442,7 +444,7 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend state: State, dataType: DataType): State = { val State(locals, top :: rest, cond, expr) = state - State(locals, Cast(top, dataType) :: rest, cond, expr) + State(locals, new Cast(top, dataType, None) :: rest, cond, expr) } private def checkcast(lambdaReflection: LambdaReflection, state: State): State = { @@ -800,13 +802,13 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend EndsWith(args.head, args.last) case "equals" => checkArgs(methodName, List(StringType, StringType), args) - Cast(EqualNullSafe(args.head, args.last), IntegerType) + new Cast(EqualNullSafe(args.head, args.last), IntegerType, None) case "equalsIgnoreCase" => checkArgs(methodName, List(StringType, StringType), args) - Cast(EqualNullSafe(Upper(args.head), Upper(args.last)), IntegerType) + new Cast(EqualNullSafe(Upper(args.head), Upper(args.last)), IntegerType, None) case "isEmpty" => checkArgs(methodName, List(StringType), args) - Cast(EqualTo(Length(args.head), Literal(0)), IntegerType) + new Cast(EqualTo(Length(args.head), Literal(0)), IntegerType, None) case "length" => checkArgs(methodName, List(StringType), args) Length(args.head) @@ -862,7 +864,7 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend s"String.${methodName}: " + s"${args.head.dataType}") } - Cast(args.head, StringType) + new Cast(args.head, StringType, None) case "indexOf" => if (args.length == 2) { if (args(1).dataType == StringType) { @@ -910,10 +912,10 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend case "getBytes" => if (args.length == 1) { checkArgs(methodName, List(StringType), args) - Encode(args.head, Literal(Charset.defaultCharset.toString)) + new Encode(args.head, Literal(Charset.defaultCharset.toString)) } else if (args.length == 2) { checkArgs(methodName, List(StringType, StringType), args) - Encode(args.head, args.last) + new Encode(args.head, args.last) } else { throw new SparkException( s"String.${methodName} operation expects 1 or 2 argument(s), " + @@ -979,6 +981,8 @@ case class Instruction(opcode: Int, operand: Int, instructionStr: String) extend * Ultimately, every opcode will have to be covered here. */ object Instruction { + private val log = LoggerFactory.getLogger(classOf[Instruction]) + def apply(codeIterator: CodeIterator, offset: Int, instructionStr: String): Instruction = { val opcode: Int = codeIterator.byteAt(offset) val operand: Int = opcode match { diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/CFG.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/CFG.scala index 34472e890e0..1aa76958675 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/CFG.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/CFG.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. + * Copyright (c) 2019-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,9 +22,9 @@ import scala.collection.immutable.{HashMap, SortedMap, SortedSet} import CatalystExpressionBuilder.simplify import javassist.bytecode.{CodeIterator, ConstPool, InstructionPrinter, Opcode} import javassist.bytecode.analysis.Util +import org.slf4j.LoggerFactory import org.apache.spark.SparkException -import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ /** @@ -43,7 +43,7 @@ import org.apache.spark.sql.catalyst.expressions._ * * @param instructionTable */ -case class BB(instructionTable: SortedMap[Int, Instruction]) extends Logging { +case class BB(instructionTable: SortedMap[Int, Instruction]) { def offset: Int = instructionTable.head._1 def last: (Int, Instruction) = instructionTable.last @@ -54,18 +54,24 @@ case class BB(instructionTable: SortedMap[Int, Instruction]) extends Logging { def propagateState(cfg: CFG, states: Map[BB, State]): Map[BB, State] = { val state@State(_, _, cond, expr) = states(this) - logDebug(s"[BB.propagateState] propagating condition: ${cond} from state ${state} " + - s"onto states: ${states}") + if (BB.log.isDebugEnabled) { + BB.log.debug(s"[BB.propagateState] propagating condition: ${cond} from state ${state} " + + s"onto states: ${states}") + } lastInstruction.opcode match { case Opcode.IF_ICMPEQ | Opcode.IF_ICMPNE | Opcode.IF_ICMPLT | Opcode.IF_ICMPGE | Opcode.IF_ICMPGT | Opcode.IF_ICMPLE | Opcode.IFLT | Opcode.IFLE | Opcode.IFGT | Opcode.IFGE | Opcode.IFEQ | Opcode.IFNE | Opcode.IFNULL | Opcode.IFNONNULL => { - logTrace(s"[BB.propagateState] lastInstruction: ${lastInstruction.instructionStr}") + if (BB.log.isTraceEnabled) { + BB.log.trace(s"[BB.propagateState] lastInstruction: ${lastInstruction.instructionStr}") + } // An if statement has both a false and a true successor val (0, falseSucc) :: (1, trueSucc) :: Nil = cfg.successor(this) - logTrace(s"[BB.propagateState] falseSucc ${falseSucc} trueSuccc ${trueSucc}") + if (BB.log.isTraceEnabled) { + BB.log.trace(s"[BB.propagateState] falseSucc ${falseSucc} trueSuccc ${trueSucc}") + } // cond is the entry condition into the condition block, and expr is the // actual condition for IF* (see Instruction.ifOp). @@ -80,7 +86,9 @@ case class BB(instructionTable: SortedMap[Int, Instruction]) extends Logging { val falseState = state.copy(cond = simplify(And(cond, Not(expr.get)))) val trueState = state.copy(cond = simplify(And(cond, expr.get))) - logDebug(s"[BB.propagateState] States before: ${states}") + if (BB.log.isDebugEnabled) { + BB.log.debug(s"[BB.propagateState] States before: ${states}") + } // Each successor may already have the state populated if it has // multiple predecessors. @@ -88,7 +96,9 @@ case class BB(instructionTable: SortedMap[Int, Instruction]) extends Logging { val newStates = (states + (falseSucc -> falseState.merge(states.get(falseSucc))) + (trueSucc -> trueState.merge(states.get(trueSucc)))) - logDebug(s"[BB.propagateState] States after: ${newStates}") + if (BB.log.isDebugEnabled) { + BB.log.debug(s"[BB.propagateState] States after: ${newStates}") + } newStates } case Opcode.TABLESWITCH | Opcode.LOOKUPSWITCH => @@ -120,6 +130,10 @@ case class BB(instructionTable: SortedMap[Int, Instruction]) extends Logging { } } +object BB { + private val log = LoggerFactory.getLogger(classOf[BB]) +} + /** * The Control Flow Graph object. * diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala index 2628b17457f..919a63b59d9 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/CatalystExpressionBuilder.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,9 +19,9 @@ package com.nvidia.spark.udf import scala.annotation.tailrec import javassist.CtClass +import org.slf4j.LoggerFactory import org.apache.spark.SparkException -import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -42,7 +42,7 @@ import org.apache.spark.sql.types._ * * @param function the original Scala UDF provided by the user */ -case class CatalystExpressionBuilder(private val function: AnyRef) extends Logging { +case class CatalystExpressionBuilder(private val function: AnyRef) { final private val lambdaReflection: LambdaReflection = LambdaReflection(function) final private val cfg = CFG(lambdaReflection) @@ -72,23 +72,28 @@ case class CatalystExpressionBuilder(private val function: AnyRef) extends Loggi // pick first of the Basic Blocks, and start recursing val entryBlock = cfg.basicBlocks.head - logDebug(s"[CatalystExpressionBuilder] Attempting to compile: ${function}, " + - s"with children: ${children}, " + s"entry block: ${entryBlock}, and " + - s"entry state: ${entryState}") + if (CatalystExpressionBuilder.log.isDebugEnabled) { + CatalystExpressionBuilder.log.debug( + s"[CatalystExpressionBuilder] Attempting to compile: ${function}, " + + s"with children: ${children}, " + s"entry block: ${entryBlock}, and " + + s"entry state: ${entryState}") + } // start recursing val compiled = doCompile(List(entryBlock), Map(entryBlock -> entryState)).map { e => if (lambdaReflection.ret == CtClass.booleanType) { // JVM bytecode returns an integer value when the return type is // boolean, hence the cast. - CatalystExpressionBuilder.simplify(Cast(e, BooleanType)) + CatalystExpressionBuilder.simplify(new Cast(e, BooleanType, None)) } else { e } } if (compiled.isEmpty) { - logDebug(s"[CatalystExpressionBuilder] failed to compile") + if (CatalystExpressionBuilder.log.isDebugEnabled) { + CatalystExpressionBuilder.log.debug(s"[CatalystExpressionBuilder] failed to compile") + } } else { val expr = compiled.get val internal = expr.find(_.isInstanceOf[Repr.CompilerInternal]) @@ -96,7 +101,10 @@ case class CatalystExpressionBuilder(private val function: AnyRef) extends Loggi throw new IllegalStateException( s"compiled UDF has compiler internal expression $e: $expr") } - logDebug(s"[CatalystExpressionBuilder] compiled expression: $expr") + if (CatalystExpressionBuilder.log.isDebugEnabled) { + CatalystExpressionBuilder.log.debug( + s"[CatalystExpressionBuilder] compiled expression: $expr") + } } compiled @@ -156,7 +164,9 @@ case class CatalystExpressionBuilder(private val function: AnyRef) extends Loggi // find the state associated with this BB val state: State = states(basicBlock) - logTrace(s"States for basic block ${basicBlock} => ${state}") + if (CatalystExpressionBuilder.log.isTraceEnabled) { + CatalystExpressionBuilder.log.trace(s"States for basic block ${basicBlock} => ${state}") + } /** * Iterate through the instruction table for the BB: @@ -274,7 +284,9 @@ case class CatalystExpressionBuilder(private val function: AnyRef) extends Loggi * simplify a directly translated catalyst expression (from bytecode) into something simpler * that the remaining catalyst optimizations can handle. */ -object CatalystExpressionBuilder extends Logging { +object CatalystExpressionBuilder { + private val log = LoggerFactory.getLogger(classOf[CatalystExpressionBuilder]) + /** simplify: given a raw converted catalyst expression, attempt to match patterns to simplify * before handing it over to catalyst optimizers (the LogicalPlan does this later). * @@ -473,8 +485,8 @@ object CatalystExpressionBuilder extends Logging { ce.child match { case If(c, t, f) => simplifyExpr(If(simplifyExpr(c), - simplifyExpr(Cast(t, BooleanType, ce.timeZoneId)), - simplifyExpr(Cast(f, BooleanType, ce.timeZoneId)))) + simplifyExpr(new Cast(t, BooleanType, ce.timeZoneId)), + simplifyExpr(new Cast(f, BooleanType, ce.timeZoneId)))) } case If(c, Repr.ArrayBuffer(t), Repr.ArrayBuffer(f)) => Repr.ArrayBuffer(If(c, t, f)) case If(c, Repr.StringBuilder(t), Repr.StringBuilder(f)) => Repr.StringBuilder(If(c, t, f)) @@ -483,7 +495,10 @@ object CatalystExpressionBuilder extends Logging { case If(c, t, f) => If(simplifyExpr(c), simplifyExpr(t), simplifyExpr(f)) case _ => expr } - logDebug(s"[CatalystExpressionBuilder] simplify: ${expr} ==> ${res}") + if (CatalystExpressionBuilder.log.isDebugEnabled) { + CatalystExpressionBuilder.log.debug( + s"[CatalystExpressionBuilder] simplify: ${expr} ==> ${res}") + } res } diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/GpuScalaUDF.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/GpuScalaUDF.scala index bee2f73a3bf..7ef6ccff0e4 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/GpuScalaUDF.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/GpuScalaUDF.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,15 +19,15 @@ package com.nvidia.spark.udf import scala.util.control.NonFatal import com.nvidia.spark.rapids.shims.ShimExpression +import org.slf4j.LoggerFactory import org.apache.spark.SparkException -import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.DataType -case class GpuScalaUDFLogical(udf: ScalaUDF) extends ShimExpression with Logging { +case class GpuScalaUDFLogical(udf: ScalaUDF) extends ShimExpression { override def nullable: Boolean = udf.nullable override def eval(input: InternalRow): Any = { @@ -53,15 +53,21 @@ case class GpuScalaUDFLogical(udf: ScalaUDF) extends ShimExpression with Logging } catch { case e: SparkException => val udfName = udf.udfName.getOrElse("") - logDebug(s"UDF $udfName compilation failure: $e") + if (GpuScalaUDFLogical.log.isDebugEnabled) { + GpuScalaUDFLogical.log.debug(s"UDF $udfName compilation failure: $e") + } if (isTestEnabled) { throw e } udf case NonFatal(e) => val udfName = udf.udfName.getOrElse("") - logWarning(s"Unable to translate UDF $udfName: $e") + GpuScalaUDFLogical.log.warn(s"Unable to translate UDF $udfName: $e") udf } } } + +object GpuScalaUDFLogical { + private val log = LoggerFactory.getLogger(classOf[GpuScalaUDFLogical]) +} diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/LogicalPlanRules.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/LogicalPlanRules.scala index e9732e064c7..41373b330cf 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/LogicalPlanRules.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/LogicalPlanRules.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,14 +19,13 @@ package com.nvidia.spark.udf import ai.rapids.cudf.{NvtxColor, NvtxRange} import com.nvidia.spark.rapids.RapidsConf -import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, ScalaUDF} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.rapids.GpuScalaUDF.getRapidsUDFInstance -case class LogicalPlanRules() extends Rule[LogicalPlan] with Logging { +class LogicalPlanRules extends Rule[LogicalPlan] { def replacePartialFunc(plan: LogicalPlan): PartialFunction[Expression, Expression] = { case d: Expression => { val nvtx = new NvtxRange("replace UDF", NvtxColor.BLUE) diff --git a/udf-compiler/src/main/scala/com/nvidia/spark/udf/State.scala b/udf-compiler/src/main/scala/com/nvidia/spark/udf/State.scala index 114469ba019..7d4c76b5365 100644 --- a/udf-compiler/src/main/scala/com/nvidia/spark/udf/State.scala +++ b/udf-compiler/src/main/scala/com/nvidia/spark/udf/State.scala @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -89,8 +89,8 @@ case class State(locals: IndexedSeq[Expression], val commonType = TypeCoercion.findTightestCommonType(l1.dataType, l2.dataType) commonType.fold(throw new SparkException(s"Conditional type check failure")){ t => simplify(If(cond, - if (t == l1.dataType) l1 else Cast(l1, t), - if (t == l2.dataType) l2 else Cast(l2, t))) + if (t == l1.dataType) l1 else new Cast(l1, t, None), + if (t == l2.dataType) l2 else new Cast(l2, t, None))) } } // At the end of the compliation, the expression at the top of stack is