From c0c834885c29fe7ac642f026c17df347acc11529 Mon Sep 17 00:00:00 2001 From: Niranjan Artal Date: Thu, 25 Jun 2026 14:13:32 -0700 Subject: [PATCH 1/5] Add JNI Variant extraction support Add native and Java bindings for cuDF Variant field extraction. Signed-off-by: Niranjan Artal --- src/main/cpp/CMakeLists.txt | 1 + src/main/cpp/src/VariantUtilsJni.cpp | 89 +++++++ .../nvidia/spark/rapids/jni/VariantUtils.java | 100 ++++++++ .../spark/rapids/jni/VariantUtilsTest.java | 220 ++++++++++++++++++ 4 files changed, 410 insertions(+) create mode 100644 src/main/cpp/src/VariantUtilsJni.cpp create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/VariantUtils.java create mode 100644 src/test/java/com/nvidia/spark/rapids/jni/VariantUtilsTest.java diff --git a/src/main/cpp/CMakeLists.txt b/src/main/cpp/CMakeLists.txt index d2459d690d..8c5e281bc3 100644 --- a/src/main/cpp/CMakeLists.txt +++ b/src/main/cpp/CMakeLists.txt @@ -230,6 +230,7 @@ add_library( src/StringUtilsJni.cpp src/SubStringIndexJni.cpp src/TaskPriorityJni.cpp + src/VariantUtilsJni.cpp src/ZOrderJni.cpp src/iceberg/IcebergBucketJni.cpp src/iceberg/IcebergDateTimeUtilJni.cpp diff --git a/src/main/cpp/src/VariantUtilsJni.cpp b/src/main/cpp/src/VariantUtilsJni.cpp new file mode 100644 index 0000000000..2a61fb3a5d --- /dev/null +++ b/src/main/cpp/src/VariantUtilsJni.cpp @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cudf_jni_apis.hpp" +#include "jni_utils.hpp" + +#include +#include +#include +#include + +extern "C" { + +JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_VariantUtils_getVariantFieldValue( + JNIEnv* env, jclass, jlong variant_struct_handle, jstring j_path) +{ + JNI_NULL_CHECK(env, variant_struct_handle, "variant struct column is null", 0); + JNI_NULL_CHECK(env, j_path, "path is null", 0); + JNI_TRY + { + cudf::jni::auto_set_device(env); + auto const& variant_struct = *reinterpret_cast(variant_struct_handle); + cudf::jni::native_jstring path(env, j_path); + return cudf::jni::release_as_jlong( + cudf::io::parquet::experimental::get_variant_field(variant_struct, + path.get(), + cudf::get_default_stream(), + cudf::get_current_device_resource_ref())); + } + JNI_CATCH(env, 0); +} + +JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_VariantUtils_castVariantValue( + JNIEnv* env, jclass, jlong value_bytes_handle, jint cudf_type_id) +{ + JNI_NULL_CHECK(env, value_bytes_handle, "value bytes column is null", 0); + JNI_TRY + { + cudf::jni::auto_set_device(env); + auto const& value_bytes = *reinterpret_cast(value_bytes_handle); + return cudf::jni::release_as_jlong(cudf::io::parquet::experimental::cast_variant( + value_bytes, + cudf::data_type{static_cast(cudf_type_id)}, + cudf::get_default_stream(), + cudf::get_current_device_resource_ref())); + } + JNI_CATCH(env, 0); +} + +JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_VariantUtils_extractVariantField( + JNIEnv* env, jclass, jlong variant_struct_handle, jstring j_path, jint cudf_type_id) +{ + JNI_NULL_CHECK(env, variant_struct_handle, "variant struct column is null", 0); + JNI_NULL_CHECK(env, j_path, "path is null", 0); + JNI_TRY + { + cudf::jni::auto_set_device(env); + auto const& variant_struct = *reinterpret_cast(variant_struct_handle); + cudf::jni::native_jstring path(env, j_path); + return cudf::jni::release_as_jlong(cudf::io::parquet::experimental::extract_variant_field( + variant_struct, + path.get(), + cudf::data_type{static_cast(cudf_type_id)}, + cudf::get_default_stream(), + cudf::get_current_device_resource_ref())); + } + JNI_CATCH(env, 0); +} + +JNIEXPORT jboolean JNICALL Java_com_nvidia_spark_rapids_jni_VariantUtils_isAvailableNative(JNIEnv*, + jclass) +{ + return JNI_TRUE; +} + +} // extern "C" diff --git a/src/main/java/com/nvidia/spark/rapids/jni/VariantUtils.java b/src/main/java/com/nvidia/spark/rapids/jni/VariantUtils.java new file mode 100644 index 0000000000..7960fee982 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/VariantUtils.java @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni; + +import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.ColumnView; +import ai.rapids.cudf.DType; +import ai.rapids.cudf.NativeDepsLoader; + +import java.util.Objects; + +/** + * JNI bridge to cuDF's experimental Parquet Variant field extraction APIs. + */ +public class VariantUtils { + static { + NativeDepsLoader.loadNativeDeps(); + } + + private VariantUtils() {} + + private static void validateTargetType(DType targetType) { + Objects.requireNonNull(targetType, "targetType"); + if (targetType != DType.STRING && targetType != DType.INT8 && targetType != DType.INT16 && + targetType != DType.INT32 && targetType != DType.INT64) { + throw new IllegalArgumentException("unsupported Variant target type: " + targetType + + "; supported types are STRING, INT8, INT16, INT32, and INT64"); + } + } + + /** + * Extract raw Variant-encoded value bytes at {@code path} from a Variant struct column. + * + * @param variantStruct Variant materialization: STRUCT(metadata LIST<UINT8>, + * value LIST<UINT8>, optional shredded children...) + * @param path JSONPath-like path accepted by cuDF's Variant extractor. Paths are expected to + * be ASCII object-field paths like {@code x}, {@code $.x}, or {@code $.x.y}. + * @return LIST<UINT8> column of raw encoded Variant values + */ + public static ColumnVector getVariantFieldValue(ColumnView variantStruct, String path) { + return new ColumnVector(getVariantFieldValue(variantStruct.getNativeView(), path)); + } + + /** + * Decode raw Variant-encoded value bytes into {@code targetType}. Supported target types are + * {@link DType#STRING}, {@link DType#INT8}, {@link DType#INT16}, {@link DType#INT32}, and + * {@link DType#INT64}. + */ + public static ColumnVector castVariantValue(ColumnView valueBytes, DType targetType) { + validateTargetType(targetType); + return new ColumnVector(castVariantValue( + valueBytes.getNativeView(), targetType.getTypeId().getNativeId())); + } + + /** + * Extract a Variant field and decode it into {@code targetType} in one native call. + * Supported target types are {@link DType#STRING}, {@link DType#INT8}, {@link DType#INT16}, + * {@link DType#INT32}, and {@link DType#INT64}. + */ + public static ColumnVector extractVariantField( + ColumnView variantStruct, String path, DType targetType) { + validateTargetType(targetType); + return new ColumnVector(extractVariantField( + variantStruct.getNativeView(), path, targetType.getTypeId().getNativeId())); + } + + /** + * Returns true when this JNI library was built against cuDF with Variant extraction APIs. + */ + public static boolean isAvailable() { + try { + return isAvailableNative(); + } catch (UnsatisfiedLinkError e) { + return false; + } + } + + private static native long getVariantFieldValue(long variantStructHandle, String path); + + private static native long castVariantValue(long valueBytesHandle, int cudfTypeId); + + private static native long extractVariantField( + long variantStructHandle, String path, int cudfTypeId); + + private static native boolean isAvailableNative(); +} diff --git a/src/test/java/com/nvidia/spark/rapids/jni/VariantUtilsTest.java b/src/test/java/com/nvidia/spark/rapids/jni/VariantUtilsTest.java new file mode 100644 index 0000000000..6d2b00b5cf --- /dev/null +++ b/src/test/java/com/nvidia/spark/rapids/jni/VariantUtilsTest.java @@ -0,0 +1,220 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni; + +import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.CudfException; +import ai.rapids.cudf.DType; +import ai.rapids.cudf.HostColumnVector.BasicType; +import ai.rapids.cudf.HostColumnVector.ListType; +import ai.rapids.cudf.HostColumnVector.StructData; +import ai.rapids.cudf.HostColumnVector.StructType; + +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class VariantUtilsTest { + private static final ListType BINARY_TYPE = + new ListType(true, new BasicType(false, DType.UINT8)); + private static final StructType VARIANT_TYPE = + new StructType(true, Arrays.asList(BINARY_TYPE, BINARY_TYPE)); + + private static List bytes(int... values) { + List list = new ArrayList<>(values.length); + for (int value : values) { + list.add((byte) value); + } + return list; + } + + private static StructData variant(List metadata, List value) { + return new StructData(metadata, value); + } + + private static ColumnVector makeApacheObjectNestedVariantColumn() { + List metadata = bytes( + 0x01, 0x0a, 0x00, 0x02, 0x09, 0x0d, 0x17, 0x22, 0x26, 0x2e, 0x33, 0x3e, + 0x46, 'i', 'd', 's', 'p', 'e', 'c', 'i', 'e', 's', 'n', 'a', 'm', 'e', + 'p', 'o', 'p', 'u', 'l', 'a', 't', 'i', 'o', 'n', 'o', 'b', 's', 'e', + 'r', 'v', 'a', 't', 'i', 'o', 'n', 't', 'i', 'm', 'e', 'l', 'o', 'c', + 'a', 't', 'i', 'o', 'n', 'v', 'a', 'l', 'u', 'e', 't', 'e', 'm', 'p', + 'e', 'r', 'a', 't', 'u', 'r', 'e', 'h', 'u', 'm', 'i', 'd', 'i', 't', + 'y'); + List value = bytes( + 0x02, 0x03, 0x00, 0x04, 0x01, 0x00, 0x19, 0x02, 0x46, 0x0c, 0x01, + 0x02, 0x02, 0x02, 0x03, 0x00, 0x0d, 0x10, 0x31, 'l', 'a', 'v', 'a', + ' ', 'm', 'o', 'n', 's', 't', 'e', 'r', 0x10, 0x85, 0x1a, 0x02, 0x03, + 0x06, 0x05, 0x07, 0x09, 0x00, 0x18, 0x24, 0x21, '1', '2', ':', '3', + '4', ':', '5', '6', 0x39, 'I', 'n', ' ', 't', 'h', 'e', ' ', 'V', + 'o', 'l', 'c', 'a', 'n', 'o', 0x02, 0x02, 0x09, 0x08, 0x02, 0x00, + 0x05, 0x0c, 0x7b, 0x10, 0xc8, 0x01); + return ColumnVector.fromStructs(VARIANT_TYPE, variant(metadata, value)); + } + + private static ColumnVector makeXyzVariantColumn() { + List m1 = bytes(0x01, 0x02, 0x00, 0x01, 0x02, 'x', 'y'); + List v1 = bytes( + 0x02, 0x02, 0x00, 0x01, 0x00, 0x05, 0x08, + 0x14, 0x07, 0x00, 0x00, 0x00, + 0x09, 'h', 'i'); + List m2 = bytes(0x01, 0x02, 0x00, 0x01, 0x02, 'x', 'z'); + List v2 = bytes( + 0x02, 0x02, 0x00, 0x01, 0x00, 0x05, 0x0a, + 0x14, 0x2a, 0x00, 0x00, 0x00, + 0x14, 0x63, 0x00, 0x00, 0x00); + List m3 = bytes(0x01, 0x01, 0x00, 0x01, 'y'); + List v3 = bytes(0x02, 0x01, 0x00, 0x00, 0x04, 0x0d, 'z', 'z', 'z'); + + return ColumnVector.fromStructs( + VARIANT_TYPE, + variant(m1, v1), + variant(m2, v2), + variant(m3, v3)); + } + + @Test + void isAvailable() { + assertTrue(VariantUtils.isAvailable()); + } + + @Test + void extractStringField() { + try (ColumnVector variant = makeXyzVariantColumn(); + ColumnVector result = VariantUtils.extractVariantField(variant, "y", DType.STRING); + ColumnVector expected = ColumnVector.fromStrings("hi", null, "zzz")) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void extractIntField() { + try (ColumnVector variant = makeXyzVariantColumn(); + ColumnVector result = VariantUtils.extractVariantField(variant, "x", DType.INT32); + ColumnVector expected = ColumnVector.fromBoxedInts(7, 42, null)) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void extractDollarPrefixedPath() { + try (ColumnVector variant = makeXyzVariantColumn(); + ColumnVector result = VariantUtils.extractVariantField(variant, "$.x", DType.INT32); + ColumnVector expected = ColumnVector.fromBoxedInts(7, 42, null)) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void extractNestedStringField() { + try (ColumnVector variant = makeApacheObjectNestedVariantColumn(); + ColumnVector result = VariantUtils.extractVariantField( + variant, "$.species.name", DType.STRING); + ColumnVector expected = ColumnVector.fromStrings("lava monster")) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void extractNestedInt16Field() { + try (ColumnVector variant = makeApacheObjectNestedVariantColumn(); + ColumnVector result = VariantUtils.extractVariantField( + variant, "$.species.population", DType.INT16); + ColumnVector expected = ColumnVector.fromBoxedShorts((short) 6789)) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void getThenCastFieldValue() { + try (ColumnVector variant = makeXyzVariantColumn(); + ColumnVector valueBytes = VariantUtils.getVariantFieldValue(variant, "z"); + ColumnVector result = VariantUtils.castVariantValue(valueBytes, DType.INT32); + ColumnVector expected = ColumnVector.fromBoxedInts(null, 99, null)) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void emptyInputProducesEmptyOutput() { + try (ColumnVector variant = ColumnVector.fromStructs(VARIANT_TYPE); + ColumnVector result = VariantUtils.extractVariantField(variant, "x", DType.INT32); + ColumnVector expected = ColumnVector.fromBoxedInts()) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void emptyPathThrows() { + try (ColumnVector variant = makeXyzVariantColumn()) { + assertThrows(CudfException.class, () -> VariantUtils.getVariantFieldValue(variant, "")); + assertThrows(CudfException.class, + () -> VariantUtils.extractVariantField(variant, "", DType.INT32)); + } + } + + @Test + void malformedPathThrows() { + try (ColumnVector variant = makeXyzVariantColumn()) { + assertThrows(CudfException.class, + () -> VariantUtils.getVariantFieldValue(variant, "$.x[0]")); + assertThrows(CudfException.class, + () -> VariantUtils.extractVariantField(variant, "$.x[0]", DType.INT32)); + } + } + + @Test + void nullPathThrows() { + try (ColumnVector variant = makeXyzVariantColumn()) { + assertThrows(NullPointerException.class, + () -> VariantUtils.getVariantFieldValue(variant, null)); + assertThrows(NullPointerException.class, + () -> VariantUtils.extractVariantField(variant, null, DType.INT32)); + } + } + + @Test + void unsupportedTargetTypeThrows() { + try (ColumnVector variant = makeXyzVariantColumn(); + ColumnVector valueBytes = VariantUtils.getVariantFieldValue(variant, "x")) { + assertThrows(IllegalArgumentException.class, + () -> VariantUtils.castVariantValue(valueBytes, DType.FLOAT64)); + assertThrows(IllegalArgumentException.class, + () -> VariantUtils.extractVariantField(variant, "x", DType.FLOAT64)); + } + } + + @Test + void parentStructNullIsPreserved() { + List metadata = bytes(0x01, 0x01, 0x00, 0x01, 'x'); + List value = bytes(0x02, 0x01, 0x00, 0x00, 0x05, 0x14, 0x07, 0x00, 0x00, 0x00); + try (ColumnVector variant = ColumnVector.fromStructs( + VARIANT_TYPE, + variant(metadata, value), + null); + ColumnVector result = VariantUtils.extractVariantField(variant, "x", DType.INT32); + ColumnVector expected = ColumnVector.fromBoxedInts(7, null)) { + assertColumnsAreEqual(expected, result); + } + } +} From 6948c21321befe96f14b1a5e8a26bcc8933e7232 Mon Sep 17 00:00:00 2001 From: Niranjan Artal Date: Fri, 26 Jun 2026 10:45:28 -0700 Subject: [PATCH 2/5] addressed review comments Signed-off-by: Niranjan Artal --- .../java/com/nvidia/spark/rapids/jni/VariantUtils.java | 7 +++++-- .../java/com/nvidia/spark/rapids/jni/VariantUtilsTest.java | 5 +++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/main/java/com/nvidia/spark/rapids/jni/VariantUtils.java b/src/main/java/com/nvidia/spark/rapids/jni/VariantUtils.java index 7960fee982..565306b724 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/VariantUtils.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/VariantUtils.java @@ -35,8 +35,9 @@ private VariantUtils() {} private static void validateTargetType(DType targetType) { Objects.requireNonNull(targetType, "targetType"); - if (targetType != DType.STRING && targetType != DType.INT8 && targetType != DType.INT16 && - targetType != DType.INT32 && targetType != DType.INT64) { + if (!targetType.equals(DType.STRING) && !targetType.equals(DType.INT8) && + !targetType.equals(DType.INT16) && !targetType.equals(DType.INT32) && + !targetType.equals(DType.INT64)) { throw new IllegalArgumentException("unsupported Variant target type: " + targetType + "; supported types are STRING, INT8, INT16, INT32, and INT64"); } @@ -52,6 +53,7 @@ private static void validateTargetType(DType targetType) { * @return LIST<UINT8> column of raw encoded Variant values */ public static ColumnVector getVariantFieldValue(ColumnView variantStruct, String path) { + Objects.requireNonNull(path, "path"); return new ColumnVector(getVariantFieldValue(variantStruct.getNativeView(), path)); } @@ -73,6 +75,7 @@ public static ColumnVector castVariantValue(ColumnView valueBytes, DType targetT */ public static ColumnVector extractVariantField( ColumnView variantStruct, String path, DType targetType) { + Objects.requireNonNull(path, "path"); validateTargetType(targetType); return new ColumnVector(extractVariantField( variantStruct.getNativeView(), path, targetType.getTypeId().getNativeId())); diff --git a/src/test/java/com/nvidia/spark/rapids/jni/VariantUtilsTest.java b/src/test/java/com/nvidia/spark/rapids/jni/VariantUtilsTest.java index 6d2b00b5cf..ce145c98c7 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/VariantUtilsTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/VariantUtilsTest.java @@ -204,6 +204,11 @@ void unsupportedTargetTypeThrows() { } } + @Test + void nullTargetTypeThrows() { + assertThrows(NullPointerException.class, () -> VariantUtils.castVariantValue(null, null)); + } + @Test void parentStructNullIsPreserved() { List metadata = bytes(0x01, 0x01, 0x00, 0x01, 'x'); From 7e1ec7bc26a330daa74d0386fc36661eb963e7f9 Mon Sep 17 00:00:00 2001 From: Niranjan Artal Date: Fri, 26 Jun 2026 14:47:52 -0700 Subject: [PATCH 3/5] refactored code Signed-off-by: Niranjan Artal --- src/main/cpp/src/VariantUtilsJni.cpp | 6 ---- .../nvidia/spark/rapids/jni/VariantUtils.java | 24 +++++-------- .../spark/rapids/jni/VariantUtilsTest.java | 35 ++++++++++++++++--- 3 files changed, 38 insertions(+), 27 deletions(-) diff --git a/src/main/cpp/src/VariantUtilsJni.cpp b/src/main/cpp/src/VariantUtilsJni.cpp index 2a61fb3a5d..2c9fab6237 100644 --- a/src/main/cpp/src/VariantUtilsJni.cpp +++ b/src/main/cpp/src/VariantUtilsJni.cpp @@ -80,10 +80,4 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_VariantUtils_extractVar JNI_CATCH(env, 0); } -JNIEXPORT jboolean JNICALL Java_com_nvidia_spark_rapids_jni_VariantUtils_isAvailableNative(JNIEnv*, - jclass) -{ - return JNI_TRUE; -} - } // extern "C" diff --git a/src/main/java/com/nvidia/spark/rapids/jni/VariantUtils.java b/src/main/java/com/nvidia/spark/rapids/jni/VariantUtils.java index 565306b724..43fef99b0c 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/VariantUtils.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/VariantUtils.java @@ -21,6 +21,8 @@ import ai.rapids.cudf.DType; import ai.rapids.cudf.NativeDepsLoader; +import java.util.Arrays; +import java.util.List; import java.util.Objects; /** @@ -31,15 +33,16 @@ public class VariantUtils { NativeDepsLoader.loadNativeDeps(); } + private static final List SUPPORTED_TYPES = Arrays.asList( + DType.STRING, DType.INT8, DType.INT16, DType.INT32, DType.INT64); + private VariantUtils() {} private static void validateTargetType(DType targetType) { Objects.requireNonNull(targetType, "targetType"); - if (!targetType.equals(DType.STRING) && !targetType.equals(DType.INT8) && - !targetType.equals(DType.INT16) && !targetType.equals(DType.INT32) && - !targetType.equals(DType.INT64)) { + if (!SUPPORTED_TYPES.contains(targetType)) { throw new IllegalArgumentException("unsupported Variant target type: " + targetType + - "; supported types are STRING, INT8, INT16, INT32, and INT64"); + "; supported types are " + SUPPORTED_TYPES); } } @@ -63,6 +66,7 @@ public static ColumnVector getVariantFieldValue(ColumnView variantStruct, String * {@link DType#INT64}. */ public static ColumnVector castVariantValue(ColumnView valueBytes, DType targetType) { + Objects.requireNonNull(valueBytes, "valueBytes"); validateTargetType(targetType); return new ColumnVector(castVariantValue( valueBytes.getNativeView(), targetType.getTypeId().getNativeId())); @@ -81,17 +85,6 @@ public static ColumnVector extractVariantField( variantStruct.getNativeView(), path, targetType.getTypeId().getNativeId())); } - /** - * Returns true when this JNI library was built against cuDF with Variant extraction APIs. - */ - public static boolean isAvailable() { - try { - return isAvailableNative(); - } catch (UnsatisfiedLinkError e) { - return false; - } - } - private static native long getVariantFieldValue(long variantStructHandle, String path); private static native long castVariantValue(long valueBytesHandle, int cudfTypeId); @@ -99,5 +92,4 @@ public static boolean isAvailable() { private static native long extractVariantField( long variantStructHandle, String path, int cudfTypeId); - private static native boolean isAvailableNative(); } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/VariantUtilsTest.java b/src/test/java/com/nvidia/spark/rapids/jni/VariantUtilsTest.java index ce145c98c7..7f041d7e93 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/VariantUtilsTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/VariantUtilsTest.java @@ -32,7 +32,6 @@ import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; public class VariantUtilsTest { private static final ListType BINARY_TYPE = @@ -93,9 +92,13 @@ private static ColumnVector makeXyzVariantColumn() { variant(m3, v3)); } - @Test - void isAvailable() { - assertTrue(VariantUtils.isAvailable()); + private static ColumnVector makeExactWidthIntVariantColumn() { + List metadata = bytes(0x01, 0x02, 0x00, 0x01, 0x02, 'b', 'l'); + List value = bytes( + 0x02, 0x02, 0x00, 0x01, 0x00, 0x02, 0x0b, + 0x0c, 0x2a, + 0x18, 0x15, 0x81, 0xe9, 0x7d, 0xf4, 0x10, 0x22, 0x11); + return ColumnVector.fromStructs(VARIANT_TYPE, variant(metadata, value)); } @Test @@ -145,6 +148,24 @@ void extractNestedInt16Field() { } } + @Test + void extractInt8Field() { + try (ColumnVector variant = makeExactWidthIntVariantColumn(); + ColumnVector result = VariantUtils.extractVariantField(variant, "b", DType.INT8); + ColumnVector expected = ColumnVector.fromBoxedBytes((byte) 42)) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void extractInt64Field() { + try (ColumnVector variant = makeExactWidthIntVariantColumn(); + ColumnVector result = VariantUtils.extractVariantField(variant, "l", DType.INT64); + ColumnVector expected = ColumnVector.fromBoxedLongs(1234567890123456789L)) { + assertColumnsAreEqual(expected, result); + } + } + @Test void getThenCastFieldValue() { try (ColumnVector variant = makeXyzVariantColumn(); @@ -205,8 +226,12 @@ void unsupportedTargetTypeThrows() { } @Test - void nullTargetTypeThrows() { + void nullCastArgumentsThrow() { assertThrows(NullPointerException.class, () -> VariantUtils.castVariantValue(null, null)); + assertThrows(NullPointerException.class, + () -> VariantUtils.castVariantValue(null, DType.INT32)); + assertThrows(NullPointerException.class, + () -> VariantUtils.castVariantValue(null, DType.FLOAT64)); } @Test From 1612eddefbaa6e91217e358b71ae4760df0b3159 Mon Sep 17 00:00:00 2001 From: Niranjan Artal Date: Fri, 26 Jun 2026 15:12:49 -0700 Subject: [PATCH 4/5] address review comments Signed-off-by: Niranjan Artal --- src/main/cpp/src/VariantUtilsJni.cpp | 6 ++++++ .../com/nvidia/spark/rapids/jni/VariantUtils.java | 12 ++++++++++++ .../nvidia/spark/rapids/jni/VariantUtilsTest.java | 6 ++++++ 3 files changed, 24 insertions(+) diff --git a/src/main/cpp/src/VariantUtilsJni.cpp b/src/main/cpp/src/VariantUtilsJni.cpp index 2c9fab6237..2a61fb3a5d 100644 --- a/src/main/cpp/src/VariantUtilsJni.cpp +++ b/src/main/cpp/src/VariantUtilsJni.cpp @@ -80,4 +80,10 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_VariantUtils_extractVar JNI_CATCH(env, 0); } +JNIEXPORT jboolean JNICALL Java_com_nvidia_spark_rapids_jni_VariantUtils_isAvailableNative(JNIEnv*, + jclass) +{ + return JNI_TRUE; +} + } // extern "C" diff --git a/src/main/java/com/nvidia/spark/rapids/jni/VariantUtils.java b/src/main/java/com/nvidia/spark/rapids/jni/VariantUtils.java index 43fef99b0c..d92eaec8d3 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/VariantUtils.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/VariantUtils.java @@ -85,6 +85,17 @@ public static ColumnVector extractVariantField( variantStruct.getNativeView(), path, targetType.getTypeId().getNativeId())); } + /** + * Returns true when the loaded JNI library exposes the Variant extraction entry points. + */ + public static boolean isAvailable() { + try { + return isAvailableNative(); + } catch (UnsatisfiedLinkError e) { + return false; + } + } + private static native long getVariantFieldValue(long variantStructHandle, String path); private static native long castVariantValue(long valueBytesHandle, int cudfTypeId); @@ -92,4 +103,5 @@ public static ColumnVector extractVariantField( private static native long extractVariantField( long variantStructHandle, String path, int cudfTypeId); + private static native boolean isAvailableNative(); } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/VariantUtilsTest.java b/src/test/java/com/nvidia/spark/rapids/jni/VariantUtilsTest.java index 7f041d7e93..21f96b3385 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/VariantUtilsTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/VariantUtilsTest.java @@ -32,6 +32,7 @@ import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; public class VariantUtilsTest { private static final ListType BINARY_TYPE = @@ -92,6 +93,11 @@ private static ColumnVector makeXyzVariantColumn() { variant(m3, v3)); } + @Test + void isAvailable() { + assertTrue(VariantUtils.isAvailable()); + } + private static ColumnVector makeExactWidthIntVariantColumn() { List metadata = bytes(0x01, 0x02, 0x00, 0x01, 0x02, 'b', 'l'); List value = bytes( From 68ccc32477ba5a92580c705e4768f3fd80469690 Mon Sep 17 00:00:00 2001 From: Niranjan Artal Date: Fri, 26 Jun 2026 15:19:40 -0700 Subject: [PATCH 5/5] addressed review comment Signed-off-by: Niranjan Artal --- .../java/com/nvidia/spark/rapids/jni/VariantUtils.java | 2 ++ .../com/nvidia/spark/rapids/jni/VariantUtilsTest.java | 8 ++++++++ 2 files changed, 10 insertions(+) diff --git a/src/main/java/com/nvidia/spark/rapids/jni/VariantUtils.java b/src/main/java/com/nvidia/spark/rapids/jni/VariantUtils.java index d92eaec8d3..94b74e3cba 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/VariantUtils.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/VariantUtils.java @@ -56,6 +56,7 @@ private static void validateTargetType(DType targetType) { * @return LIST<UINT8> column of raw encoded Variant values */ public static ColumnVector getVariantFieldValue(ColumnView variantStruct, String path) { + Objects.requireNonNull(variantStruct, "variantStruct"); Objects.requireNonNull(path, "path"); return new ColumnVector(getVariantFieldValue(variantStruct.getNativeView(), path)); } @@ -79,6 +80,7 @@ public static ColumnVector castVariantValue(ColumnView valueBytes, DType targetT */ public static ColumnVector extractVariantField( ColumnView variantStruct, String path, DType targetType) { + Objects.requireNonNull(variantStruct, "variantStruct"); Objects.requireNonNull(path, "path"); validateTargetType(targetType); return new ColumnVector(extractVariantField( diff --git a/src/test/java/com/nvidia/spark/rapids/jni/VariantUtilsTest.java b/src/test/java/com/nvidia/spark/rapids/jni/VariantUtilsTest.java index 21f96b3385..5d4e7a2813 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/VariantUtilsTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/VariantUtilsTest.java @@ -220,6 +220,14 @@ void nullPathThrows() { } } + @Test + void nullVariantStructThrows() { + assertThrows(NullPointerException.class, + () -> VariantUtils.getVariantFieldValue(null, "x")); + assertThrows(NullPointerException.class, + () -> VariantUtils.extractVariantField(null, "x", DType.INT32)); + } + @Test void unsupportedTargetTypeThrows() { try (ColumnVector variant = makeXyzVariantColumn();