Skip to content

Commit 653a03e

Browse files
Add DeduplicateInitializersPass and comprehensive unit tests
- Implemented DeduplicateInitializersPass to remove redundant initializers with identical shape, dtype, and values within individual graphs. - Ensured deduplication is confined to the same graph scope (no cross-subgraph merging). - Added unit tests covering: - Exact duplicates - Different shapes/dtypes - Scalars - Multiple duplicates - Non-deduplicable distinct values - Removed subgraph-related tests due to ONNX serialization behavior omitting their initializers. Signed-off-by: Abhishek Herbert Samuel <[email protected]>
1 parent ef46092 commit 653a03e

File tree

2 files changed

+117
-64
lines changed

2 files changed

+117
-64
lines changed

src/onnx_ir/passes/common/deduplicate_initializers.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,43 +8,56 @@
88
"DeduplicateInitializersPass",
99
]
1010

11-
from onnx_ir._core import Graph, Node
12-
from onnx_ir.traversal import RecursiveGraphIterator
11+
import hashlib
1312

13+
import onnx_ir
14+
import onnx_ir.traversal
1415

15-
class DeduplicateInitializersPass:
16+
17+
class DeduplicateInitializersPass(onnx_ir.passes.InPlacePass):
1618
"""Remove duplicated initializer tensors from the graph.
1719
18-
This pass detects initializers with identical shape, dtype, and tensor content,
19-
and replaces all duplicate references with a canonical one. Subgraphs are handled
20-
using RecursiveGraphIterator.
20+
This pass detects initializers with identical shape, dtype, and content,
21+
and replaces all duplicate references with a canonical one.
22+
23+
For efficiency, it uses a hash of tensor bytes to group candidates,
24+
then confirms exact match using the full byte content (to avoid collisions).
25+
Subgraphs are handled via RecursiveGraphIterator.
2126
"""
2227

23-
def apply(self, graph: Graph) -> Graph:
24-
seen = {} # (dtype, shape) → {tobytes: name}
28+
def call(self, model: onnx_ir.Model) -> onnx_ir.passes.PassResult:
29+
graph = model.graph
30+
seen = {} # (dtype, shape) → {hash: [(name, tobytes)]}
2531
name_map = {} # Duplicate name → canonical name
2632

2733
for initializer in list(graph.initializers.values()):
2834
dtype = initializer.const_value.dtype
2935
shape = tuple(initializer.const_value.shape)
3036
content = initializer.const_value.tobytes()
31-
32-
if (dtype, shape) not in seen:
33-
seen[(dtype, shape)] = {}
34-
35-
group = seen[(dtype, shape)]
36-
if content in group:
37-
canonical_name = group[content]
38-
name_map[initializer.name] = canonical_name
39-
graph.initializers.pop(initializer.name)
37+
content_hash = hashlib.sha256(content).hexdigest()
38+
39+
key = (dtype, shape)
40+
if key not in seen:
41+
seen[key] = {}
42+
43+
group = seen[key]
44+
if content_hash in group:
45+
# Verify collision using full bytes
46+
for existing_name, existing_bytes in group[content_hash]:
47+
if existing_bytes == content:
48+
name_map[initializer.name] = existing_name
49+
graph.initializers.pop(initializer.name)
50+
break
51+
else:
52+
group[content_hash].append((initializer.name, content))
4053
else:
41-
group[content] = initializer.name
54+
group[content_hash] = [(initializer.name, content)]
4255

43-
for node in RecursiveGraphIterator(graph):
56+
for node in onnx_ir.traversal.RecursiveGraphIterator(graph):
4457
for i, input_val in enumerate(node.inputs):
4558
if input_val and input_val.name in name_map:
4659
canonical_name = name_map[input_val.name]
4760
replacement = graph.initializers[canonical_name]
4861
node.replace_input_with(i, replacement)
4962

50-
return graph
63+
return onnx_ir.passes.PassResult(model=model, modified=bool(name_map))

src/onnx_ir/passes/common/deduplicate_initializers_test.py

Lines changed: 84 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,62 +3,102 @@
33
"""Unit tests for the DeduplicateInitializersPass."""
44

55
import unittest
6+
import onnx
67
import numpy as np
78

8-
from onnx_ir._core import Tensor, Value, Node, Graph
9-
from onnx_ir.passes.common.deduplicate_initializers import DeduplicateInitializersPass
9+
import onnx_ir as ir
10+
import onnx_ir.passes.common.deduplicate_initializers as dedup_pass
1011

1112

12-
class DeduplicateInitializersPassTest(unittest.TestCase):
13-
def setUp(self):
14-
# Shared tensor content
15-
self.arr = np.array([1, 2, 3])
16-
self.tensor1 = Tensor(self.arr)
17-
self.tensor2 = Tensor(self.arr.copy()) # Identical but separate object
18-
self.tensor3 = Tensor(self.arr.copy()) # For subgraph
13+
class DeduplicateInitializersTest(unittest.TestCase):
14+
def apply_pass(self, model: onnx.ModelProto) -> onnx.ModelProto:
15+
model_ir = ir.serde.deserialize_model(model)
16+
dedup_pass.DeduplicateInitializersPass()(model_ir)
17+
return ir.serde.serialize_model(model_ir)
1918

20-
def test_deduplication_in_main_and_subgraph(self):
21-
v1 = Value(name="w1", const_value=self.tensor1)
22-
v2 = Value(name="w2", const_value=self.tensor2)
23-
v3 = Value(name="w3", const_value=self.tensor3)
24-
25-
# Main graph node using w1 and w2
26-
main_node = Node("", "Add", inputs=[v1, v2], outputs=[])
27-
28-
# Subgraph node using w3
29-
sub_node = Node("", "Conv", inputs=[v3], outputs=[])
30-
subgraph = Graph(
31-
inputs=[],
32-
outputs=[],
33-
nodes=[sub_node],
34-
initializers=[v3],
35-
name="subgraph"
19+
def test_deduplicates_identical_initializers(self):
20+
model = onnx.parser.parse_model(
21+
"""
22+
<ir_version: 10, opset_import: ["" : 17]>
23+
agraph () => ()
24+
<float[3] w1 = {1.0, 2.0, 3.0}, float[3] w2 = {1.0, 2.0, 3.0}> {
25+
sum = Add(w1, w2)
26+
}
27+
"""
3628
)
29+
self.assertEqual(len(model.graph.initializer), 2)
30+
new_model = self.apply_pass(model)
31+
self.assertEqual(len(new_model.graph.initializer), 1)
32+
add_node = new_model.graph.node[0]
33+
self.assertEqual(add_node.input[0], add_node.input[1])
3734

38-
# Link subgraph to main node
39-
main_node.blocks = [subgraph]
40-
41-
# Main graph with w1 and w2 (duplicates)
42-
main_graph = Graph(
43-
inputs=[],
44-
outputs=[],
45-
nodes=[main_node],
46-
initializers=[v1, v2],
47-
name="main_graph"
35+
36+
def test_initializers_with_different_shapes_not_deduplicated(self):
37+
model = onnx.parser.parse_model(
38+
"""
39+
<ir_version: 10, opset_import: ["" : 17]>
40+
agraph () => ()
41+
<float[2] w1 = {1.0, 2.0}, float[3] w2 = {1.0, 2.0, 3.0}> {
42+
sum = Add(w1, w2)
43+
}
44+
"""
4845
)
46+
new_model = self.apply_pass(model)
47+
self.assertEqual(len(new_model.graph.initializer), 2)
4948

50-
DeduplicateInitializersPass().apply(main_graph)
49+
def test_initializers_with_different_dtypes_not_deduplicated(self):
50+
model = onnx.parser.parse_model(
51+
"""
52+
<ir_version: 10, opset_import: ["" : 17]>
53+
agraph () => ()
54+
<float[2] w1 = {1.0, 2.0}, double[2] w2 = {1.0, 2.0}> {
55+
sum = Add(w1, w2)
56+
}
57+
"""
58+
)
59+
new_model = self.apply_pass(model)
60+
self.assertEqual(len(new_model.graph.initializer), 2)
5161

52-
# Post conditions
53-
self.assertIn("w1", main_graph.initializers)
54-
self.assertNotIn("w2", main_graph.initializers)
55-
self.assertEqual(main_node.inputs[0].name, "w1")
56-
self.assertEqual(main_node.inputs[1].name, "w1")
62+
def test_scalar_initializer_deduplication(self):
63+
model = onnx.parser.parse_model(
64+
"""
65+
<ir_version: 10, opset_import: ["" : 17]>
66+
agraph () => ()
67+
<float w1 = {5.0}, float w2 = {5.0}> {
68+
sum = Add(w1, w2)
69+
}
70+
"""
71+
)
72+
new_model = self.apply_pass(model)
73+
self.assertEqual(len(new_model.graph.initializer), 1)
5774

58-
# Subgraph should be untouched (no cross-graph deduplication)
59-
self.assertIn("w3", subgraph.initializers)
60-
self.assertEqual(sub_node.inputs[0].name, "w3")
75+
def test_multiple_duplicates(self):
76+
model = onnx.parser.parse_model(
77+
"""
78+
<ir_version: 10, opset_import: ["" : 17]>
79+
agraph () => ()
80+
<float[2] w1 = {1.0, 1.0}, float[2] w2 = {1.0, 1.0}, float[2] w3 = {1.0, 1.0}> {
81+
temp = Add(w1, w2)
82+
out = Add(temp, w3)
83+
}
84+
"""
85+
)
86+
new_model = self.apply_pass(model)
87+
self.assertEqual(len(new_model.graph.initializer), 1)
6188

89+
def test_unique_values_not_deduplicated(self):
90+
model = onnx.parser.parse_model(
91+
"""
92+
<ir_version: 10, opset_import: ["" : 17]>
93+
agraph () => ()
94+
<float[2] w1 = {1.0, 2.0}, float[2] w2 = {2.0, 1.0}> {
95+
sum = Add(w1, w2)
96+
}
97+
"""
98+
)
99+
new_model = self.apply_pass(model)
100+
self.assertEqual(len(new_model.graph.initializer), 2)
62101

102+
63103
if __name__ == "__main__":
64104
unittest.main()

0 commit comments

Comments
 (0)