-
Notifications
You must be signed in to change notification settings - Fork 7
Add deduplication pass for initializer tensors (#66) #67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add deduplication pass for initializer tensors (#66) #67
Conversation
# Iterate over all initializers in the graph | ||
for initializer in list(graph.initializers.values()): | ||
key = ( | ||
initializer.const_value.tobytes(), # Content fingerprint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is memory consuming and thus highly inefficient. Consider comparing the dtype and shape first, and only compare values when you need to
from onnx_ir.passes.base import GraphTransformPass | ||
|
||
|
||
class DeduplicateInitializersPass(GraphTransformPass): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand this may be generated by some AIs. Please ensure the class names etc. are correct, and follow coding style from other files in this directory.
seen[key] = initializer.name | ||
|
||
# Update node inputs to use the canonical initializer names | ||
for node in graph: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you probably need to check nodes from the subgraphs too. You may use the ir.traversal recursive iterator for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @justinchuby,
I’ve addressed your feedback in the latest commit:
Optimized memory usage by grouping by (dtype, shape) before comparing tobytes()
Used iterate_graph(graph) to handle nodes in subgraphs as well
Let me know if any further changes are needed. Thanks again for the thoughtful review!
Signed-off-by: Abhishek Herbert Samuel <[email protected]>
Signed-off-by: Abhishek Herbert Samuel <[email protected]>
f99fa0c
to
ae8f078
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It’s fine to use an AI for contribution. Please ensure however that the code actually works
Thank you for the feedback, Justin. I'll check if it works and then only send it here. |
…bgraph traversal Address reviewer feedback: - Optimized memory by grouping by dtype and shape before comparing values - Used iterate_graph to handle subgraphs - Validated on normal and subgraph models; deduplication works as expected Signed-off-by: Abhishek Herbert Samuel <[email protected]>
Signed-off-by: Abhishek Herbert Samuel <[email protected]>
Hi @justinchuby, I've pushed the finalized implementation and test as separate, signed commits. The following have been addressed: DeduplicateInitializersPass: Added under passes/common, follows repo conventions, uses (dtype, shape) → {tobytes: name} grouping for memory efficiency, and traverses all subgraphs via RecursiveGraphIterator. Test coverage: A dedicated unittest verifies correct deduplication in the main graph and ensures subgraphs remain isolated. Coding standards: Followed the structure and documentation style of other passes (e.g., topological_sort.py). Commit signed: Used -s with a clean message summarizing the functionality. I have also attached a screenshot of the unit test which passed successfully on my local copy of this repository. Please let me know if any final changes are needed. Thanks again for your guidance and mentorship throughout this PR! |
from onnx_ir.traversal import RecursiveGraphIterator | ||
|
||
|
||
class DeduplicateInitializersPass: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class DeduplicateInitializersPass: | |
class DeduplicateInitializersPass(ir.passes.InPlacePass): |
please subclass ir.passes.InPlacePass
. You can use https://github.com/AbhishekHerbertSamuel/ir-py/blob/ef46092b5f10303bb9fe126eef0f5b44585e3b16/src/onnx_ir/passes/common/constant_manipulation.py#L23 as an example.
using RecursiveGraphIterator. | ||
""" | ||
|
||
def apply(self, graph: Graph) -> Graph: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please implement the call method. The first argument should be an ir.Model
. You may use other passes in this directory as examples. Be sure to import modules only: https://google.github.io/styleguide/pyguide.html#224-decision
import unittest | ||
import numpy as np | ||
|
||
from onnx_ir._core import Tensor, Value, Node, Graph |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please import modules only. You may use https://github.com/onnx/ir-py/blob/main/src/onnx_ir/passes/common/unused_removal_test.py as an example
Please feel free to ask questions when you are going through the code base or need help understanding parts of the code. It would be helpful to take a look at other existing passes and usages to ensure they are implemented in a similar style. |
My concern with this pass in particular is that we are using the full bytes in the look up table. This is memory intensive. I wonder if there is a good (efficient) hash method that can be apply to the bytes content, and use the hash value in the look up table. Only when the hash matches do we compare the actual bytes. |
Hi @justinchuby, I’ll update the class to inherit from ir.passes.InPlacePass as suggested and move the main logic into the call method, following the repo’s conventions (like in constant_manipulation.py). Regarding the memory concern: Will push the changes shortly. Please let me know if I missed anything else. Appreciate your guidance! Warm regards, |
- Implemented DeduplicateInitializersPass to remove redundant initializers with identical shape, dtype, and values within individual graphs. - Ensured deduplication is confined to the same graph scope (no cross-subgraph merging). - Added unit tests covering: - Exact duplicates - Different shapes/dtypes - Scalars - Multiple duplicates - Non-deduplicable distinct values - Removed subgraph-related tests due to ONNX serialization behavior omitting their initializers. Signed-off-by: Abhishek Herbert Samuel <[email protected]>
Hi @justinchuby, Tests involving subgraph initializers were removed, as ONNX drops those during serialization, making them unreliable to assert against. Let me know if you'd like a different strategy for subgraph coverage. Thanks again for your guidance throughout! Warm regards, |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #67 +/- ##
==========================================
+ Coverage 73.57% 73.71% +0.14%
==========================================
Files 37 38 +1
Lines 4492 4543 +51
Branches 902 915 +13
==========================================
+ Hits 3305 3349 +44
- Misses 858 861 +3
- Partials 329 333 +4 ☔ View full report in Codecov by Sentry. |
Signed-off-by: Abhishek Herbert Samuel <[email protected]>
a00be10
to
6b3e0b7
Compare
Signed-off-by: Abhishek Herbert Samuel <[email protected]>
Signed-off-by: Abhishek Herbert Samuel <[email protected]>
Hi @justinchuby, tried to make the remaining changes based on the CI workflow's results. While my tests ran locally and are fine (6/6 correct), the codecov bot and workflows were not triggered here automatically as it was the previous time. Is that ok or a sign of an error? With regards, |
break # only break when deduplication is successful | ||
else: | ||
# no matching content found: append as a new entry | ||
group[content_hash].append((initializer.name, content)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You may store the values instead in the hash table, and use const_val.tobytes() for comparison. This way the bytes do not stay in the memory or take up space?
group[content_hash].append((initializer.name, content)) | |
group[content_hash].append(const_val) |
if initializer.name is not None: | ||
graph.initializers.pop(initializer.name) | ||
break # only break when deduplication is successful | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indent?
Thank you! could you also take a look at the lint errors? |
Hi @justinchuby, will work on it and send the corrected code here:) |
Signed-off-by: Abhishek Herbert Samuel <[email protected]>
Signed-off-by: Abhishek Herbert Samuel <[email protected]>
Do let me know if there are any other changes :) |
@@ -27,16 +27,16 @@ class DeduplicateInitializersPass(onnx_ir.passes.InPlacePass): | |||
|
|||
def call(self, model: onnx_ir.Model) -> onnx_ir.passes.PassResult: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you import onnx_ir as ir and use it as such, to stay consistent with the rest of the code base?
Thanks! I will do a more detailed review soon |
Sure @justinchuby, will fix it and maintain code consistency :) |
continue # Skip if initializer has no constant value | ||
dtype = const_val.dtype.name | ||
shape = tuple(int(dim) if isinstance(dim, int) else -1 for dim in const_val.shape) | ||
content = const_val.tobytes() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is going to be very slow on big tensors. Is there a way to compare rawdata directly to save some time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After discussion it seems a good idea to avoid comparing big tensors at all. @AbhishekHerbertSamuel Could you limit the size to 1024 values? You can find the element count of the tensor with tensor.size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Threshold 1024 is a very small number in production, there is case where model size is reduced from 150MB to 100MB by tying weights, maybe we can parameterize this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And when comparing those constants, we can compare shape and dtype first, instead of comparing the raw data directly, this will save tremendous time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I agree that’s a good idea
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Adds a new graph transformation pass to remove duplicate initializer tensors by hashing their content and updating node inputs to the canonical tensor.
- Introduces
DeduplicateInitializersPass
to group initializers by dtype, shape, and content hash, remove duplicates, and rewrite node inputs. - Implements content‐based deduplication with SHA-256 fingerprinting and exact byte comparison to avoid collisions.
- Provides unit tests covering identical, shape/dtype differences, scalar, multiple duplicates, and unique-value scenarios.
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
File | Description |
---|---|
src/onnx_ir/passes/common/deduplicate_initializers.py | Implements new pass to deduplicate initializer tensors |
src/onnx_ir/passes/common/deduplicate_initializers_test.py | Adds unit tests for various deduplication scenarios |
Comments suppressed due to low confidence (3)
src/onnx_ir/passes/common/deduplicate_initializers.py:31
- [nitpick] The name 'seen' is ambiguous; consider renaming it to something more descriptive like 'initializer_groups' or 'fingerprint_groups'.
seen: dict[tuple[str, tuple[int, ...]], dict[str, list[onnx_ir.Value]]] = {}
src/onnx_ir/passes/common/deduplicate_initializers.py:33
- [nitpick] The variable 'name_map' could be renamed to 'duplicate_to_canonical' or similar to clarify its purpose.
name_map = {}
src/onnx_ir/passes/common/deduplicate_initializers.py:67
- There are no tests covering deduplication behavior within nested subgraphs; consider adding a test case to verify that duplicate initializers in subgraphs are also handled.
for node in onnx_ir.traversal.RecursiveGraphIterator(graph):
Thank you @xadupre @inisis @justinchuby for the feedback. Will make the requested changes and ensure that the PR is ready to be merged. |
…nd size limit - Avoids comparing large tensors >1024 elements to reduce performance overhead - Compares shape and dtype before accessing tensor content - Adds test coverage for subgraph deduplication (If node branches) - Passes all linters: ruff, mypy, editorconfig Signed-off-by: Abhishek Herbert Samuel <[email protected]>
@xadupre @justinchuby @inisis I have made the requested changes. Please check and let me know if it's ready for merging or if other changes need to be made prior to that. Thank you once again :) |
Summary
This PR adds a new graph transformation pass:
DeduplicateInitializersPass
.It removes duplicate initializer tensors (typically model weights) based on a unique fingerprint derived from:
tobytes()
)dtype
)All redundant initializers are removed, and nodes referencing them are updated to use the canonical (first-seen) tensor.
Implementation Details
(tobytes, dtype, shape) → name
graph.initializers.pop(...)
node.replace_input_with(...)
for correctness and safetyBenefits
File Added
src/onnx_ir/passes/common/deduplicate_initializers.py
Closes
Closes #66