Skip to content

Commit

Permalink
Merge pull request numpy#28361 from eendebakpt/nonzero_unit_tests
Browse files Browse the repository at this point in the history
BUG: Make np.nonzero threading safe
  • Loading branch information
ngoldbaum authored Feb 22, 2025
2 parents dc8f46d + d1c7b4a commit a1f2d58
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 6 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/compiler_sanitizers.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ jobs:
- name: Test
run: |
# pass -s to pytest to see ASAN errors and warnings, otherwise pytest captures them
ASAN_OPTIONS=detect_leaks=0:symbolize=1:strict_init_order=true:allocator_may_return_null=1:halt_on_error=1 \
ASAN_OPTIONS=detect_leaks=0:symbolize=1:strict_init_order=true:allocator_may_return_null=1 \
python -m spin test -- -v -s --timeout=600 --durations=10
clang_TSAN:
Expand Down Expand Up @@ -121,7 +121,7 @@ jobs:
- name: Test
run: |
# These tests are slow, so only run tests in files that do "import threading" to make them count
TSAN_OPTIONS=allocator_may_return_null=1:halt_on_error=1 \
TSAN_OPTIONS="allocator_may_return_null=1:suppressions=$GITHUB_WORKSPACE/tools/ci/tsan_suppressions.txt" \
python -m spin test \
`find numpy -name "test*.py" | xargs grep -l "import threading" | tr '\n' ' '` \
-- -v -s --timeout=600 --durations=10
8 changes: 4 additions & 4 deletions numpy/_core/src/multiarray/item_selection.c
Original file line number Diff line number Diff line change
Expand Up @@ -2893,10 +2893,11 @@ PyArray_Nonzero(PyArrayObject *self)
* the fast bool count is followed by this sparse path is faster
* than combining the two loops, even for larger arrays
*/
npy_intp * multi_index_end = multi_index + nonzero_count;
if (((double)nonzero_count / count) <= 0.1) {
npy_intp subsize;
npy_intp j = 0;
while (1) {
while (multi_index < multi_index_end) {
npy_memchr(data + j * stride, 0, stride, count - j,
&subsize, 1);
j += subsize;
Expand All @@ -2911,11 +2912,10 @@ PyArray_Nonzero(PyArrayObject *self)
* stalls that are very expensive on most modern processors.
*/
else {
npy_intp *multi_index_end = multi_index + nonzero_count;
npy_intp j = 0;

/* Manually unroll for GCC and maybe other compilers */
while (multi_index + 4 < multi_index_end) {
while (multi_index + 4 < multi_index_end && (j < count - 4) ) {
*multi_index = j;
multi_index += data[0] != 0;
*multi_index = j + 1;
Expand All @@ -2928,7 +2928,7 @@ PyArray_Nonzero(PyArrayObject *self)
j += 4;
}

while (multi_index < multi_index_end) {
while (multi_index < multi_index_end && (j < count) ) {
*multi_index = j;
multi_index += *data != 0;
data += stride;
Expand Down
24 changes: 24 additions & 0 deletions numpy/_core/tests/test_multithreading.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,27 @@ def closure(b):
# Reducing the number of threads means the test doesn't trigger the
# bug. Better to skip on some platforms than add a useless test.
pytest.skip("Couldn't spawn enough threads to run the test")

@pytest.mark.parametrize("dtype", [bool, int, float])
def test_nonzero(dtype):
# See: gh-28361
#
# np.nonzero uses np.count_nonzero to determine the size of the output array
# In a second pass the indices of the non-zero elements are determined, but they can have changed
#
# This test triggers a data race which is suppressed in the TSAN CI. The test is to ensure
# np.nonzero does not generate a segmentation fault
x = np.random.randint(4, size=10_000).astype(dtype)

def func(index):
for _ in range(10):
if index == 0:
x[::2] = np.random.randint(2)
else:
try:
_ = np.nonzero(x)
except RuntimeError as ex:
assert 'number of non-zero array elements changed during function execution' in str(ex)

run_threaded(func, max_workers=10, pass_count=True, outer_iterations=50)

11 changes: 11 additions & 0 deletions tools/ci/tsan_suppressions.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# This file contains suppressions for the TSAN tool
#
# Reference: https://github.com/google/sanitizers/wiki/ThreadSanitizerSuppressions

# For np.nonzero, see gh-28361
race:PyArray_Nonzero
race:count_nonzero_int
race:count_nonzero_bool
race:count_nonzero_float
race:DOUBLE_nonzero

0 comments on commit a1f2d58

Please sign in to comment.