diff --git a/src/main/cpp/CMakeLists.txt b/src/main/cpp/CMakeLists.txt index 5709c6ea40..171d12e63f 100644 --- a/src/main/cpp/CMakeLists.txt +++ b/src/main/cpp/CMakeLists.txt @@ -244,6 +244,7 @@ add_library( src/datetime_truncate.cu src/decimal_utils.cu src/exception_with_row_index_utilities.cu + src/find_in_set.cu src/format_float.cu src/from_json_to_raw_map.cu src/from_json_to_structs.cu diff --git a/src/main/cpp/src/StringUtilsJni.cpp b/src/main/cpp/src/StringUtilsJni.cpp index 051c70d865..a69ec65f8d 100644 --- a/src/main/cpp/src/StringUtilsJni.cpp +++ b/src/main/cpp/src/StringUtilsJni.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2025, NVIDIA CORPORATION. + * Copyright (c) 2025-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ */ #include "cudf_jni_apis.hpp" +#include "find_in_set.hpp" #include "uuid.hpp" extern "C" { @@ -31,4 +32,43 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_StringUtils_randomUUIDs } JNI_CATCH(env, 0); } + +JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_StringUtils_findInSet(JNIEnv* env, + jclass, + jlong sets, + jstring word) +{ + JNI_NULL_CHECK(env, sets, "sets column is null", 0); + JNI_NULL_CHECK(env, word, "word is null", 0); + + JNI_TRY + { + cudf::jni::auto_set_device(env); + auto const input = reinterpret_cast(sets); + cudf::jni::native_jstring native_word(env, word); + return cudf::jni::release_as_jlong(spark_rapids_jni::find_in_set( + cudf::strings_column_view{*input}, std::string(native_word.get(), native_word.size_bytes()))); + } + JNI_CATCH(env, 0); +} + +JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_StringUtils_findInSetRepeated( + JNIEnv* env, jclass, jlong sets, jstring word, jint max_distinct_sets) +{ + JNI_NULL_CHECK(env, sets, "sets column is null", 0); + JNI_NULL_CHECK(env, word, "word is null", 0); + + JNI_TRY + { + cudf::jni::auto_set_device(env); + auto const input = reinterpret_cast(sets); + cudf::jni::native_jstring native_word(env, word); + auto result = spark_rapids_jni::find_in_set_repeated( + cudf::strings_column_view{*input}, + std::string(native_word.get(), native_word.size_bytes()), + max_distinct_sets); + return result ? cudf::jni::release_as_jlong(std::move(result)) : 0; + } + JNI_CATCH(env, 0); +} } diff --git a/src/main/cpp/src/find_in_set.cu b/src/main/cpp/src/find_in_set.cu new file mode 100644 index 0000000000..f782c0877a --- /dev/null +++ b/src/main/cpp/src/find_in_set.cu @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "find_in_set.hpp" +#include "nvtx_ranges.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +namespace spark_rapids_jni { +namespace { + +__device__ bool token_matches(cudf::string_view set, + cudf::size_type token_start, + cudf::size_type token_size, + cudf::string_view word) +{ + auto const word_size = word.size_bytes(); + if (token_size != word_size) { return false; } + if (word_size == 0) { return true; } + + auto const* token = set.data() + token_start; + auto const* word_data = word.data(); + for (cudf::size_type idx = 0; idx < word_size; ++idx) { + if (token[idx] != word_data[idx]) { return false; } + } + return true; +} + +__device__ cudf::size_type find_token_position(cudf::string_view set, cudf::string_view word) +{ + auto const* set_data = set.data(); + auto const set_size = set.size_bytes(); + cudf::size_type token_pos = 1; + cudf::size_type token_start = 0; + + for (cudf::size_type idx = 0; idx <= set_size; ++idx) { + if (idx == set_size || set_data[idx] == ',') { + if (token_matches(set, token_start, idx - token_start, word)) { return token_pos; } + ++token_pos; + token_start = idx + 1; + } + } + return 0; +} + +} // namespace + +std::unique_ptr find_in_set(cudf::strings_column_view const& sets, + std::string const& word, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + SRJ_FUNC_RANGE(); + + auto const row_count = sets.size(); + if (row_count == 0) { return cudf::make_empty_column(cudf::type_id::INT32); } + + auto results = cudf::make_numeric_column(cudf::data_type{cudf::type_id::INT32}, + row_count, + cudf::copy_bitmask(sets.parent(), stream, mr), + sets.null_count(), + stream, + mr); + auto const d_results = results->mutable_view().data(); + + if (word.find(',') != std::string::npos) { + thrust::fill_n(rmm::exec_policy(stream), d_results, row_count, cudf::size_type{0}); + results->set_null_count(sets.null_count()); + return results; + } + + auto word_scalar = cudf::make_string_scalar(word, stream); + auto const& word_string_scalar = static_cast(*word_scalar); + auto const d_word = cudf::string_view(word_string_scalar.data(), word_string_scalar.size()); + + auto const sets_column = cudf::column_device_view::create(sets.parent(), stream); + auto const d_sets = *sets_column; + + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(row_count), + d_results, + [d_sets, d_word] __device__(cudf::size_type idx) { + if (d_sets.is_null(idx)) { return cudf::size_type{0}; } + return find_token_position(d_sets.element(idx), d_word); + }); + results->set_null_count(sets.null_count()); + return results; +} + +std::unique_ptr find_in_set_repeated(cudf::strings_column_view const& sets, + std::string const& word, + cudf::size_type max_distinct_sets, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + SRJ_FUNC_RANGE(); + + auto const row_count = sets.size(); + if (row_count == 0) { return cudf::make_empty_column(cudf::type_id::INT32); } + + auto make_zero_or_null_result = [&]() { + auto results = cudf::make_numeric_column(cudf::data_type{cudf::type_id::INT32}, + row_count, + cudf::copy_bitmask(sets.parent(), stream, mr), + sets.null_count(), + stream, + mr); + thrust::fill_n(rmm::exec_policy(stream), + results->mutable_view().data(), + row_count, + cudf::size_type{0}); + results->set_null_count(sets.null_count()); + return results; + }; + + if (word.find(',') != std::string::npos) { return make_zero_or_null_result(); } + + auto dictionary = cudf::dictionary::encode( + sets.parent(), cudf::data_type{cudf::type_id::INT32}, stream, mr); + auto const dictionary_view = cudf::dictionary_column_view{dictionary->view()}; + auto const keys_size = dictionary_view.keys_size(); + if (keys_size > max_distinct_sets) { return nullptr; } + if (keys_size == 0) { return make_zero_or_null_result(); } + + auto key_positions = + find_in_set(cudf::strings_column_view{dictionary_view.keys()}, word, stream, mr); + + auto gather_map = cudf::make_numeric_column(cudf::data_type{cudf::type_id::INT32}, + row_count, + cudf::mask_state::UNALLOCATED, + stream, + mr); + auto const d_gather_map = gather_map->mutable_view().data(); + auto const d_dictionary = + cudf::column_device_view::create(dictionary_view.parent(), stream); + auto const d_indices = + cudf::column_device_view::create(dictionary_view.indices(), stream); + + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(row_count), + d_gather_map, + [d_dictionary = *d_dictionary, d_indices = *d_indices] __device__( + cudf::size_type idx) { + return d_dictionary.is_null(idx) ? cudf::size_type{0} + : d_indices.element(idx); + }); + + auto gathered_table = cudf::gather(cudf::table_view{{key_positions->view()}}, + gather_map->view(), + cudf::out_of_bounds_policy::DONT_CHECK, + stream, + mr); + auto result = std::move(gathered_table->release()[0]); + result->set_null_mask(cudf::copy_bitmask(sets.parent(), stream, mr), sets.null_count()); + return result; +} + +} // namespace spark_rapids_jni diff --git a/src/main/cpp/src/find_in_set.hpp b/src/main/cpp/src/find_in_set.hpp new file mode 100644 index 0000000000..9925d6bf95 --- /dev/null +++ b/src/main/cpp/src/find_in_set.hpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace spark_rapids_jni { + +std::unique_ptr find_in_set( + cudf::strings_column_view const& sets, + std::string const& word, + rmm::cuda_stream_view stream = cudf::get_default_stream(), + rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref()); + +std::unique_ptr find_in_set_repeated( + cudf::strings_column_view const& sets, + std::string const& word, + cudf::size_type max_distinct_sets, + rmm::cuda_stream_view stream = cudf::get_default_stream(), + rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref()); + +} // namespace spark_rapids_jni diff --git a/src/main/java/com/nvidia/spark/rapids/jni/StringUtils.java b/src/main/java/com/nvidia/spark/rapids/jni/StringUtils.java index 7bc8238af1..4e31ac5b79 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/StringUtils.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/StringUtils.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2025, NVIDIA CORPORATION. + * Copyright (c) 2025-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,12 +17,19 @@ package com.nvidia.spark.rapids.jni; import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.ColumnView; import ai.rapids.cudf.Cuda; +import ai.rapids.cudf.DType; +import ai.rapids.cudf.NativeDepsLoader; import java.lang.management.ManagementFactory; import java.util.Arrays; +import java.util.Objects; import java.util.concurrent.atomic.AtomicLong; public class StringUtils { + static { + NativeDepsLoader.loadNativeDeps(); + } // Stores the sequence ID of calling generate UUIDs. private static AtomicLong sequence = new AtomicLong(0); @@ -88,5 +95,49 @@ public static ColumnVector randomUUIDsWithSeed(int rowCount, long seed) { return new ColumnVector(randomUUIDs(rowCount, seed)); } + /** + * Return the 1-based position of {@code word} in each comma-delimited row of {@code sets}. + * Missing words produce 0, and null rows remain null. + * + * @param sets String column containing comma-delimited tokens + * @param word Literal token to search for + * @return INT32 ColumnVector containing Spark find_in_set-compatible positions + */ + public static ColumnVector findInSet(ColumnView sets, String word) { + Objects.requireNonNull(sets, "sets"); + Objects.requireNonNull(word, "word"); + if (!sets.getType().equals(DType.STRING)) { + throw new IllegalArgumentException("sets must be a string column"); + } + return new ColumnVector(findInSet(sets.getNativeView(), word)); + } + + /** + * Return the 1-based position of {@code word} in each comma-delimited row of {@code sets}. + * This variant dictionary-encodes repeated {@code sets} values and only scans up to + * {@code maxDistinctSets} distinct set strings. Missing words produce 0, null rows remain null, + * and a null return value means the distinct set count exceeded {@code maxDistinctSets}. + * + * @param sets String column containing comma-delimited tokens + * @param word Literal token to search for + * @param maxDistinctSets Maximum number of distinct set strings to scan + * @return INT32 ColumnVector containing Spark find_in_set-compatible positions, or null if the + * distinct set count exceeds maxDistinctSets + */ + public static ColumnVector findInSetRepeated(ColumnView sets, String word, int maxDistinctSets) { + Objects.requireNonNull(sets, "sets"); + Objects.requireNonNull(word, "word"); + if (!sets.getType().equals(DType.STRING)) { + throw new IllegalArgumentException("sets must be a string column"); + } + if (maxDistinctSets < 0) { + throw new IllegalArgumentException("maxDistinctSets must be non-negative"); + } + long result = findInSetRepeated(sets.getNativeView(), word, maxDistinctSets); + return result == 0 ? null : new ColumnVector(result); + } + private static native long randomUUIDs(int rowCount, long seed); + private static native long findInSet(long sets, String word); + private static native long findInSetRepeated(long sets, String word, int maxDistinctSets); } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/StringUtilsTest.java b/src/test/java/com/nvidia/spark/rapids/jni/StringUtilsTest.java index 8199e23e49..e741bdd515 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/StringUtilsTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/StringUtilsTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2025, NVIDIA CORPORATION. + * Copyright (c) 2025-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -93,4 +93,67 @@ void testUuidSameSeed() { assertColumnsAreEqual(round1, round2); } } + + @Test + void testFindInSet() { + try ( + ColumnVector sets = ColumnVector.fromStrings( + "a,b,c", + "b,a,b", + "x,y", + "", + ",", + "a,,b", + null, + "\u00e9,b", + "a,\u00e9", + "aa,a", + "a,aa"); + ColumnVector expectedB = ColumnVector.fromBoxedInts( + 2, 1, 0, 0, 0, 3, null, 2, 0, 0, 0); + ColumnVector expectedEmpty = ColumnVector.fromBoxedInts( + 0, 0, 0, 1, 1, 2, null, 0, 0, 0, 0); + ColumnVector expectedAccent = ColumnVector.fromBoxedInts( + 0, 0, 0, 0, 0, 0, null, 1, 2, 0, 0); + ColumnVector expectedComma = ColumnVector.fromBoxedInts( + 0, 0, 0, 0, 0, 0, null, 0, 0, 0, 0); + ColumnVector actualB = StringUtils.findInSet(sets, "b"); + ColumnVector actualEmpty = StringUtils.findInSet(sets, ""); + ColumnVector actualAccent = StringUtils.findInSet(sets, "\u00e9"); + ColumnVector actualComma = StringUtils.findInSet(sets, "a,b")) { + assertColumnsAreEqual(expectedB, actualB); + assertColumnsAreEqual(expectedEmpty, actualEmpty); + assertColumnsAreEqual(expectedAccent, actualAccent); + assertColumnsAreEqual(expectedComma, actualComma); + } + } + + @Test + void testFindInSetRepeated() { + try ( + ColumnVector sets = ColumnVector.fromStrings( + "a,b,c", + "x,b", + "a,b,c", + null, + "b", + "", + ","); + ColumnVector expectedB = ColumnVector.fromBoxedInts( + 2, 2, 2, null, 1, 0, 0); + ColumnVector expectedEmpty = ColumnVector.fromBoxedInts( + 0, 0, 0, null, 0, 1, 1); + ColumnVector expectedComma = ColumnVector.fromBoxedInts( + 0, 0, 0, null, 0, 0, 0); + ColumnVector actualB = StringUtils.findInSetRepeated(sets, "b", 5); + ColumnVector actualEmpty = StringUtils.findInSetRepeated(sets, "", 5); + ColumnVector actualComma = StringUtils.findInSetRepeated(sets, "a,b", 5)) { + assertColumnsAreEqual(expectedB, actualB); + assertColumnsAreEqual(expectedEmpty, actualEmpty); + assertColumnsAreEqual(expectedComma, actualComma); + try (ColumnVector tooManyDistinct = StringUtils.findInSetRepeated(sets, "b", 4)) { + assertNull(tooManyDistinct); + } + } + } }