Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import scala.annotation.nowarn
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.GpuOverrides.exec

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Average
Expand Down Expand Up @@ -51,7 +50,7 @@ import org.apache.spark.sql.rapids.shims.TrampolineConnectShims.SparkSession
* Shim base class that can be compiled with every supported 3.2.0+
*/
trait Spark320PlusShims extends SparkShims with RebaseShims
with WindowInPandasShims with Logging {
with WindowInPandasShims {


override final def aqeShuffleReaderExec: ExecRule[_ <: SparkPlan] = exec[AQEShuffleReadExec](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@ package com.nvidia.spark.rapids.shims

import org.apache.parquet.schema.MessageType

import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters

/**
* Shim base class that can be compiled with every supported 3.2.1+
*/
trait Spark321PlusShims extends Spark320PlusShims with RebaseShims with Logging {
trait Spark321PlusShims extends Spark320PlusShims with RebaseShims {
override def getParquetFilters(
schema: MessageType,
pushDownDate: Boolean,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ object GpuWindowUtil {
case GpuLiteral(value, _: DayTimeIntervalType) =>
var x = value.asInstanceOf[Long]
if (x == Long.MinValue) x = Long.MaxValue
ParsedBoundary(isUnbounded = false, RangeBoundaryValue.long(Math.abs(x)))
new ParsedBoundary(isUnbounded = false, RangeBoundaryValue.long(Math.abs(x)))
case anything => throw new UnsupportedOperationException("Unsupported window frame" +
s" expression $anything")
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy

// Keep executable line numbers aligned with the newer shim so binary-dedupe
// can recognize the common module class.
object LegacyBehaviorPolicyShim {
val CORRECTED_STR: String = LegacyBehaviorPolicy.CORRECTED.toString
val EXCEPTION_STR: String = LegacyBehaviorPolicy.EXCEPTION.toString
val CORRECTED_STR: String = "CORRECTED"
val EXCEPTION_STR: String = "EXCEPTION"

def isLegacyTimeParserPolicy(): Boolean = {
SQLConf.get.legacyTimeParserPolicy == LegacyBehaviorPolicy.LEGACY
SQLConf.get.legacyTimeParserPolicy.toString == "LEGACY"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,27 @@ package com.nvidia.spark.rapids.shims

import org.apache.spark.sql.catalyst.expressions.NullIntolerant

trait NullIntolerantShim extends NullIntolerant
trait NullIntolerantShim extends NullIntolerant {
def nullIntolerant: Boolean = true
}

abstract class GpuLiteralShim extends com.nvidia.spark.rapids.GpuLeafExpression {
def value: Any
def dataType: org.apache.spark.sql.types.DataType

override protected def jsonFields: List[org.json4s.JsonAST.JField] = {
val jsonValue = (value, dataType) match {
case (null, _) => org.json4s.JsonAST.JNull
case (i: Int, org.apache.spark.sql.types.DateType) =>
org.json4s.JsonAST.JString(
org.apache.spark.sql.catalyst.util.DateTimeUtils.toJavaDate(i).toString)
case (l: Long, org.apache.spark.sql.types.TimestampType) =>
org.json4s.JsonAST.JString(
org.apache.spark.sql.catalyst.util.DateTimeUtils.toJavaTimestamp(l).toString)
case (other, _) => org.json4s.JsonAST.JString(other.toString)
}
("value" -> jsonValue) ::
("dataType" -> org.apache.spark.sql.rapids.execution.TrampolineUtil.jsonValue(dataType)
.asInstanceOf[org.json4s.JsonAST.JValue]) :: Nil
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,13 @@
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

import org.apache.commons.io.output.NullOutputStream
import java.io.OutputStream

// Keep executable line numbers aligned with newer shims for binary-dedupe.

object NullOutputStreamShim {
def INSTANCE = NullOutputStream.NULL_OUTPUT_STREAM
val INSTANCE: OutputStream = new OutputStream {
override def write(b: Int): Unit = {}
override def write(b: Array[Byte], off: Int, len: Int): Unit = {}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,85 @@
{"spark": "333"}
{"spark": "334"}
spark-rapids-shim-json-lines ***/
// Keep executable line numbers aligned with newer shims for binary-dedupe.














package com.nvidia.spark.rapids.shims

import com.google.protobuf.{AbstractMessage, CodedOutputStream}
import java.io.OutputStream
import java.lang.reflect.Method

import org.apache.orc.impl.OutStream

class OrcProtoWriterShim(orcOutStream: OutStream) {
val proxied = CodedOutputStream.newInstance(orcOutStream)
def writeAndFlush(obj: Any): Unit = obj match {
case m: AbstractMessage =>
m.writeTo(proxied)
proxied.flush()
orcOutStream.flush()
case _ =>
require(obj.isInstanceOf[AbstractMessage],
s"Unexpected protobuf message type: $obj")
import OrcProtoWriterShim.ProtoApi

private[this] var proxiedApi: ProtoApi = _
private[this] var proxied: AnyRef = _

private def proxiedFor(api: ProtoApi): AnyRef = {
if (proxiedApi != api) {
proxiedApi = api
proxied = api.newInstance.invoke(null, orcOutStream.asInstanceOf[OutputStream])
}
proxied
}

def writeAndFlush(obj: Any): Unit = {
val api = OrcProtoWriterShim.apiFor(obj).getOrElse {
throw new IllegalArgumentException(
s"requirement failed: Unexpected protobuf message type: $obj")
}
val currentProxied = proxiedFor(api)
api.writeTo.invoke(obj.asInstanceOf[AnyRef], currentProxied)
api.flush.invoke(currentProxied)
orcOutStream.flush()
}
}

object OrcProtoWriterShim {
private case class ProtoApi(
messageClass: Class[_],
newInstance: Method,
writeTo: Method,
flush: Method)

private val protoClassNames = Seq(
("org.apache.orc.protobuf.AbstractMessage",
"org.apache.orc.protobuf.CodedOutputStream"),
("com.google.protobuf.AbstractMessage",
"com.google.protobuf.CodedOutputStream"))

private lazy val protoApis: Seq[ProtoApi] = protoClassNames.flatMap { case (msg, out) =>
try {
val messageClass = Class.forName(msg)
val codedOutputStreamClass = Class.forName(out)
Some(ProtoApi(
messageClass,
codedOutputStreamClass.getMethod("newInstance", classOf[OutputStream]),
messageClass.getMethod("writeTo", codedOutputStreamClass),
codedOutputStreamClass.getMethod("flush")))
} catch {
case _: ReflectiveOperationException => None
}
}

private def apiFor(obj: Any): Option[ProtoApi] = {
protoApis.find(_.messageClass.isInstance(obj))
}

def apply(orcOutStream: OutStream) = {
new OrcProtoWriterShim(orcOutStream)
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,19 @@ package com.nvidia.spark.rapids.shims

import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, REBALANCE_PARTITIONS_BY_COL, REBALANCE_PARTITIONS_BY_NONE, REPARTITION_BY_COL, REPARTITION_BY_NUM, ShuffleOrigin}

// Keep executable line numbers aligned with newer shims for binary-dedupe.












object ShuffleOriginUtil {
private val knownOrigins: Set[ShuffleOrigin] = Set(ENSURE_REQUIREMENTS,
REPARTITION_BY_COL, REPARTITION_BY_NUM, REBALANCE_PARTITIONS_BY_NONE,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -21,14 +21,15 @@ package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids._

import org.apache.spark.sql.execution.command.{CreateDataSourceTableAsSelectCommand, DataWritingCommand, RunnableCommand}
import org.apache.spark.sql.execution.command.{DataWritingCommand, RunnableCommand}

object SparkShimImpl extends Spark330PlusShims with AnsiCastRuleShims {
override def getDataWriteCmds: Map[Class[_ <: DataWritingCommand],
DataWritingCommandRule[_ <: DataWritingCommand]] = {
Seq(GpuOverrides.dataWriteCmd[CreateDataSourceTableAsSelectCommand](
"Create table with select command",
(a, conf, p, r) => new CreateDataSourceTableAsSelectCommandMeta(a, conf, p, r))
Seq(
GpuOverrides.dataWriteCmdFromShim(
CreateDataSourceTableAsSelectRules.dataWriteCmd,
(a, conf, p, r) => new CreateDataSourceTableAsSelectCommandMeta(a, conf, p, r))
).map(r => (r.getClassFor.asSubclass(classOf[DataWritingCommand]), r)).toMap
}

Expand Down

This file was deleted.

Loading
Loading