-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[IR] Implement pass to remove unused nodes in graph #1841
base: main
Are you sure you want to change the base?
Changes from 4 commits
d8dac5e
26afad6
d57ec54
6d23912
a82023f
b69dd83
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
"""Utilities for removing unused nodes the IR graph.""" | ||
|
||
from __future__ import annotations | ||
|
||
from collections import deque | ||
from typing import Union | ||
|
||
import onnxscript.ir as ir | ||
from onnxscript.ir import Graph, Node, Value | ||
|
||
|
||
class RemoveUnused: | ||
def __init__(self, graph_like: Union[Graph, ir.GraphView]): | ||
self._graph = graph_like | ||
|
||
def purge(self) -> None: | ||
"""Remove unused nodes in this graph and all its subgraphs that do not contribute to any graph_outputs.""" | ||
# 1. Initialize | ||
# Gather all nodes from the graph and its subgraphs using a recursive iterator. | ||
# Identify all subgraphs by checking the graph of each node. | ||
# Initialize sets to keep track of visited values and nodes. | ||
# 2. BFS traversal: | ||
# Create a queue initialized with all output values from the subgraphs. | ||
# While there are values in the queue: | ||
# - Dequeue a value and retrieve its producer node. | ||
# - Skip processing if the producer node is already visited or doesn't exist. | ||
# - Mark the producer node as visited. | ||
# - Enqueue all input values of this producer node for further exploration, if they haven't been visited. | ||
# 3. Remove: | ||
# Remove any node from its graph if it has not been marked as visited during the BFS traversal. | ||
|
||
# Initialize | ||
all_nodes: list[Node] = list(ir.traversal.RecursiveGraphIterator(self._graph)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: now this can be moved down to line 79, which is where it is used, I think ... |
||
subgraphs: set[Graph] = {node.graph for node in all_nodes if node.graph} | ||
visited_values: set[Value] = set() | ||
visited_nodes: set[Node] = set() | ||
|
||
# BFS Traversal | ||
value_queue: deque[Value] = deque( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it would be better if a node's subgraphs are processed only after the node is itself determined to be useful (that is, added to
Here, if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree. Do we really need all subgraphs' outputs? @justinchuby There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I adopted @gramalingam 's idea, modified the code and added a testcase. |
||
output for graph in subgraphs for output in graph.outputs if output | ||
) | ||
while value_queue: | ||
current_value = value_queue.popleft() | ||
producer_node = current_value.producer() | ||
if not producer_node or producer_node in visited_nodes: | ||
continue | ||
visited_nodes.add(producer_node) | ||
for input_value in producer_node.inputs: | ||
if input_value and input_value not in visited_values: | ||
visited_values.add(input_value) | ||
value_queue.append(input_value) | ||
|
||
# Remove | ||
for node in all_nodes: | ||
if node not in visited_nodes: | ||
node.graph.remove(node) # type: ignore[union-attr] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
import unittest | ||
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved
Hide resolved
|
||
|
||
from onnxscript import ir | ||
from onnxscript.ir.passes._remove_unused import RemoveUnused | ||
|
||
|
||
class RemoveUnusedTest(unittest.TestCase): | ||
def test_purge_empty(self): | ||
graph = ir.Graph( | ||
inputs=(), | ||
outputs=(), | ||
nodes=(), | ||
opset_imports={"": 1}, | ||
) | ||
remove_unused = RemoveUnused(graph) | ||
remove_unused.purge() | ||
self.assertEqual(tuple(graph), ()) | ||
|
||
def test_purge_a_single_node(self): | ||
v0 = ir.Value(name="v0") | ||
node0 = ir.Node("", "Node0", inputs=(v0,), num_outputs=1) | ||
node1 = ir.Node("", "Node1", inputs=(v0,), num_outputs=1) | ||
node2 = ir.Node("", "Node2", inputs=(v0,), num_outputs=0) | ||
node3 = ir.Node("", "Node3", inputs=(), num_outputs=1) | ||
node4 = ir.Node("", "Node4", inputs=(None,), num_outputs=1) | ||
graph = ir.Graph( | ||
(v0,), | ||
(node0.outputs[0], node3.outputs[0], node4.outputs[0]), | ||
nodes=(node0, node1, node2, node3, node4), | ||
opset_imports={"": 1}, | ||
) | ||
remove_unused = RemoveUnused(graph) | ||
remove_unused.purge() | ||
self.assertEqual(tuple(graph), (node0, node3, node4)) | ||
|
||
def test_purge_a_tree(self): | ||
v0 = ir.Value(name="v0") | ||
node0 = ir.Node("", "Node0", inputs=(v0,), num_outputs=1) | ||
node1 = ir.Node("", "Node1", inputs=(node0.outputs[0],), num_outputs=1) | ||
node2 = ir.Node("", "Node2", inputs=(node0.outputs[0],), num_outputs=1) | ||
graph = ir.Graph( | ||
(v0,), | ||
(), | ||
nodes=(node0, node1, node2), | ||
opset_imports={"": 1}, | ||
) | ||
remove_unused = RemoveUnused(graph) | ||
remove_unused.purge() | ||
self.assertEqual(tuple(graph), ()) | ||
|
||
def test_purge_subgraph(self): | ||
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved
Hide resolved
|
||
v0 = ir.Value(name="va") | ||
v1 = ir.Value(name="vb") | ||
v2 = ir.Value(name="vc") | ||
v3 = ir.Value(name="vd") | ||
node0 = ir.Node("", "a", inputs=(v0,), num_outputs=1) | ||
node1 = ir.Node("", "b", inputs=(v1,), num_outputs=1) | ||
node2 = ir.Node("", "c", inputs=(v2,), num_outputs=1) | ||
node3 = ir.Node("", "d", inputs=(v3,), num_outputs=1) | ||
node4 = ir.Node("", "sub", inputs=(node2.outputs[0], node3.outputs[0]), num_outputs=1) | ||
node5 = ir.Node("", "add", inputs=(node2.outputs[0], node3.outputs[0]), num_outputs=1) | ||
node6 = ir.Node("", ">", inputs=(node0.outputs[0], node1.outputs[0]), num_outputs=1) | ||
then_graph = ir.Graph( | ||
inputs=(node2.outputs[0], node3.outputs[0]), | ||
outputs=(node4.outputs[0],), | ||
nodes=(node4,), | ||
name="then_graph", | ||
) | ||
else_graph = ir.Graph( | ||
inputs=(node2.outputs[0], node3.outputs[0]), | ||
outputs=(), | ||
nodes=(node5,), | ||
name="else_graph", | ||
) | ||
|
||
node7 = ir.Node( | ||
"", | ||
"if", | ||
inputs=(node6.outputs[0],), | ||
num_outputs=1, | ||
attributes=[ | ||
ir.AttrGraph("then_branch", then_graph), | ||
ir.AttrGraph("else_branch", else_graph), | ||
], | ||
) | ||
main_graph = ir.Graph( | ||
inputs=(v0, v1, v2, v3), | ||
outputs=(node7.outputs[0],), | ||
nodes=(node0, node1, node2, node3, node6, node7), | ||
name="main_graph", | ||
opset_imports={"": 1}, | ||
) | ||
remove_unused = RemoveUnused(main_graph) | ||
remove_unused.purge() | ||
self.assertEqual(tuple(main_graph), (node0, node1, node2, node3, node6, node7)) | ||
self.assertEqual(tuple(then_graph), (node4,)) | ||
self.assertEqual(tuple(else_graph), ()) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() | ||
Check warning
Code scanning / lintrunner
RUFF-FORMAT/format Warning