Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Implement pass to remove unused nodes in graph #1841

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions onnxscript/ir/passes/_remove_unused.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.
# Licensed under the MIT License.
"""Utilities for removing unused nodes the IR graph."""

from __future__ import annotations

from collections import deque
from typing import Union

import onnxscript.ir as ir
from onnxscript.ir import Graph, Node, Value


class RemoveUnused:
def __init__(self, graph_like: Union[Graph, ir.GraphView]):
self._graph = graph_like

def purge(self) -> None:
"""Remove unused nodes in this graph and all its subgraphs that do not contribute to any graph_outputs."""
# 1. Initialize
# Gather all nodes from the graph and its subgraphs using a recursive iterator.
# Identify all subgraphs by checking the graph of each node.
# Initialize sets to keep track of visited values and nodes.
# 2. BFS traversal:
# Create a queue initialized with all output values from the subgraphs.
# While there are values in the queue:
# - Dequeue a value and retrieve its producer node.
# - Skip processing if the producer node is already visited or doesn't exist.
# - Mark the producer node as visited.
# - Enqueue all input values of this producer node for further exploration, if they haven't been visited.
# 3. Remove:
# Remove any node from its graph if it has not been marked as visited during the BFS traversal.

# Initialize
all_nodes: list[Node] = list(ir.traversal.RecursiveGraphIterator(self._graph))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: now this can be moved down to line 79, which is where it is used, I think ...

subgraphs: set[Graph] = {node.graph for node in all_nodes if node.graph}
visited_values: set[Value] = set()
visited_nodes: set[Node] = set()

# BFS Traversal
value_queue: deque[Value] = deque(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better if a node's subgraphs are processed only after the node is itself determined to be useful (that is, added to visited_nodes. This will handle examples such as the one below better:

   x = ...
   y = If ( cond, ... x ..., ...)

Here, if y is not used, then we may not need x either. But the current logic will, I believe, mark x as visited since it is used to compute the output of the If's then subgraph's output.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. Do we really need all subgraphs' outputs? @justinchuby

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I adopted @gramalingam 's idea, modified the code and added a testcase.

output for graph in subgraphs for output in graph.outputs if output
)
while value_queue:
current_value = value_queue.popleft()
producer_node = current_value.producer()
if not producer_node or producer_node in visited_nodes:
continue
visited_nodes.add(producer_node)
for input_value in producer_node.inputs:
if input_value and input_value not in visited_values:
visited_values.add(input_value)
value_queue.append(input_value)

# Remove
for node in all_nodes:
if node not in visited_nodes:
node.graph.remove(node) # type: ignore[union-attr]
103 changes: 103 additions & 0 deletions onnxscript/ir/passes/_remove_unused_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import unittest
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved Hide resolved

from onnxscript import ir
from onnxscript.ir.passes._remove_unused import RemoveUnused


class RemoveUnusedTest(unittest.TestCase):
def test_purge_empty(self):
graph = ir.Graph(
inputs=(),
outputs=(),
nodes=(),
opset_imports={"": 1},
)
remove_unused = RemoveUnused(graph)
remove_unused.purge()
self.assertEqual(tuple(graph), ())

def test_purge_a_single_node(self):
v0 = ir.Value(name="v0")
node0 = ir.Node("", "Node0", inputs=(v0,), num_outputs=1)
node1 = ir.Node("", "Node1", inputs=(v0,), num_outputs=1)
node2 = ir.Node("", "Node2", inputs=(v0,), num_outputs=0)
node3 = ir.Node("", "Node3", inputs=(), num_outputs=1)
node4 = ir.Node("", "Node4", inputs=(None,), num_outputs=1)
graph = ir.Graph(
(v0,),
(node0.outputs[0], node3.outputs[0], node4.outputs[0]),
nodes=(node0, node1, node2, node3, node4),
opset_imports={"": 1},
)
remove_unused = RemoveUnused(graph)
remove_unused.purge()
self.assertEqual(tuple(graph), (node0, node3, node4))

def test_purge_a_tree(self):
v0 = ir.Value(name="v0")
node0 = ir.Node("", "Node0", inputs=(v0,), num_outputs=1)
node1 = ir.Node("", "Node1", inputs=(node0.outputs[0],), num_outputs=1)
node2 = ir.Node("", "Node2", inputs=(node0.outputs[0],), num_outputs=1)
graph = ir.Graph(
(v0,),
(),
nodes=(node0, node1, node2),
opset_imports={"": 1},
)
remove_unused = RemoveUnused(graph)
remove_unused.purge()
self.assertEqual(tuple(graph), ())

def test_purge_subgraph(self):
github-advanced-security[bot] marked this conversation as resolved.
Fixed
Show resolved Hide resolved
v0 = ir.Value(name="va")
v1 = ir.Value(name="vb")
v2 = ir.Value(name="vc")
v3 = ir.Value(name="vd")
node0 = ir.Node("", "a", inputs=(v0,), num_outputs=1)
node1 = ir.Node("", "b", inputs=(v1,), num_outputs=1)
node2 = ir.Node("", "c", inputs=(v2,), num_outputs=1)
node3 = ir.Node("", "d", inputs=(v3,), num_outputs=1)
node4 = ir.Node("", "sub", inputs=(node2.outputs[0], node3.outputs[0]), num_outputs=1)
node5 = ir.Node("", "add", inputs=(node2.outputs[0], node3.outputs[0]), num_outputs=1)
node6 = ir.Node("", ">", inputs=(node0.outputs[0], node1.outputs[0]), num_outputs=1)
then_graph = ir.Graph(
inputs=(node2.outputs[0], node3.outputs[0]),
outputs=(node4.outputs[0],),
nodes=(node4,),
name="then_graph",
)
else_graph = ir.Graph(
inputs=(node2.outputs[0], node3.outputs[0]),
outputs=(),
nodes=(node5,),
name="else_graph",
)

node7 = ir.Node(
"",
"if",
inputs=(node6.outputs[0],),
num_outputs=1,
attributes=[
ir.AttrGraph("then_branch", then_graph),
ir.AttrGraph("else_branch", else_graph),
],
)
main_graph = ir.Graph(
inputs=(v0, v1, v2, v3),
outputs=(node7.outputs[0],),
nodes=(node0, node1, node2, node3, node6, node7),
name="main_graph",
opset_imports={"": 1},
)
remove_unused = RemoveUnused(main_graph)
remove_unused.purge()
self.assertEqual(tuple(main_graph), (node0, node1, node2, node3, node6, node7))
self.assertEqual(tuple(then_graph), (node4,))
self.assertEqual(tuple(else_graph), ())


if __name__ == "__main__":
unittest.main()

Check warning on line 103 in onnxscript/ir/passes/_remove_unused_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/_remove_unused_test.py#L103

Added line #L103 was not covered by tests
Loading