Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/main/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ add_library(
src/datetime_truncate.cu
src/decimal_utils.cu
src/exception_with_row_index_utilities.cu
src/find_in_set.cu
src/format_float.cu
src/from_json_to_raw_map.cu
src/from_json_to_structs.cu
Expand Down
42 changes: 41 additions & 1 deletion src/main/cpp/src/StringUtilsJni.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
* Copyright (c) 2025-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,6 +15,7 @@
*/

#include "cudf_jni_apis.hpp"
#include "find_in_set.hpp"
#include "uuid.hpp"

extern "C" {
Expand All @@ -31,4 +32,43 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_StringUtils_randomUUIDs
}
JNI_CATCH(env, 0);
}

JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_StringUtils_findInSet(JNIEnv* env,
jclass,
jlong sets,
jstring word)
{
JNI_NULL_CHECK(env, sets, "sets column is null", 0);
JNI_NULL_CHECK(env, word, "word is null", 0);

JNI_TRY
{
cudf::jni::auto_set_device(env);
auto const input = reinterpret_cast<cudf::column_view const*>(sets);
cudf::jni::native_jstring native_word(env, word);
return cudf::jni::release_as_jlong(spark_rapids_jni::find_in_set(
cudf::strings_column_view{*input}, std::string(native_word.get(), native_word.size_bytes())));
}
JNI_CATCH(env, 0);
}

JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_StringUtils_findInSetRepeated(
JNIEnv* env, jclass, jlong sets, jstring word, jint max_distinct_sets)
{
JNI_NULL_CHECK(env, sets, "sets column is null", 0);
JNI_NULL_CHECK(env, word, "word is null", 0);

JNI_TRY
{
cudf::jni::auto_set_device(env);
auto const input = reinterpret_cast<cudf::column_view const*>(sets);
cudf::jni::native_jstring native_word(env, word);
auto result = spark_rapids_jni::find_in_set_repeated(
cudf::strings_column_view{*input},
std::string(native_word.get(), native_word.size_bytes()),
max_distinct_sets);
return result ? cudf::jni::release_as_jlong(std::move(result)) : 0;
}
JNI_CATCH(env, 0);
}
}
187 changes: 187 additions & 0 deletions src/main/cpp/src/find_in_set.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
/*
* Copyright (c) 2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "find_in_set.hpp"
#include "nvtx_ranges.hpp"

#include <cudf/column/column_device_view.cuh>
#include <cudf/column/column_factories.hpp>
#include <cudf/copying.hpp>
#include <cudf/dictionary/dictionary_column_view.hpp>
#include <cudf/dictionary/encode.hpp>
#include <cudf/null_mask.hpp>
#include <cudf/scalar/scalar.hpp>
#include <cudf/scalar/scalar_factories.hpp>
#include <cudf/strings/string_view.cuh>

#include <rmm/exec_policy.hpp>

#include <thrust/fill.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/transform.h>

namespace spark_rapids_jni {
namespace {

__device__ bool token_matches(cudf::string_view set,
cudf::size_type token_start,
cudf::size_type token_size,
cudf::string_view word)
{
auto const word_size = word.size_bytes();
if (token_size != word_size) { return false; }
if (word_size == 0) { return true; }

auto const* token = set.data() + token_start;
auto const* word_data = word.data();
for (cudf::size_type idx = 0; idx < word_size; ++idx) {
if (token[idx] != word_data[idx]) { return false; }
}
return true;
}

__device__ cudf::size_type find_token_position(cudf::string_view set, cudf::string_view word)
{
auto const* set_data = set.data();
auto const set_size = set.size_bytes();
cudf::size_type token_pos = 1;
cudf::size_type token_start = 0;

for (cudf::size_type idx = 0; idx <= set_size; ++idx) {
if (idx == set_size || set_data[idx] == ',') {
if (token_matches(set, token_start, idx - token_start, word)) { return token_pos; }
++token_pos;
token_start = idx + 1;
}
}
return 0;
}

} // namespace

std::unique_ptr<cudf::column> find_in_set(cudf::strings_column_view const& sets,
std::string const& word,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
SRJ_FUNC_RANGE();

auto const row_count = sets.size();
if (row_count == 0) { return cudf::make_empty_column(cudf::type_id::INT32); }

auto results = cudf::make_numeric_column(cudf::data_type{cudf::type_id::INT32},
row_count,
cudf::copy_bitmask(sets.parent(), stream, mr),
sets.null_count(),
stream,
mr);
auto const d_results = results->mutable_view().data<cudf::size_type>();

if (word.find(',') != std::string::npos) {
thrust::fill_n(rmm::exec_policy(stream), d_results, row_count, cudf::size_type{0});
results->set_null_count(sets.null_count());
return results;
}
Comment on lines +93 to +97

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Redundant set_null_count after make_numeric_column

cudf::make_numeric_column already accepts and stores the null count from its constructor argument (sets.null_count()). The subsequent results->set_null_count(sets.null_count()) call (repeated in both the comma fast path here and inside make_zero_or_null_result in find_in_set_repeated) is a no-op. The same redundancy appears on line 114 in the normal kernel path. These extra calls are harmless but create noise; removing them improves clarity.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!


auto word_scalar = cudf::make_string_scalar(word, stream);
auto const& word_string_scalar = static_cast<cudf::string_scalar const&>(*word_scalar);
auto const d_word = cudf::string_view(word_string_scalar.data(), word_string_scalar.size());

auto const sets_column = cudf::column_device_view::create(sets.parent(), stream);
auto const d_sets = *sets_column;

thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<cudf::size_type>(0),
thrust::make_counting_iterator<cudf::size_type>(row_count),
d_results,
[d_sets, d_word] __device__(cudf::size_type idx) {
if (d_sets.is_null(idx)) { return cudf::size_type{0}; }
return find_token_position(d_sets.element<cudf::string_view>(idx), d_word);
});
Comment on lines +99 to +113

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 d_word device pointer outlives its owning scalar in async execution

word_scalar is a unique_ptr destroyed when find_in_set returns. The lambda captures d_word — a cudf::string_view that holds a raw pointer into the scalar's device buffer — and that buffer will be freed when word_scalar is destroyed. Since thrust::transform is asynchronous on the stream, the kernel may still be accessing d_word.data() after word_scalar's destructor runs.

This is safe only if the device memory resource uses stream-ordered deallocation (the cuDF convention with cuda_async_memory_resource). If a caller supplies a synchronous mr (or the RMM pool is configured otherwise), this is a use-after-free. Consider keeping word_scalar alive by either extending its scope past the function or storing its device buffer in the results and syncing explicitly before the scalar is freed.

results->set_null_count(sets.null_count());
return results;
}

std::unique_ptr<cudf::column> find_in_set_repeated(cudf::strings_column_view const& sets,
std::string const& word,
cudf::size_type max_distinct_sets,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
SRJ_FUNC_RANGE();

auto const row_count = sets.size();
if (row_count == 0) { return cudf::make_empty_column(cudf::type_id::INT32); }

auto make_zero_or_null_result = [&]() {
auto results = cudf::make_numeric_column(cudf::data_type{cudf::type_id::INT32},
row_count,
cudf::copy_bitmask(sets.parent(), stream, mr),
sets.null_count(),
stream,
mr);
thrust::fill_n(rmm::exec_policy(stream),
results->mutable_view().data<cudf::size_type>(),
row_count,
cudf::size_type{0});
results->set_null_count(sets.null_count());
return results;
};

if (word.find(',') != std::string::npos) { return make_zero_or_null_result(); }

auto dictionary = cudf::dictionary::encode(
sets.parent(), cudf::data_type{cudf::type_id::INT32}, stream, mr);
auto const dictionary_view = cudf::dictionary_column_view{dictionary->view()};
auto const keys_size = dictionary_view.keys_size();
if (keys_size > max_distinct_sets) { return nullptr; }
if (keys_size == 0) { return make_zero_or_null_result(); }

auto key_positions =
find_in_set(cudf::strings_column_view{dictionary_view.keys()}, word, stream, mr);

auto gather_map = cudf::make_numeric_column(cudf::data_type{cudf::type_id::INT32},
row_count,
cudf::mask_state::UNALLOCATED,
stream,
mr);
auto const d_gather_map = gather_map->mutable_view().data<cudf::size_type>();
auto const d_dictionary =
cudf::column_device_view::create(dictionary_view.parent(), stream);
auto const d_indices =
cudf::column_device_view::create(dictionary_view.indices(), stream);

thrust::transform(rmm::exec_policy(stream),
thrust::make_counting_iterator<cudf::size_type>(0),
thrust::make_counting_iterator<cudf::size_type>(row_count),
d_gather_map,
[d_dictionary = *d_dictionary, d_indices = *d_indices] __device__(
cudf::size_type idx) {
return d_dictionary.is_null(idx) ? cudf::size_type{0}
: d_indices.element<cudf::size_type>(idx);
});
Comment on lines +167 to +175

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Null-row gather index silently relies on keys_size >= 1

For null rows, the lambda returns cudf::size_type{0}, which becomes the gather index into key_positions. This is safe because if (keys_size == 0) { return make_zero_or_null_result(); } on line 151 guarantees at least one key when we reach the gather. However, the safety of DONT_CHECK at the gather call depends entirely on this ordering — there is no bounds assertion or comment to document the invariant. If the early-return guard is ever moved or removed, null rows will silently gather out of bounds with DONT_CHECK. A brief comment coupling the two guards would make the invariant explicit.


auto gathered_table = cudf::gather(cudf::table_view{{key_positions->view()}},
gather_map->view(),
cudf::out_of_bounds_policy::DONT_CHECK,
stream,
mr);
auto result = std::move(gathered_table->release()[0]);
result->set_null_mask(cudf::copy_bitmask(sets.parent(), stream, mr), sets.null_count());
return result;
}

} // namespace spark_rapids_jni
42 changes: 42 additions & 0 deletions src/main/cpp/src/find_in_set.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright (c) 2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cudf/column/column.hpp>
#include <cudf/strings/strings_column_view.hpp>
#include <cudf/utilities/default_stream.hpp>
#include <cudf/utilities/memory_resource.hpp>

#include <memory>
#include <string>

namespace spark_rapids_jni {

std::unique_ptr<cudf::column> find_in_set(
cudf::strings_column_view const& sets,
std::string const& word,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());

std::unique_ptr<cudf::column> find_in_set_repeated(
cudf::strings_column_view const& sets,
std::string const& word,
cudf::size_type max_distinct_sets,
rmm::cuda_stream_view stream = cudf::get_default_stream(),
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref());

} // namespace spark_rapids_jni
53 changes: 52 additions & 1 deletion src/main/java/com/nvidia/spark/rapids/jni/StringUtils.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
* Copyright (c) 2025-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,12 +17,19 @@
package com.nvidia.spark.rapids.jni;

import ai.rapids.cudf.ColumnVector;
import ai.rapids.cudf.ColumnView;
import ai.rapids.cudf.Cuda;
import ai.rapids.cudf.DType;
import ai.rapids.cudf.NativeDepsLoader;
import java.lang.management.ManagementFactory;
import java.util.Arrays;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicLong;

public class StringUtils {
static {
NativeDepsLoader.loadNativeDeps();
}

// Stores the sequence ID of calling generate UUIDs.
private static AtomicLong sequence = new AtomicLong(0);
Expand Down Expand Up @@ -88,5 +95,49 @@ public static ColumnVector randomUUIDsWithSeed(int rowCount, long seed) {
return new ColumnVector(randomUUIDs(rowCount, seed));
}

/**
* Return the 1-based position of {@code word} in each comma-delimited row of {@code sets}.
* Missing words produce 0, and null rows remain null.
*
* @param sets String column containing comma-delimited tokens
* @param word Literal token to search for
* @return INT32 ColumnVector containing Spark find_in_set-compatible positions
*/
public static ColumnVector findInSet(ColumnView sets, String word) {
Objects.requireNonNull(sets, "sets");
Objects.requireNonNull(word, "word");
if (!sets.getType().equals(DType.STRING)) {
throw new IllegalArgumentException("sets must be a string column");
}
return new ColumnVector(findInSet(sets.getNativeView(), word));
}

/**
* Return the 1-based position of {@code word} in each comma-delimited row of {@code sets}.
* This variant dictionary-encodes repeated {@code sets} values and only scans up to
* {@code maxDistinctSets} distinct set strings. Missing words produce 0, null rows remain null,
* and a null return value means the distinct set count exceeded {@code maxDistinctSets}.
*
* @param sets String column containing comma-delimited tokens
* @param word Literal token to search for
* @param maxDistinctSets Maximum number of distinct set strings to scan
* @return INT32 ColumnVector containing Spark find_in_set-compatible positions, or null if the
* distinct set count exceeds maxDistinctSets
*/
public static ColumnVector findInSetRepeated(ColumnView sets, String word, int maxDistinctSets) {
Objects.requireNonNull(sets, "sets");
Objects.requireNonNull(word, "word");
if (!sets.getType().equals(DType.STRING)) {
throw new IllegalArgumentException("sets must be a string column");
}
if (maxDistinctSets < 0) {
throw new IllegalArgumentException("maxDistinctSets must be non-negative");
}
long result = findInSetRepeated(sets.getNativeView(), word, maxDistinctSets);
return result == 0 ? null : new ColumnVector(result);
}

private static native long randomUUIDs(int rowCount, long seed);
private static native long findInSet(long sets, String word);
private static native long findInSetRepeated(long sets, String word, int maxDistinctSets);
}
Loading