Skip to content

Commit ef46092

Browse files
Add DeduplicateInitializersPass and test covering graph and subgraph
Signed-off-by: Abhishek Herbert Samuel <[email protected]>
1 parent 8001366 commit ef46092

File tree

2 files changed

+82
-4
lines changed

2 files changed

+82
-4
lines changed

src/onnx_ir/passes/common/deduplicate_initializers.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,25 @@
1-
from onnx_ir._core import Node, Graph
1+
# Copyright (c) ONNX Project Contributors
2+
# SPDX-License-Identifier: Apache-2.0
3+
"""Pass for removing duplicated initializer tensors from a graph."""
4+
5+
from __future__ import annotations
6+
7+
__all__ = [
8+
"DeduplicateInitializersPass",
9+
]
10+
11+
from onnx_ir._core import Graph, Node
212
from onnx_ir.traversal import RecursiveGraphIterator
313

414

515
class DeduplicateInitializersPass:
16+
"""Remove duplicated initializer tensors from the graph.
17+
18+
This pass detects initializers with identical shape, dtype, and tensor content,
19+
and replaces all duplicate references with a canonical one. Subgraphs are handled
20+
using RecursiveGraphIterator.
21+
"""
22+
623
def apply(self, graph: Graph) -> Graph:
724
seen = {} # (dtype, shape) → {tobytes: name}
825
name_map = {} # Duplicate name → canonical name
@@ -31,6 +48,3 @@ def apply(self, graph: Graph) -> Graph:
3148
node.replace_input_with(i, replacement)
3249

3350
return graph
34-
35-
36-
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright (c) ONNX Project Contributors
2+
# SPDX-License-Identifier: Apache-2.0
3+
"""Unit tests for the DeduplicateInitializersPass."""
4+
5+
import unittest
6+
import numpy as np
7+
8+
from onnx_ir._core import Tensor, Value, Node, Graph
9+
from onnx_ir.passes.common.deduplicate_initializers import DeduplicateInitializersPass
10+
11+
12+
class DeduplicateInitializersPassTest(unittest.TestCase):
13+
def setUp(self):
14+
# Shared tensor content
15+
self.arr = np.array([1, 2, 3])
16+
self.tensor1 = Tensor(self.arr)
17+
self.tensor2 = Tensor(self.arr.copy()) # Identical but separate object
18+
self.tensor3 = Tensor(self.arr.copy()) # For subgraph
19+
20+
def test_deduplication_in_main_and_subgraph(self):
21+
v1 = Value(name="w1", const_value=self.tensor1)
22+
v2 = Value(name="w2", const_value=self.tensor2)
23+
v3 = Value(name="w3", const_value=self.tensor3)
24+
25+
# Main graph node using w1 and w2
26+
main_node = Node("", "Add", inputs=[v1, v2], outputs=[])
27+
28+
# Subgraph node using w3
29+
sub_node = Node("", "Conv", inputs=[v3], outputs=[])
30+
subgraph = Graph(
31+
inputs=[],
32+
outputs=[],
33+
nodes=[sub_node],
34+
initializers=[v3],
35+
name="subgraph"
36+
)
37+
38+
# Link subgraph to main node
39+
main_node.blocks = [subgraph]
40+
41+
# Main graph with w1 and w2 (duplicates)
42+
main_graph = Graph(
43+
inputs=[],
44+
outputs=[],
45+
nodes=[main_node],
46+
initializers=[v1, v2],
47+
name="main_graph"
48+
)
49+
50+
DeduplicateInitializersPass().apply(main_graph)
51+
52+
# Post conditions
53+
self.assertIn("w1", main_graph.initializers)
54+
self.assertNotIn("w2", main_graph.initializers)
55+
self.assertEqual(main_node.inputs[0].name, "w1")
56+
self.assertEqual(main_node.inputs[1].name, "w1")
57+
58+
# Subgraph should be untouched (no cross-graph deduplication)
59+
self.assertIn("w3", subgraph.initializers)
60+
self.assertEqual(sub_node.inputs[0].name, "w3")
61+
62+
63+
if __name__ == "__main__":
64+
unittest.main()

0 commit comments

Comments
 (0)