@@ -7,7 +7,7 @@ use core::hash::{BuildHasher, Hash};
7
7
use core:: iter:: { Chain , FusedIterator } ;
8
8
use core:: ops:: { BitAnd , BitOr , BitXor , Sub } ;
9
9
10
- use super :: map:: { self , DefaultHashBuilder , HashMap , Keys } ;
10
+ use super :: map:: { self , make_hash , DefaultHashBuilder , HashMap , Keys , RawEntryMut } ;
11
11
use crate :: raw:: { Allocator , Global , RawExtractIf } ;
12
12
13
13
// Future Optimization (FIXME!)
@@ -955,6 +955,12 @@ where
955
955
/// Inserts a value computed from `f` into the set if the given `value` is
956
956
/// not present, then returns a reference to the value in the set.
957
957
///
958
+ /// # Panics
959
+ ///
960
+ /// Panics if the value from the function and the provided lookup value
961
+ /// are not equivalent or have different hashes. See [`Equivalent`]
962
+ /// and [`Hash`] for more information.
963
+ ///
958
964
/// # Examples
959
965
///
960
966
/// ```
@@ -969,20 +975,37 @@ where
969
975
/// assert_eq!(value, pet);
970
976
/// }
971
977
/// assert_eq!(set.len(), 4); // a new "fish" was inserted
978
+ /// assert!(set.contains("fish"));
972
979
/// ```
973
980
#[ cfg_attr( feature = "inline-more" , inline) ]
974
981
pub fn get_or_insert_with < Q : ?Sized , F > ( & mut self , value : & Q , f : F ) -> & T
975
982
where
976
983
Q : Hash + Equivalent < T > ,
977
984
F : FnOnce ( & Q ) -> T ,
978
985
{
986
+ #[ cold]
987
+ #[ inline( never) ]
988
+ fn assert_failed ( ) {
989
+ panic ! (
990
+ "the value from the function and the lookup value \
991
+ must be equivalent and have the same hash"
992
+ ) ;
993
+ }
994
+
979
995
// Although the raw entry gives us `&mut T`, we only return `&T` to be consistent with
980
996
// `get`. Key mutation is "raw" because you're not supposed to affect `Eq` or `Hash`.
981
- self . map
982
- . raw_entry_mut ( )
983
- . from_key ( value)
984
- . or_insert_with ( || ( f ( value) , ( ) ) )
985
- . 0
997
+ let hash = make_hash :: < Q , S > ( & self . map . hash_builder , value) ;
998
+ let raw_entry_builder = self . map . raw_entry_mut ( ) ;
999
+ match raw_entry_builder. from_key_hashed_nocheck ( hash, value) {
1000
+ RawEntryMut :: Occupied ( entry) => entry. into_key ( ) ,
1001
+ RawEntryMut :: Vacant ( entry) => {
1002
+ let insert_value = f ( value) ;
1003
+ if !value. equivalent ( & insert_value) {
1004
+ assert_failed ( ) ;
1005
+ }
1006
+ entry. insert_hashed_nocheck ( hash, insert_value, ( ) ) . 0
1007
+ }
1008
+ }
986
1009
}
987
1010
988
1011
/// Gets the given value's corresponding entry in the set for in-place manipulation.
@@ -2492,7 +2515,7 @@ fn assert_covariance() {
2492
2515
#[ cfg( test) ]
2493
2516
mod test_set {
2494
2517
use super :: super :: map:: DefaultHashBuilder ;
2495
- use super :: HashSet ;
2518
+ use super :: { make_hash , Equivalent , HashSet } ;
2496
2519
use std:: vec:: Vec ;
2497
2520
2498
2521
#[ test]
@@ -2958,4 +2981,57 @@ mod test_set {
2958
2981
// (and without the `map`, it does not).
2959
2982
let mut _set: HashSet < _ > = ( 0 ..3 ) . map ( |_| ( ) ) . collect ( ) ;
2960
2983
}
2984
+
2985
+ #[ test]
2986
+ fn duplicate_insert ( ) {
2987
+ let mut set = HashSet :: new ( ) ;
2988
+ set. insert ( 1 ) ;
2989
+ set. get_or_insert_with ( & 1 , |_| 1 ) ;
2990
+ set. get_or_insert_with ( & 1 , |_| 1 ) ;
2991
+ assert ! ( [ 1 ] . iter( ) . eq( set. iter( ) ) ) ;
2992
+ }
2993
+
2994
+ #[ test]
2995
+ #[ should_panic]
2996
+ fn some_invalid_equivalent ( ) {
2997
+ use core:: hash:: { Hash , Hasher } ;
2998
+ struct Invalid {
2999
+ count : u32 ,
3000
+ other : u32 ,
3001
+ }
3002
+
3003
+ struct InvalidRef {
3004
+ count : u32 ,
3005
+ other : u32 ,
3006
+ }
3007
+
3008
+ impl PartialEq for Invalid {
3009
+ fn eq ( & self , other : & Self ) -> bool {
3010
+ self . count == other. count && self . other == other. other
3011
+ }
3012
+ }
3013
+ impl Eq for Invalid { }
3014
+
3015
+ impl Equivalent < Invalid > for InvalidRef {
3016
+ fn equivalent ( & self , key : & Invalid ) -> bool {
3017
+ self . count == key. count && self . other == key. other
3018
+ }
3019
+ }
3020
+ impl Hash for Invalid {
3021
+ fn hash < H : Hasher > ( & self , state : & mut H ) {
3022
+ self . count . hash ( state) ;
3023
+ }
3024
+ }
3025
+ impl Hash for InvalidRef {
3026
+ fn hash < H : Hasher > ( & self , state : & mut H ) {
3027
+ self . count . hash ( state) ;
3028
+ }
3029
+ }
3030
+ let mut set: HashSet < Invalid > = HashSet :: new ( ) ;
3031
+ let key = InvalidRef { count : 1 , other : 1 } ;
3032
+ let value = Invalid { count : 1 , other : 2 } ;
3033
+ if make_hash ( set. hasher ( ) , & key) == make_hash ( set. hasher ( ) , & value) {
3034
+ set. get_or_insert_with ( & key, |_| value) ;
3035
+ }
3036
+ }
2961
3037
}
0 commit comments