Skip to content

Commit 6b3e0b7

Browse files
Finalize DeduplicateInitializersPass implementation and test coverage
Signed-off-by: Abhishek Herbert Samuel <[email protected]>
1 parent 653a03e commit 6b3e0b7

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

src/onnx_ir/passes/common/deduplicate_initializers.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,17 @@ class DeduplicateInitializersPass(onnx_ir.passes.InPlacePass):
2727

2828
def call(self, model: onnx_ir.Model) -> onnx_ir.passes.PassResult:
2929
graph = model.graph
30-
seen = {} # (dtype, shape) → {hash: [(name, tobytes)]}
30+
seen: dict[tuple[str, tuple[int, ...]], dict[int, list[tuple[str, bytes]]]] = {}
31+
3132
name_map = {} # Duplicate name → canonical name
3233

3334
for initializer in list(graph.initializers.values()):
34-
dtype = initializer.const_value.dtype
35-
shape = tuple(initializer.const_value.shape)
36-
content = initializer.const_value.tobytes()
35+
const_val = initializer.const_value
36+
if const_val is None:
37+
continue # Skip if initializer has no constant value
38+
dtype = const_val.dtype
39+
shape = tuple(const_val.shape)
40+
content = const_val.tobytes()
3741
content_hash = hashlib.sha256(content).hexdigest()
3842

3943
key = (dtype, shape)
@@ -46,9 +50,11 @@ def call(self, model: onnx_ir.Model) -> onnx_ir.passes.PassResult:
4650
for existing_name, existing_bytes in group[content_hash]:
4751
if existing_bytes == content:
4852
name_map[initializer.name] = existing_name
49-
graph.initializers.pop(initializer.name)
50-
break
53+
if initializer.name is not None:
54+
graph.initializers.pop(initializer.name)
55+
break # only break when deduplication is successful
5156
else:
57+
# no matching content found: append as a new entry
5258
group[content_hash].append((initializer.name, content))
5359
else:
5460
group[content_hash] = [(initializer.name, content)]

src/onnx_ir/passes/common/deduplicate_initializers_test.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
"""Unit tests for the DeduplicateInitializersPass."""
44

55
import unittest
6+
67
import onnx
7-
import numpy as np
88

99
import onnx_ir as ir
1010
import onnx_ir.passes.common.deduplicate_initializers as dedup_pass
@@ -32,7 +32,6 @@ def test_deduplicates_identical_initializers(self):
3232
add_node = new_model.graph.node[0]
3333
self.assertEqual(add_node.input[0], add_node.input[1])
3434

35-
3635
def test_initializers_with_different_shapes_not_deduplicated(self):
3736
model = onnx.parser.parse_model(
3837
"""
@@ -99,6 +98,6 @@ def test_unique_values_not_deduplicated(self):
9998
new_model = self.apply_pass(model)
10099
self.assertEqual(len(new_model.graph.initializer), 2)
101100

102-
101+
103102
if __name__ == "__main__":
104103
unittest.main()

0 commit comments

Comments
 (0)