Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT: WIP] First version of attention fusion #1986

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,29 @@ def _get_int_attribute(node: ir.Node, name: str, default: int | None = None) ->
return default


@register("Reshape")
def reshape(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
"""Replace a Reshape node by Identity when applicable."""
input = _get_input(node, 0)
shape = _get_input(node, 1)
if input is None or shape is None:
return None
input_shape = input.shape
if input_shape is None:
return None
input_shape_dims = list(input_shape.dims)
if any(not isinstance(dim, int) for dim in input_shape_dims):
return None
shape_value = _get_numpy_value(shape)
if shape_value is None:
return None
target_shape_dims = shape_value.tolist()
if input_shape_dims == target_shape_dims:
# No need to check for special values like -1, 0, etc. here
return op.Identity(input)
return None


@register("Cast")
def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
input = _get_input(node, 0)
Expand Down
27 changes: 27 additions & 0 deletions onnxscript/rewriter/_ir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# Licensed under the MIT License.
from __future__ import annotations

import math
from typing import Callable

import numpy as np

import onnxscript.ir as ir
Expand Down Expand Up @@ -77,3 +80,27 @@ def get_singleton_value(val: ir.Value | None):
if np_val is not None and np_val.size == 1:
return np_val.item()
return None


def is_singleton_value(
val: ir.Value | None, expected: float | int | Callable, *, rtol: float | None = None
) -> bool:
"""Returns True if the value is a single element tensor with given value, and False otherwise."""
scalar = get_singleton_value(val)
if scalar is None:
return False
if callable(expected):
return expected(scalar)
if isinstance(expected, int):
return expected == scalar
# rtol must be specified for float comparison
assert rtol is not None
return math.isclose(scalar, expected, rel_tol=rtol)


def has_rank(value: ir.Value | None, rank: int) -> bool:
"""Returns True if the value is statically known to have the given rank, and False otherwise."""
if value is None:
return False
shape = value.shape
return (shape is not None) and (shape.rank() == rank)
18 changes: 18 additions & 0 deletions onnxscript/rewriter/onnxruntime/xformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

from onnxscript.rewriter.onnxruntime.xformers.cos_sin_cache import fuse_cos_sin_cache
from onnxscript.rewriter.onnxruntime.xformers.fuse_xformers import fuse_xformers
from onnxscript.rewriter.onnxruntime.xformers.mha import fuse_mha
from onnxscript.rewriter.onnxruntime.xformers.rms_normalization import fuse_rms_normalization
from onnxscript.rewriter.onnxruntime.xformers.rotary_embedding import fuse_rotary_embedding
from onnxscript.rewriter.onnxruntime.xformers.sdpa import fuse_sdpa
from onnxscript.rewriter.onnxruntime.xformers.skip_normalization import fuse_normalization

__all__ = [
"fuse_rms_normalization",
"fuse_normalization",
"fuse_rotary_embedding",
"fuse_cos_sin_cache",
"fuse_sdpa",
"fuse_mha",
"fuse_xformers",
]
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT License.

"""
A one-layer SmolLM model test case.
A one-layer SmolLM model test case, with inputs: input_ids, attention_mask, and position_ids.
This is an onnxscript version of the model.
"""

Expand Down Expand Up @@ -234,7 +234,7 @@ def make_model_with_random_weights():
return model


class _SmollmTestData:
class TestData:
def get_onnx_model(self):
if not hasattr(self, "_onnx_model"):
model_proto = make_model_with_random_weights()
Expand Down
Loading
Loading