Skip to content

Commit

Permalink
Guard par_bridge against work-stealing recursion
Browse files Browse the repository at this point in the history
It doesn't make sense for the `par_bridge` producer to run nested on one
thread, since each split is capable of running the entire iterator to
completion. However, this can still happen if the serial iterator or the
consumer side make any rayon calls that would block, where it may go
into work-stealing and start another split of the `par_bridge`. With the
iterator in particular, this is a problem because we'll already be
holding the mutex, so trying to lock again will deadlock or panic.

We now set a flag in each thread when they start working on the bridge,
and bail out if we re-enter that bridge again on the same thread. The
new `par_bridge_recursion` test would previously hang almost every time
for me, but runs reliably with this new check in place.
  • Loading branch information
cuviper committed Dec 8, 2022
1 parent 78feb98 commit 168d5a7
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 26 deletions.
59 changes: 33 additions & 26 deletions src/iter/par_bridge.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Mutex;

use crate::current_num_threads;
use crate::iter::plumbing::{bridge_unindexed, Folder, UnindexedConsumer, UnindexedProducer};
use crate::iter::ParallelIterator;
use crate::{current_num_threads, current_thread_index};

/// Conversion trait to convert an `Iterator` to a `ParallelIterator`.
///
Expand Down Expand Up @@ -75,39 +75,27 @@ where
where
C: UnindexedConsumer<Self::Item>,
{
let split_count = AtomicUsize::new(current_num_threads());

let iter = Mutex::new(self.iter.fuse());
let num_threads = current_num_threads();
let threads_started: Vec<_> = (0..num_threads).map(|_| AtomicBool::new(false)).collect();

bridge_unindexed(
IterParallelProducer {
split_count: &split_count,
iter: &iter,
&IterParallelProducer {
split_count: AtomicUsize::new(num_threads),
iter: Mutex::new(self.iter.fuse()),
threads_started: &threads_started,
},
consumer,
)
}
}

struct IterParallelProducer<'a, Iter: Iterator> {
split_count: &'a AtomicUsize,
iter: &'a Mutex<std::iter::Fuse<Iter>>,
}

// manual clone because T doesn't need to be Clone, but the derive assumes it should be
impl<'a, Iter: Iterator + 'a> Clone for IterParallelProducer<'a, Iter> {
fn clone(&self) -> Self {
IterParallelProducer {
split_count: self.split_count,
iter: self.iter,
}
}
struct IterParallelProducer<'a, Iter> {
split_count: AtomicUsize,
iter: Mutex<std::iter::Fuse<Iter>>,
threads_started: &'a [AtomicBool],
}

impl<'a, Iter: Iterator + Send + 'a> UnindexedProducer for IterParallelProducer<'a, Iter>
where
Iter::Item: Send,
{
impl<Iter: Iterator + Send> UnindexedProducer for &IterParallelProducer<'_, Iter> {
type Item = Iter::Item;

fn split(self) -> (Self, Option<Self>) {
Expand All @@ -122,7 +110,7 @@ where
Ordering::SeqCst,
Ordering::SeqCst,
) {
Ok(_) => return (self.clone(), Some(self)),
Ok(_) => return (self, Some(self)),
Err(last_count) => count = last_count,
}
} else {
Expand All @@ -135,6 +123,25 @@ where
where
F: Folder<Self::Item>,
{
// Guard against work-stealing-induced recursion, in case `Iter::next()`
// calls rayon internally, so we don't deadlock our mutex. We might also
// be recursing via `folder` methods, which doesn't present a mutex hazard,
// but it's lower overhead for us to just check this once, rather than
// updating additional shared state on every mutex lock/unlock.
// (If this isn't a rayon thread, then there's no work-stealing anyway...)
if let Some(i) = current_thread_index() {
// Note: If the number of threads in the pool ever grows dynamically, then
// we'll end up sharing flags and may falsely detect recursion -- that's
// still fine for overall correctness, just not optimal for parallelism.
let thread_started = &self.threads_started[i % self.threads_started.len()];
if thread_started.swap(true, Ordering::Relaxed) {
// We can't make progress with a nested mutex, so just return and let
// the outermost loop continue with the rest of the iterator items.
// eprintln!("par_bridge recursion detected!");
return folder;
}
}

loop {
if let Ok(mut iter) = self.iter.lock() {
if let Some(it) = iter.next() {
Expand Down
30 changes: 30 additions & 0 deletions tests/par_bridge_recursion.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use rayon::prelude::*;
use std::iter::once_with;

const N: usize = 100_000;

#[test]
fn par_bridge_recursion() {
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(10)
.build()
.unwrap();

let seq: Vec<_> = (0..N).map(|i| (i, i.to_string())).collect();

pool.broadcast(|_| {
let mut par: Vec<_> = (0..N)
.into_par_iter()
.flat_map(|i| {
once_with(move || {
// Using rayon within the serial iterator creates an opportunity for
// work-stealing to make par_bridge's mutex accidentally recursive.
rayon::join(move || i, move || i.to_string())
})
.par_bridge()
})
.collect();
par.par_sort_unstable();
assert_eq!(seq, par);
});
}

0 comments on commit 168d5a7

Please sign in to comment.