Skip to content

Commit 914de39

Browse files
committed
Impl Hash for Id
1 parent f4832a1 commit 914de39

File tree

6 files changed

+91
-11
lines changed

6 files changed

+91
-11
lines changed

src/cache/location_hasher.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
use core::{fmt::Display, ops::BitXor, panic::Location};
22

33
#[derive(Default)]
4-
pub struct LocationHasher {
4+
pub struct FxHasher {
55
hash: u64,
66
}
77

88
const K: u64 = 0x517cc1b727220a95;
99

10-
impl core::hash::Hasher for LocationHasher {
10+
impl core::hash::Hasher for FxHasher {
1111
#[inline]
1212
fn finish(&self) -> u64 {
1313
self.hash
1414
}
1515

1616
#[inline]
1717
fn write(&mut self, _bytes: &[u8]) {
18-
unimplemented!("LocationHasher only hashes u64, (u32 and usize as u64 cast).")
18+
unimplemented!("(this) FxHasher only hashes u64, (u32 and usize as u64 cast).")
1919
}
2020

2121
#[inline]

src/cache/owned_cache.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@ use std::collections::HashMap;
33

44
use std::rc::Rc;
55

6-
use crate::{flag::AllocFlag, Alloc, Device, HashLocation, LocationHasher, ShallowCopy, Shape};
6+
use crate::{flag::AllocFlag, Alloc, Device, HashLocation, FxHasher, ShallowCopy, Shape};
77

88
#[derive(Debug, Clone)]
99
pub struct Cache {
1010
pub nodes:
11-
HashMap<HashLocation<'static>, Rc<dyn core::any::Any>, BuildHasherDefault<LocationHasher>>,
11+
HashMap<HashLocation<'static>, Rc<dyn core::any::Any>, BuildHasherDefault<FxHasher>>,
1212
}
1313

1414
impl Default for Cache {

src/devices/opencl/unified.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use crate::{AllocFlag, DeviceError};
66

77
use super::CLPtr;
88
use crate::{
9-
Base, Buffer, CachedCPU, CachedModule, Device, HashLocation, LocationHasher, OnDropBuffer,
9+
Base, Buffer, CachedCPU, CachedModule, Device, HashLocation, FxHasher, OnDropBuffer,
1010
OpenCL, Shape, UnifiedMemChain, CPU,
1111
};
1212
use min_cl::api::{create_buffer, MemFlags};
@@ -62,7 +62,7 @@ pub unsafe fn to_cached_unified<OclMods, CpuMods, T, S>(
6262
cache: &mut HashMap<
6363
HashLocation<'static>,
6464
Rc<dyn core::any::Any>,
65-
BuildHasherDefault<LocationHasher>,
65+
BuildHasherDefault<FxHasher>,
6666
>,
6767
location: HashLocation<'static>,
6868
) -> crate::Result<*mut c_void>
@@ -126,7 +126,7 @@ pub fn construct_buffer<'a, OclMods: OnDropBuffer, CpuMods: OnDropBuffer, T: 'st
126126
cache: &mut HashMap<
127127
HashLocation<'static>,
128128
Rc<dyn core::any::Any>,
129-
BuildHasherDefault<LocationHasher>,
129+
BuildHasherDefault<FxHasher>,
130130
>,
131131
location: HashLocation<'static>,
132132
) -> crate::Result<Buffer<'a, T, OpenCL<OclMods>, S>> {

src/id.rs

+7
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ pub struct Id {
2020
pub len: usize,
2121
}
2222

23+
impl core::hash::Hash for Id {
24+
#[inline]
25+
fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
26+
self.id.hash(state);
27+
}
28+
}
29+
2330
impl Deref for Id {
2431
type Target = u64;
2532

src/modules/autograd/tape.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use core::{any::Any, fmt::Debug, hash::BuildHasherDefault, panic::Location};
22
use std::collections::{HashMap, HashSet};
33

44
use crate::{
5-
Alloc, Buffer, Buffers, HasId, HashLocation, LazyGraph, LocationHasher, Parents, Shape,
5+
Alloc, Buffer, Buffers, HasId, HashLocation, LazyGraph, FxHasher, Parents, Shape,
66
TapeActions, UpdateArgs, WriteBuf,
77
};
88

@@ -16,7 +16,7 @@ pub struct Tape {
1616
// Caches gradients for each [`Buffer`]'s id ([`Ident`]).
1717
// pub grads: Gradients,
1818
grad_fns: Vec<GradFn>,
19-
grad_fns_loc: HashMap<HashLocation<'static>, GradFn, BuildHasherDefault<LocationHasher>>,
19+
grad_fns_loc: HashMap<HashLocation<'static>, GradFn, BuildHasherDefault<FxHasher>>,
2020
grad_fn_order: Vec<HashLocation<'static>>,
2121

2222
unconsumed_locations: HashSet<HashLocation<'static>>,

src/parents.rs

+74-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
1-
use crate::{HasId, Id, UpdateArg};
1+
use core::hash::Hasher;
2+
3+
use crate::{FxHasher, HasId, Id, UpdateArg};
24

35
pub trait Parents<const N: usize>: AllParents {
46
fn ids(&self) -> [Id; N];
57
fn maybe_ids(&self) -> [Option<Id>; N];
8+
9+
#[inline]
10+
fn hash(&self) -> u64 {
11+
let mut hasher = FxHasher::default();
12+
core::hash::Hash::hash(&self.ids(), &mut hasher);
13+
hasher.finish()
14+
}
615
}
716

817
impl Parents<0> for () {
@@ -103,3 +112,67 @@ impl<T: HasId + Copy, const N: usize> Parents<N> for [T; N] {
103112
impl<T: HasId + Copy, const N: usize> AllParents for [T; N] {}
104113

105114
pub trait AllParents {}
115+
116+
#[cfg(test)]
117+
mod tests {
118+
119+
#[cfg(feature = "std")]
120+
#[ignore = "slow"]
121+
#[test]
122+
fn test_collisions() {
123+
use std::collections::HashSet;
124+
use crate::{Id, Parents};
125+
126+
let handle = std::thread::spawn(|| {
127+
let mut hashes = HashSet::new();
128+
for i in 20000..30000u16 {
129+
for j in 20000..30000 {
130+
let i = Id {
131+
id: i as u64,
132+
len: 0,
133+
};
134+
let j = Id {
135+
id: j,
136+
len: 0,
137+
};
138+
let parents = (i, j);
139+
let hash = parents.hash();
140+
if hashes.contains(&(hash)) {
141+
panic!("collision {}, {}, hash: {hash}", i.id, j.id,);
142+
}
143+
hashes.insert(hash);
144+
}
145+
if i % 1000 == 0 {
146+
println!("i: {}", i);
147+
}
148+
}
149+
hashes
150+
});
151+
let mut hashes = HashSet::new();
152+
153+
for i in 10000..20000 {
154+
for j in 10000..20000 {
155+
let i = Id {
156+
id: i,
157+
len: 0,
158+
};
159+
let j = Id {
160+
id: j,
161+
len: 0,
162+
};
163+
let parents = (i, j);
164+
let hash = parents.hash();
165+
if hashes.contains(&(hash)) {
166+
panic!("collision");
167+
}
168+
hashes.insert(hash);
169+
}
170+
if i % 1000 == 0 {
171+
println!("i: {}", i);
172+
}
173+
}
174+
175+
let other_hashes = handle.join().unwrap();
176+
assert_eq!(hashes.intersection(&other_hashes).count(), 0);
177+
}
178+
}

0 commit comments

Comments
 (0)