diff --git a/src/main/cpp/CMakeLists.txt b/src/main/cpp/CMakeLists.txt index d2459d690d..8c5e281bc3 100644 --- a/src/main/cpp/CMakeLists.txt +++ b/src/main/cpp/CMakeLists.txt @@ -230,6 +230,7 @@ add_library( src/StringUtilsJni.cpp src/SubStringIndexJni.cpp src/TaskPriorityJni.cpp + src/VariantUtilsJni.cpp src/ZOrderJni.cpp src/iceberg/IcebergBucketJni.cpp src/iceberg/IcebergDateTimeUtilJni.cpp diff --git a/src/main/cpp/src/VariantUtilsJni.cpp b/src/main/cpp/src/VariantUtilsJni.cpp new file mode 100644 index 0000000000..2a61fb3a5d --- /dev/null +++ b/src/main/cpp/src/VariantUtilsJni.cpp @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cudf_jni_apis.hpp" +#include "jni_utils.hpp" + +#include +#include +#include +#include + +extern "C" { + +JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_VariantUtils_getVariantFieldValue( + JNIEnv* env, jclass, jlong variant_struct_handle, jstring j_path) +{ + JNI_NULL_CHECK(env, variant_struct_handle, "variant struct column is null", 0); + JNI_NULL_CHECK(env, j_path, "path is null", 0); + JNI_TRY + { + cudf::jni::auto_set_device(env); + auto const& variant_struct = *reinterpret_cast(variant_struct_handle); + cudf::jni::native_jstring path(env, j_path); + return cudf::jni::release_as_jlong( + cudf::io::parquet::experimental::get_variant_field(variant_struct, + path.get(), + cudf::get_default_stream(), + cudf::get_current_device_resource_ref())); + } + JNI_CATCH(env, 0); +} + +JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_VariantUtils_castVariantValue( + JNIEnv* env, jclass, jlong value_bytes_handle, jint cudf_type_id) +{ + JNI_NULL_CHECK(env, value_bytes_handle, "value bytes column is null", 0); + JNI_TRY + { + cudf::jni::auto_set_device(env); + auto const& value_bytes = *reinterpret_cast(value_bytes_handle); + return cudf::jni::release_as_jlong(cudf::io::parquet::experimental::cast_variant( + value_bytes, + cudf::data_type{static_cast(cudf_type_id)}, + cudf::get_default_stream(), + cudf::get_current_device_resource_ref())); + } + JNI_CATCH(env, 0); +} + +JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_VariantUtils_extractVariantField( + JNIEnv* env, jclass, jlong variant_struct_handle, jstring j_path, jint cudf_type_id) +{ + JNI_NULL_CHECK(env, variant_struct_handle, "variant struct column is null", 0); + JNI_NULL_CHECK(env, j_path, "path is null", 0); + JNI_TRY + { + cudf::jni::auto_set_device(env); + auto const& variant_struct = *reinterpret_cast(variant_struct_handle); + cudf::jni::native_jstring path(env, j_path); + return cudf::jni::release_as_jlong(cudf::io::parquet::experimental::extract_variant_field( + variant_struct, + path.get(), + cudf::data_type{static_cast(cudf_type_id)}, + cudf::get_default_stream(), + cudf::get_current_device_resource_ref())); + } + JNI_CATCH(env, 0); +} + +JNIEXPORT jboolean JNICALL Java_com_nvidia_spark_rapids_jni_VariantUtils_isAvailableNative(JNIEnv*, + jclass) +{ + return JNI_TRUE; +} + +} // extern "C" diff --git a/src/main/java/com/nvidia/spark/rapids/jni/VariantUtils.java b/src/main/java/com/nvidia/spark/rapids/jni/VariantUtils.java new file mode 100644 index 0000000000..94b74e3cba --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/VariantUtils.java @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni; + +import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.ColumnView; +import ai.rapids.cudf.DType; +import ai.rapids.cudf.NativeDepsLoader; + +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +/** + * JNI bridge to cuDF's experimental Parquet Variant field extraction APIs. + */ +public class VariantUtils { + static { + NativeDepsLoader.loadNativeDeps(); + } + + private static final List SUPPORTED_TYPES = Arrays.asList( + DType.STRING, DType.INT8, DType.INT16, DType.INT32, DType.INT64); + + private VariantUtils() {} + + private static void validateTargetType(DType targetType) { + Objects.requireNonNull(targetType, "targetType"); + if (!SUPPORTED_TYPES.contains(targetType)) { + throw new IllegalArgumentException("unsupported Variant target type: " + targetType + + "; supported types are " + SUPPORTED_TYPES); + } + } + + /** + * Extract raw Variant-encoded value bytes at {@code path} from a Variant struct column. + * + * @param variantStruct Variant materialization: STRUCT(metadata LIST<UINT8>, + * value LIST<UINT8>, optional shredded children...) + * @param path JSONPath-like path accepted by cuDF's Variant extractor. Paths are expected to + * be ASCII object-field paths like {@code x}, {@code $.x}, or {@code $.x.y}. + * @return LIST<UINT8> column of raw encoded Variant values + */ + public static ColumnVector getVariantFieldValue(ColumnView variantStruct, String path) { + Objects.requireNonNull(variantStruct, "variantStruct"); + Objects.requireNonNull(path, "path"); + return new ColumnVector(getVariantFieldValue(variantStruct.getNativeView(), path)); + } + + /** + * Decode raw Variant-encoded value bytes into {@code targetType}. Supported target types are + * {@link DType#STRING}, {@link DType#INT8}, {@link DType#INT16}, {@link DType#INT32}, and + * {@link DType#INT64}. + */ + public static ColumnVector castVariantValue(ColumnView valueBytes, DType targetType) { + Objects.requireNonNull(valueBytes, "valueBytes"); + validateTargetType(targetType); + return new ColumnVector(castVariantValue( + valueBytes.getNativeView(), targetType.getTypeId().getNativeId())); + } + + /** + * Extract a Variant field and decode it into {@code targetType} in one native call. + * Supported target types are {@link DType#STRING}, {@link DType#INT8}, {@link DType#INT16}, + * {@link DType#INT32}, and {@link DType#INT64}. + */ + public static ColumnVector extractVariantField( + ColumnView variantStruct, String path, DType targetType) { + Objects.requireNonNull(variantStruct, "variantStruct"); + Objects.requireNonNull(path, "path"); + validateTargetType(targetType); + return new ColumnVector(extractVariantField( + variantStruct.getNativeView(), path, targetType.getTypeId().getNativeId())); + } + + /** + * Returns true when the loaded JNI library exposes the Variant extraction entry points. + */ + public static boolean isAvailable() { + try { + return isAvailableNative(); + } catch (UnsatisfiedLinkError e) { + return false; + } + } + + private static native long getVariantFieldValue(long variantStructHandle, String path); + + private static native long castVariantValue(long valueBytesHandle, int cudfTypeId); + + private static native long extractVariantField( + long variantStructHandle, String path, int cudfTypeId); + + private static native boolean isAvailableNative(); +} diff --git a/src/test/java/com/nvidia/spark/rapids/jni/VariantUtilsTest.java b/src/test/java/com/nvidia/spark/rapids/jni/VariantUtilsTest.java new file mode 100644 index 0000000000..5d4e7a2813 --- /dev/null +++ b/src/test/java/com/nvidia/spark/rapids/jni/VariantUtilsTest.java @@ -0,0 +1,264 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni; + +import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.CudfException; +import ai.rapids.cudf.DType; +import ai.rapids.cudf.HostColumnVector.BasicType; +import ai.rapids.cudf.HostColumnVector.ListType; +import ai.rapids.cudf.HostColumnVector.StructData; +import ai.rapids.cudf.HostColumnVector.StructType; + +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class VariantUtilsTest { + private static final ListType BINARY_TYPE = + new ListType(true, new BasicType(false, DType.UINT8)); + private static final StructType VARIANT_TYPE = + new StructType(true, Arrays.asList(BINARY_TYPE, BINARY_TYPE)); + + private static List bytes(int... values) { + List list = new ArrayList<>(values.length); + for (int value : values) { + list.add((byte) value); + } + return list; + } + + private static StructData variant(List metadata, List value) { + return new StructData(metadata, value); + } + + private static ColumnVector makeApacheObjectNestedVariantColumn() { + List metadata = bytes( + 0x01, 0x0a, 0x00, 0x02, 0x09, 0x0d, 0x17, 0x22, 0x26, 0x2e, 0x33, 0x3e, + 0x46, 'i', 'd', 's', 'p', 'e', 'c', 'i', 'e', 's', 'n', 'a', 'm', 'e', + 'p', 'o', 'p', 'u', 'l', 'a', 't', 'i', 'o', 'n', 'o', 'b', 's', 'e', + 'r', 'v', 'a', 't', 'i', 'o', 'n', 't', 'i', 'm', 'e', 'l', 'o', 'c', + 'a', 't', 'i', 'o', 'n', 'v', 'a', 'l', 'u', 'e', 't', 'e', 'm', 'p', + 'e', 'r', 'a', 't', 'u', 'r', 'e', 'h', 'u', 'm', 'i', 'd', 'i', 't', + 'y'); + List value = bytes( + 0x02, 0x03, 0x00, 0x04, 0x01, 0x00, 0x19, 0x02, 0x46, 0x0c, 0x01, + 0x02, 0x02, 0x02, 0x03, 0x00, 0x0d, 0x10, 0x31, 'l', 'a', 'v', 'a', + ' ', 'm', 'o', 'n', 's', 't', 'e', 'r', 0x10, 0x85, 0x1a, 0x02, 0x03, + 0x06, 0x05, 0x07, 0x09, 0x00, 0x18, 0x24, 0x21, '1', '2', ':', '3', + '4', ':', '5', '6', 0x39, 'I', 'n', ' ', 't', 'h', 'e', ' ', 'V', + 'o', 'l', 'c', 'a', 'n', 'o', 0x02, 0x02, 0x09, 0x08, 0x02, 0x00, + 0x05, 0x0c, 0x7b, 0x10, 0xc8, 0x01); + return ColumnVector.fromStructs(VARIANT_TYPE, variant(metadata, value)); + } + + private static ColumnVector makeXyzVariantColumn() { + List m1 = bytes(0x01, 0x02, 0x00, 0x01, 0x02, 'x', 'y'); + List v1 = bytes( + 0x02, 0x02, 0x00, 0x01, 0x00, 0x05, 0x08, + 0x14, 0x07, 0x00, 0x00, 0x00, + 0x09, 'h', 'i'); + List m2 = bytes(0x01, 0x02, 0x00, 0x01, 0x02, 'x', 'z'); + List v2 = bytes( + 0x02, 0x02, 0x00, 0x01, 0x00, 0x05, 0x0a, + 0x14, 0x2a, 0x00, 0x00, 0x00, + 0x14, 0x63, 0x00, 0x00, 0x00); + List m3 = bytes(0x01, 0x01, 0x00, 0x01, 'y'); + List v3 = bytes(0x02, 0x01, 0x00, 0x00, 0x04, 0x0d, 'z', 'z', 'z'); + + return ColumnVector.fromStructs( + VARIANT_TYPE, + variant(m1, v1), + variant(m2, v2), + variant(m3, v3)); + } + + @Test + void isAvailable() { + assertTrue(VariantUtils.isAvailable()); + } + + private static ColumnVector makeExactWidthIntVariantColumn() { + List metadata = bytes(0x01, 0x02, 0x00, 0x01, 0x02, 'b', 'l'); + List value = bytes( + 0x02, 0x02, 0x00, 0x01, 0x00, 0x02, 0x0b, + 0x0c, 0x2a, + 0x18, 0x15, 0x81, 0xe9, 0x7d, 0xf4, 0x10, 0x22, 0x11); + return ColumnVector.fromStructs(VARIANT_TYPE, variant(metadata, value)); + } + + @Test + void extractStringField() { + try (ColumnVector variant = makeXyzVariantColumn(); + ColumnVector result = VariantUtils.extractVariantField(variant, "y", DType.STRING); + ColumnVector expected = ColumnVector.fromStrings("hi", null, "zzz")) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void extractIntField() { + try (ColumnVector variant = makeXyzVariantColumn(); + ColumnVector result = VariantUtils.extractVariantField(variant, "x", DType.INT32); + ColumnVector expected = ColumnVector.fromBoxedInts(7, 42, null)) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void extractDollarPrefixedPath() { + try (ColumnVector variant = makeXyzVariantColumn(); + ColumnVector result = VariantUtils.extractVariantField(variant, "$.x", DType.INT32); + ColumnVector expected = ColumnVector.fromBoxedInts(7, 42, null)) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void extractNestedStringField() { + try (ColumnVector variant = makeApacheObjectNestedVariantColumn(); + ColumnVector result = VariantUtils.extractVariantField( + variant, "$.species.name", DType.STRING); + ColumnVector expected = ColumnVector.fromStrings("lava monster")) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void extractNestedInt16Field() { + try (ColumnVector variant = makeApacheObjectNestedVariantColumn(); + ColumnVector result = VariantUtils.extractVariantField( + variant, "$.species.population", DType.INT16); + ColumnVector expected = ColumnVector.fromBoxedShorts((short) 6789)) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void extractInt8Field() { + try (ColumnVector variant = makeExactWidthIntVariantColumn(); + ColumnVector result = VariantUtils.extractVariantField(variant, "b", DType.INT8); + ColumnVector expected = ColumnVector.fromBoxedBytes((byte) 42)) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void extractInt64Field() { + try (ColumnVector variant = makeExactWidthIntVariantColumn(); + ColumnVector result = VariantUtils.extractVariantField(variant, "l", DType.INT64); + ColumnVector expected = ColumnVector.fromBoxedLongs(1234567890123456789L)) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void getThenCastFieldValue() { + try (ColumnVector variant = makeXyzVariantColumn(); + ColumnVector valueBytes = VariantUtils.getVariantFieldValue(variant, "z"); + ColumnVector result = VariantUtils.castVariantValue(valueBytes, DType.INT32); + ColumnVector expected = ColumnVector.fromBoxedInts(null, 99, null)) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void emptyInputProducesEmptyOutput() { + try (ColumnVector variant = ColumnVector.fromStructs(VARIANT_TYPE); + ColumnVector result = VariantUtils.extractVariantField(variant, "x", DType.INT32); + ColumnVector expected = ColumnVector.fromBoxedInts()) { + assertColumnsAreEqual(expected, result); + } + } + + @Test + void emptyPathThrows() { + try (ColumnVector variant = makeXyzVariantColumn()) { + assertThrows(CudfException.class, () -> VariantUtils.getVariantFieldValue(variant, "")); + assertThrows(CudfException.class, + () -> VariantUtils.extractVariantField(variant, "", DType.INT32)); + } + } + + @Test + void malformedPathThrows() { + try (ColumnVector variant = makeXyzVariantColumn()) { + assertThrows(CudfException.class, + () -> VariantUtils.getVariantFieldValue(variant, "$.x[0]")); + assertThrows(CudfException.class, + () -> VariantUtils.extractVariantField(variant, "$.x[0]", DType.INT32)); + } + } + + @Test + void nullPathThrows() { + try (ColumnVector variant = makeXyzVariantColumn()) { + assertThrows(NullPointerException.class, + () -> VariantUtils.getVariantFieldValue(variant, null)); + assertThrows(NullPointerException.class, + () -> VariantUtils.extractVariantField(variant, null, DType.INT32)); + } + } + + @Test + void nullVariantStructThrows() { + assertThrows(NullPointerException.class, + () -> VariantUtils.getVariantFieldValue(null, "x")); + assertThrows(NullPointerException.class, + () -> VariantUtils.extractVariantField(null, "x", DType.INT32)); + } + + @Test + void unsupportedTargetTypeThrows() { + try (ColumnVector variant = makeXyzVariantColumn(); + ColumnVector valueBytes = VariantUtils.getVariantFieldValue(variant, "x")) { + assertThrows(IllegalArgumentException.class, + () -> VariantUtils.castVariantValue(valueBytes, DType.FLOAT64)); + assertThrows(IllegalArgumentException.class, + () -> VariantUtils.extractVariantField(variant, "x", DType.FLOAT64)); + } + } + + @Test + void nullCastArgumentsThrow() { + assertThrows(NullPointerException.class, () -> VariantUtils.castVariantValue(null, null)); + assertThrows(NullPointerException.class, + () -> VariantUtils.castVariantValue(null, DType.INT32)); + assertThrows(NullPointerException.class, + () -> VariantUtils.castVariantValue(null, DType.FLOAT64)); + } + + @Test + void parentStructNullIsPreserved() { + List metadata = bytes(0x01, 0x01, 0x00, 0x01, 'x'); + List value = bytes(0x02, 0x01, 0x00, 0x00, 0x05, 0x14, 0x07, 0x00, 0x00, 0x00); + try (ColumnVector variant = ColumnVector.fromStructs( + VARIANT_TYPE, + variant(metadata, value), + null); + ColumnVector result = VariantUtils.extractVariantField(variant, "x", DType.INT32); + ColumnVector expected = ColumnVector.fromBoxedInts(7, null)) { + assertColumnsAreEqual(expected, result); + } + } +}