@@ -27,13 +27,17 @@ class DeduplicateInitializersPass(onnx_ir.passes.InPlacePass):
27
27
28
28
def call (self , model : onnx_ir .Model ) -> onnx_ir .passes .PassResult :
29
29
graph = model .graph
30
- seen = {} # (dtype, shape) → {hash: [(name, tobytes)]}
30
+ seen : dict [tuple [str , tuple [int , ...]], dict [int , list [tuple [str , bytes ]]]] = {}
31
+
31
32
name_map = {} # Duplicate name → canonical name
32
33
33
34
for initializer in list (graph .initializers .values ()):
34
- dtype = initializer .const_value .dtype
35
- shape = tuple (initializer .const_value .shape )
36
- content = initializer .const_value .tobytes ()
35
+ const_val = initializer .const_value
36
+ if const_val is None :
37
+ continue # Skip if initializer has no constant value
38
+ dtype = const_val .dtype
39
+ shape = tuple (const_val .shape )
40
+ content = const_val .tobytes ()
37
41
content_hash = hashlib .sha256 (content ).hexdigest ()
38
42
39
43
key = (dtype , shape )
@@ -46,9 +50,11 @@ def call(self, model: onnx_ir.Model) -> onnx_ir.passes.PassResult:
46
50
for existing_name , existing_bytes in group [content_hash ]:
47
51
if existing_bytes == content :
48
52
name_map [initializer .name ] = existing_name
49
- graph .initializers .pop (initializer .name )
50
- break
53
+ if initializer .name is not None :
54
+ graph .initializers .pop (initializer .name )
55
+ break # only break when deduplication is successful
51
56
else :
57
+ # no matching content found: append as a new entry
52
58
group [content_hash ].append ((initializer .name , content ))
53
59
else :
54
60
group [content_hash ] = [(initializer .name , content )]
0 commit comments