Skip to content

Commit a36ec86

Browse files
authored
Rotary embedding needs function extraction (#2139)
Rotary embedding fusion needs as_function=True.
1 parent 1d5972f commit a36ec86

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

onnxscript/rewriter/ort_fusions/rotary_embedding.py

+3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ def _rotate_half_pattern(op, x, start1, end1, start2, end2):
2424

2525

2626
class RotaryEmbeddingFusion(pattern.RewriteRuleClassBase):
27+
def __init__(self):
28+
super().__init__(name="RotaryEmbedding", as_function=True)
29+
2730
def pattern(self, op, x, cos, sin, start1, end1, start2, end2):
2831
return x * cos + _rotate_half_pattern(op, x, start1, end1, start2, end2) * sin
2932

0 commit comments

Comments
 (0)