Skip to content

Commit 23ca8f3

Browse files
author
houzhenggang
committed
hashtable save and load
1 parent a20755b commit 23ca8f3

1 file changed

Lines changed: 8 additions & 11 deletions

File tree

fbgemm_gpu/src/dram_kv_embedding_cache/SynchronizedShardedMap.h

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,10 @@ class SynchronizedShardedMap {
7575
out.close();
7676

7777
// save every mempool
78-
for (std::size_t i = 0; i < num_shards; ++i) {
78+
for (std::size_t shard_id = 0; shard_id < getNumShards(); ++shard_id) {
7979
std::string pool_filename = filename + ".pool." + std::to_string(i);
80-
mempools_[i]->serialize(pool_filename);
80+
auto wlock = shards_[shard_id].wlock();
81+
mempools_[shard_id]->serialize(pool_filename);
8182
}
8283
}
8384

@@ -95,18 +96,14 @@ class SynchronizedShardedMap {
9596
throw std::runtime_error("Shard count mismatch between file and map");
9697
}
9798

98-
// first deserialize mempool
99-
for (std::size_t i = 0; i < num_shards; ++i) {
100-
std::string pool_filename = filename + ".pool." + std::to_string(i);
101-
mempools_[i]->deserialize(pool_filename);
102-
}
103-
104-
// load map from mempool
10599
for (std::size_t shard_id = 0; shard_id < getNumShards(); ++shard_id) {
100+
std::string pool_filename = filename + ".pool." + std::to_string(i);
106101
auto wlock = shards_[shard_id].wlock();
107-
auto* mempool = mempools_[shard_id].get();
102+
// first deserialize mempool
103+
mempools_[shard_id]->deserialize(pool_filename);
104+
// load map from mempool
108105
wlock->clear();
109-
mempool->for_each_block([&wlock](void* block) {
106+
mempools_[shard_id]->for_each_block([&wlock](void* block) {
110107
auto key = FixedBlockPool::get_key(block);
111108
wlock->emplace(key, reinterpret_cast<V>(block));
112109
});

0 commit comments

Comments
 (0)