|
1 | 1 | #include "kernel/graph.h" |
2 | 2 |
|
| 3 | +namespace refactor { |
| 4 | + struct DataKey { |
| 5 | + Arc<hardware::Device> dev; |
| 6 | + Arc<kernel::Blob> blob; |
| 7 | + bool operator==(const DataKey &) const = default;// since C++20 |
| 8 | + }; |
| 9 | +}// namespace refactor |
| 10 | + |
| 11 | +template<> |
| 12 | +struct std::hash<refactor::DataKey> { |
| 13 | + std::size_t operator()(refactor::DataKey const &s) const noexcept { |
| 14 | + auto hd = std::hash<decltype(s.dev)>()(s.dev), |
| 15 | + hb = std::hash<decltype(s.blob)>()(s.blob); |
| 16 | + return hd ^ (hb << 1); |
| 17 | + } |
| 18 | +}; |
| 19 | + |
3 | 20 | namespace refactor::kernel { |
4 | 21 |
|
5 | 22 | Graph::Graph(graph_topo::GraphTopo topology, |
@@ -31,13 +48,19 @@ namespace refactor::kernel { |
31 | 48 | _internal.edges, |
32 | 49 | 32); |
33 | 50 |
|
| 51 | + static std::unordered_map<DataKey, Arc<hardware::Device::Blob>> CACHE; |
| 52 | + |
34 | 53 | for (auto i : range0_(edges_.size())) { |
35 | 54 | auto const &edge = _internal.edges[i]; |
36 | 55 | edges_[i].name = edge.name; |
37 | 56 | if (edge.data) { |
38 | | - auto blob = device->malloc(edge.size); |
39 | | - blob->copyFromHost(edge.data->get<void>()); |
40 | | - edges_[i].blob = std::move(blob); |
| 57 | + auto it = CACHE.find({device, edge.data}); |
| 58 | + if (it == CACHE.end()) { |
| 59 | + auto blob = device->malloc(edge.size); |
| 60 | + blob->copyFromHost(edge.data->get<void>()); |
| 61 | + std::tie(it, std::ignore) = CACHE.emplace(DataKey{device, edge.data}, std::move(blob)); |
| 62 | + } |
| 63 | + edges_[i].blob = it->second; |
41 | 64 | } else if (edges_[i].stackOffset == SIZE_MAX - 1) { |
42 | 65 | edges_[i].blob = device->malloc(edge.size); |
43 | 66 | } |
|
0 commit comments