|
1 | 1 | from onnx_ir import ir
|
2 | 2 | from onnx_ir.passes.base import GraphTransformPass
|
| 3 | +from onnx_ir.traversal import iterate_graph |
3 | 4 |
|
4 | 5 |
|
5 | 6 | class DeduplicateInitializersPass(GraphTransformPass):
|
6 | 7 | """
|
7 |
| - This pass removes duplicate initializer tensors from the graph. |
| 8 | + Graph transformation pass to remove duplicate initializer tensors. |
8 | 9 |
|
9 |
| - It identifies duplicates based on a content-based fingerprint consisting of: |
10 |
| - - Tensor byte content (`tobytes()`) |
11 |
| - - Data type (`dtype`) |
| 10 | + Identifies duplicates based on: |
| 11 | + - Data type |
12 | 12 | - Shape
|
| 13 | + - Byte content (used only if dtype and shape match) |
13 | 14 |
|
14 |
| - All duplicates are replaced with the first (canonical) occurrence, and node |
15 |
| - inputs referring to redundant initializers are updated accordingly. |
| 15 | + Updates all node inputs (including subgraphs) to refer to the canonical tensor. |
16 | 16 | """
|
17 | 17 |
|
18 | 18 | def apply(self, graph: ir.Graph) -> ir.Graph:
|
19 |
| - seen = {} # Maps (tobytes, dtype, shape) -> canonical initializer name |
20 |
| - name_map = {} # Maps duplicate initializer name -> canonical name |
| 19 | + seen = {} # (dtype, shape) → {tobytes: name} |
| 20 | + name_map = {} # Duplicate name → canonical name |
21 | 21 |
|
22 |
| - # Iterate over all initializers in the graph |
| 22 | + # Iterate through initializers and group by dtype and shape first |
23 | 23 | for initializer in list(graph.initializers.values()):
|
24 |
| - key = ( |
25 |
| - initializer.const_value.tobytes(), # Content fingerprint |
26 |
| - initializer.const_value.dtype, # Data type |
27 |
| - tuple(initializer.const_value.shape), # Shape tuple |
28 |
| - ) |
29 |
| - |
30 |
| - if key in seen: |
31 |
| - # Found a duplicate: store the name mapping and remove it from graph |
32 |
| - canonical_name = seen[key] |
| 24 | + dtype = initializer.const_value.dtype |
| 25 | + shape = tuple(initializer.const_value.shape) |
| 26 | + content = initializer.const_value.tobytes() |
| 27 | + |
| 28 | + if (dtype, shape) not in seen: |
| 29 | + seen[(dtype, shape)] = {} |
| 30 | + |
| 31 | + group = seen[(dtype, shape)] |
| 32 | + if content in group: |
| 33 | + # Duplicate found |
| 34 | + canonical_name = group[content] |
33 | 35 | name_map[initializer.name] = canonical_name
|
34 | 36 | graph.initializers.pop(initializer.name)
|
35 | 37 | else:
|
36 |
| - # First time seeing this tensor → keep it |
37 |
| - seen[key] = initializer.name |
| 38 | + group[content] = initializer.name |
38 | 39 |
|
39 |
| - # Update node inputs to use the canonical initializer names |
40 |
| - for node in graph: |
| 40 | + # Update all node inputs (including subgraphs) |
| 41 | + for node in iterate_graph(graph): |
41 | 42 | for i, input_value in enumerate(node.inputs):
|
42 | 43 | if input_value is not None and input_value.name in name_map:
|
43 |
| - # Replace input with the deduplicated initializer |
44 |
| - new_name = name_map[input_value.name] |
45 |
| - replacement = graph.initializers[new_name] |
| 44 | + canonical = name_map[input_value.name] |
| 45 | + replacement = graph.initializers[canonical] |
46 | 46 | node.replace_input_with(i, replacement)
|
47 | 47 |
|
48 | 48 | return graph
|
49 |
| - |
0 commit comments