From 9b0d29b2208a769a6f73a79d87e49c8c0fe895d8 Mon Sep 17 00:00:00 2001 From: Gera Shegalov Date: Wed, 10 Jun 2026 08:26:38 -0700 Subject: [PATCH 1/2] Move ORC timezone helper to Java Signed-off-by: Gera Shegalov --- .../spark/rapids/GpuOrcTimezoneUtils.java | 163 ++++++++++++++++++ .../spark/rapids/GpuOrcTimezoneUtils.scala | 151 ---------------- 2 files changed, 163 insertions(+), 151 deletions(-) create mode 100644 sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/GpuOrcTimezoneUtils.java delete mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcTimezoneUtils.scala diff --git a/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/GpuOrcTimezoneUtils.java b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/GpuOrcTimezoneUtils.java new file mode 100644 index 00000000000..34cc1ac221f --- /dev/null +++ b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/GpuOrcTimezoneUtils.java @@ -0,0 +1,163 @@ +/* + * Copyright (c) 2025-2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids; + +import java.time.LocalDateTime; +import java.time.ZoneId; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.ColumnView; +import ai.rapids.cudf.DType; +import ai.rapids.cudf.Scalar; +import ai.rapids.cudf.Table; + +public final class GpuOrcTimezoneUtils { + private static final ZoneId UTC = ZoneId.of("UTC"); + + private GpuOrcTimezoneUtils() { + } + + /** + * Get the offset in microseconds for 2015-01-01 between JVM timezone and UTC timezone. + * + * @param jvmTz the JVM timezone to calculate the offset for + * @return the offset in microseconds between the JVM timezone and UTC timezone + */ + private static long getOffsetForJanuaryFirst2015(ZoneId jvmTz) { + long t1 = LocalDateTime.of(2015, 1, 1, 0, 0, 0).atZone(jvmTz).toInstant() + .getEpochSecond(); + long t2 = LocalDateTime.of(2015, 1, 1, 0, 0, 0).atZone(UTC).toInstant() + .getEpochSecond(); + return (t2 - t1) * 1000000L; + } + + private static T addToClose(List toClose, T view) { + toClose.add(view); + return view; + } + + /** + * Recursively rebase timestamp columns in an input column view to the target timezone. + * This handles nested list and struct types. + */ + private static ColumnView rebaseTimestampRecursively( + ColumnView col, + List toClose, + long diffMicros) { + DType dType = col.getType(); + if (dType.hasTimeResolution()) { + assert dType.equals(DType.TIMESTAMP_MICROSECONDS) : + "Only TIMESTAMP_MICROSECONDS is supported, but got " + dType; + + try (ColumnView longs = col.bitCastTo(DType.INT64); + Scalar offsetScalar = Scalar.fromLong(diffMicros); + ColumnVector rebased = longs.sub(offsetScalar)) { + return rebased.castTo(DType.TIMESTAMP_MICROSECONDS); + } + } else if (DType.LIST.equals(dType)) { + ColumnView child = addToClose(toClose, col.getChildColumnView(0)); + ColumnView newChild = rebaseTimestampRecursively(child, toClose, diffMicros); + if (newChild != child) { + return col.replaceListChild(addToClose(toClose, newChild)); + } + return col; + } else if (DType.STRUCT.equals(dType)) { + ColumnView[] newViews = new ColumnView[col.getNumChildren()]; + for (int i = 0; i < newViews.length; i++) { + ColumnView child = addToClose(toClose, col.getChildColumnView(i)); + ColumnView newChild = rebaseTimestampRecursively(child, toClose, diffMicros); + if (newChild != child) { + addToClose(toClose, newChild); + } + newViews[i] = newChild; + } + return new ColumnView(col.getType(), col.getRowCount(), Optional.of(col.getNullCount()), + col.getValid(), col.getOffsets(), newViews); + } + return col; + } + + /** + * Rebase timestamp columns in the input table to the system default timezone. If the system's + * default timezone is UTC, this returns the input table as-is. Otherwise the input table is + * closed before returning. + * + * @param input the input table + * @return a table with timestamp columns rebased + */ + public static Table rebaseTimeZone(Table input) { + ZoneId toZoneId = ZoneId.systemDefault(); + + if (UTC.equals(toZoneId)) { + return input; + } + + long diffMicros = getOffsetForJanuaryFirst2015(toZoneId); + try (Table ignored = input) { + ColumnVector[] newColumns = new ColumnVector[input.getNumberOfColumns()]; + try { + for (int colIdx = 0; colIdx < newColumns.length; colIdx++) { + ColumnVector col = input.getColumn(colIdx); + List toClose = new ArrayList<>(); + try { + ColumnView rebased = rebaseTimestampRecursively(col, toClose, diffMicros); + if (col == rebased) { + newColumns[colIdx] = col.incRefCount(); + } else { + toClose.add(rebased); + newColumns[colIdx] = rebased.copyToColumnVector(); + } + } finally { + closeAll(toClose); + } + } + return new Table(newColumns); + } finally { + closeAll(newColumns); + } + } + } + + private static void closeAll(ColumnView[] views) { + for (ColumnView view : views) { + if (view != null) { + view.close(); + } + } + } + + private static void closeAll(List views) { + RuntimeException firstException = null; + for (ColumnView view : views) { + try { + view.close(); + } catch (RuntimeException e) { + if (firstException == null) { + firstException = e; + } else { + firstException.addSuppressed(e); + } + } + } + if (firstException != null) { + throw firstException; + } + } +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcTimezoneUtils.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcTimezoneUtils.scala deleted file mode 100644 index 9ef9d5f6ebb..00000000000 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOrcTimezoneUtils.scala +++ /dev/null @@ -1,151 +0,0 @@ -/* - * Copyright (c) 2025, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids - -import ai.rapids.cudf.{ColumnView, DType, Scalar, Table} -import com.nvidia.spark.rapids.Arm.withResource -import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingSeq -import java.time.{LocalDateTime, ZoneId} -import java.util.Optional -import scala.collection.mutable.ArrayBuffer - -object GpuOrcTimezoneUtils { - - /** - * Get the offset in microseconds for 2025-01-01 between JVM timezone and UTC timezone. - * @param jvmTz the JVM timezone to calculate the offset for - * @return the offset in microseconds - * between the JVM timezone and UTC timezone for 2025-01-01 - * This is used to rebase the timestamp columns in the input table. - */ - private def getOffsetForJanuaryFirst2015(jvmTz: ZoneId): Long = { - val t1 = LocalDateTime.of(2015, 1, 1, 0, 0, 0).atZone(jvmTz).toInstant.getEpochSecond - val t2 = LocalDateTime.of(2015, 1, 1, 0, 0, 0).atZone(ZoneId.of("UTC")).toInstant.getEpochSecond - val diffMicros: Long = (t2 - t1) * 1000000L // convert seconds to microseconds - diffMicros - } - - /** - * Recursively rebase the timestamp columns in the input column view to the target timezone. - * It handles nested types: list and struct. - * The rebase logic is simple: just subtract the offset in microseconds between the - * target timezone and UTC timezone. - * For more details about the rebase logic, please refer to: - * https://github.com/apache/orc/blob/rel/release-1.9.1/ - * java/core/src/java/org/apache/orc/impl/TreeReaderFactory.java#L1157 - * `TimestampTreeReader.getBaseTimestamp` generates the base timestamp with JVM default timezone. - * `threadLocalDateFormat.get().setTimeZone(writerTimeZone);` - * The above writerTimeZone is not the timezone in the ORC file stripe footer, - * it is the default JVM timezone. - * `TimestampTreeReader.readTimestamp` applies the diff: - * `long millis = (data.next() + base_timestamp)` - * Note: the input timestamp columns are read as in the UTC timezone. - * - */ - private def rebaseTimestampRecursively( - col: ColumnView, - toZoneId: ZoneId, - toClose: ArrayBuffer[ColumnView], - diffMicros: Long): ColumnView = { - - // Util function to add a view to the buffer "toClose". - val addToClose = (v: ColumnView) => { - toClose += v - v - } - - val dType = col.getType - if (dType.hasTimeResolution) { - assert(dType == DType.TIMESTAMP_MICROSECONDS, - s"Only TIMESTAMP_MICROSECONDS is supported, but got $dType") - - // 1. timestamp type, rebase timestamp column - withResource(col.bitCastTo(DType.INT64)) { longs => - withResource(Scalar.fromLong(diffMicros)) { offsetScalar => - withResource(longs.sub(offsetScalar)) { rebased => - rebased.castTo(DType.TIMESTAMP_MICROSECONDS) - } - } - } - } else if (dType == DType.LIST) { - // 2. nest list type - val child = addToClose(col.getChildColumnView(0)) - val newChild = rebaseTimestampRecursively(child, toZoneId, toClose, diffMicros) - if (newChild != child) { - col.replaceListChild(addToClose(newChild)) - } else { - col - } - } else if (dType == DType.STRUCT) { - // 3. nest struct type - val newViews = (0 until col.getNumChildren).safeMap { i => - val child = addToClose(col.getChildColumnView(i)) - val newChild = rebaseTimestampRecursively(child, toZoneId, toClose, diffMicros) - if (newChild != child) { - addToClose(newChild) - } - newChild - } - val opNullCount = Optional.of(col.getNullCount.asInstanceOf[java.lang.Long]) - new ColumnView(col.getType, col.getRowCount, opNullCount, col.getValid, - col.getOffsets, newViews.toArray) - } else { - // 4. other types, no need to rebase - col - } - } - - /** - * Rebase the timestamp columns in the input table to the system default timezone. - * If the system's default timezone is UTC, it returns the input table as it is. - * - * @param input the input table, it will be closed after returning - * @return a new table with rebased timestamp columns - */ - def rebaseTimeZone(input: Table): Table = { - val toZoneId = ZoneId.systemDefault() - - if (toZoneId == ZoneId.of("UTC")) { - // UTC timezone, no need to rebase - return input - } - - // get the offset in microseconds for 2015-01-01 between JVM timezone and UTC timezone - val diffMicros = getOffsetForJanuaryFirst2015(toZoneId) - - withResource(input) { _ => - val newColumns = (0 until input.getNumberOfColumns).safeMap { colIdx => - val col = input.getColumn(colIdx) - withResource(new ArrayBuffer[ColumnView]) { toClose => - val rebased = rebaseTimestampRecursively(col, toZoneId, toClose, diffMicros) - if (col == rebased) { - // no change - col.incRefCount() - } else { - // rebased, copy the new column - toClose += rebased - rebased.copyToColumnVector() - } - } - } - - withResource(newColumns) { _ => - new Table(newColumns: _*) - } - } - } -} From 66c2b1bd421c0eb5ca73597609b84760fc16147f Mon Sep 17 00:00:00 2001 From: Gera Shegalov Date: Wed, 10 Jun 2026 08:27:01 -0700 Subject: [PATCH 2/2] Move columnar runtime config helpers to Java Signed-off-by: Gera Shegalov --- .../spark/source/GpuReaderFactory.scala | 2 +- .../spark/rapids/AutoCloseableTargetSize.java | 78 ++++++++++++++++++ .../com/nvidia/spark/rapids/CombineConf.java | 63 +++++++++++++++ .../spark/rapids/DefaultThreadPoolConf.java | 64 +++++++++++++++ .../spark/rapids/DeviceBuffersUtils.java | 78 ++++++++++++++++++ .../nvidia/spark/rapids/ExecutorCache.java | 55 +++++++++++++ .../nvidia/spark/rapids/HostAllocResult.java | 61 ++++++++++++++ .../spark/rapids/MemoryBoundedPoolConf.java | 80 +++++++++++++++++++ .../nvidia/spark/rapids/ThreadPoolConf.java | 31 +++++++ .../nvidia/spark/rapids/ExecutorCache.scala | 51 ------------ .../nvidia/spark/rapids/WithRetrySuite.scala | 8 +- 11 files changed, 515 insertions(+), 56 deletions(-) create mode 100644 sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/AutoCloseableTargetSize.java create mode 100644 sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/CombineConf.java create mode 100644 sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/DefaultThreadPoolConf.java create mode 100644 sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/DeviceBuffersUtils.java create mode 100644 sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/ExecutorCache.java create mode 100644 sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/HostAllocResult.java create mode 100644 sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/MemoryBoundedPoolConf.java create mode 100644 sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/ThreadPoolConf.java delete mode 100644 sql-plugin/src/main/scala/com/nvidia/spark/rapids/ExecutorCache.scala diff --git a/iceberg/common/src/main/scala/org/apache/iceberg/spark/source/GpuReaderFactory.scala b/iceberg/common/src/main/scala/org/apache/iceberg/spark/source/GpuReaderFactory.scala index 0efbee1da56..4e9b2a6ec6d 100644 --- a/iceberg/common/src/main/scala/org/apache/iceberg/spark/source/GpuReaderFactory.scala +++ b/iceberg/common/src/main/scala/org/apache/iceberg/spark/source/GpuReaderFactory.scala @@ -108,7 +108,7 @@ class GpuReaderFactory(private val metrics: Map[String, GpuMetric], queryUsesInputFile || hasFilePathMetadata || hasRowPositionMetadata || !hasNoDeletes MultiThread(poolConfBuilder, partition.maxNumParquetFilesParallel, - CombineConf(combineThresholdSize, combineWaitTime), + new CombineConf(combineThresholdSize, combineWaitTime), disableCombining, hasFilePathMetadata, hasRowPositionMetadata) diff --git a/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/AutoCloseableTargetSize.java b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/AutoCloseableTargetSize.java new file mode 100644 index 00000000000..db294308c4f --- /dev/null +++ b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/AutoCloseableTargetSize.java @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2024-2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids; + +import java.io.Serializable; +import java.util.Objects; + +public class AutoCloseableTargetSize implements AutoCloseable, Serializable { + private static final long serialVersionUID = 1L; + + public final long targetSize; + public final long minSize; + public final long dataSize; + + public AutoCloseableTargetSize(long targetSize, long minSize) { + this(targetSize, minSize, 0); + } + + public AutoCloseableTargetSize(long targetSize, long minSize, long dataSize) { + this.targetSize = targetSize; + this.minSize = minSize; + this.dataSize = dataSize; + } + + public long targetSize() { + return targetSize; + } + + public long minSize() { + return minSize; + } + + public long dataSize() { + return dataSize; + } + + @Override + public void close() { + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof AutoCloseableTargetSize)) { + return false; + } + AutoCloseableTargetSize that = (AutoCloseableTargetSize) other; + return targetSize == that.targetSize && + minSize == that.minSize && + dataSize == that.dataSize; + } + + @Override + public int hashCode() { + return Objects.hash(targetSize, minSize, dataSize); + } + + @Override + public String toString() { + return "AutoCloseableTargetSize(" + targetSize + "," + minSize + "," + dataSize + ")"; + } +} diff --git a/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/CombineConf.java b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/CombineConf.java new file mode 100644 index 00000000000..1bb19bc9261 --- /dev/null +++ b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/CombineConf.java @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2024-2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids; + +import java.io.Serializable; +import java.util.Objects; + +public class CombineConf implements Serializable { + private static final long serialVersionUID = 1L; + + private final long combineThresholdSize; + private final int combineWaitTime; + + public CombineConf(long combineThresholdSize, int combineWaitTime) { + this.combineThresholdSize = combineThresholdSize; + this.combineWaitTime = combineWaitTime; + } + + public long combineThresholdSize() { + return combineThresholdSize; + } + + public int combineWaitTime() { + return combineWaitTime; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof CombineConf)) { + return false; + } + CombineConf that = (CombineConf) other; + return combineThresholdSize == that.combineThresholdSize && + combineWaitTime == that.combineWaitTime; + } + + @Override + public int hashCode() { + return Objects.hash(combineThresholdSize, combineWaitTime); + } + + @Override + public String toString() { + return "CombineConf(" + combineThresholdSize + "," + combineWaitTime + ")"; + } +} diff --git a/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/DefaultThreadPoolConf.java b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/DefaultThreadPoolConf.java new file mode 100644 index 00000000000..c42edeb8405 --- /dev/null +++ b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/DefaultThreadPoolConf.java @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2024-2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids; + +import java.util.Objects; + +public class DefaultThreadPoolConf implements ThreadPoolConf { + private static final long serialVersionUID = 1L; + + private final int maxThreadNumber; + private final boolean stageLevelPool; + + public DefaultThreadPoolConf(int maxThreadNumber, boolean stageLevelPool) { + this.maxThreadNumber = maxThreadNumber; + this.stageLevelPool = stageLevelPool; + } + + @Override + public int maxThreadNumber() { + return maxThreadNumber; + } + + @Override + public boolean stageLevelPool() { + return stageLevelPool; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof DefaultThreadPoolConf)) { + return false; + } + DefaultThreadPoolConf that = (DefaultThreadPoolConf) other; + return maxThreadNumber == that.maxThreadNumber && + stageLevelPool == that.stageLevelPool; + } + + @Override + public int hashCode() { + return Objects.hash(maxThreadNumber, stageLevelPool); + } + + @Override + public String toString() { + return "DefaultThreadPoolConf(" + maxThreadNumber + "," + stageLevelPool + ")"; + } +} diff --git a/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/DeviceBuffersUtils.java b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/DeviceBuffersUtils.java new file mode 100644 index 00000000000..0e855219887 --- /dev/null +++ b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/DeviceBuffersUtils.java @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2024-2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids; + +import ai.rapids.cudf.BaseDeviceMemoryBuffer; +import ai.rapids.cudf.DeviceMemoryBuffer; + +public final class DeviceBuffersUtils { + private DeviceBuffersUtils() {} + + public static BaseDeviceMemoryBuffer[] incRefCount(BaseDeviceMemoryBuffer[] bufs) { + BaseDeviceMemoryBuffer[] ret = new BaseDeviceMemoryBuffer[bufs.length]; + int initialized = 0; + try { + for (BaseDeviceMemoryBuffer buf : bufs) { + buf.incRefCount(); + ret[initialized] = buf; + initialized++; + } + return ret; + } catch (Throwable t) { + closeAll(ret, initialized, t); + throw t; + } + } + + public static DeviceMemoryBuffer[] allocateBuffers(long[] bufSizes) { + DeviceMemoryBuffer[] ret = new DeviceMemoryBuffer[bufSizes.length]; + int initialized = 0; + try (DeviceMemoryBuffer singleBuf = DeviceMemoryBuffer.allocate(sum(bufSizes))) { + long curPos = 0L; + for (long len : bufSizes) { + ret[initialized] = singleBuf.slice(curPos, len); + initialized++; + curPos += len; + } + return ret; + } catch (Throwable t) { + closeAll(ret, initialized, t); + throw t; + } + } + + private static long sum(long[] values) { + long ret = 0L; + for (long value : values) { + ret += value; + } + return ret; + } + + private static void closeAll(AutoCloseable[] values, int count, Throwable cause) { + for (int i = 0; i < count; i++) { + AutoCloseable value = values[i]; + if (value != null) { + try { + value.close(); + } catch (Throwable t) { + cause.addSuppressed(t); + } + } + } + } +} diff --git a/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/ExecutorCache.java b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/ExecutorCache.java new file mode 100644 index 00000000000..d2dba3fef99 --- /dev/null +++ b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/ExecutorCache.java @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2025-2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids; + +import java.lang.management.ManagementFactory; + +import ai.rapids.cudf.Cuda; +import ai.rapids.cudf.CudaComputeMode; + +/** + * Caches executor-related information. Values are initialized lazily to match the previous Scala + * object semantics. + */ +final class ExecutorCache { + private ExecutorCache() { + } + + static CudaComputeMode getCurrentDeviceComputeMode() { + return CurrentDeviceComputeModeHolder.VALUE; + } + + static byte[] getCurrentDeviceUuid() { + return CurrentDeviceUuidHolder.VALUE; + } + + static String getProcessName() { + return ProcessNameHolder.VALUE; + } + + private static final class CurrentDeviceComputeModeHolder { + private static final CudaComputeMode VALUE = Cuda.getComputeMode(); + } + + private static final class CurrentDeviceUuidHolder { + private static final byte[] VALUE = Cuda.getGpuUuid(); + } + + private static final class ProcessNameHolder { + private static final String VALUE = ManagementFactory.getRuntimeMXBean().getName(); + } +} diff --git a/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/HostAllocResult.java b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/HostAllocResult.java new file mode 100644 index 00000000000..8cb8cf9f0a0 --- /dev/null +++ b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/HostAllocResult.java @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2023-2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids; + +import java.util.Objects; + +import ai.rapids.cudf.HostMemoryBuffer; + +public class HostAllocResult { + public final HostMemoryBuffer buffer; + public final boolean isPinned; + + public HostAllocResult(HostMemoryBuffer buffer, boolean isPinned) { + this.buffer = buffer; + this.isPinned = isPinned; + } + + public HostMemoryBuffer buffer() { + return buffer; + } + + public boolean isPinned() { + return isPinned; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof HostAllocResult)) { + return false; + } + HostAllocResult that = (HostAllocResult) other; + return isPinned == that.isPinned && Objects.equals(buffer, that.buffer); + } + + @Override + public int hashCode() { + return Objects.hash(buffer, isPinned); + } + + @Override + public String toString() { + return "HostAllocResult(" + buffer + "," + isPinned + ")"; + } +} diff --git a/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/MemoryBoundedPoolConf.java b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/MemoryBoundedPoolConf.java new file mode 100644 index 00000000000..781d7949132 --- /dev/null +++ b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/MemoryBoundedPoolConf.java @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2024-2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids; + +import java.util.Objects; + +public class MemoryBoundedPoolConf implements ThreadPoolConf { + private static final long serialVersionUID = 1L; + + private final int maxThreadNumber; + private final boolean stageLevelPool; + private final long memoryCapacity; + private final long waitMemTimeoutMs; + + public MemoryBoundedPoolConf(int maxThreadNumber, boolean stageLevelPool, + long memoryCapacity, long waitMemTimeoutMs) { + this.maxThreadNumber = maxThreadNumber; + this.stageLevelPool = stageLevelPool; + this.memoryCapacity = memoryCapacity; + this.waitMemTimeoutMs = waitMemTimeoutMs; + } + + @Override + public int maxThreadNumber() { + return maxThreadNumber; + } + + @Override + public boolean stageLevelPool() { + return stageLevelPool; + } + + public long memoryCapacity() { + return memoryCapacity; + } + + public long waitMemTimeoutMs() { + return waitMemTimeoutMs; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof MemoryBoundedPoolConf)) { + return false; + } + MemoryBoundedPoolConf that = (MemoryBoundedPoolConf) other; + return maxThreadNumber == that.maxThreadNumber && + stageLevelPool == that.stageLevelPool && + memoryCapacity == that.memoryCapacity && + waitMemTimeoutMs == that.waitMemTimeoutMs; + } + + @Override + public int hashCode() { + return Objects.hash(maxThreadNumber, stageLevelPool, memoryCapacity, waitMemTimeoutMs); + } + + @Override + public String toString() { + return "MemoryBoundedPoolConf(" + maxThreadNumber + "," + stageLevelPool + "," + + memoryCapacity + "," + waitMemTimeoutMs + ")"; + } +} diff --git a/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/ThreadPoolConf.java b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/ThreadPoolConf.java new file mode 100644 index 00000000000..b2bed218f95 --- /dev/null +++ b/sql-plugin-columnar/src/main/java/com/nvidia/spark/rapids/ThreadPoolConf.java @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2024-2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids; + +import java.io.Serializable; + +public interface ThreadPoolConf extends Serializable { + /** + * The maximum number of threads used by the thread pool, not necessarily the final number. + */ + int maxThreadNumber(); + + /** + * Whether to create pools for each Spark stage, only for testing for now. + */ + boolean stageLevelPool(); +} diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ExecutorCache.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ExecutorCache.scala deleted file mode 100644 index 4190d52c9ce..00000000000 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/ExecutorCache.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2025, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids - -import java.lang.management.ManagementFactory - -import ai.rapids.cudf.{Cuda, CudaComputeMode} - -/** - * A singleton object to cache executor related information. - * Uses lazy mode to ensure the values are only computed once per executor. - */ -object ExecutorCache { - - /** - * Cache the current device compute mode for current executor. - * It's based on the assumption that executor has been assigned to a single device, - * and will not change during the lifetime of the executor. - * Should be called on executor side. - */ - private[rapids] lazy val getCurrentDeviceComputeMode: CudaComputeMode = Cuda.getComputeMode() - - /** - * Cache the current device UUID for current executor. - * It's based on the assumption that executor has been assigned to a single device, - * and will not change during the lifetime of the executor. - * Should be called on executor side. - */ - private[rapids] lazy val getCurrentDeviceUuid: Array[Byte] = Cuda.getGpuUuid() - - /** - * Cache the current process name for current executor. - * Should be called on executor side. - */ - private[rapids] lazy val getProcessName: String = ManagementFactory.getRuntimeMXBean.getName -} - diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/WithRetrySuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/WithRetrySuite.scala index 5346dbb3754..3bdb06086b3 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/WithRetrySuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/WithRetrySuite.scala @@ -250,7 +250,7 @@ class WithRetrySuite val numSplits = 2 var doThrow = numSplits var lastSplitSize = 0L - val myTarget = AutoCloseableTargetSize(initialValue, minValue) + val myTarget = new AutoCloseableTargetSize(initialValue, minValue) try { withRetry(myTarget, splitTargetSizeInHalfGpu) { attempt => lastSplitSize = attempt.targetSize @@ -274,7 +274,7 @@ class WithRetrySuite val dataSize = 200L // less than targetSize/2=500, so halving targetSize is a no-op var doThrow = true var splitTargetUsed = 0L - val myTarget = AutoCloseableTargetSize(targetSize, minSize, dataSize) + val myTarget = new AutoCloseableTargetSize(targetSize, minSize, dataSize) try { withRetry(myTarget, splitTargetSizeInHalfGpu) { attempt => splitTargetUsed = attempt.targetSize @@ -300,7 +300,7 @@ class WithRetrySuite val childDataSize = 2L // actual bytes in the smaller child; less than targetSize/2=50 var doThrow = true var splitTargetUsed = 0L - val myTarget = AutoCloseableTargetSize(targetSize, minSize, childDataSize) + val myTarget = new AutoCloseableTargetSize(targetSize, minSize, childDataSize) try { withRetry(myTarget, splitTargetSizeInHalfGpu) { attempt => splitTargetUsed = attempt.targetSize @@ -321,7 +321,7 @@ class WithRetrySuite val numSplits = 3 var doThrow = numSplits var lastSplitSize = 0L - val myTarget = AutoCloseableTargetSize(initialValue, minValue) + val myTarget = new AutoCloseableTargetSize(initialValue, minValue) try { assertThrows[GpuSplitAndRetryOOM] { withRetry(myTarget, splitTargetSizeInHalfGpu) { attempt =>