Skip to content

Commit f99fa0c

Browse files
Address feedback: optimize tensor fingerprinting and traverse subgraphs
1 parent 191ddb4 commit f99fa0c

File tree

1 file changed

+25
-26
lines changed

1 file changed

+25
-26
lines changed
Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,48 @@
11
from onnx_ir import ir
22
from onnx_ir.passes.base import GraphTransformPass
3+
from onnx_ir.traversal import iterate_graph
34

45

56
class DeduplicateInitializersPass(GraphTransformPass):
67
"""
7-
This pass removes duplicate initializer tensors from the graph.
8+
Graph transformation pass to remove duplicate initializer tensors.
89
9-
It identifies duplicates based on a content-based fingerprint consisting of:
10-
- Tensor byte content (`tobytes()`)
11-
- Data type (`dtype`)
10+
Identifies duplicates based on:
11+
- Data type
1212
- Shape
13+
- Byte content (used only if dtype and shape match)
1314
14-
All duplicates are replaced with the first (canonical) occurrence, and node
15-
inputs referring to redundant initializers are updated accordingly.
15+
Updates all node inputs (including subgraphs) to refer to the canonical tensor.
1616
"""
1717

1818
def apply(self, graph: ir.Graph) -> ir.Graph:
19-
seen = {} # Maps (tobytes, dtype, shape) -> canonical initializer name
20-
name_map = {} # Maps duplicate initializer name -> canonical name
19+
seen = {} # (dtype, shape) → {tobytes: name}
20+
name_map = {} # Duplicate name canonical name
2121

22-
# Iterate over all initializers in the graph
22+
# Iterate through initializers and group by dtype and shape first
2323
for initializer in list(graph.initializers.values()):
24-
key = (
25-
initializer.const_value.tobytes(), # Content fingerprint
26-
initializer.const_value.dtype, # Data type
27-
tuple(initializer.const_value.shape), # Shape tuple
28-
)
29-
30-
if key in seen:
31-
# Found a duplicate: store the name mapping and remove it from graph
32-
canonical_name = seen[key]
24+
dtype = initializer.const_value.dtype
25+
shape = tuple(initializer.const_value.shape)
26+
content = initializer.const_value.tobytes()
27+
28+
if (dtype, shape) not in seen:
29+
seen[(dtype, shape)] = {}
30+
31+
group = seen[(dtype, shape)]
32+
if content in group:
33+
# Duplicate found
34+
canonical_name = group[content]
3335
name_map[initializer.name] = canonical_name
3436
graph.initializers.pop(initializer.name)
3537
else:
36-
# First time seeing this tensor → keep it
37-
seen[key] = initializer.name
38+
group[content] = initializer.name
3839

39-
# Update node inputs to use the canonical initializer names
40-
for node in graph:
40+
# Update all node inputs (including subgraphs)
41+
for node in iterate_graph(graph):
4142
for i, input_value in enumerate(node.inputs):
4243
if input_value is not None and input_value.name in name_map:
43-
# Replace input with the deduplicated initializer
44-
new_name = name_map[input_value.name]
45-
replacement = graph.initializers[new_name]
44+
canonical = name_map[input_value.name]
45+
replacement = graph.initializers[canonical]
4646
node.replace_input_with(i, replacement)
4747

4848
return graph
49-

0 commit comments

Comments
 (0)