Skip to content

Commit 41036a0

Browse files
committed
refactor(kernel): 添加一个显存的全局缓存,避免反复的 h2d 拷贝
Signed-off-by: YdrMaster <[email protected]>
1 parent 20788a3 commit 41036a0

File tree

1 file changed

+26
-3
lines changed

1 file changed

+26
-3
lines changed

src/04kernel/src/graph.cc

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,22 @@
11
#include "kernel/graph.h"
22

3+
namespace refactor {
4+
struct DataKey {
5+
Arc<hardware::Device> dev;
6+
Arc<kernel::Blob> blob;
7+
bool operator==(const DataKey &) const = default;// since C++20
8+
};
9+
}// namespace refactor
10+
11+
template<>
12+
struct std::hash<refactor::DataKey> {
13+
std::size_t operator()(refactor::DataKey const &s) const noexcept {
14+
auto hd = std::hash<decltype(s.dev)>()(s.dev),
15+
hb = std::hash<decltype(s.blob)>()(s.blob);
16+
return hd ^ (hb << 1);
17+
}
18+
};
19+
320
namespace refactor::kernel {
421

522
Graph::Graph(graph_topo::GraphTopo topology,
@@ -31,13 +48,19 @@ namespace refactor::kernel {
3148
_internal.edges,
3249
32);
3350

51+
static std::unordered_map<DataKey, Arc<hardware::Device::Blob>> CACHE;
52+
3453
for (auto i : range0_(edges_.size())) {
3554
auto const &edge = _internal.edges[i];
3655
edges_[i].name = edge.name;
3756
if (edge.data) {
38-
auto blob = device->malloc(edge.size);
39-
blob->copyFromHost(edge.data->get<void>());
40-
edges_[i].blob = std::move(blob);
57+
auto it = CACHE.find({device, edge.data});
58+
if (it == CACHE.end()) {
59+
auto blob = device->malloc(edge.size);
60+
blob->copyFromHost(edge.data->get<void>());
61+
std::tie(it, std::ignore) = CACHE.emplace(DataKey{device, edge.data}, std::move(blob));
62+
}
63+
edges_[i].blob = it->second;
4164
} else if (edges_[i].stackOffset == SIZE_MAX - 1) {
4265
edges_[i].blob = device->malloc(edge.size);
4366
}

0 commit comments

Comments
 (0)