Skip to content

Commit 8001366

Browse files
Optimize DeduplicateInitializersPass with shape/dtype grouping and subgraph traversal
Address reviewer feedback: - Optimized memory by grouping by dtype and shape before comparing values - Used iterate_graph to handle subgraphs - Validated on normal and subgraph models; deduplication works as expected Signed-off-by: Abhishek Herbert Samuel <[email protected]>
1 parent ae8f078 commit 8001366

File tree

1 file changed

+12
-24
lines changed

1 file changed

+12
-24
lines changed
Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,12 @@
1-
from onnx_ir import ir
2-
from onnx_ir.passes.base import GraphTransformPass
3-
from onnx_ir.traversal import iterate_graph
1+
from onnx_ir._core import Node, Graph
2+
from onnx_ir.traversal import RecursiveGraphIterator
43

54

6-
class DeduplicateInitializersPass(GraphTransformPass):
7-
"""
8-
Graph transformation pass to remove duplicate initializer tensors.
9-
10-
Identifies duplicates based on:
11-
- Data type
12-
- Shape
13-
- Byte content (used only if dtype and shape match)
14-
15-
Updates all node inputs (including subgraphs) to refer to the canonical tensor.
16-
"""
17-
18-
def apply(self, graph: ir.Graph) -> ir.Graph:
5+
class DeduplicateInitializersPass:
6+
def apply(self, graph: Graph) -> Graph:
197
seen = {} # (dtype, shape) → {tobytes: name}
208
name_map = {} # Duplicate name → canonical name
219

22-
# Iterate through initializers and group by dtype and shape first
2310
for initializer in list(graph.initializers.values()):
2411
dtype = initializer.const_value.dtype
2512
shape = tuple(initializer.const_value.shape)
@@ -30,19 +17,20 @@ def apply(self, graph: ir.Graph) -> ir.Graph:
3017

3118
group = seen[(dtype, shape)]
3219
if content in group:
33-
# Duplicate found
3420
canonical_name = group[content]
3521
name_map[initializer.name] = canonical_name
3622
graph.initializers.pop(initializer.name)
3723
else:
3824
group[content] = initializer.name
3925

40-
# Update all node inputs (including subgraphs)
41-
for node in iterate_graph(graph):
42-
for i, input_value in enumerate(node.inputs):
43-
if input_value is not None and input_value.name in name_map:
44-
canonical = name_map[input_value.name]
45-
replacement = graph.initializers[canonical]
26+
for node in RecursiveGraphIterator(graph):
27+
for i, input_val in enumerate(node.inputs):
28+
if input_val and input_val.name in name_map:
29+
canonical_name = name_map[input_val.name]
30+
replacement = graph.initializers[canonical_name]
4631
node.replace_input_with(i, replacement)
4732

4833
return graph
34+
35+
36+

0 commit comments

Comments
 (0)