Skip to content

Commit 1d5972f

Browse files
A couple of ort fusion fixes (#2136)
* Enable the use of SDPA fusions, along with undoing it when it does not lead to some subsequent final fusion (such as MHA or GQA). * Fix the use of constants in extracted functions from fusion. * Fix the use of Gelu instead of FastGelu in the new fusion introduced earlier today. --------- Co-authored-by: Justin Chu <[email protected]>
1 parent 7d800b6 commit 1d5972f

File tree

10 files changed

+99
-26
lines changed

10 files changed

+99
-26
lines changed

noxfile.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
"beartype==0.17.2",
1616
"expecttest==0.1.6",
1717
"hypothesis",
18-
'numpy==1.24.4; python_version<"3.9"',
19-
'numpy==1.26.4; python_version>="3.9"',
18+
"numpy",
2019
"packaging",
2120
"parameterized",
2221
'psutil; sys_platform != "win32"',

onnxscript/rewriter/ort_fusions/_core.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import onnxscript.ir as ir
66
from onnxscript.ir.passes.common import shape_inference
7-
from onnxscript.optimizer import optimize, remove_unused_nodes
7+
from onnxscript.optimizer import optimize
88
from onnxscript.rewriter import rewrite
99
from onnxscript.rewriter.ort_fusions import (
1010
fused_matmul_rule_sets,
@@ -36,7 +36,6 @@
3636
# TODO: There are some potential redundancies below. Can be targeted for optimization
3737
# once we have robust fusion.
3838
def _pre_optimize(model: ir.Model) -> ir.Model:
39-
optimize(model)
4039
# TODO: Do we need this dependence on ONNX's partial-data-propagation? There are some
4140
# extra shape-propagation and partial-data-propagation rules in ONNX that are not yet
4241
# incorporated in our optimizer.
@@ -45,7 +44,7 @@ def _pre_optimize(model: ir.Model) -> ir.Model:
4544
return model
4645

4746

48-
def fuse_xformers(model: ir.Model) -> None:
47+
def fuse_xformers(model: ir.Model) -> ir.Model:
4948
model = _pre_optimize(model)
5049
fuse_rms_normalization(model)
5150
fuse_normalization(model)
@@ -55,9 +54,29 @@ def fuse_xformers(model: ir.Model) -> None:
5554
fuse_sdpa(model)
5655
fuse_mha(model)
5756
fuse_gelu(model)
58-
remove_unused_nodes(model)
57+
# Finally: inline any intermediate fusion functions introduced that were not
58+
# consumed by other fusions, and eliminate any remaining unused nodes.
59+
optimize(model)
60+
return model
61+
5962

63+
def optimize_for_ort(model: ir.Model, config_name: str | None = None) -> ir.Model:
64+
"""
65+
Optimize the model for ORT backend.
66+
67+
TODO: config_name is not used yet. It should be used to select the appropriate
68+
optimization configuration (for an EP). Currently, a default implementation is used.
69+
70+
Args:
71+
model: The model to optimize.
72+
config_name: The name of the configuration to use for optimization.
73+
Typically it identifies the Execution Provider (EP) to optimize for.
74+
If None, the default configuration will be used.
75+
76+
Returns:
77+
The optimized model.
78+
"""
6079

61-
def optimize_for_ort(model: ir.Model) -> None:
6280
fuse_xformers(model)
6381
rewrite(model, ORT_PATTERN_REWRITE_RULES)
82+
return model

onnxscript/rewriter/ort_fusions/_test_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def ort_run(model_name: str, model, inputs):
3333
return session.run(None, inputs)
3434

3535

36-
def assert_allclose(outputs, expected_outputs, rtol=1e-2, atol=1e-2):
36+
def assert_allclose(outputs, expected_outputs, rtol=1e-4, atol=1e-4):
3737
for i, (baseline_output, optimized_output) in enumerate(zip(expected_outputs, outputs)):
3838
try:
3939
np.testing.assert_equal(baseline_output.shape, optimized_output.shape)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
import unittest
6+
7+
import onnxscript.optimizer
8+
from onnxscript.rewriter.ort_fusions._core import fuse_xformers
9+
from onnxscript.rewriter.ort_fusions._smollm_1 import smollm_test_1
10+
from onnxscript.rewriter.ort_fusions._test_utils import assert_allclose, ort_run
11+
12+
13+
class TestFuseXformers(unittest.TestCase):
14+
def test_fuse_xformers(self):
15+
test = smollm_test_1()
16+
model = test.get_onnx_model()
17+
onnxscript.optimizer.optimize(model)
18+
inputs = test.get_ort_inputs()
19+
original_outputs = ort_run("original", model, inputs)
20+
model = fuse_xformers(model)
21+
new_outputs = ort_run("optimized", model, inputs)
22+
assert_allclose(new_outputs, original_outputs)
23+
24+
25+
if __name__ == "__main__":
26+
unittest.main()

onnxscript/rewriter/ort_fusions/gelu.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def pattern(self, op, x):
2525
return result
2626

2727
def rewrite(self, op, x):
28-
return op.Gelu(x, _domain="com.microsoft")
28+
return op.FastGelu(x, _domain="com.microsoft")
2929

3030

3131
_rule = GeluTanhFusion.rule()

onnxscript/rewriter/ort_fusions/gelu_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def gelu_model(x):
4747
remove_unused_nodes(model)
4848

4949
self.assertEqual(len(model.graph), 1)
50-
self.assertEqual(model.graph.node(0).op_type, "Gelu")
50+
self.assertEqual(model.graph.node(0).op_type, "FastGelu")
5151

5252
optimized_output = test_utils.ort_run("Optimized", model, input)
5353
test_utils.assert_allclose(original_output, optimized_output)

onnxscript/rewriter/ort_fusions/rms_normalization.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def check(self, op, x, scale, epsilon, compute_dtype, target_dtype):
7171
def rewrite(self, op, x, scale, epsilon, compute_dtype, target_dtype):
7272
stash_dtype = compute_dtype.value if self._cast_input else x.dtype
7373
# Note: ORT's SimplifiedLayerNormalization was placed in onnx domain by mistake.
74-
# No need to use com.microsoft domain here.
74+
# No need to use com.microsoft domain here; but this is a custom op in ORT.
7575
return op.SimplifiedLayerNormalization(
7676
x,
7777
scale,

onnxscript/rewriter/ort_fusions/sdpa.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
class SDPA(pattern.RewriteRuleClassBase):
1212
def __init__(self, name: str, *, use_mask: bool, pre_scale: bool, use_mul: bool):
13-
super().__init__(name=name)
13+
super().__init__(name=name, as_function=True)
1414
self._use_mask = use_mask
1515
self._pre_scale = pre_scale
1616
self._use_mul = use_mul

onnxscript/rewriter/ort_fusions/sdpa_test.py

+4
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,7 @@ def test_sdpa_fusion(self, name, script_func):
180180

181181
# new_outputs = ort_run("optimized", model, inputs)
182182
# assert_allclose(new_outputs, original_outputs)
183+
184+
185+
if __name__ == "__main__":
186+
unittest.main()

onnxscript/rewriter/pattern.py

+39-14
Original file line numberDiff line numberDiff line change
@@ -1428,6 +1428,7 @@ def replace_pattern(new_pattern):
14281428
self.remove_nodes,
14291429
self.graph_pre_visitor,
14301430
self.graph_post_visitor,
1431+
self.as_function,
14311432
)
14321433

14331434
return [replace_pattern(p) for p in self._target_pattern.commute()]
@@ -1509,21 +1510,23 @@ class RewriteRuleClassBase:
15091510
@classmethod
15101511
def rule(cls, *args, **kwargs):
15111512
instance = cls(*args, **kwargs)
1512-
setup = instance.setup if hasattr(instance, "setup") else None
1513-
cleanup = instance.cleanup if hasattr(instance, "cleanup") else None
15141513
return RewriteRule(
15151514
instance.pattern,
15161515
instance.rewrite,
15171516
instance.check,
15181517
name=instance.name,
15191518
remove_nodes=instance.remove_nodes,
1520-
graph_pre_visitor=setup,
1521-
graph_post_visitor=cleanup,
1519+
graph_pre_visitor=instance.setup,
1520+
graph_post_visitor=instance.cleanup,
1521+
as_function=instance.as_function,
15221522
)
15231523

1524-
def __init__(self, name: str | None = None, remove_nodes: bool = True) -> None:
1524+
def __init__(
1525+
self, name: str | None = None, remove_nodes: bool = True, as_function: bool = False
1526+
) -> None:
15251527
self.name = name or self.__class__.__name__
15261528
self.remove_nodes = remove_nodes
1529+
self.as_function = as_function
15271530

15281531
def pattern(self, op, *args, **kwargs):
15291532
raise NotImplementedError("Method 'pattern' must be implemented by derived class.")
@@ -1535,30 +1538,52 @@ def check(self, op, *args, **kwargs):
15351538
def rewrite(self, op, *args, **kwargs):
15361539
raise NotImplementedError("Method 'rewrite' must be implemented by derived class.")
15371540

1541+
def setup(self):
1542+
# Optional setup function that can be overridden by derived classes. Used to do
1543+
# per model/function initialization.
1544+
pass
1545+
1546+
def cleanup(self):
1547+
# Optional cleanup function that can be overridden by derived classes. Used to do
1548+
# per model/function cleanup.
1549+
pass
1550+
15381551

15391552
def _copy_for_function(
15401553
inputs: Sequence[ir.Value | None], nodes: Sequence[ir.Node], outputs: Sequence[ir.Value]
15411554
):
15421555
"""Utility function to extract a subgraph out as a function."""
15431556
value_map: dict[ir.Value, ir.Value] = {}
15441557
function_inputs: list[ir.Value] = []
1558+
constant_nodes: list[ir.Node] = []
15451559
for input in inputs:
15461560
# Create a function input (formal-parameter value) to represent this value:
1547-
if input is None:
1548-
raise NotImplementedError("None inputs not supported.")
1549-
new_value = ir.Value(
1550-
name=input.name,
1551-
shape=input.shape,
1552-
type=input.type,
1553-
doc_string=input.doc_string,
1561+
new_value = (
1562+
ir.Value(
1563+
name=input.name,
1564+
shape=input.shape,
1565+
type=input.type,
1566+
doc_string=input.doc_string,
1567+
)
1568+
if input
1569+
else ir.Value() # dummy parameter for a None input
15541570
)
1555-
value_map[input] = new_value
1571+
if input is not None:
1572+
value_map[input] = new_value
15561573
function_inputs.append(new_value)
15571574

15581575
def copy_value(value: ir.Value | None) -> ir.Value | None:
15591576
if value is None:
15601577
return None
15611578
if value not in value_map:
1579+
const_value = value.const_value
1580+
if const_value is not None:
1581+
# create a Constant node to represent the value
1582+
value_attr = ir.AttrTensor("value", const_value)
1583+
const_node = ir.Node("", "Constant", [], [value_attr])
1584+
constant_nodes.append(const_node)
1585+
value_map[value] = result = const_node.outputs[0]
1586+
return result
15621587
raise ValueError(f"Value {value} not found in value_map.")
15631588
return value_map[value]
15641589

@@ -1598,7 +1623,7 @@ def copy_node(node: ir.Node) -> ir.Node:
15981623

15991624
function_nodes = [copy_node(node) for node in nodes]
16001625
function_outputs = [copy_value(v) for v in outputs]
1601-
return (function_inputs, function_nodes, function_outputs)
1626+
return (function_inputs, constant_nodes + function_nodes, function_outputs)
16021627

16031628

16041629
def _get_new_overload(model: ir.Model, domain: str, name: str) -> str:

0 commit comments

Comments
 (0)