Skip to content

Commit db83742

Browse files
committed
Fix HashSet::get_or_insert_with
1 parent 97c2140 commit db83742

File tree

1 file changed

+83
-7
lines changed

1 file changed

+83
-7
lines changed

src/set.rs

+83-7
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use core::hash::{BuildHasher, Hash};
77
use core::iter::{Chain, FusedIterator};
88
use core::ops::{BitAnd, BitOr, BitXor, Sub};
99

10-
use super::map::{self, DefaultHashBuilder, HashMap, Keys};
10+
use super::map::{self, make_hash, DefaultHashBuilder, HashMap, Keys, RawEntryMut};
1111
use crate::raw::{Allocator, Global, RawExtractIf};
1212

1313
// Future Optimization (FIXME!)
@@ -955,6 +955,12 @@ where
955955
/// Inserts a value computed from `f` into the set if the given `value` is
956956
/// not present, then returns a reference to the value in the set.
957957
///
958+
/// # Panics
959+
///
960+
/// Panics if the value from the function and the provided lookup value
961+
/// are not equivalent or have different hashes. See [`Equivalent`]
962+
/// and [`Hash`] for more information.
963+
///
958964
/// # Examples
959965
///
960966
/// ```
@@ -969,20 +975,37 @@ where
969975
/// assert_eq!(value, pet);
970976
/// }
971977
/// assert_eq!(set.len(), 4); // a new "fish" was inserted
978+
/// assert!(set.contains("fish"));
972979
/// ```
973980
#[cfg_attr(feature = "inline-more", inline)]
974981
pub fn get_or_insert_with<Q: ?Sized, F>(&mut self, value: &Q, f: F) -> &T
975982
where
976983
Q: Hash + Equivalent<T>,
977984
F: FnOnce(&Q) -> T,
978985
{
986+
#[cold]
987+
#[inline(never)]
988+
fn assert_failed() {
989+
panic!(
990+
"the value from the function and the lookup value \
991+
must be equivalent and have the same hash"
992+
);
993+
}
994+
979995
// Although the raw entry gives us `&mut T`, we only return `&T` to be consistent with
980996
// `get`. Key mutation is "raw" because you're not supposed to affect `Eq` or `Hash`.
981-
self.map
982-
.raw_entry_mut()
983-
.from_key(value)
984-
.or_insert_with(|| (f(value), ()))
985-
.0
997+
let hash = make_hash::<Q, S>(&self.map.hash_builder, value);
998+
let raw_entry_builder = self.map.raw_entry_mut();
999+
match raw_entry_builder.from_key_hashed_nocheck(hash, value) {
1000+
RawEntryMut::Occupied(entry) => entry.into_key(),
1001+
RawEntryMut::Vacant(entry) => {
1002+
let insert_value = f(value);
1003+
if !value.equivalent(&insert_value) {
1004+
assert_failed();
1005+
}
1006+
entry.insert_hashed_nocheck(hash, insert_value, ()).0
1007+
}
1008+
}
9861009
}
9871010

9881011
/// Gets the given value's corresponding entry in the set for in-place manipulation.
@@ -2492,7 +2515,7 @@ fn assert_covariance() {
24922515
#[cfg(test)]
24932516
mod test_set {
24942517
use super::super::map::DefaultHashBuilder;
2495-
use super::HashSet;
2518+
use super::{make_hash, Equivalent, HashSet};
24962519
use std::vec::Vec;
24972520

24982521
#[test]
@@ -2958,4 +2981,57 @@ mod test_set {
29582981
// (and without the `map`, it does not).
29592982
let mut _set: HashSet<_> = (0..3).map(|_| ()).collect();
29602983
}
2984+
2985+
#[test]
2986+
fn duplicate_insert() {
2987+
let mut set = HashSet::new();
2988+
set.insert(1);
2989+
set.get_or_insert_with(&1, |_| 1);
2990+
set.get_or_insert_with(&1, |_| 1);
2991+
assert!([1].iter().eq(set.iter()));
2992+
}
2993+
2994+
#[test]
2995+
#[should_panic]
2996+
fn some_invalid_equivalent() {
2997+
use core::hash::{Hash, Hasher};
2998+
struct Invalid {
2999+
count: u32,
3000+
other: u32,
3001+
}
3002+
3003+
struct InvalidRef {
3004+
count: u32,
3005+
other: u32,
3006+
}
3007+
3008+
impl PartialEq for Invalid {
3009+
fn eq(&self, other: &Self) -> bool {
3010+
self.count == other.count && self.other == other.other
3011+
}
3012+
}
3013+
impl Eq for Invalid {}
3014+
3015+
impl Equivalent<Invalid> for InvalidRef {
3016+
fn equivalent(&self, key: &Invalid) -> bool {
3017+
self.count == key.count && self.other == key.other
3018+
}
3019+
}
3020+
impl Hash for Invalid {
3021+
fn hash<H: Hasher>(&self, state: &mut H) {
3022+
self.count.hash(state);
3023+
}
3024+
}
3025+
impl Hash for InvalidRef {
3026+
fn hash<H: Hasher>(&self, state: &mut H) {
3027+
self.count.hash(state);
3028+
}
3029+
}
3030+
let mut set: HashSet<Invalid> = HashSet::new();
3031+
let key = InvalidRef { count: 1, other: 1 };
3032+
let value = Invalid { count: 1, other: 2 };
3033+
if make_hash(set.hasher(), &key) == make_hash(set.hasher(), &value) {
3034+
set.get_or_insert_with(&key, |_| value);
3035+
}
3036+
}
29613037
}

0 commit comments

Comments
 (0)