-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[IR] Implement pass to remove unused nodes in graph #1841
base: main
Are you sure you want to change the base?
Changes from all commits
d8dac5e
26afad6
d57ec54
6d23912
a82023f
b69dd83
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
"""Utilities for removing unused nodes the IR graph.""" | ||
|
||
from __future__ import annotations | ||
|
||
from collections import deque | ||
|
||
import onnxscript.ir as ir | ||
from onnxscript.ir import Attr, Graph, Node, Value, _enums | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Import modules only |
||
|
||
|
||
class RemoveUnused: | ||
def __init__(self, graph_like: Graph): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ir.Graph? |
||
self._graph = graph_like | ||
|
||
def purge(self) -> None: | ||
"""Remove unused nodes in this graph (and all subgraphs) that do not contribute to main graph outputs.""" | ||
# 1. Initialize: | ||
# Gather all nodes from the graph and its subgraphs. | ||
# Initialize sets to keep track of visited graphs, values, and nodes. | ||
# 2. BFS traversal: | ||
# Create a queue initialized with all output values of the main graph. | ||
# While there are values in the queue: | ||
# - Dequeue a value and retrieve its producer node. | ||
# - Mark the producer node as visited, if it hasn't been visited. | ||
# - Enqueue all output values of the attribute subgraphs of the producer node, | ||
# if they haven't been visited. | ||
# - Enqueue all input values of the producer node, if they haven't been visited. | ||
# 3. Remove: | ||
# Remove all nodes that have not been marked as visited during the BFS traversal. | ||
|
||
# Initialize | ||
all_nodes: list[Node] = list(ir.traversal.RecursiveGraphIterator(self._graph)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: now this can be moved down to line 79, which is where it is used, I think ... |
||
visited_graphs: set[Graph] = set() | ||
visited_values: set[Value] = set() | ||
visited_nodes: set[Node] = set() | ||
|
||
# BFS Traversal | ||
queue: deque[Value] = deque() | ||
|
||
def add_graph_output_values_to_queue(graph: Graph | None) -> None: | ||
"""Helper function to add all output values of a graph to the queue.""" | ||
if not graph or graph in visited_graphs: | ||
return | ||
visited_graphs.add(graph) | ||
for output in graph.outputs: | ||
if not output: | ||
continue | ||
queue.append(output) | ||
visited_values.add(output) | ||
|
||
add_graph_output_values_to_queue(self._graph) | ||
|
||
while queue: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be possible to avoid the queue by looping over all nodes in the backward order (assuming the ir preserve a consistent order on the nodes +@justinchuby). |
||
# Dequeue a value and retrieve its producer_node | ||
# Add producer_node to visited_nodes | ||
current_value = queue.popleft() | ||
producer_node = current_value.producer() | ||
if not producer_node or producer_node in visited_nodes: | ||
continue | ||
visited_nodes.add(producer_node) | ||
# Add producer_node's subgraphs to visited_graphs | ||
# Add subgraphs' output values to queue | ||
for attr in producer_node.attributes.values(): | ||
if not isinstance(attr, Attr): | ||
continue | ||
if attr.type == _enums.AttributeType.GRAPH: | ||
add_graph_output_values_to_queue(attr.value) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some subgraphs use intermediate results declared in the main graph. You need to loop over nodes inside subgraphs as well. You'll have to handle inputs/outputs with the same name. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The RecursiveGraphIterator will loop over all nodes in subgraphs. So There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the code handles Xavier's case, but not because of the recursive graph iterator. (That seems to be used only in the later loop below to remove nodes). The code above goes from value to the producer of the value: this should go from a use inside a subgraph to a producer outside the subgraph (as long as the IR is constructed correctly.) |
||
elif attr.type == _enums.AttributeType.GRAPHS: | ||
for subgraph in attr.value: | ||
add_graph_output_values_to_queue(subgraph) | ||
# Add producer_node's input values to queue | ||
for input_value in producer_node.inputs: | ||
if input_value and input_value not in visited_values: | ||
visited_values.add(input_value) | ||
queue.append(input_value) | ||
|
||
# Remove | ||
for node in all_nodes: | ||
if node not in visited_nodes: # type: ignore[union-attr]` | ||
Check failure Code scanning / lintrunner MYPY/syntax Error
Invalid "type: ignore" comment
To disable, use # type: ignore[syntax]
|
||
node.graph.remove(node) | ||
Check failure Code scanning / lintrunner MYPY/union-attr Error
Item "None" of "Graph | None" has no attribute "remove"
To disable, use # type: ignore[union-attr]
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Acknowledged. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
import unittest | ||
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved
Hide resolved
|
||
|
||
from onnxscript import ir | ||
from onnxscript.ir.passes._remove_unused import RemoveUnused | ||
|
||
|
||
class RemoveUnusedTest(unittest.TestCase): | ||
def test_purge_empty(self): | ||
graph = ir.Graph( | ||
inputs=(), | ||
outputs=(), | ||
nodes=(), | ||
opset_imports={"": 1}, | ||
) | ||
remove_unused = RemoveUnused(graph) | ||
remove_unused.purge() | ||
self.assertEqual(tuple(graph), ()) | ||
|
||
def test_purge_a_single_node(self): | ||
v0 = ir.Value(name="v0") | ||
node0 = ir.Node("", "Node0", inputs=(v0,), num_outputs=1) | ||
node1 = ir.Node("", "Node1", inputs=(v0,), num_outputs=1) | ||
node2 = ir.Node("", "Node2", inputs=(v0,), num_outputs=0) | ||
node3 = ir.Node("", "Node3", inputs=(), num_outputs=1) | ||
node4 = ir.Node("", "Node4", inputs=(None,), num_outputs=1) | ||
graph = ir.Graph( | ||
(v0,), | ||
(node0.outputs[0], node3.outputs[0], node4.outputs[0]), | ||
nodes=(node0, node1, node2, node3, node4), | ||
opset_imports={"": 1}, | ||
) | ||
remove_unused = RemoveUnused(graph) | ||
remove_unused.purge() | ||
self.assertEqual(tuple(graph), (node0, node3, node4)) | ||
|
||
def test_purge_a_tree(self): | ||
v0 = ir.Value(name="v0") | ||
node0 = ir.Node("", "Node0", inputs=(v0,), num_outputs=1) | ||
node1 = ir.Node("", "Node1", inputs=(node0.outputs[0],), num_outputs=1) | ||
node2 = ir.Node("", "Node2", inputs=(node0.outputs[0],), num_outputs=1) | ||
graph = ir.Graph( | ||
(v0,), | ||
(), | ||
nodes=(node0, node1, node2), | ||
opset_imports={"": 1}, | ||
) | ||
remove_unused = RemoveUnused(graph) | ||
remove_unused.purge() | ||
self.assertEqual(tuple(graph), ()) | ||
|
||
def test_purge_subgraph_partial(self): | ||
v0 = ir.Value(name="va") | ||
v1 = ir.Value(name="vb") | ||
v2 = ir.Value(name="vc") | ||
v3 = ir.Value(name="vd") | ||
node0 = ir.Node("", "a", inputs=(v0,), num_outputs=1) | ||
node1 = ir.Node("", "b", inputs=(v1,), num_outputs=1) | ||
node2 = ir.Node("", "c", inputs=(v2,), num_outputs=1) | ||
node3 = ir.Node("", "d", inputs=(v3,), num_outputs=1) | ||
node4 = ir.Node("", "sub", inputs=(node2.outputs[0], node3.outputs[0]), num_outputs=1) | ||
node5 = ir.Node("", "add", inputs=(node2.outputs[0], node3.outputs[0]), num_outputs=1) | ||
node6 = ir.Node("", ">", inputs=(node0.outputs[0], node1.outputs[0]), num_outputs=1) | ||
then_graph = ir.Graph( | ||
inputs=(node2.outputs[0], node3.outputs[0]), | ||
outputs=(node4.outputs[0],), | ||
nodes=(node4,), | ||
name="then_graph", | ||
) | ||
else_graph = ir.Graph( | ||
inputs=(node2.outputs[0], node3.outputs[0]), | ||
outputs=(), | ||
nodes=(node5,), | ||
name="else_graph", | ||
) | ||
|
||
node7 = ir.Node( | ||
"", | ||
"if", | ||
inputs=(node6.outputs[0],), | ||
num_outputs=1, | ||
attributes=[ | ||
ir.AttrGraphs("subgraphs", [then_graph, else_graph]), | ||
], | ||
) | ||
main_graph = ir.Graph( | ||
inputs=(v0, v1, v2, v3), | ||
outputs=(node7.outputs[0],), | ||
nodes=(node0, node1, node2, node3, node6, node7), | ||
name="main_graph", | ||
opset_imports={"": 1}, | ||
) | ||
remove_unused = RemoveUnused(main_graph) | ||
remove_unused.purge() | ||
self.assertEqual(tuple(main_graph), (node0, node1, node2, node3, node6, node7)) | ||
self.assertEqual(tuple(then_graph), (node4,)) | ||
self.assertEqual(tuple(else_graph), ()) | ||
|
||
def test_purge_subgraph_all(self): | ||
v0 = ir.Value(name="v0") | ||
node0 = ir.Node("", "c", inputs=(v0,), num_outputs=1) | ||
node1 = ir.Node("", "sub", inputs=(node0.outputs[0],), num_outputs=1) | ||
node2 = ir.Node("", ">", inputs=(v0,), num_outputs=1) | ||
then_graph = ir.Graph( | ||
inputs=(node0.outputs[0],), | ||
outputs=(node1.outputs[0],), | ||
nodes=(node1,), | ||
name="then_graph", | ||
) | ||
node4 = ir.Node( | ||
"", | ||
"if", | ||
inputs=(node2.outputs[0],), | ||
num_outputs=1, | ||
attributes=[ | ||
ir.AttrGraph("then_graph", then_graph), | ||
], | ||
) | ||
main_graph = ir.Graph( | ||
inputs=(v0,), | ||
outputs=(), | ||
nodes=(node0, node2, node4), | ||
name="main_graph", | ||
) | ||
remove_unused = RemoveUnused(main_graph) | ||
remove_unused.purge() | ||
self.assertEqual(tuple(main_graph), ()) | ||
self.assertEqual(tuple(then_graph), ()) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() | ||
Check warning
Code scanning / lintrunner
RUFF-FORMAT/format Warning