Skip to content

Commit faf1891

Browse files
committed
Auto merge of #104818 - scottmcm:refactor-extend-func, r=the8472
Stop peeling the last iteration of the loop in `Vec::resize_with` `resize_with` uses the `ExtendWith` code that peels the last iteration: https://github.com/rust-lang/rust/blob/341d8b8a2c290b4535e965867e876b095461ff6e/library/alloc/src/vec/mod.rs#L2525-L2529 But that's kinda weird for `ExtendFunc` because it does the same thing on the last iteration anyway: https://github.com/rust-lang/rust/blob/341d8b8a2c290b4535e965867e876b095461ff6e/library/alloc/src/vec/mod.rs#L2494-L2502 So this just has it use the normal `extend`-from-`TrustedLen` code instead. r? `@ghost`
2 parents c0e9c86 + 9d68a1a commit faf1891

File tree

7 files changed

+106
-44
lines changed

7 files changed

+106
-44
lines changed

library/alloc/src/vec/mod.rs

+35-11
Original file line numberDiff line numberDiff line change
@@ -2163,7 +2163,7 @@ impl<T, A: Allocator> Vec<T, A> {
21632163
{
21642164
let len = self.len();
21652165
if new_len > len {
2166-
self.extend_with(new_len - len, ExtendFunc(f));
2166+
self.extend_trusted(iter::repeat_with(f).take(new_len - len));
21672167
} else {
21682168
self.truncate(new_len);
21692169
}
@@ -2491,16 +2491,6 @@ impl<T: Clone> ExtendWith<T> for ExtendElement<T> {
24912491
}
24922492
}
24932493

2494-
struct ExtendFunc<F>(F);
2495-
impl<T, F: FnMut() -> T> ExtendWith<T> for ExtendFunc<F> {
2496-
fn next(&mut self) -> T {
2497-
(self.0)()
2498-
}
2499-
fn last(mut self) -> T {
2500-
(self.0)()
2501-
}
2502-
}
2503-
25042494
impl<T, A: Allocator> Vec<T, A> {
25052495
#[cfg(not(no_global_oom_handling))]
25062496
/// Extend the vector by `n` values, using the given generator.
@@ -2870,6 +2860,40 @@ impl<T, A: Allocator> Vec<T, A> {
28702860
}
28712861
}
28722862

2863+
// specific extend for `TrustedLen` iterators, called both by the specializations
2864+
// and internal places where resolving specialization makes compilation slower
2865+
#[cfg(not(no_global_oom_handling))]
2866+
fn extend_trusted(&mut self, iterator: impl iter::TrustedLen<Item = T>) {
2867+
let (low, high) = iterator.size_hint();
2868+
if let Some(additional) = high {
2869+
debug_assert_eq!(
2870+
low,
2871+
additional,
2872+
"TrustedLen iterator's size hint is not exact: {:?}",
2873+
(low, high)
2874+
);
2875+
self.reserve(additional);
2876+
unsafe {
2877+
let ptr = self.as_mut_ptr();
2878+
let mut local_len = SetLenOnDrop::new(&mut self.len);
2879+
iterator.for_each(move |element| {
2880+
ptr::write(ptr.add(local_len.current_len()), element);
2881+
// Since the loop executes user code which can panic we have to update
2882+
// the length every step to correctly drop what we've written.
2883+
// NB can't overflow since we would have had to alloc the address space
2884+
local_len.increment_len(1);
2885+
});
2886+
}
2887+
} else {
2888+
// Per TrustedLen contract a `None` upper bound means that the iterator length
2889+
// truly exceeds usize::MAX, which would eventually lead to a capacity overflow anyway.
2890+
// Since the other branch already panics eagerly (via `reserve()`) we do the same here.
2891+
// This avoids additional codegen for a fallback code path which would eventually
2892+
// panic anyway.
2893+
panic!("capacity overflow");
2894+
}
2895+
}
2896+
28732897
/// Creates a splicing iterator that replaces the specified range in the vector
28742898
/// with the given `replace_with` iterator and yields the removed items.
28752899
/// `replace_with` does not need to be the same length as `range`.

library/alloc/src/vec/set_len_on_drop.rs

+5
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ impl<'a> SetLenOnDrop<'a> {
1818
pub(super) fn increment_len(&mut self, increment: usize) {
1919
self.local_len += increment;
2020
}
21+
22+
#[inline]
23+
pub(super) fn current_len(&self) -> usize {
24+
self.local_len
25+
}
2126
}
2227

2328
impl Drop for SetLenOnDrop<'_> {

library/alloc/src/vec/spec_extend.rs

+2-32
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
use crate::alloc::Allocator;
22
use core::iter::TrustedLen;
3-
use core::ptr::{self};
43
use core::slice::{self};
54

6-
use super::{IntoIter, SetLenOnDrop, Vec};
5+
use super::{IntoIter, Vec};
76

87
// Specialization trait used for Vec::extend
98
pub(super) trait SpecExtend<T, I> {
@@ -24,36 +23,7 @@ where
2423
I: TrustedLen<Item = T>,
2524
{
2625
default fn spec_extend(&mut self, iterator: I) {
27-
// This is the case for a TrustedLen iterator.
28-
let (low, high) = iterator.size_hint();
29-
if let Some(additional) = high {
30-
debug_assert_eq!(
31-
low,
32-
additional,
33-
"TrustedLen iterator's size hint is not exact: {:?}",
34-
(low, high)
35-
);
36-
self.reserve(additional);
37-
unsafe {
38-
let mut ptr = self.as_mut_ptr().add(self.len());
39-
let mut local_len = SetLenOnDrop::new(&mut self.len);
40-
iterator.for_each(move |element| {
41-
ptr::write(ptr, element);
42-
ptr = ptr.add(1);
43-
// Since the loop executes user code which can panic we have to bump the pointer
44-
// after each step.
45-
// NB can't overflow since we would have had to alloc the address space
46-
local_len.increment_len(1);
47-
});
48-
}
49-
} else {
50-
// Per TrustedLen contract a `None` upper bound means that the iterator length
51-
// truly exceeds usize::MAX, which would eventually lead to a capacity overflow anyway.
52-
// Since the other branch already panics eagerly (via `reserve()`) we do the same here.
53-
// This avoids additional codegen for a fallback code path which would eventually
54-
// panic anyway.
55-
panic!("capacity overflow");
56-
}
26+
self.extend_trusted(iterator)
5727
}
5828
}
5929

library/core/src/iter/adapters/take.rs

+20-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ where
7575
#[inline]
7676
fn try_fold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
7777
where
78-
Self: Sized,
7978
Fold: FnMut(Acc, Self::Item) -> R,
8079
R: Try<Output = Acc>,
8180
{
@@ -100,6 +99,26 @@ where
10099

101100
impl_fold_via_try_fold! { fold -> try_fold }
102101

102+
#[inline]
103+
fn for_each<F: FnMut(Self::Item)>(mut self, f: F) {
104+
// The default implementation would use a unit accumulator, so we can
105+
// avoid a stateful closure by folding over the remaining number
106+
// of items we wish to return instead.
107+
fn check<'a, Item>(
108+
mut action: impl FnMut(Item) + 'a,
109+
) -> impl FnMut(usize, Item) -> Option<usize> + 'a {
110+
move |more, x| {
111+
action(x);
112+
more.checked_sub(1)
113+
}
114+
}
115+
116+
let remaining = self.n;
117+
if remaining > 0 {
118+
self.iter.try_fold(remaining - 1, check(f));
119+
}
120+
}
121+
103122
#[inline]
104123
#[rustc_inherit_overflow_checks]
105124
fn advance_by(&mut self, n: usize) -> Result<(), usize> {

library/core/src/iter/sources/repeat_with.rs

+17
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use crate::iter::{FusedIterator, TrustedLen};
2+
use crate::ops::Try;
23

34
/// Creates a new iterator that repeats elements of type `A` endlessly by
45
/// applying the provided closure, the repeater, `F: FnMut() -> A`.
@@ -89,6 +90,22 @@ impl<A, F: FnMut() -> A> Iterator for RepeatWith<F> {
8990
fn size_hint(&self) -> (usize, Option<usize>) {
9091
(usize::MAX, None)
9192
}
93+
94+
#[inline]
95+
fn try_fold<Acc, Fold, R>(&mut self, mut init: Acc, mut fold: Fold) -> R
96+
where
97+
Fold: FnMut(Acc, Self::Item) -> R,
98+
R: Try<Output = Acc>,
99+
{
100+
// This override isn't strictly needed, but avoids the need to optimize
101+
// away the `next`-always-returns-`Some` and emphasizes that the `?`
102+
// is the only way to exit the loop.
103+
104+
loop {
105+
let item = (self.repeater)();
106+
init = fold(init, item)?;
107+
}
108+
}
92109
}
93110

94111
#[stable(feature = "iterator_repeat_with", since = "1.28.0")]

library/core/tests/iter/adapters/take.rs

+20
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,23 @@ fn test_take_try_folds() {
146146
assert_eq!(iter.try_for_each(Err), Err(2));
147147
assert_eq!(iter.try_for_each(Err), Ok(()));
148148
}
149+
150+
#[test]
151+
fn test_byref_take_consumed_items() {
152+
let mut inner = 10..90;
153+
154+
let mut count = 0;
155+
inner.by_ref().take(0).for_each(|_| count += 1);
156+
assert_eq!(count, 0);
157+
assert_eq!(inner, 10..90);
158+
159+
let mut count = 0;
160+
inner.by_ref().take(10).for_each(|_| count += 1);
161+
assert_eq!(count, 10);
162+
assert_eq!(inner, 20..90);
163+
164+
let mut count = 0;
165+
inner.by_ref().take(100).for_each(|_| count += 1);
166+
assert_eq!(count, 70);
167+
assert_eq!(inner, 90..90);
168+
}

src/test/codegen/repeat-trusted-len.rs

+7
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,10 @@ pub fn repeat_take_collect() -> Vec<u8> {
1111
// CHECK: call void @llvm.memset.{{.+}}({{i8\*|ptr}} {{.*}}align 1{{.*}} %{{[0-9]+}}, i8 42, i{{[0-9]+}} 100000, i1 false)
1212
iter::repeat(42).take(100000).collect()
1313
}
14+
15+
// CHECK-LABEL: @repeat_with_take_collect
16+
#[no_mangle]
17+
pub fn repeat_with_take_collect() -> Vec<u8> {
18+
// CHECK: call void @llvm.memset.{{.+}}({{i8\*|ptr}} {{.*}}align 1{{.*}} %{{[0-9]+}}, i8 13, i{{[0-9]+}} 12345, i1 false)
19+
iter::repeat_with(|| 13).take(12345).collect()
20+
}

0 commit comments

Comments
 (0)