From 7bd391db06a3f057ee8bb58153439cdda963ab0b Mon Sep 17 00:00:00 2001 From: Tommaso Date: Thu, 24 Jul 2025 02:35:00 +0000 Subject: [PATCH 01/26] Added Causal Mask Pattern Fusion for LongRoPe Models --- onnxscript/rewriter/ort_fusions/gqa.py | 157 ++++++++++++++++++++++++- 1 file changed, 156 insertions(+), 1 deletion(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 99852f712a..2b7c314b3e 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -7,6 +7,7 @@ import numpy as np import onnx_ir as ir +import onnxscript.onnx_types as _onnx_types import onnxscript.rewriter._fusion_utils as _fusion_utils from onnxscript.rewriter import _basics, _ir_utils, pattern @@ -354,9 +355,163 @@ def rewrite( _outputs=3, ) +class LongRoPeGQACausalMask(pattern.RewriteRuleClassBase): + def __init__(self): + super().__init__("LongRoPeGQACausalMask", remove_nodes=False) + self._mask_cache = {} + + def _get_mask_key(self, attention_mask): + """ + Generate a unique key for the mask based on input_ids and past_kv_cache. + This is used to cache the mask to avoid recomputation. + """ + return (id(attention_mask)) + + def compute_mask(self, op, attention_mask : _onnx_types.INT64['batch', 'seq_len']): + mask_key = self._get_mask_key(attention_mask) + + if mask_key in self._mask_cache: + total_seq_length_int32, seqlens_k_int32 = self._mask_cache[mask_key] + + else: + # Construct total_seq_length_int32 and seqlens_k + attention_shape = op.Shape(attention_mask, _outputs=["seq_len"]) + total_seq_length = op.Gather(attention_shape, op.Constant(value=ir.tensor(1, ir.DataType.INT64)), axis=0, _outputs=["total_seq_length"]) + reduced_attention = op.ReduceSum(attention_mask, op.Constant(value=ir.tensor([1], ir.DataType.INT64)), _outputs=["reduced_attention"]) + sub_reduced_attention = op.Sub(reduced_attention, op.Constant(value=ir.tensor([1], ir.DataType.INT64)), _outputs=["sub_reduced_attention"]) + total_seq_length_int32 = op.Cast(total_seq_length, to=ir.DataType.INT32, _outputs=["total_seq_length_int32"]) + seqlens_k_int32 = op.Cast(sub_reduced_attention, to=ir.DataType.INT32, _outputs=["seqlens_k_int32"]) + self._mask_cache[mask_key] = (total_seq_length_int32, seqlens_k_int32) + + return self._mask_cache[mask_key] + + + def pattern( + self, + op, + mask, + input_ids, + past_kv_cache_1, + past_kv_cache_2, + attention_mask, + past_seq_length, + total_seq_length, + ): + seq_len = op.Shape(input_ids, end=2, start=1, _outputs=["seq_len"]) + seq_len_0D = op.Squeeze(seq_len, _outputs=["seq_len_0D"]) + past_seq_len = op.Shape(past_kv_cache_1, end=3, start=2, _outputs=["past_seq_len"]) + past_seq_len_0D = op.Squeeze(past_seq_len, _outputs=["past_seq_len_0D"]) + total_seq_len_0D = op.Add(past_seq_len_0D, seq_len_0D, _outputs=["total_seq_len_0D"]) + + # All of the Add node's outputs + current_range_A = op.Range(past_seq_len_0D, total_seq_len_0D, 1, _outputs=["current_range_A"]) + total_seq_len_A = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_A"]) + current_range_B = op.Range(0, total_seq_len_0D, 1, _outputs=["current_range_B"]) + total_seq_len_B = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_B"]) + total_seq_len_C = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_C"]) + + total_seq_len_final = op.Reshape(total_seq_len_0D, pattern.ANY_VALUE, allowzero=0, _outputs=["total_seq_len_final"]) + + # EXPAND BRANCH A + batch_size = op.Shape(past_kv_cache_2, end=1, start=0, _outputs=["batch_size"]) + mask_shape_A = op.Concat(batch_size, [1], seq_len, total_seq_len_A, axis=0, _outputs=["mask_shape_A"]) + mask_shape_A_abs = op.Abs(mask_shape_A, _outputs=["mask_shape_A_abs"]) + reshaped_range_A = op.Reshape(current_range_A, [1, 1, -1, 1], allowzero=1, _outputs=["reshaped_range_A"]) + mask_expanded_A = op.Expand(reshaped_range_A, mask_shape_A_abs, _outputs=["mask_expanded_A"]) + + # EXPAND BRANCH B + mask_shape_B = op.Concat(batch_size, [1], seq_len, total_seq_len_B, axis=0, _outputs=["mask_shape_B"]) + mask_shape_B_abs = op.Abs(mask_shape_B, _outputs=["mask_shape_B_abs"]) + reshaped_range_B = op.Reshape(current_range_B, [1, 1, 1, -1], allowzero=1, _outputs=["reshaped_range_B"]) + mask_expanded_B = op.Expand(reshaped_range_B, mask_shape_B_abs, _outputs=["mask_expanded_B"]) + + # EXPAND BRANCH C + mask_shape_C = op.Concat(batch_size, [1], seq_len, total_seq_len_C, axis=0, _outputs=["mask_shape_C"]) + mask_shape_C_abs = op.Abs(mask_shape_C, _outputs=["mask_shape_C_abs"]) + batch_size_squeezed = op.Squeeze(batch_size, _outputs=["batch_size_squeezed"]) + batch_range = op.Range(0, batch_size_squeezed, 1, _outputs=["batch_range"]) + reshaped_range_C = op.Reshape(batch_range, [-1, 1, 1, 1], allowzero=1, _outputs=["reshaped_range_C"]) + mask_expanded_C = op.Expand(reshaped_range_C, mask_shape_C_abs, _outputs=["mask_expanded_C"]) + + # EXPAND A/B TO AND + mask_expanded_A_sub = op.Sub(mask_expanded_A, 262144, _outputs=["mask_expanded_A_sub"]) + mask_A_B_greater = op.Greater(mask_expanded_B, mask_expanded_A_sub, _outputs=["mask_A_B_greater"]) + mask_A_B_greater_bitwise = op.And(True, mask_A_B_greater, _outputs=["mask_A_B_greater_bitwise"]) + mask_A_B_less = op.LessOrEqual(mask_expanded_B, mask_expanded_A, _outputs=["mask_A_B_less"]) + mask_A_B_combined = op.And(mask_A_B_greater_bitwise, mask_A_B_less, _outputs=["mask_A_B_combined"]) + mask_A_B_combined_bitwise = op.And(True, mask_A_B_combined, _outputs=["mask_A_B_combined_bitwise"]) + + # EXPAND B/C TO AND + unsqueezed_mask_expanded_B = op.Unsqueeze(mask_expanded_B, [-1], _outputs=["unsqueezed_mask_expanded_B"]) + unsqueezed_mask_expanded_C = op.Unsqueeze(mask_expanded_C, [-1], _outputs=["unsqueezed_mask_expanded_C"]) + mask_B_C_concat = op.Concat(unsqueezed_mask_expanded_C, unsqueezed_mask_expanded_B, axis=-1, _outputs=["mask_B_C_concat"]) + attention_mask_bool = op.Cast(attention_mask, to=ir.DataType.BOOL, _outputs=["attention_mask_bool"]) + mask_gatherND = op.GatherND(attention_mask_bool, mask_B_C_concat, batch_dims=0, _outputs=["mask_gatherND"]) + + mask_A_B_C_combined = op.And(mask_A_B_combined_bitwise, mask_gatherND, _outputs=["mask_A_B_C_combined"]) + mask_A_B_C_negated = op.Not(mask_A_B_C_combined, _outputs=["mask_A_B_C_negated"]) + mask_A_B_C_fp32 = op.Cast(mask_A_B_C_negated, to=ir.DataType.FLOAT, _outputs=["mask_A_B_C_fp32"]) + mask_A_B_C_scaled = op.Mul(mask_A_B_C_fp32, pattern.ANY_VALUE) + # Propagation to GQA + mask_sliced = op.Slice(mask_A_B_C_scaled, [0], pattern.ANY_VALUE, [3], [1], _outputs=["mask_sliced"]) + + #mask_where = op.Where(mask_sliced, pattern.ANY_VALUE, pattern.ANY_VALUE, _outputs=["mask_where"]) + + return op.GQA( + mask_sliced, + pattern.ANY_VALUE, # position_ids_k + pattern.ANY_VALUE, # position_ids_q + pattern.ANY_VALUE, # query + pattern.ANY_VALUE, # key + pattern.ANY_VALUE, # value + pattern.ANY_VALUE, # past_key + pattern.ANY_VALUE, # past_value + pattern.ANY_VALUE, # seqlens_k (optional) + pattern.ANY_VALUE, # total_seq_length (optional) + pattern.ANY_VALUE, # cos + pattern.ANY_VALUE, # sin + _allow_other_inputs=True, + _domain="ai.onnxruntime._fusion", + _outputs=["attn_output", "key_seq", "value_seq"], + ) + + + def rewrite( + self, + op, + attention_mask, + attn_output, + **_, + ): + # Compute total_seq_length_int32 and seqlens_k_int32 + total_seq_length_int32, seqlens_k_int32 = self.compute_mask(op, attention_mask) + + gqa_node = attn_output.producer() + assert len(gqa_node.inputs) == 12, ( + f"Expected 12 inputs for GQA node, got {len(gqa_node.inputs)}" + ) + query, key, value, past_key, past_value = gqa_node.inputs[3:8] + cos, sin = gqa_node.inputs[10:12] + updated_inputs = [ + query, + key, + value, + past_key, + past_value, + seqlens_k_int32, + total_seq_length_int32, + cos, + sin, + ] + attributes = gqa_node.attributes + return op.GroupQueryAttention( + *updated_inputs, **attributes, _domain="com.microsoft", _outputs=3 + ) _basic_gqa_rule = GroupQueryAttention.rule() +_longrope_gqa_causal_mask_rule = LongRoPeGQACausalMask.rule() gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule]) +gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule, _longrope_gqa_causal_mask_rule]) -fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules) +fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules) \ No newline at end of file From f0f41a80c88adcde972fe662d656c4760c257384 Mon Sep 17 00:00:00 2001 From: Tommaso Date: Thu, 31 Jul 2025 23:28:48 +0000 Subject: [PATCH 02/26] Added Phi4-mini-reasoning cache insertion and position Id deletion logic --- .../phi4_mini_reasoning_post_processor.py | 821 ++++++++++++++++++ 1 file changed, 821 insertions(+) create mode 100644 onnxscript/rewriter/phi4_mini_reasoning_post_processor.py diff --git a/onnxscript/rewriter/phi4_mini_reasoning_post_processor.py b/onnxscript/rewriter/phi4_mini_reasoning_post_processor.py new file mode 100644 index 0000000000..3832a01498 --- /dev/null +++ b/onnxscript/rewriter/phi4_mini_reasoning_post_processor.py @@ -0,0 +1,821 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import onnx +from onnxscript import ir +import onnx.helper +import numpy as np +import logging +import torch +import math + +from transformers import AutoConfig +from dataclasses import dataclass, field +from typing import Optional, Tuple, List + +class Phi4MiniReasoningPostProcessor: + def __init__(self, config: AutoConfig, io_dtype: ir.DataType = ir.DataType.FLOAT): + self.config = config + self.original_max_position_embeddings = getattr(config, "original_max_position_embeddings", 4096) + self.max_position_embeddings = getattr(config, "max_position_embeddings", 131072) + + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_size = self.hidden_size // self.num_attention_heads + self.io_dtype: ir.DataType = ir.DataType(io_dtype) + + # Torch dtype mapping for ONNX IR DataType + self.to_torch_dtype = { + ir.DataType.FLOAT: torch.float32, + ir.DataType.FLOAT16: torch.float16, + ir.DataType.BFLOAT16: torch.bfloat16, + ir.DataType.DOUBLE: torch.float64, + ir.DataType.INT64: torch.int64, + ir.DataType.INT32: torch.int32, + } + + # Initialize rotary embedding attributes + position_scale = getattr(config, "rope_position_scale", 1.0) + partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) + rotemb_dim = int(self.head_size * partial_rotary_factor) if partial_rotary_factor != 1.0 else 0 + rope_theta = getattr(config, "rope_theta", getattr(config, "rope_embedding_base", 10000.0)) + + self.rotemb_attrs = { + "create_caches": True, # Create cos/sin caches for rotary embeddings + "save_caches": True, # Auto-save cos/sin caches for rotary embeddings after creation + "cache_length": self.max_position_embeddings, # Cache length to use when creating cos/sin caches for rotary embeddings + "theta": rope_theta, # Base value if calculating cos/sin caches from scratch + "partial_rotary_factor": partial_rotary_factor, # Factor for partial rotary embeddings + "interleaved": 0, # Interleave the rotary embeddings (e.g. [0, 0, 0, 1, 1, 1] to [0, 1, 0, 1, 0, 1], RotaryEmbedding kernel expects a default value of 0) + "rotary_embedding_dim": rotemb_dim, # For partial rotary embeddings (RotaryEmbedding kernel expects a default value of 0) + "rescale_factors": 1.0, # Rescale factors when calculating `inv_freq` in rotary embeddings + "t_dtype": torch.int64, # Torch dtype when calculating `t` in rotary embeddings + "position_scale": position_scale, # Scale value when calculating `t` in rotary embeddings + "mscale": 1.0, # Magnitude scaling factor when scaling `emb.cos()/emb.sin()` in rotary embeddings + "mscale_policy": "", # Magnitude scaling policy when scaling `emb.cos()/emb.sin()` in rotary embeddings + } + + # Handle rope scaling configuration for multi-cache scenarios + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + if "short_factor" in config.rope_scaling: + # For models with multiple rotary embedding caches (e.g. Phi-3 mini 128K) + self.rotemb_attrs["mscale_policy"] = config.rope_scaling.get("type", "") + short_factor = torch.tensor(config.rope_scaling["short_factor"], dtype=torch.float32) + long_factor = torch.tensor(config.rope_scaling["long_factor"], dtype=torch.float32) + + short_mscale = config.rope_scaling.get("short_mscale", 0) + long_mscale = config.rope_scaling.get("long_mscale", 0) + short_mscale = short_mscale if short_mscale > 0 else self.make_mscale(self.max_position_embeddings / self.original_max_position_embeddings) + long_mscale = long_mscale if long_mscale > 0 else self.make_mscale(self.max_position_embeddings / self.original_max_position_embeddings) + + self.rotemb_attrs["multi_cache"] = { + "short_factor": short_factor, # Short factor when calculating `inv_freq` in rotary embeddings + "long_factor": long_factor, # Long factor when calculating `inv_freq` in rotary embeddings + "short_mscale": short_mscale, # Magnitude scaling for short factor when scaling `emb.cos()/emb.sin()` in rotary embeddings + "long_mscale": long_mscale, # Magnitude scaling for long factor when scaling `emb.cos()/emb.sin()` in rotary embeddings + } + + @dataclass + class PatternNodes: + """Container for the nodes found in the old Cos/Sin value generation pattern.""" + gather_value: Optional[ir.Value] = None + matmul_node: Optional[ir.Node] = None + cos_node: Optional[ir.Node] = None + sin_node: Optional[ir.Node] = None + + @dataclass + class CacheData: + """Container for generated cache data.""" + cos_large: np.ndarray + sin_large: np.ndarray + cos_small: np.ndarray + sin_small: np.ndarray + + @dataclass + class IfNodeComponents: + """Container for If node components.""" + threshold_const_node: ir.Node + greater_node: ir.Node + if_node: ir.Node + cos_output: ir.Value + sin_output: ir.Value + + @dataclass + class ProcessingChainNodes: + """Container for position processing chain nodes.""" + position_ids_input: Optional[ir.Value] = None + reduce_max_node: Optional[ir.Node] = None + add_node: Optional[ir.Node] = None + range_node: Optional[ir.Node] = None + reshape_node: Optional[ir.Node] = None + cast_node: Optional[ir.Node] = None + constant_nodes: List[ir.Node] = field(default_factory=list) + + def make_mscale(self, mscale: float) -> float: + """Calculate magnitude scaling factor for RoPE.""" + if mscale <= 1.0: + return 1.0 + return math.sqrt(1 + math.log(mscale) / math.log(self.original_max_position_embeddings)) + + def calculate_rotary_embedding_caches(self): + """Generate cos/sin caches from scratch using the current rotemb_attrs.""" + if self.rotemb_attrs["rotary_embedding_dim"] > 0: + dim = self.rotemb_attrs["rotary_embedding_dim"] + else: + dim = int(self.rotemb_attrs["partial_rotary_factor"] * self.head_size) + + inv_freq, attention_factor = self._compute_longrope_parameters( + cache_length=self.rotemb_attrs["cache_length"], + dim=dim + ) + + cache_length = self.rotemb_attrs["cache_length"] + position_ids = torch.arange(cache_length, dtype=torch.int64).unsqueeze(0) # Shape: (1, cache_length) + + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) # (1, dim//2, 1) + position_ids_expanded = position_ids[:, None, :].float() # (1, 1, cache_length) + + device_type = "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) # (1, cache_length, dim//2) + emb = torch.cat((freqs, freqs), dim=-1) # (1, cache_length, dim) + cos_cache = emb.cos() * attention_factor # (1, cache_length, dim) + sin_cache = emb.sin() * attention_factor # (1, cache_length, dim) + + return cos_cache, sin_cache + + def _compute_longrope_parameters(self, cache_length: int, dim: int) -> tuple: + """ + Computes the inverse frequencies with LongRoPE scaling for Phi-4. + Based on the official transformers implementation. + """ + base = self.rotemb_attrs["theta"] + + # Check if we have multi_cache configuration (LongRoPE) + if "multi_cache" in self.rotemb_attrs: + long_factor = self.rotemb_attrs["multi_cache"]["long_factor"] + short_factor = self.rotemb_attrs["multi_cache"]["short_factor"] + + # Select factor based on cache length vs original max position embeddings + if cache_length > self.original_max_position_embeddings: + ext_factors = torch.tensor(long_factor, dtype=torch.float32, device="cpu") + attention_factor = self.rotemb_attrs["multi_cache"]["long_mscale"] + else: + ext_factors = torch.tensor(short_factor, dtype=torch.float32, device="cpu") + attention_factor = self.rotemb_attrs["multi_cache"]["short_mscale"] + + inv_freq_shape = torch.arange(0, dim, 2, dtype=torch.int64, device="cpu").float() / dim + inv_freq = 1.0 / (ext_factors * base**inv_freq_shape) + + if "rescale_inv_freq" in self.rotemb_attrs: + inv_freq = self.make_inv_freq_rescaled(inv_freq) + + return inv_freq, attention_factor + + def reformat_rotary_embedding_caches(self): + """Generate and format cos/sin caches for the current configuration.""" + cos_cache, sin_cache = self.calculate_rotary_embedding_caches() + + # Convert to the target dtype + cos_cache = cos_cache.to(self.to_torch_dtype[self.io_dtype]) + sin_cache = sin_cache.to(self.to_torch_dtype[self.io_dtype]) + + # Slice cos/sin caches from (M, H) to (M, H/2) + hidden_dim = cos_cache.shape[-1] + cos_cache = cos_cache.squeeze()[:, : (hidden_dim // 2)] + cos_cache = cos_cache.to(self.to_torch_dtype[self.io_dtype]) + sin_cache = sin_cache.squeeze()[:, : (hidden_dim // 2)] + sin_cache = sin_cache.to(self.to_torch_dtype[self.io_dtype]) + + # Slice cos/sin caches from (M, H/2) to (M, R/2) if partial rotary embeddings are used + if self.rotemb_attrs["partial_rotary_factor"] != 1.0: + cos_cache = cos_cache[:, : (self.rotemb_attrs["rotary_embedding_dim"] // 2)] + sin_cache = sin_cache[:, : (self.rotemb_attrs["rotary_embedding_dim"] // 2)] + + return cos_cache, sin_cache + + def make_inv_freq_rescaled(self, inv_freq): + scale_factor = self.rotemb_attrs["rescale_inv_freq"]["factor"] + low_freq_factor = self.rotemb_attrs["rescale_inv_freq"]["low_freq_factor"] + high_freq_factor = self.rotemb_attrs["rescale_inv_freq"]["high_freq_factor"] + old_context_len = self.original_max_position_embeddings + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in inv_freq: + wavelen = 2 * torch.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scale_factor) + else: + smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) + + return torch.tensor(new_freqs, dtype=inv_freq.dtype) + + def delete_position_processing_nodes(self, model: ir.Model) -> ir.Model: + """ + Delete the position processing nodes from the ONNX IR graph. + This removes the sequence: position_ids -> ReduceMax -> Add -> Range -> Reshape -> Cast + + Args: + model: ONNX IR Model to modify + + Returns: + Modified ONNX IR Model with nodes removed + """ + graph = model.graph + + # Step 1: Find position processing chain nodes + chain_nodes = self._find_position_processing_chain(graph) + if not self._validate_processing_chain(chain_nodes): + return model + + # Step 2: Find constants that feed the chain + self._find_chain_feeding_constants(graph, chain_nodes) + + # Step 3: Remove the processing chain nodes + self._remove_processing_chain_nodes(graph, chain_nodes) + + # Step 4: Clean up position_ids input if unused + self._cleanup_position_ids_input(graph, chain_nodes.position_ids_input) + + return model + + def _find_position_processing_chain(self, graph) -> ProcessingChainNodes: + """Find the position processing chain nodes in the graph.""" + chain = self.ProcessingChainNodes() + + # Find position_ids input + chain.position_ids_input = self._find_position_ids_input(graph) + if not chain.position_ids_input: + return chain + + # Find processing nodes in sequence + chain.reduce_max_node = self._find_reduce_max_node(graph, chain.position_ids_input) + + if chain.reduce_max_node: + chain.add_node = self._find_add_node(graph, chain.reduce_max_node) + + if chain.add_node: + chain.range_node = self._find_range_node(graph, chain.add_node) + + if chain.range_node: + chain.reshape_node = self._find_reshape_node(graph, chain.range_node) + + if chain.reshape_node: + chain.cast_node = self._find_cast_node(graph, chain.reshape_node) + + return chain + + def _find_position_ids_input(self, graph) -> Optional[ir.Value]: + """Find the position_ids input in the graph.""" + for input_val in graph.inputs: + if "position_ids" in input_val.name: + logging.info(f"Found position_ids input: {input_val.name}") + return input_val + + logging.warning("position_ids input not found") + return None + + def _find_reduce_max_node(self, graph, position_ids_input: ir.Value) -> Optional[ir.Node]: + """Find ReduceMax node that processes position_ids.""" + for node in graph: + if node.op_type == "ReduceMax": + if any(input_val == position_ids_input for input_val in node.inputs): + logging.info(f"Found ReduceMax node: {node.name}") + return node + return None + + def _find_add_node(self, graph, reduce_max_node: ir.Node) -> Optional[ir.Node]: + """Find Add node that follows ReduceMax.""" + reduce_max_outputs = reduce_max_node.outputs + for node in graph: + if node.op_type == "Add": + if any(input_val in reduce_max_outputs for input_val in node.inputs): + logging.info(f"Found Add node following ReduceMax: {node.name}") + return node + return None + + def _find_range_node(self, graph, add_node: ir.Node) -> Optional[ir.Node]: + """Find Range node that follows Add.""" + add_outputs = add_node.outputs + for node in graph: + if node.op_type == "Range": + if any(input_val in add_outputs for input_val in node.inputs): + logging.info(f"Found Range node following Add: {node.name}") + return node + return None + + def _find_reshape_node(self, graph, range_node: ir.Node) -> Optional[ir.Node]: + """Find Reshape node that follows Range.""" + range_outputs = range_node.outputs + for node in graph: + if node.op_type == "Reshape": + if any(input_val in range_outputs for input_val in node.inputs): + logging.info(f"Found Reshape node following Range: {node.name}") + return node + return None + + def _find_cast_node(self, graph, reshape_node: ir.Node) -> Optional[ir.Node]: + """Find Cast node that follows Reshape.""" + reshape_outputs = reshape_node.outputs + for node in graph: + if node.op_type == "Cast": + if any(input_val in reshape_outputs for input_val in node.inputs): + logging.info(f"Found Cast node following Reshape: {node.name}") + return node + return None + + def _validate_processing_chain(self, chain_nodes: ProcessingChainNodes) -> bool: + """Validate that sufficient chain nodes were found for deletion.""" + if not chain_nodes.position_ids_input: + logging.warning("Cannot delete processing chain: position_ids input not found") + return False + + # We need at least the reduce_max_node to proceed + if not chain_nodes.reduce_max_node: + logging.warning("Cannot delete processing chain: ReduceMax node not found") + return False + + # Log found nodes + found_nodes = [] + if chain_nodes.reduce_max_node: + found_nodes.append(f"ReduceMax: {chain_nodes.reduce_max_node.name}") + if chain_nodes.add_node: + found_nodes.append(f"Add: {chain_nodes.add_node.name}") + if chain_nodes.range_node: + found_nodes.append(f"Range: {chain_nodes.range_node.name}") + if chain_nodes.reshape_node: + found_nodes.append(f"Reshape: {chain_nodes.reshape_node.name}") + if chain_nodes.cast_node: + found_nodes.append(f"Cast: {chain_nodes.cast_node.name}") + + logging.info(f"Found position processing chain: {', '.join(found_nodes)}") + return True + + def _find_chain_feeding_constants(self, graph, chain_nodes: ProcessingChainNodes) -> None: + """Find constant nodes that exclusively feed the processing chain.""" + chain_node_list = [ + node for node in [ + chain_nodes.reduce_max_node, + chain_nodes.add_node, + chain_nodes.range_node, + chain_nodes.reshape_node, + chain_nodes.cast_node + ] if node is not None + ] + + for node in graph: + if node.op_type == "Constant": + constant_output = node.outputs[0] if node.outputs else None + if constant_output and self._constant_feeds_chain_exclusively( + graph, constant_output, chain_node_list, node + ): + chain_nodes.constant_nodes.append(node) + logging.info(f"Found constant node feeding chain: {node.name}") + + def _constant_feeds_chain_exclusively( + self, + graph, + constant_output: ir.Value, + chain_nodes: List[ir.Node], + constant_node: ir.Node + ) -> bool: + """Check if a constant exclusively feeds the processing chain.""" + # Check if constant feeds any chain node + feeds_chain = any( + any(input_val == constant_output for input_val in chain_node.inputs) + for chain_node in chain_nodes + ) + + if not feeds_chain: + return False + + # Check if constant is used by any non-chain nodes + for node in graph: + if node not in chain_nodes and node != constant_node: + if any(input_val == constant_output for input_val in node.inputs): + return False + + return True + + def _remove_processing_chain_nodes(self, graph, chain_nodes: ProcessingChainNodes) -> None: + """Remove all processing chain nodes from the graph.""" + nodes_to_delete = [ + node for node in [ + chain_nodes.reduce_max_node, + chain_nodes.add_node, + chain_nodes.range_node, + chain_nodes.reshape_node, + chain_nodes.cast_node + ] if node is not None + ] + nodes_to_delete.extend(chain_nodes.constant_nodes) + + if nodes_to_delete: + self._delete_nodes_from_graph(graph, nodes_to_delete) + else: + logging.warning("No processing chain nodes found to delete") + + def _delete_nodes_from_graph(self, graph, nodes_to_delete: List[ir.Node]) -> None: + """Delete nodes from the graph with error handling.""" + try: + graph.remove(nodes_to_delete) + logging.info(f"Successfully deleted {len(nodes_to_delete)} processing chain nodes") + except Exception as e: + logging.error(f"Error deleting nodes in batch: {e}") + # Try deleting nodes one by one + self._delete_nodes_individually(graph, nodes_to_delete) + """ + def _delete_nodes_individually(self, graph, nodes_to_delete: List[ir.Node]) -> None: + Delete nodes individually with error handling. + for node in nodes_to_delete: + try: + graph.remove([node]) + logging.info(f"Successfully deleted node: {node.name}") + except Exception as e: + logging.error(f"Failed to delete node {node.name}: {e}") + """ + def _cleanup_position_ids_input(self, graph, position_ids_input: Optional[ir.Value]) -> None: + """Remove position_ids input if it's no longer used.""" + if not position_ids_input: + return + + # Check if position_ids is still used by any remaining nodes + if self._input_still_used(graph, position_ids_input): + logging.info(f"position_ids input {position_ids_input.name} is still in use") + return + + try: + graph.inputs.remove(position_ids_input) + logging.info(f"Removed unused position_ids input: {position_ids_input.name}") + except Exception as e: + logging.warning(f"Could not remove position_ids input: {e}") + + def _input_still_used(self, graph, input_value: ir.Value) -> bool: + """Check if an input value is still used by any nodes in the graph.""" + return any( + any(input_val == input_value for input_val in node.inputs) + for node in graph + ) + + def insert_rotary_embedding_caches(self, model: ir.Model, threshold: int = 4096) -> ir.Model: + """ + Replaces the current Cos/Sin value generation with an control flow node containing + cached Cos/Sin values. + + Args: + model: ONNX IR Model to modify + threshold: Threshold value for Phi-4-mini-reasoning cache selection (default: 4096) + + Returns: + Modified ONNX IR Model with MatMul→Cos/Sin replaced by cache-enabled If node + """ + graph = model.graph + + # Step 1: Find pattern nodes + pattern = self._find_pattern_nodes(graph) + if not self._validate_pattern_nodes(pattern): + return model + + # Step 2: Generate cache data + cache_data = self._generate_cache_data() + + # Step 3: Create If node with caches + if_components = self._create_if_node_with_caches(cache_data, threshold, pattern.gather_value) + + # Step 4: Replace pattern with If node + self._replace_pattern_with_if_node(graph, pattern, if_components) + + # Step 5: Clean up old nodes + self._remove_old_nodes(graph, pattern) + + return model + + + def _find_pattern_nodes(self, graph) -> PatternNodes: + """Find the MatMul→Cos/Sin pattern nodes in the graph.""" + pattern = self.PatternNodes() + + # Find attention mask gather chain + pattern.gather_value = self._find_attention_mask_gather_value(graph) + + # Find MatMul→Cos/Sin pattern + matmul_cos_sin = self._find_matmul_cos_sin_nodes(graph) + pattern.matmul_node = matmul_cos_sin[0] + pattern.cos_node = matmul_cos_sin[1] + pattern.sin_node = matmul_cos_sin[2] + + return pattern + + def _find_attention_mask_gather_value(self, graph) -> Optional[ir.Value]: + """ + Find the gather value from the attention mask processing chain. + Chain: attention_mask → Shape → Gather + """ + ATTENTION_MASK_NAME = "attention_mask" + + # Find Shape node that processes attention_mask + shape_output_name = None + for node in graph: + if node.op_type == "Shape": + for input_value in node.inputs: + if ATTENTION_MASK_NAME in input_value.name: + shape_output_name = node.outputs[0].name if node.outputs else None + break + if shape_output_name: + break + + if not shape_output_name: + return None + + # Find Gather node that follows the Shape + for node in graph: + if node.op_type == "Gather": + for input_value in node.inputs: + if input_value.name == shape_output_name: + return node.outputs[0] if node.outputs else None + + return None + + def _find_matmul_cos_sin_nodes(self, graph) -> Tuple[Optional[ir.Node], Optional[ir.Node], Optional[ir.Node]]: + """ + Find MatMul node that feeds into both Cos and Sin nodes. + + Returns: + Tuple of (matmul_node, cos_node, sin_node) + """ + for node in graph: + if node.op_type == "MatMul": + matmul_output = node.outputs[0] if node.outputs else None + if matmul_output: + cos_node, sin_node = self._find_cos_sin_consumers(graph, matmul_output) + + if cos_node and sin_node: + logging.info(f"Found target MatMul node '{node.name}' that feeds into Cos and Sin nodes") + return node, cos_node, sin_node + + return None, None, None + + def _find_cos_sin_consumers(self, graph, matmul_output: ir.Value) -> Tuple[Optional[ir.Node], Optional[ir.Node]]: + """Find Cos and Sin nodes that consume the MatMul output.""" + cos_node = None + sin_node = None + + for consumer_node in graph: + if consumer_node.op_type == "Cos": + if self._node_consumes_value(consumer_node, matmul_output): + cos_node = consumer_node + elif consumer_node.op_type == "Sin": + if self._node_consumes_value(consumer_node, matmul_output): + sin_node = consumer_node + + return cos_node, sin_node + + def _node_consumes_value(self, node: ir.Node, value: ir.Value) -> bool: + """Check if a node consumes the given value as input.""" + return any(input_val == value for input_val in node.inputs) + + def _validate_pattern_nodes(self, pattern: PatternNodes) -> bool: + """Validate that all required pattern nodes were found.""" + if not pattern.gather_value: + logging.warning("Error: Could not find attention mask gather node") + return False + + if not pattern.matmul_node: + logging.warning("Error: Could not find MatMul node that feeds into Cos and Sin nodes") + return False + + if not pattern.cos_node or not pattern.sin_node: + logging.warning("Error: Could not find both Cos and Sin nodes fed by the MatMul") + return False + + # Log found pattern + logging.info(f"Found MatMul→Cos/Sin pattern:") + logging.info(f"MatMul: {pattern.matmul_node.name}") + logging.info(f"Cos: {pattern.cos_node.name}") + logging.info(f"Sin: {pattern.sin_node.name}") + + return True + + def _generate_cache_data(self) -> CacheData: + """Generate cos/sin cache data for both large and small scenarios.""" + original_cache_length = self.rotemb_attrs["cache_length"] + + try: + # Generate large cache (for long sequences) + self.rotemb_attrs["cache_length"] = self.max_position_embeddings + if "multi_cache" in self.rotemb_attrs: + self.rotemb_attrs["rescale_factors"] = self.rotemb_attrs["multi_cache"]["long_factor"] + self.rotemb_attrs["mscale"] = self.rotemb_attrs["multi_cache"]["long_mscale"] + cos_cache_large, sin_cache_large = self.reformat_rotary_embedding_caches() + + # Generate small cache (for short sequences) + self.rotemb_attrs["cache_length"] = self.original_max_position_embeddings + if "multi_cache" in self.rotemb_attrs: + self.rotemb_attrs["rescale_factors"] = self.rotemb_attrs["multi_cache"]["short_factor"] + self.rotemb_attrs["mscale"] = self.rotemb_attrs["multi_cache"]["short_mscale"] + cos_cache_small, sin_cache_small = self.reformat_rotary_embedding_caches() + + # Convert to numpy arrays for ONNX + cache_data = self.CacheData( + cos_large=cos_cache_large.detach().cpu().numpy(), + sin_large=sin_cache_large.detach().cpu().numpy(), + cos_small=cos_cache_small.detach().cpu().numpy(), + sin_small=sin_cache_small.detach().cpu().numpy() + ) + + logging.info(f"Generated caches - Large: {cache_data.cos_large.shape}, Small: {cache_data.cos_small.shape}") + return cache_data + + finally: + # Restore original cache length + self.rotemb_attrs["cache_length"] = original_cache_length + + def _create_if_node_with_caches(self, cache_data: CacheData, threshold: int, gather_value: ir.Value) -> IfNodeComponents: + """Create the If node with cache branches.""" + # Create threshold comparison + threshold_const_node, greater_node = self._create_threshold_comparison(threshold, gather_value) + + # Create cache branches + then_branch = self._create_cache_branch(cache_data.cos_large, cache_data.sin_large, "large") + else_branch = self._create_cache_branch(cache_data.cos_small, cache_data.sin_small, "small") + + # Create If node outputs + if_cos_output = ir.Value( + name="cos_cache", + type=ir.TensorType(self.io_dtype), + shape=ir.Shape(["max_sequence_length", "head_dim / 2"]) + ) + + if_sin_output = ir.Value( + name="sin_cache", + type=ir.TensorType(self.io_dtype), + shape=ir.Shape(["max_sequence_length", "head_dim / 2"]) + ) + + # Create the If node + if_node = ir.node( + "If", + inputs=[greater_node.outputs[0]], + outputs=[if_cos_output, if_sin_output], + name="cos_sin_cache_if", + attributes={ + "then_branch": ir.Attr("then_branch", ir.AttributeType.GRAPH, then_branch), + "else_branch": ir.Attr("else_branch", ir.AttributeType.GRAPH, else_branch) + } + ) + + return self.IfNodeComponents( + threshold_const_node=threshold_const_node, + greater_node=greater_node, + if_node=if_node, + cos_output=if_cos_output, + sin_output=if_sin_output + ) + + def _create_threshold_comparison(self, threshold: int, gather_value: ir.Value) -> Tuple[ir.Node, ir.Node]: + """Create threshold constant and greater comparison nodes.""" + # Create threshold constant + threshold_const_name = f"threshold_const_{threshold}" + threshold_value = ir.Value( + name=threshold_const_name, + type=ir.TensorType(ir.DataType.INT64), + shape=ir.Shape([]) + ) + threshold_value.const_value = ir.tensor(threshold, dtype=ir.DataType.INT64) + + threshold_const_node = ir.node( + "Constant", + inputs=[], + outputs=[threshold_value], + name=f"Constant_{threshold}", + attributes={"value": ir.tensor(threshold, dtype=ir.DataType.INT64)} + ) + + # Create Greater node + greater_output_value = ir.Value( + name=f"greater_output_{threshold}", + type=ir.TensorType(ir.DataType.BOOL), + shape=ir.Shape([]) + ) + + greater_node = ir.node( + "Greater", + inputs=[gather_value, threshold_value], + outputs=[greater_output_value], + name=f"Greater_{threshold}" + ) + + return threshold_const_node, greater_node + + def _create_cache_branch(self, cos_cache: np.ndarray, sin_cache: np.ndarray, branch_type: str) -> ir.Graph: + """Create a cache branch for the If node.""" + # Create cache constant values and nodes + cos_cache_value = ir.Value( + name=f"cos_cache_{branch_type}", + type=ir.TensorType(self.io_dtype), + shape=ir.Shape(cos_cache.shape) + ) + cos_cache_node = ir.node( + "Constant", + inputs=[], + outputs=[cos_cache_value], + name=f"{branch_type}_cos_cache_Constant", + attributes={"value": ir.tensor(cos_cache, dtype=self.io_dtype)} + ) + + sin_cache_value = ir.Value( + name=f"sin_cache_{branch_type}", + type=ir.TensorType(self.io_dtype), + shape=ir.Shape(sin_cache.shape) + ) + sin_cache_node = ir.node( + "Constant", + inputs=[], + outputs=[sin_cache_value], + name=f"{branch_type}_sin_cache_Constant", + attributes={"value": ir.tensor(sin_cache, dtype=self.io_dtype)} + ) + + # Create subgraph + return ir.Graph( + inputs=[], + outputs=[cos_cache_value, sin_cache_value], + nodes=[cos_cache_node, sin_cache_node], + name=f"{branch_type}_rotemb_caches_graph", + ) + + def _replace_pattern_with_if_node(self, graph, pattern: PatternNodes, if_components: IfNodeComponents) -> None: + """Replace the pattern nodes with the If node.""" + # Find all consumers of the original Cos and Sin outputs + cos_consumers = self._find_value_consumers(graph, pattern.cos_node.outputs[0]) + sin_consumers = self._find_value_consumers(graph, pattern.sin_node.outputs[0]) + + # Replace references to original outputs with If node outputs + self._update_consumers(cos_consumers, if_components.cos_output) + self._update_consumers(sin_consumers, if_components.sin_output) + + # Update GroupQueryAttention nodes if present + self._update_group_query_attention_nodes(graph, if_components) + + # Add new nodes to the graph + graph.append(if_components.threshold_const_node) + graph.append(if_components.greater_node) + graph.append(if_components.if_node) + + def _find_value_consumers(self, graph, value: ir.Value) -> List[Tuple[ir.Node, int]]: + """Find all nodes that consume a given value.""" + consumers = [] + for node in graph: + for i, input_val in enumerate(node.inputs): + if input_val == value: + consumers.append((node, i)) + return consumers + + def _update_consumers(self, consumers: List[Tuple[ir.Node, int]], new_value: ir.Value) -> None: + """Update consumer nodes to use a new value.""" + for node, input_idx in consumers: + try: + ir.Node.replace_input_with(node, input_idx, new_value) + except Exception as e: + logging.warning(f"Warning: Could not update {node.name or 'unnamed_node'} input[{input_idx}]: {e}") + + def _update_group_query_attention_nodes(self, graph, if_components: IfNodeComponents) -> None: + """Update GroupQueryAttention nodes to use cache inputs.""" + gqa_nodes = [node for node in graph if node.op_type == "GroupQueryAttention"] + + for gqa_node in gqa_nodes: + node_name = gqa_node.name or "GroupQueryAttention_node" + try: + # Replace cos_cache at position 7 and sin_cache at position 8 + if len(gqa_node.inputs) > 7: + ir.Node.replace_input_with(gqa_node, 7, if_components.cos_output) + + if len(gqa_node.inputs) > 8: + ir.Node.replace_input_with(gqa_node, 8, if_components.sin_output) + + except Exception as e: + logging.warning(f"Warning: Could not update {node_name} inputs: {e}") + + def _remove_old_nodes(self, graph, pattern: PatternNodes) -> None: + """Remove the old MatMul, Cos, and Sin nodes.""" + nodes_to_remove = [pattern.matmul_node, pattern.cos_node, pattern.sin_node] + + try: + graph.remove(nodes_to_remove) + logging.info(f"Successfully removed MatMul→Cos/Sin sequence") + except Exception as e: + logging.warning(f"Warning: Could not remove some nodes: {e}") + # Try removing nodes one by one + for node in nodes_to_remove: + try: + graph.remove([node]) + logging.info(f"Removed {node.op_type} node: {node.name}") + except Exception as e2: + logging.warning(f"Could not remove {node.op_type} node {node.name}: {e2}") From 758e92d751f552a68a5d243e427f0a58b80ff1b6 Mon Sep 17 00:00:00 2001 From: Tommaso Date: Fri, 1 Aug 2025 16:00:59 +0000 Subject: [PATCH 03/26] Removed whitespace from gqa longrope fusion --- onnxscript/rewriter/ort_fusions/gqa.py | 282 +++++++++++-------------- 1 file changed, 127 insertions(+), 155 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 2b7c314b3e..7b509d4840 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -7,9 +7,8 @@ import numpy as np import onnx_ir as ir -import onnxscript.onnx_types as _onnx_types import onnxscript.rewriter._fusion_utils as _fusion_utils -from onnxscript.rewriter import _basics, _ir_utils, pattern +from onnxscript.rewriter import _ir_utils, pattern """ GroupQueryAttention: This generalizes MHA by allowing the number of heads to be different @@ -33,20 +32,7 @@ Dim = Union[int, ir.SymbolicDim] -def _is_model_input(value: ir.Value, name: str, model: ir.Model) -> bool: - return value in model.graph.inputs and value.name == name - - -def _causal_mask( - op, - input_ids, - past_kv_cache, - shape_B111, - min_val, - window_size, - dtype, -): - """Defines a pattern for a pure causal mask, with optional sliding window support.""" +def causal_mask_pattern(op, input_ids, past_kv_cache, shape_B111): seq_len = op.Shape(input_ids, end=2, start=1) seq_len_0D = op.Squeeze(seq_len) @@ -56,93 +42,28 @@ def _causal_mask( total_seq_len_0D = op.Add(past_seq_len_0D, seq_len_0D) total_seq_len = op.Reshape(total_seq_len_0D, [-1]) + # The Phi modeling code generates the following +1 as the target-length, which seems + # unnecessary in this context. But using it for pattern-matching against + # generated onnx model. + total_seq_len_plus_1_0D = op.Add(total_seq_len_0D, 1) + total_seq_len_plus_1 = op.Reshape(total_seq_len_plus_1_0D, [-1]) + current_range = op.Range(past_seq_len_0D, total_seq_len_0D, 1) - mask_shape = op.Concat(seq_len, total_seq_len, axis=0) - mask_all_min_expand = op.Expand(min_val, mask_shape) - # The following Trilu is optional: not used in Phi models, but used in LLama. - mask_all_min_trilu = op.Trilu(mask_all_min_expand, 1, upper=1) - mask_all_min = pattern.OrValue([mask_all_min_expand, mask_all_min_trilu]) - total_range_as_row = op.Range(0, total_seq_len_0D, 1) + mask_shape = op.Concat(seq_len, total_seq_len_plus_1, axis=0) + min_float32 = float(np.finfo(np.float32).min) + mask_all_min = op.Expand(min_float32, mask_shape) + total_range_as_row = op.Range(0, total_seq_len_plus_1_0D, 1) current_range_as_column = op.Reshape(current_range, [-1, 1]) - - non_causal = op.Greater(total_range_as_row, current_range_as_column) - - # sliding window support: - current_range_minus_window = op.Sub(current_range_as_column, window_size) - out_of_sliding_window = op.LessOrEqual(total_range_as_row, current_range_minus_window) - non_causal_sliding_window = op.Or(non_causal, out_of_sliding_window) - - boolean_mask = pattern.OrValue([non_causal, non_causal_sliding_window]) - - float_0_1_mask = op.Cast(boolean_mask, to=dtype) + boolean_mask = op.Greater(total_range_as_row, current_range_as_column) + float_0_1_mask = op.Cast(boolean_mask, to=1) float_0_min_mask = op.Mul(mask_all_min, float_0_1_mask) - mask_4d_11ST = op.Unsqueeze(float_0_min_mask, [0, 1]) - mask_4d_B1ST = op.Expand(mask_4d_11ST, shape_B111) - - return mask_4d_B1ST - - -class _CausalMaskPattern(pattern.PatternBase): - def pattern( - self, - op, - input_ids, - past_kv_cache, - shape_B111, - min_val, - window_size, - dtype1, - attn_mask_2d, - dtype2, - ): - causal_mask = _causal_mask( - op, - input_ids, - past_kv_cache, - shape_B111, - min_val, - window_size, - dtype1, - ) - - attn_mask_4d = op.Unsqueeze(attn_mask_2d, [1, 2]) - attn_mask_4d_cast = op.Cast(attn_mask_4d, to=dtype2) - - sum = op.Add(causal_mask, attn_mask_4d_cast) - sum_fp32 = op.Cast(sum, to=ir.DataType.FLOAT) - # The cast is optional, and may be absent if the sum is already in float32. - sum_fp32 = pattern.OrValue([sum_fp32, sum]) - is_zero = op.Equal(sum_fp32, 0.0) - result = op.Where(is_zero, min_val, causal_mask) - return result - - def check(self, context, dtype1, dtype2, min_val, attn_mask_2d, sliding_window=None, **_): - # Check that attn_mask_2d is the model input "attention_mask" - if not _is_model_input(attn_mask_2d, "attention_mask", context.model): - return pattern.MatchResult().fail("Invalid attention_mask input", attn_mask_2d) - - if dtype1.as_int() != dtype2.as_int(): - return pattern.MatchResult().fail("Dtype mismatch", [dtype1, dtype2]) - - # Check that min_val is a constant and matches the expected minimum value for the dtype. - min_value = _ir_utils.get_singleton_value(min_val) - if min_value is None: - return pattern.MatchResult().fail("Minval is not a constant.", min_val) - expected_min_value = np.finfo(min_val.dtype.numpy()).min - if min_value != expected_min_value: - return pattern.MatchResult().fail( - f"Expected min value {expected_min_value}, got {min_value}", min_val - ) - - # TODO(rama) Sliding window: not yet supported. - if sliding_window: - return pattern.MatchResult().fail( - "Sliding window not yet supported", sliding_window - ) - return True - + mask_4d = op.Unsqueeze(float_0_min_mask, [0, 1]) + mask_B1ST_plus = op.Expand(mask_4d, shape_B111) -_causal_mask_pattern = _CausalMaskPattern() + # Get rid of the extra +1 added above: total_seq_len is enough, no + # need for total_seq_len+1. + mask_B1ST = op.Slice(mask_B1ST_plus, [0], total_seq_len, [3], [1]) + return mask_B1ST class GroupQueryAttention(pattern.RewriteRuleClassBase): @@ -157,7 +78,8 @@ def pattern( value_BSDkv, past_key, past_value, - position_ids, + position_ids_q, + position_ids_k, cos, sin, mask, @@ -179,7 +101,7 @@ def pattern( query_BHSDh_rope = op.RotaryEmbedding( query_BHSDh, - position_ids, + position_ids_q, cos, sin, _domain="com.microsoft", @@ -187,7 +109,7 @@ def pattern( ) key_BHkvSDh_rope = op.RotaryEmbedding( key_BHkvSDh, - position_ids, + position_ids_k, cos, sin, _domain="com.microsoft", @@ -232,7 +154,7 @@ def pattern( def check( self, - context: _basics.MatchContext, + op, query_BSD, key_BSDkv, value_BSDkv, @@ -242,7 +164,6 @@ def check( key_BHkvSDh_rope, query_BSHDh, key_BSHkvDh, - mask, **_, ): bindings: dict[str, Dim] = {} @@ -289,20 +210,6 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool: ) self._interleaved = query_interleaved - # Check mask: - mask_node = mask.producer() - if mask_node is None: - return pattern.MatchResult().fail("Unhandled mask pattern", mask) - mask_match_result = _causal_mask_pattern.match( - context.model, - context.graph_or_function, - mask_node, - check_nodes_are_removable=False, - ) - if mask_match_result is None: - return pattern.MatchResult().fail("Mask does not match causal mask pattern", mask) - # TODO: handle sliding window support in mask - return True def rewrite( @@ -313,37 +220,24 @@ def rewrite( value_BSDkv, past_key, past_value, - position_ids, + position_ids_q, + position_ids_k, cos, sin, mask, **_, ): - # Note that the following optimization is specific to current ORT GenAI attention-mask - # usage. Specifically, it assumes that the model-input "attention_mask" is a 2D - # mask with shape [batch_size, sequence_length], and that the mask is a 0/1 mask - # that is used only to indicate the current tokens. Hence, the input attention_mask - # is redundant as long as past-sequence-length and current-sequence-length can be - # computed. - - # Construct seqlens_k and total_seq_length_int32 from position_ids - # seqlens_k : int32[batch_size] indicates total_sequence-length-1 for each batch - # position_ids: int64[batch_size, sequence_length] indicates the position of each token - one_int32_0d = op.Constant(value=ir.tensor(1, dtype=ir.DataType.INT32)) - one_int64_1d = op.Constant(value=ir.tensor([1], dtype=ir.DataType.INT64)) - zero_int64_1d = op.Constant(value=ir.tensor([0], dtype=ir.DataType.INT64)) - seqlens_k_int64 = op.ReduceMax(position_ids, one_int64_1d, keepdims=0) - seqlens_k = op.Cast(seqlens_k_int64, to=ir.DataType.INT32) - max_seq_length = op.ReduceMax(seqlens_k, zero_int64_1d, keepdims=0) - total_seq_length_int32 = op.Add(max_seq_length, one_int32_0d) - return op.GroupQueryAttention( + return op.GQA( + mask, + position_ids_k, + position_ids_q, query_BSD, key_BSDkv, value_BSDkv, past_key, past_value, - seqlens_k, - total_seq_length_int32, + None, # seqlens_k, + None, # total_seq_length_int32, cos, sin, num_heads=self.num_heads, @@ -351,23 +245,101 @@ def rewrite( do_rotary=1, rotary_interleaved=self._interleaved, # skipped optional attributes: local_window_size, scale, smooth_softmax, softcap - _domain="com.microsoft", + _domain="ai.onnxruntime._fusion", _outputs=3, ) + +class GQACausalMask(pattern.RewriteRuleClassBase): + def __init__(self): + super().__init__("GQACausalMask", remove_nodes=False) + + def pattern( + self, + op, + mask, + input_ids, + some_kv_cache, + shape_B111, + past_seq_length, + total_seq_length, + ): + mask = causal_mask_pattern(op, input_ids, some_kv_cache, shape_B111) + position_ids = op.Range(past_seq_length, total_seq_length, 1) + position_ids_q = op.Unsqueeze(position_ids, [0]) + position_ids_k = op.Unsqueeze(position_ids, [0]) + return op.GQA( + mask, + position_ids_k, + position_ids_q, + _allow_other_inputs=True, + _domain="ai.onnxruntime._fusion", + _outputs=["attn_output", "key_seq", "value_seq"], + ) + + def rewrite( + self, + op, + total_seq_length, + attn_output, + **_, + ): + # Construct total_seq_length_int32 and seqlens_k + total_seq_length_int32 = op.Cast(total_seq_length, to=ir.DataType.INT32) + one_0D = op.Constant(value_int=1) + one_0D_int32 = op.Cast(one_0D, to=ir.DataType.INT32) + seqlens_k_0D = op.Sub(total_seq_length_int32, one_0D_int32) + zero_1D = op.Constant(value_int=0, dtype=ir.DataType.INT64, shape=[1]) + seqlens_k = op.Unsqueeze(seqlens_k_0D, zero_1D) + + gqa_node = attn_output.producer() + assert len(gqa_node.inputs) == 12, ( + f"Expected 12 inputs for GQA node, got {len(gqa_node.inputs)}" + ) + query, key, value, past_key, past_value = gqa_node.inputs[3:8] + cos, sin = gqa_node.inputs[10:12] + updated_inputs = [ + query, + key, + value, + past_key, + past_value, + seqlens_k, + total_seq_length_int32, + cos, + sin, + ] + attributes = gqa_node.attributes + return op.GroupQueryAttention( + *updated_inputs, **attributes, _domain="com.microsoft", _outputs=3 + ) + + +_basic_gqa_rule = GroupQueryAttention.rule() +_gqa_causal_mask_rule = GQACausalMask.rule() + +gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule, _gqa_causal_mask_rule]) + +fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules) + + class LongRoPeGQACausalMask(pattern.RewriteRuleClassBase): def __init__(self): super().__init__("LongRoPeGQACausalMask", remove_nodes=False) self._mask_cache = {} - + def _get_mask_key(self, attention_mask): """ Generate a unique key for the mask based on input_ids and past_kv_cache. This is used to cache the mask to avoid recomputation. """ return (id(attention_mask)) - - def compute_mask(self, op, attention_mask : _onnx_types.INT64['batch', 'seq_len']): + + def compute_mask(self, op, attention_mask): + """ + Computes the total_seq_length_int32 and seqlens_k_int32 based on the attention_mask, + caching results to avoid recomputation at each layer. + """ mask_key = self._get_mask_key(attention_mask) if mask_key in self._mask_cache: @@ -377,14 +349,14 @@ def compute_mask(self, op, attention_mask : _onnx_types.INT64['batch', 'seq_len' # Construct total_seq_length_int32 and seqlens_k attention_shape = op.Shape(attention_mask, _outputs=["seq_len"]) total_seq_length = op.Gather(attention_shape, op.Constant(value=ir.tensor(1, ir.DataType.INT64)), axis=0, _outputs=["total_seq_length"]) - reduced_attention = op.ReduceSum(attention_mask, op.Constant(value=ir.tensor([1], ir.DataType.INT64)), _outputs=["reduced_attention"]) + reduced_attention = op.ReduceSum(attention_mask, op.Constant(value=ir.tensor([1], ir.DataType.INT64)), _outputs=["reduced_attention"]) sub_reduced_attention = op.Sub(reduced_attention, op.Constant(value=ir.tensor([1], ir.DataType.INT64)), _outputs=["sub_reduced_attention"]) total_seq_length_int32 = op.Cast(total_seq_length, to=ir.DataType.INT32, _outputs=["total_seq_length_int32"]) seqlens_k_int32 = op.Cast(sub_reduced_attention, to=ir.DataType.INT32, _outputs=["seqlens_k_int32"]) self._mask_cache[mask_key] = (total_seq_length_int32, seqlens_k_int32) - + return self._mask_cache[mask_key] - + def pattern( self, @@ -409,9 +381,9 @@ def pattern( current_range_B = op.Range(0, total_seq_len_0D, 1, _outputs=["current_range_B"]) total_seq_len_B = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_B"]) total_seq_len_C = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_C"]) - + total_seq_len_final = op.Reshape(total_seq_len_0D, pattern.ANY_VALUE, allowzero=0, _outputs=["total_seq_len_final"]) - + # EXPAND BRANCH A batch_size = op.Shape(past_kv_cache_2, end=1, start=0, _outputs=["batch_size"]) mask_shape_A = op.Concat(batch_size, [1], seq_len, total_seq_len_A, axis=0, _outputs=["mask_shape_A"]) @@ -424,7 +396,7 @@ def pattern( mask_shape_B_abs = op.Abs(mask_shape_B, _outputs=["mask_shape_B_abs"]) reshaped_range_B = op.Reshape(current_range_B, [1, 1, 1, -1], allowzero=1, _outputs=["reshaped_range_B"]) mask_expanded_B = op.Expand(reshaped_range_B, mask_shape_B_abs, _outputs=["mask_expanded_B"]) - + # EXPAND BRANCH C mask_shape_C = op.Concat(batch_size, [1], seq_len, total_seq_len_C, axis=0, _outputs=["mask_shape_C"]) mask_shape_C_abs = op.Abs(mask_shape_C, _outputs=["mask_shape_C_abs"]) @@ -455,12 +427,12 @@ def pattern( # Propagation to GQA mask_sliced = op.Slice(mask_A_B_C_scaled, [0], pattern.ANY_VALUE, [3], [1], _outputs=["mask_sliced"]) - #mask_where = op.Where(mask_sliced, pattern.ANY_VALUE, pattern.ANY_VALUE, _outputs=["mask_where"]) + gqa_input = pattern.OrValue([mask_sliced, mask_A_B_C_scaled]) return op.GQA( - mask_sliced, + gqa_input, pattern.ANY_VALUE, # position_ids_k - pattern.ANY_VALUE, # position_ids_q + pattern.ANY_VALUE, # position_ids_q pattern.ANY_VALUE, # query pattern.ANY_VALUE, # key pattern.ANY_VALUE, # value @@ -509,9 +481,9 @@ def rewrite( ) _basic_gqa_rule = GroupQueryAttention.rule() +_gqa_causal_mask_rule = GQACausalMask.rule() _longrope_gqa_causal_mask_rule = LongRoPeGQACausalMask.rule() -gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule]) -gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule, _longrope_gqa_causal_mask_rule]) +gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule, _gqa_causal_mask_rule, _longrope_gqa_causal_mask_rule]) -fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules) \ No newline at end of file +fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules) From d4a8c57117f752c84a45ecdc4281864fc5ef38ba Mon Sep 17 00:00:00 2001 From: Tommaso Date: Fri, 1 Aug 2025 16:07:21 +0000 Subject: [PATCH 04/26] Added docstrings to GQA pattern method --- onnxscript/rewriter/ort_fusions/gqa.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 7b509d4840..4a23d78a5c 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -324,6 +324,12 @@ def rewrite( class LongRoPeGQACausalMask(pattern.RewriteRuleClassBase): + """ + LongRoPeGQACausalMask is a specialized version of GQACausalMask that handles + the LongRoPe GQA fusion. It computes the causal mask for Group Query Attention + with LongRoPe (Long Range Rotary Position Embedding) and caches the mask to + avoid recomputation at each layer. + """ def __init__(self): super().__init__("LongRoPeGQACausalMask", remove_nodes=False) self._mask_cache = {} @@ -369,6 +375,12 @@ def pattern( past_seq_length, total_seq_length, ): + """ + Pattern for LongRoPe GQA Causal Mask. + This pattern computes the causal mask for Group Query Attention with LongRoPe. + It constructs the mask based on input_ids and past_kv_cache, and handles the + expansion of the mask across the batch and sequence dimensions. + """ seq_len = op.Shape(input_ids, end=2, start=1, _outputs=["seq_len"]) seq_len_0D = op.Squeeze(seq_len, _outputs=["seq_len_0D"]) past_seq_len = op.Shape(past_kv_cache_1, end=3, start=2, _outputs=["past_seq_len"]) @@ -455,6 +467,11 @@ def rewrite( attn_output, **_, ): + """ + Rewrite the GQA node with the new mask information. + This method computes the total sequence length and seqlens_k based on the + attention_mask and rewrites the GQA node to use these values. + """ # Compute total_seq_length_int32 and seqlens_k_int32 total_seq_length_int32, seqlens_k_int32 = self.compute_mask(op, attention_mask) From 30faab7cff48a49024b364c47fd5d9e56882ea8b Mon Sep 17 00:00:00 2001 From: Tommaso Date: Fri, 1 Aug 2025 16:32:28 +0000 Subject: [PATCH 05/26] Renamed pattern branches to match kv_range, query_range, and batch_range computation --- onnxscript/rewriter/ort_fusions/gqa.py | 103 +++++++++++-------------- 1 file changed, 47 insertions(+), 56 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 4a23d78a5c..4bc05f17de 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -387,73 +387,64 @@ def pattern( past_seq_len_0D = op.Squeeze(past_seq_len, _outputs=["past_seq_len_0D"]) total_seq_len_0D = op.Add(past_seq_len_0D, seq_len_0D, _outputs=["total_seq_len_0D"]) - # All of the Add node's outputs - current_range_A = op.Range(past_seq_len_0D, total_seq_len_0D, 1, _outputs=["current_range_A"]) - total_seq_len_A = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_A"]) - current_range_B = op.Range(0, total_seq_len_0D, 1, _outputs=["current_range_B"]) - total_seq_len_B = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_B"]) - total_seq_len_C = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_C"]) + # Create ranges for different dimensions + kv_range = op.Range(past_seq_len_0D, total_seq_len_0D, 1, _outputs=["kv_range"]) + total_seq_len_for_kv = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_for_kv"]) + query_range = op.Range(0, total_seq_len_0D, 1, _outputs=["query_range"]) + total_seq_len_for_query = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_for_query"]) + total_seq_len_for_batch = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_for_batch"]) - total_seq_len_final = op.Reshape(total_seq_len_0D, pattern.ANY_VALUE, allowzero=0, _outputs=["total_seq_len_final"]) + #total_seq_len_final = op.Reshape(total_seq_len_0D, pattern.ANY_VALUE, allowzero=0, _outputs=["total_seq_len_final"]) - # EXPAND BRANCH A + # BRANCH A: KV Range - Creates tensor with KV positions [1, 1, seq_len, 1] batch_size = op.Shape(past_kv_cache_2, end=1, start=0, _outputs=["batch_size"]) - mask_shape_A = op.Concat(batch_size, [1], seq_len, total_seq_len_A, axis=0, _outputs=["mask_shape_A"]) - mask_shape_A_abs = op.Abs(mask_shape_A, _outputs=["mask_shape_A_abs"]) - reshaped_range_A = op.Reshape(current_range_A, [1, 1, -1, 1], allowzero=1, _outputs=["reshaped_range_A"]) - mask_expanded_A = op.Expand(reshaped_range_A, mask_shape_A_abs, _outputs=["mask_expanded_A"]) - - # EXPAND BRANCH B - mask_shape_B = op.Concat(batch_size, [1], seq_len, total_seq_len_B, axis=0, _outputs=["mask_shape_B"]) - mask_shape_B_abs = op.Abs(mask_shape_B, _outputs=["mask_shape_B_abs"]) - reshaped_range_B = op.Reshape(current_range_B, [1, 1, 1, -1], allowzero=1, _outputs=["reshaped_range_B"]) - mask_expanded_B = op.Expand(reshaped_range_B, mask_shape_B_abs, _outputs=["mask_expanded_B"]) - - # EXPAND BRANCH C - mask_shape_C = op.Concat(batch_size, [1], seq_len, total_seq_len_C, axis=0, _outputs=["mask_shape_C"]) - mask_shape_C_abs = op.Abs(mask_shape_C, _outputs=["mask_shape_C_abs"]) + kv_mask_shape = op.Concat(batch_size, [1], seq_len, total_seq_len_for_kv, axis=0, _outputs=["kv_mask_shape"]) + kv_mask_shape_abs = op.Abs(kv_mask_shape, _outputs=["kv_mask_shape_abs"]) + reshaped_kv_range = op.Reshape(kv_range, [1, 1, -1, 1], allowzero=1, _outputs=["reshaped_kv_range"]) + expanded_kv_range = op.Expand(reshaped_kv_range, kv_mask_shape_abs, _outputs=["expanded_kv_range"]) + + # BRANCH B: Query Range - Creates tensor with query positions [1, 1, 1, total_seq_len] + query_mask_shape = op.Concat(batch_size, [1], seq_len, total_seq_len_for_query, axis=0, _outputs=["query_mask_shape"]) + query_mask_shape_abs = op.Abs(query_mask_shape, _outputs=["query_mask_shape_abs"]) + reshaped_query_range = op.Reshape(query_range, [1, 1, 1, -1], allowzero=1, _outputs=["reshaped_query_range"]) + expanded_query_range = op.Expand(reshaped_query_range, query_mask_shape_abs, _outputs=["expanded_query_range"]) + + # BRANCH C: Batch Range - Creates tensor with batch indices [batch_size, 1, 1, 1] + batch_mask_shape = op.Concat(batch_size, [1], seq_len, total_seq_len_for_batch, axis=0, _outputs=["batch_mask_shape"]) + batch_mask_shape_abs = op.Abs(batch_mask_shape, _outputs=["batch_mask_shape_abs"]) batch_size_squeezed = op.Squeeze(batch_size, _outputs=["batch_size_squeezed"]) batch_range = op.Range(0, batch_size_squeezed, 1, _outputs=["batch_range"]) - reshaped_range_C = op.Reshape(batch_range, [-1, 1, 1, 1], allowzero=1, _outputs=["reshaped_range_C"]) - mask_expanded_C = op.Expand(reshaped_range_C, mask_shape_C_abs, _outputs=["mask_expanded_C"]) - - # EXPAND A/B TO AND - mask_expanded_A_sub = op.Sub(mask_expanded_A, 262144, _outputs=["mask_expanded_A_sub"]) - mask_A_B_greater = op.Greater(mask_expanded_B, mask_expanded_A_sub, _outputs=["mask_A_B_greater"]) - mask_A_B_greater_bitwise = op.And(True, mask_A_B_greater, _outputs=["mask_A_B_greater_bitwise"]) - mask_A_B_less = op.LessOrEqual(mask_expanded_B, mask_expanded_A, _outputs=["mask_A_B_less"]) - mask_A_B_combined = op.And(mask_A_B_greater_bitwise, mask_A_B_less, _outputs=["mask_A_B_combined"]) - mask_A_B_combined_bitwise = op.And(True, mask_A_B_combined, _outputs=["mask_A_B_combined_bitwise"]) - - # EXPAND B/C TO AND - unsqueezed_mask_expanded_B = op.Unsqueeze(mask_expanded_B, [-1], _outputs=["unsqueezed_mask_expanded_B"]) - unsqueezed_mask_expanded_C = op.Unsqueeze(mask_expanded_C, [-1], _outputs=["unsqueezed_mask_expanded_C"]) - mask_B_C_concat = op.Concat(unsqueezed_mask_expanded_C, unsqueezed_mask_expanded_B, axis=-1, _outputs=["mask_B_C_concat"]) + reshaped_batch_range = op.Reshape(batch_range, [-1, 1, 1, 1], allowzero=1, _outputs=["reshaped_batch_range"]) + expanded_batch_range = op.Expand(reshaped_batch_range, batch_mask_shape_abs, _outputs=["expanded_batch_range"]) + + # Combine KV/Query Ranges for Sliding Window Mask + kv_range_offset = op.Sub(expanded_kv_range, 262144, _outputs=["kv_range_offset"]) + query_gt_kv_offset = op.Greater(expanded_query_range, kv_range_offset, _outputs=["query_gt_kv_offset"]) + query_gt_kv_offset_mask = op.And(True, query_gt_kv_offset, _outputs=["query_gt_kv_offset_mask"]) + query_le_kv = op.LessOrEqual(expanded_query_range, expanded_kv_range, _outputs=["query_le_kv"]) + sliding_window_mask = op.And(query_gt_kv_offset_mask, query_le_kv, _outputs=["sliding_window_mask"]) + sliding_window_mask_final = op.And(True, sliding_window_mask, _outputs=["sliding_window_mask_final"]) + + # Combine Query/Batch Ranges for Attention Mask Lookup + unsqueezed_query_range = op.Unsqueeze(expanded_query_range, [-1], _outputs=["unsqueezed_query_range"]) + unsqueezed_batch_range = op.Unsqueeze(expanded_batch_range, [-1], _outputs=["unsqueezed_batch_range"]) + batch_query_indices = op.Concat(unsqueezed_batch_range, unsqueezed_query_range, axis=-1, _outputs=["batch_query_indices"]) attention_mask_bool = op.Cast(attention_mask, to=ir.DataType.BOOL, _outputs=["attention_mask_bool"]) - mask_gatherND = op.GatherND(attention_mask_bool, mask_B_C_concat, batch_dims=0, _outputs=["mask_gatherND"]) - - mask_A_B_C_combined = op.And(mask_A_B_combined_bitwise, mask_gatherND, _outputs=["mask_A_B_C_combined"]) - mask_A_B_C_negated = op.Not(mask_A_B_C_combined, _outputs=["mask_A_B_C_negated"]) - mask_A_B_C_fp32 = op.Cast(mask_A_B_C_negated, to=ir.DataType.FLOAT, _outputs=["mask_A_B_C_fp32"]) - mask_A_B_C_scaled = op.Mul(mask_A_B_C_fp32, pattern.ANY_VALUE) + attention_lookup = op.GatherND(attention_mask_bool, batch_query_indices, batch_dims=0, _outputs=["attention_lookup"]) + + # Final Mask Combination + final_attention_mask = op.And(sliding_window_mask_final, attention_lookup, _outputs=["final_attention_mask"]) + inverted_mask = op.Not(final_attention_mask, _outputs=["inverted_mask"]) + mask_fp32 = op.Cast(inverted_mask, to=ir.DataType.FLOAT, _outputs=["mask_fp32"]) + scaled_mask = op.Mul(mask_fp32, pattern.ANY_VALUE) + # Propagation to GQA - mask_sliced = op.Slice(mask_A_B_C_scaled, [0], pattern.ANY_VALUE, [3], [1], _outputs=["mask_sliced"]) + sliced_mask = op.Slice(scaled_mask, [0], pattern.ANY_VALUE, [3], [1], _outputs=["sliced_mask"]) - gqa_input = pattern.OrValue([mask_sliced, mask_A_B_C_scaled]) + gqa_input = pattern.OrValue([sliced_mask, scaled_mask]) return op.GQA( gqa_input, - pattern.ANY_VALUE, # position_ids_k - pattern.ANY_VALUE, # position_ids_q - pattern.ANY_VALUE, # query - pattern.ANY_VALUE, # key - pattern.ANY_VALUE, # value - pattern.ANY_VALUE, # past_key - pattern.ANY_VALUE, # past_value - pattern.ANY_VALUE, # seqlens_k (optional) - pattern.ANY_VALUE, # total_seq_length (optional) - pattern.ANY_VALUE, # cos - pattern.ANY_VALUE, # sin _allow_other_inputs=True, _domain="ai.onnxruntime._fusion", _outputs=["attn_output", "key_seq", "value_seq"], From 912a80b308bdeb64eee66f4ba8bd630b0c0bc1ab Mon Sep 17 00:00:00 2001 From: Tommaso Date: Fri, 1 Aug 2025 16:33:04 +0000 Subject: [PATCH 06/26] Removed unecessary pattern variable --- onnxscript/rewriter/ort_fusions/gqa.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 4bc05f17de..04edb3c74e 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -394,8 +394,6 @@ def pattern( total_seq_len_for_query = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_for_query"]) total_seq_len_for_batch = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_for_batch"]) - #total_seq_len_final = op.Reshape(total_seq_len_0D, pattern.ANY_VALUE, allowzero=0, _outputs=["total_seq_len_final"]) - # BRANCH A: KV Range - Creates tensor with KV positions [1, 1, seq_len, 1] batch_size = op.Shape(past_kv_cache_2, end=1, start=0, _outputs=["batch_size"]) kv_mask_shape = op.Concat(batch_size, [1], seq_len, total_seq_len_for_kv, axis=0, _outputs=["kv_mask_shape"]) @@ -437,7 +435,7 @@ def pattern( inverted_mask = op.Not(final_attention_mask, _outputs=["inverted_mask"]) mask_fp32 = op.Cast(inverted_mask, to=ir.DataType.FLOAT, _outputs=["mask_fp32"]) scaled_mask = op.Mul(mask_fp32, pattern.ANY_VALUE) - + # Propagation to GQA sliced_mask = op.Slice(scaled_mask, [0], pattern.ANY_VALUE, [3], [1], _outputs=["sliced_mask"]) From fd957192b56cd68172d797f3081a031b24bf95fb Mon Sep 17 00:00:00 2001 From: Tommaso Date: Fri, 1 Aug 2025 16:37:07 +0000 Subject: [PATCH 07/26] Added snake casing for variable names --- onnxscript/rewriter/ort_fusions/gqa.py | 32 +++++++++++++------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 04edb3c74e..2c704188ad 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -34,25 +34,25 @@ def causal_mask_pattern(op, input_ids, past_kv_cache, shape_B111): seq_len = op.Shape(input_ids, end=2, start=1) - seq_len_0D = op.Squeeze(seq_len) + seq_len_0d = op.Squeeze(seq_len) past_seq_len = op.Shape(past_kv_cache, end=3, start=2) - past_seq_len_0D = op.Squeeze(past_seq_len) + past_seq_len_0d = op.Squeeze(past_seq_len) - total_seq_len_0D = op.Add(past_seq_len_0D, seq_len_0D) - total_seq_len = op.Reshape(total_seq_len_0D, [-1]) + total_seq_len_0d = op.Add(past_seq_len_0d, seq_len_0d) + total_seq_len = op.Reshape(total_seq_len_0d, [-1]) # The Phi modeling code generates the following +1 as the target-length, which seems # unnecessary in this context. But using it for pattern-matching against # generated onnx model. - total_seq_len_plus_1_0D = op.Add(total_seq_len_0D, 1) - total_seq_len_plus_1 = op.Reshape(total_seq_len_plus_1_0D, [-1]) + total_seq_len_plus_1_0d = op.Add(total_seq_len_0d, 1) + total_seq_len_plus_1 = op.Reshape(total_seq_len_plus_1_0d, [-1]) - current_range = op.Range(past_seq_len_0D, total_seq_len_0D, 1) + current_range = op.Range(past_seq_len_0d, total_seq_len_0d, 1) mask_shape = op.Concat(seq_len, total_seq_len_plus_1, axis=0) min_float32 = float(np.finfo(np.float32).min) mask_all_min = op.Expand(min_float32, mask_shape) - total_range_as_row = op.Range(0, total_seq_len_plus_1_0D, 1) + total_range_as_row = op.Range(0, total_seq_len_plus_1_0d, 1) current_range_as_column = op.Reshape(current_range, [-1, 1]) boolean_mask = op.Greater(total_range_as_row, current_range_as_column) float_0_1_mask = op.Cast(boolean_mask, to=1) @@ -382,17 +382,17 @@ def pattern( expansion of the mask across the batch and sequence dimensions. """ seq_len = op.Shape(input_ids, end=2, start=1, _outputs=["seq_len"]) - seq_len_0D = op.Squeeze(seq_len, _outputs=["seq_len_0D"]) + seq_len_0d = op.Squeeze(seq_len, _outputs=["seq_len_0d"]) past_seq_len = op.Shape(past_kv_cache_1, end=3, start=2, _outputs=["past_seq_len"]) - past_seq_len_0D = op.Squeeze(past_seq_len, _outputs=["past_seq_len_0D"]) - total_seq_len_0D = op.Add(past_seq_len_0D, seq_len_0D, _outputs=["total_seq_len_0D"]) + past_seq_len_0d = op.Squeeze(past_seq_len, _outputs=["past_seq_len_0d"]) + total_seq_len_0d = op.Add(past_seq_len_0d, seq_len_0d, _outputs=["total_seq_len_0d"]) # Create ranges for different dimensions - kv_range = op.Range(past_seq_len_0D, total_seq_len_0D, 1, _outputs=["kv_range"]) - total_seq_len_for_kv = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_for_kv"]) - query_range = op.Range(0, total_seq_len_0D, 1, _outputs=["query_range"]) - total_seq_len_for_query = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_for_query"]) - total_seq_len_for_batch = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_for_batch"]) + kv_range = op.Range(past_seq_len_0d, total_seq_len_0d, 1, _outputs=["kv_range"]) + total_seq_len_for_kv = op.Reshape(total_seq_len_0d, [-1], allowzero=0, _outputs=["total_seq_len_for_kv"]) + query_range = op.Range(0, total_seq_len_0d, 1, _outputs=["query_range"]) + total_seq_len_for_query = op.Reshape(total_seq_len_0d, [-1], allowzero=0, _outputs=["total_seq_len_for_query"]) + total_seq_len_for_batch = op.Reshape(total_seq_len_0d, [-1], allowzero=0, _outputs=["total_seq_len_for_batch"]) # BRANCH A: KV Range - Creates tensor with KV positions [1, 1, seq_len, 1] batch_size = op.Shape(past_kv_cache_2, end=1, start=0, _outputs=["batch_size"]) From 19d26568dad50b178f6f7d181a861b91ed25cc3e Mon Sep 17 00:00:00 2001 From: Tommaso Date: Fri, 1 Aug 2025 16:37:36 +0000 Subject: [PATCH 08/26] Added more snake casing and removed uneeded code --- onnxscript/rewriter/ort_fusions/gqa.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 2c704188ad..6a4069984c 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -286,11 +286,11 @@ def rewrite( ): # Construct total_seq_length_int32 and seqlens_k total_seq_length_int32 = op.Cast(total_seq_length, to=ir.DataType.INT32) - one_0D = op.Constant(value_int=1) - one_0D_int32 = op.Cast(one_0D, to=ir.DataType.INT32) - seqlens_k_0D = op.Sub(total_seq_length_int32, one_0D_int32) - zero_1D = op.Constant(value_int=0, dtype=ir.DataType.INT64, shape=[1]) - seqlens_k = op.Unsqueeze(seqlens_k_0D, zero_1D) + one_0d = op.Constant(value_int=1) + one_0d_int32 = op.Cast(one_0d, to=ir.DataType.INT32) + seqlens_k_0d = op.Sub(total_seq_length_int32, one_0d_int32) + zero_1d = op.Constant(value_int=0, dtype=ir.DataType.INT64, shape=[1]) + seqlens_k = op.Unsqueeze(seqlens_k_0d, zero_1d) gqa_node = attn_output.producer() assert len(gqa_node.inputs) == 12, ( @@ -314,15 +314,6 @@ def rewrite( *updated_inputs, **attributes, _domain="com.microsoft", _outputs=3 ) - -_basic_gqa_rule = GroupQueryAttention.rule() -_gqa_causal_mask_rule = GQACausalMask.rule() - -gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule, _gqa_causal_mask_rule]) - -fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules) - - class LongRoPeGQACausalMask(pattern.RewriteRuleClassBase): """ LongRoPeGQACausalMask is a specialized version of GQACausalMask that handles From 0742db228ceca66afcd383c7e3943e8f3fb39306 Mon Sep 17 00:00:00 2001 From: Tommaso Date: Fri, 1 Aug 2025 17:06:06 +0000 Subject: [PATCH 09/26] Moved get_mask_key method to module level and used IR value directly --- onnxscript/rewriter/ort_fusions/gqa.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 6a4069984c..90db62d24b 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -31,7 +31,6 @@ Dim = Union[int, ir.SymbolicDim] - def causal_mask_pattern(op, input_ids, past_kv_cache, shape_B111): seq_len = op.Shape(input_ids, end=2, start=1) seq_len_0d = op.Squeeze(seq_len) @@ -314,6 +313,13 @@ def rewrite( *updated_inputs, **attributes, _domain="com.microsoft", _outputs=3 ) +def _get_mask_key(attention_mask): + """ + Generate a unique key for the mask based on input_ids and past_kv_cache. + This is used to cache the mask to avoid recomputation. + """ + return attention_mask + class LongRoPeGQACausalMask(pattern.RewriteRuleClassBase): """ LongRoPeGQACausalMask is a specialized version of GQACausalMask that handles @@ -325,19 +331,12 @@ def __init__(self): super().__init__("LongRoPeGQACausalMask", remove_nodes=False) self._mask_cache = {} - def _get_mask_key(self, attention_mask): - """ - Generate a unique key for the mask based on input_ids and past_kv_cache. - This is used to cache the mask to avoid recomputation. - """ - return (id(attention_mask)) - def compute_mask(self, op, attention_mask): """ Computes the total_seq_length_int32 and seqlens_k_int32 based on the attention_mask, caching results to avoid recomputation at each layer. """ - mask_key = self._get_mask_key(attention_mask) + mask_key = _get_mask_key(attention_mask) if mask_key in self._mask_cache: total_seq_length_int32, seqlens_k_int32 = self._mask_cache[mask_key] From 2772f77aa17285e7ad407887dc386c4fa6bc0941 Mon Sep 17 00:00:00 2001 From: Tommaso Date: Fri, 1 Aug 2025 17:09:32 +0000 Subject: [PATCH 10/26] Added cleanup method for the attention mask cache --- onnxscript/rewriter/ort_fusions/gqa.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 90db62d24b..c25ee56799 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -331,6 +331,9 @@ def __init__(self): super().__init__("LongRoPeGQACausalMask", remove_nodes=False) self._mask_cache = {} + def cleanup(self): + self._mask_cache.clear() + def compute_mask(self, op, attention_mask): """ Computes the total_seq_length_int32 and seqlens_k_int32 based on the attention_mask, From 87a0464405dce730240909feeeff1cdb1cbd34b1 Mon Sep 17 00:00:00 2001 From: Tommaso Date: Fri, 1 Aug 2025 17:25:53 +0000 Subject: [PATCH 11/26] Added LongRoPE GQA Causal Mask Fusion Separately --- .../rewriter/ort_fusions/longrope_gqa.py | 485 ++++++++++++++++++ 1 file changed, 485 insertions(+) create mode 100644 onnxscript/rewriter/ort_fusions/longrope_gqa.py diff --git a/onnxscript/rewriter/ort_fusions/longrope_gqa.py b/onnxscript/rewriter/ort_fusions/longrope_gqa.py new file mode 100644 index 0000000000..effc023291 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/longrope_gqa.py @@ -0,0 +1,485 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from typing import Sequence, Union + +import numpy as np +import onnx_ir as ir + +import onnxscript.rewriter._fusion_utils as _fusion_utils +from onnxscript.rewriter import _ir_utils, pattern + +""" +GroupQueryAttention: This generalizes MHA by allowing the number of heads to be different +for query and key/value. + +We use the following abbreviations for the dimensions: +B: Batch size +S: Sequence length (for current query/key/value) + +Hkv: number of heads for key/value +G = number of groups +H: number of heads = G * Hkv + +Dh: head size or embedding dimension per head +D: input embedding dimension (hidden size) = H * Dh +Dkv: key/value hidden size = Hkv * Dh + +T: total sequence length (after concatenation of past and current key/value) +""" + +Dim = Union[int, ir.SymbolicDim] + +def causal_mask_pattern(op, input_ids, past_kv_cache, shape_B111): + seq_len = op.Shape(input_ids, end=2, start=1) + seq_len_0d = op.Squeeze(seq_len) + + past_seq_len = op.Shape(past_kv_cache, end=3, start=2) + past_seq_len_0d = op.Squeeze(past_seq_len) + + total_seq_len_0d = op.Add(past_seq_len_0d, seq_len_0d) + total_seq_len = op.Reshape(total_seq_len_0d, [-1]) + + # The Phi modeling code generates the following +1 as the target-length, which seems + # unnecessary in this context. But using it for pattern-matching against + # generated onnx model. + total_seq_len_plus_1_0d = op.Add(total_seq_len_0d, 1) + total_seq_len_plus_1 = op.Reshape(total_seq_len_plus_1_0d, [-1]) + + current_range = op.Range(past_seq_len_0d, total_seq_len_0d, 1) + mask_shape = op.Concat(seq_len, total_seq_len_plus_1, axis=0) + min_float32 = float(np.finfo(np.float32).min) + mask_all_min = op.Expand(min_float32, mask_shape) + total_range_as_row = op.Range(0, total_seq_len_plus_1_0d, 1) + current_range_as_column = op.Reshape(current_range, [-1, 1]) + boolean_mask = op.Greater(total_range_as_row, current_range_as_column) + float_0_1_mask = op.Cast(boolean_mask, to=1) + float_0_min_mask = op.Mul(mask_all_min, float_0_1_mask) + mask_4d = op.Unsqueeze(float_0_min_mask, [0, 1]) + mask_B1ST_plus = op.Expand(mask_4d, shape_B111) + + # Get rid of the extra +1 added above: total_seq_len is enough, no + # need for total_seq_len+1. + mask_B1ST = op.Slice(mask_B1ST_plus, [0], total_seq_len, [3], [1]) + return mask_B1ST + + +class GroupQueryAttention(pattern.RewriteRuleClassBase): + def __init__(self): + super().__init__("GQA", remove_nodes=False) + + def pattern( + self, + op, + query_BSD, + key_BSDkv, + value_BSDkv, + past_key, + past_value, + position_ids_q, + position_ids_k, + cos, + sin, + mask, + ): + # Reshape query from (B, S, D) to (B, S, H, D/H) + query_BSHDh = op.Reshape(query_BSD, pattern.ANY_VALUE, _outputs=["query_BSHDh"]) + # Transpose from (B, S, H, D/H) to (B, H, S, D/H) + query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) + + # Reshape key from (B, S, Dkv) to (B, S, Hkv, D/H) + key_BSHkvDh = op.Reshape(key_BSDkv, pattern.ANY_VALUE, _outputs=["key_BSHkvDh"]) + # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H) + key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) + + # Reshape value from (B, S, Dkv) to (B, S, Hkv, D/H) + value_BSHkvDh = op.Reshape(value_BSDkv, pattern.ANY_VALUE, _outputs=["value_BSHkvDh"]) + # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H) + value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) + + query_BHSDh_rope = op.RotaryEmbedding( + query_BHSDh, + position_ids_q, + cos, + sin, + _domain="com.microsoft", + _outputs=["query_BHSDh_rope"], + ) + key_BHkvSDh_rope = op.RotaryEmbedding( + key_BHkvSDh, + position_ids_k, + cos, + sin, + _domain="com.microsoft", + _outputs=["key_BHkvSDh_rope"], + ) + + # Concatenate past_key cache and current key, expand across heads + # that share key/value. + + key_seq_BHkvTDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) + key_seq_BHkv1TDh = op.Unsqueeze(key_seq_BHkvTDh, 2) + key_seq_BHkvGTDh = op.Expand(key_seq_BHkv1TDh, pattern.ANY_VALUE) + key_seq_BHTDh = op.Reshape( + key_seq_BHkvGTDh, pattern.ANY_VALUE, _outputs=["key_seq_BHTDh"] + ) + + # Concatenate past_value cache and current value, expand across heads + # that share key/value. + value_seq_BHkvTDh = op.Concat(past_value, value_BHkvSDh, axis=-2) + value_seq_BHkv1TDh = op.Unsqueeze(value_seq_BHkvTDh, 2) + value_seq_BHkvGTDh = op.Expand(value_seq_BHkv1TDh, pattern.ANY_VALUE) + value_seq_BHTDh = op.Reshape( + value_seq_BHkvGTDh, pattern.ANY_VALUE, _outputs=["value_seq_BHTDh"] + ) + + attention_BHSDh = op.SDPA( + query_BHSDh_rope, + key_seq_BHTDh, + value_seq_BHTDh, + mask, + key_format="BHSd", + _domain="ai.onnxruntime._fusion", + ) + + # Transpose attention back to (B, S, H, D/H) + attention_BSHDh = op.Transpose(attention_BHSDh, perm=[0, 2, 1, 3]) + # Reshape back to (B, S, D) + attention_BSD = op.Reshape( + attention_BSHDh, pattern.ANY_VALUE, _outputs=["attention_BSD"] + ) + return attention_BSD, key_seq_BHkvTDh, value_seq_BHkvTDh + + def check( + self, + op, + query_BSD, + key_BSDkv, + value_BSDkv, + past_key, + past_value, + query_BHSDh_rope, + key_BHkvSDh_rope, + query_BSHDh, + key_BSHkvDh, + **_, + ): + bindings: dict[str, Dim] = {} + + def no_match(val: ir.Value, dims: Sequence[str]) -> bool: + return not _fusion_utils._check_shape(bindings, val, dims) + + if no_match(query_BSD, ["B", "S", "D"]): + return False + if no_match(key_BSDkv, ["B", "S", "Dkv"]): + return False + if no_match(value_BSDkv, ["B", "S", "Dkv"]): + return False + + if no_match(past_key, ["B", "Hkv", "P", "Dh"]): + return False + if no_match(past_value, ["B", "Hkv", "P", "Dv"]): + return False + + # TODO: verify Reshapes: + # eg.: verify bindings["B"] * bindings["H"] == bindings["B*H"]: + # and bindings["H"] * bindings["Dh"] == bindings["H*Dh"]: + # or check Reshape's shape-input value + + result = pattern.MatchResult() + num_heads = _ir_utils.get_dim(query_BSHDh, 2) + kv_num_heads = _ir_utils.get_dim(key_BSHkvDh, 2) + if not isinstance(num_heads, int): + return result.fail("Unable to determine num_heads value", query_BSHDh) + if not isinstance(kv_num_heads, int): + return result.fail("Unable to determine kv_num_heads value", key_BSHkvDh) + self.num_heads = num_heads + self.kv_num_heads = kv_num_heads + + # Rotary embedding attributes + query_rotary_attributes = query_BHSDh_rope.producer().attributes + key_rotary_attributes = key_BHkvSDh_rope.producer().attributes + query_interleaved = query_rotary_attributes.get_int("interleaved", 0) + key_interleaved = key_rotary_attributes.get_int("interleaved", 0) + if query_interleaved != key_interleaved: + return pattern.MatchResult().fail( + "Rotary embedding interleaved attribute mismatch", + [query_BHSDh_rope.producer(), key_BHkvSDh_rope.producer()], + ) + self._interleaved = query_interleaved + + return True + + def rewrite( + self, + op, + query_BSD, + key_BSDkv, + value_BSDkv, + past_key, + past_value, + position_ids_q, + position_ids_k, + cos, + sin, + mask, + **_, + ): + return op.GQA( + mask, + position_ids_k, + position_ids_q, + query_BSD, + key_BSDkv, + value_BSDkv, + past_key, + past_value, + None, # seqlens_k, + None, # total_seq_length_int32, + cos, + sin, + num_heads=self.num_heads, + kv_num_heads=self.kv_num_heads, + do_rotary=1, + rotary_interleaved=self._interleaved, + # skipped optional attributes: local_window_size, scale, smooth_softmax, softcap + _domain="ai.onnxruntime._fusion", + _outputs=3, + ) + + +class GQACausalMask(pattern.RewriteRuleClassBase): + def __init__(self): + super().__init__("GQACausalMask", remove_nodes=False) + + def pattern( + self, + op, + mask, + input_ids, + some_kv_cache, + shape_B111, + past_seq_length, + total_seq_length, + ): + mask = causal_mask_pattern(op, input_ids, some_kv_cache, shape_B111) + position_ids = op.Range(past_seq_length, total_seq_length, 1) + position_ids_q = op.Unsqueeze(position_ids, [0]) + position_ids_k = op.Unsqueeze(position_ids, [0]) + return op.GQA( + mask, + position_ids_k, + position_ids_q, + _allow_other_inputs=True, + _domain="ai.onnxruntime._fusion", + _outputs=["attn_output", "key_seq", "value_seq"], + ) + + def rewrite( + self, + op, + total_seq_length, + attn_output, + **_, + ): + # Construct total_seq_length_int32 and seqlens_k + total_seq_length_int32 = op.Cast(total_seq_length, to=ir.DataType.INT32) + one_0d = op.Constant(value_int=1) + one_0d_int32 = op.Cast(one_0d, to=ir.DataType.INT32) + seqlens_k_0d = op.Sub(total_seq_length_int32, one_0d_int32) + zero_1d = op.Constant(value_int=0, dtype=ir.DataType.INT64, shape=[1]) + seqlens_k = op.Unsqueeze(seqlens_k_0d, zero_1d) + + gqa_node = attn_output.producer() + assert len(gqa_node.inputs) == 12, ( + f"Expected 12 inputs for GQA node, got {len(gqa_node.inputs)}" + ) + query, key, value, past_key, past_value = gqa_node.inputs[3:8] + cos, sin = gqa_node.inputs[10:12] + updated_inputs = [ + query, + key, + value, + past_key, + past_value, + seqlens_k, + total_seq_length_int32, + cos, + sin, + ] + attributes = gqa_node.attributes + return op.GroupQueryAttention( + *updated_inputs, **attributes, _domain="com.microsoft", _outputs=3 + ) + +def _get_mask_key(attention_mask): + """ + Generate a unique key for the mask based on input_ids and past_kv_cache. + This is used to cache the mask to avoid recomputation. + """ + return attention_mask + +class LongRoPeGQACausalMask(pattern.RewriteRuleClassBase): + """ + LongRoPeGQACausalMask is a specialized version of GQACausalMask that handles + the LongRoPe GQA fusion. It computes the causal mask for Group Query Attention + with LongRoPe (Long Range Rotary Position Embedding) and caches the mask to + avoid recomputation at each layer. + """ + def __init__(self): + super().__init__("LongRoPeGQACausalMask", remove_nodes=False) + self._mask_cache = {} + + def cleanup(self): + self._mask_cache.clear() + + def compute_mask(self, op, attention_mask): + """ + Computes the total_seq_length_int32 and seqlens_k_int32 based on the attention_mask, + caching results to avoid recomputation at each layer. + """ + mask_key = _get_mask_key(attention_mask) + + if mask_key in self._mask_cache: + total_seq_length_int32, seqlens_k_int32 = self._mask_cache[mask_key] + + else: + # Construct total_seq_length_int32 and seqlens_k + attention_shape = op.Shape(attention_mask, _outputs=["seq_len"]) + total_seq_length = op.Gather(attention_shape, op.Constant(value=ir.tensor(1, ir.DataType.INT64)), axis=0, _outputs=["total_seq_length"]) + reduced_attention = op.ReduceSum(attention_mask, op.Constant(value=ir.tensor([1], ir.DataType.INT64)), _outputs=["reduced_attention"]) + sub_reduced_attention = op.Sub(reduced_attention, op.Constant(value=ir.tensor([1], ir.DataType.INT64)), _outputs=["sub_reduced_attention"]) + total_seq_length_int32 = op.Cast(total_seq_length, to=ir.DataType.INT32, _outputs=["total_seq_length_int32"]) + seqlens_k_int32 = op.Cast(sub_reduced_attention, to=ir.DataType.INT32, _outputs=["seqlens_k_int32"]) + self._mask_cache[mask_key] = (total_seq_length_int32, seqlens_k_int32) + + return self._mask_cache[mask_key] + + + def pattern( + self, + op, + input_ids, + past_kv_cache_1, + past_kv_cache_2, + attention_mask, + ): + """ + Pattern for LongRoPe GQA Causal Mask. + This pattern computes the causal mask for Group Query Attention with LongRoPe. + It constructs the mask based on input_ids and past_kv_cache, and handles the + expansion of the mask across the batch and sequence dimensions. + """ + seq_len = op.Shape(input_ids, end=2, start=1, _outputs=["seq_len"]) + seq_len_0d = op.Squeeze(seq_len, _outputs=["seq_len_0d"]) + past_seq_len = op.Shape(past_kv_cache_1, end=3, start=2, _outputs=["past_seq_len"]) + past_seq_len_0d = op.Squeeze(past_seq_len, _outputs=["past_seq_len_0d"]) + total_seq_len_0d = op.Add(past_seq_len_0d, seq_len_0d, _outputs=["total_seq_len_0d"]) + + # Create ranges for different dimensions + kv_range = op.Range(past_seq_len_0d, total_seq_len_0d, 1, _outputs=["kv_range"]) + total_seq_len_for_kv = op.Reshape(total_seq_len_0d, [-1], allowzero=0, _outputs=["total_seq_len_for_kv"]) + query_range = op.Range(0, total_seq_len_0d, 1, _outputs=["query_range"]) + total_seq_len_for_query = op.Reshape(total_seq_len_0d, [-1], allowzero=0, _outputs=["total_seq_len_for_query"]) + total_seq_len_for_batch = op.Reshape(total_seq_len_0d, [-1], allowzero=0, _outputs=["total_seq_len_for_batch"]) + + # BRANCH A: KV Range - Creates tensor with KV positions [1, 1, seq_len, 1] + batch_size = op.Shape(past_kv_cache_2, end=1, start=0, _outputs=["batch_size"]) + kv_mask_shape = op.Concat(batch_size, [1], seq_len, total_seq_len_for_kv, axis=0, _outputs=["kv_mask_shape"]) + kv_mask_shape_abs = op.Abs(kv_mask_shape, _outputs=["kv_mask_shape_abs"]) + reshaped_kv_range = op.Reshape(kv_range, [1, 1, -1, 1], allowzero=1, _outputs=["reshaped_kv_range"]) + expanded_kv_range = op.Expand(reshaped_kv_range, kv_mask_shape_abs, _outputs=["expanded_kv_range"]) + + # BRANCH B: Query Range - Creates tensor with query positions [1, 1, 1, total_seq_len] + query_mask_shape = op.Concat(batch_size, [1], seq_len, total_seq_len_for_query, axis=0, _outputs=["query_mask_shape"]) + query_mask_shape_abs = op.Abs(query_mask_shape, _outputs=["query_mask_shape_abs"]) + reshaped_query_range = op.Reshape(query_range, [1, 1, 1, -1], allowzero=1, _outputs=["reshaped_query_range"]) + expanded_query_range = op.Expand(reshaped_query_range, query_mask_shape_abs, _outputs=["expanded_query_range"]) + + # BRANCH C: Batch Range - Creates tensor with batch indices [batch_size, 1, 1, 1] + batch_mask_shape = op.Concat(batch_size, [1], seq_len, total_seq_len_for_batch, axis=0, _outputs=["batch_mask_shape"]) + batch_mask_shape_abs = op.Abs(batch_mask_shape, _outputs=["batch_mask_shape_abs"]) + batch_size_squeezed = op.Squeeze(batch_size, _outputs=["batch_size_squeezed"]) + batch_range = op.Range(0, batch_size_squeezed, 1, _outputs=["batch_range"]) + reshaped_batch_range = op.Reshape(batch_range, [-1, 1, 1, 1], allowzero=1, _outputs=["reshaped_batch_range"]) + expanded_batch_range = op.Expand(reshaped_batch_range, batch_mask_shape_abs, _outputs=["expanded_batch_range"]) + + # Combine KV/Query Ranges for Sliding Window Mask + kv_range_offset = op.Sub(expanded_kv_range, 262144, _outputs=["kv_range_offset"]) + query_gt_kv_offset = op.Greater(expanded_query_range, kv_range_offset, _outputs=["query_gt_kv_offset"]) + query_gt_kv_offset_mask = op.And(True, query_gt_kv_offset, _outputs=["query_gt_kv_offset_mask"]) + query_le_kv = op.LessOrEqual(expanded_query_range, expanded_kv_range, _outputs=["query_le_kv"]) + sliding_window_mask = op.And(query_gt_kv_offset_mask, query_le_kv, _outputs=["sliding_window_mask"]) + sliding_window_mask_final = op.And(True, sliding_window_mask, _outputs=["sliding_window_mask_final"]) + + # Combine Query/Batch Ranges for Attention Mask Lookup + unsqueezed_query_range = op.Unsqueeze(expanded_query_range, [-1], _outputs=["unsqueezed_query_range"]) + unsqueezed_batch_range = op.Unsqueeze(expanded_batch_range, [-1], _outputs=["unsqueezed_batch_range"]) + batch_query_indices = op.Concat(unsqueezed_batch_range, unsqueezed_query_range, axis=-1, _outputs=["batch_query_indices"]) + attention_mask_bool = op.Cast(attention_mask, to=ir.DataType.BOOL, _outputs=["attention_mask_bool"]) + attention_lookup = op.GatherND(attention_mask_bool, batch_query_indices, batch_dims=0, _outputs=["attention_lookup"]) + + # Final Mask Combination + final_attention_mask = op.And(sliding_window_mask_final, attention_lookup, _outputs=["final_attention_mask"]) + inverted_mask = op.Not(final_attention_mask, _outputs=["inverted_mask"]) + mask_fp32 = op.Cast(inverted_mask, to=ir.DataType.FLOAT, _outputs=["mask_fp32"]) + scaled_mask = op.Mul(mask_fp32, pattern.ANY_VALUE) + + # Propagation to GQA + sliced_mask = op.Slice(scaled_mask, [0], pattern.ANY_VALUE, [3], [1], _outputs=["sliced_mask"]) + + gqa_input = pattern.OrValue([sliced_mask, scaled_mask]) + + return op.GQA( + gqa_input, + _allow_other_inputs=True, + _domain="ai.onnxruntime._fusion", + _outputs=["attn_output", "key_seq", "value_seq"], + ) + + + def rewrite( + self, + op, + attention_mask, + attn_output, + **_, + ): + """ + Rewrite the GQA node with the new mask information. + This method computes the total sequence length and seqlens_k based on the + attention_mask and rewrites the GQA node to use these values. + """ + # Compute total_seq_length_int32 and seqlens_k_int32 + total_seq_length_int32, seqlens_k_int32 = self.compute_mask(op, attention_mask) + + gqa_node = attn_output.producer() + assert len(gqa_node.inputs) == 12, ( + f"Expected 12 inputs for GQA node, got {len(gqa_node.inputs)}" + ) + query, key, value, past_key, past_value = gqa_node.inputs[3:8] + cos, sin = gqa_node.inputs[10:12] + updated_inputs = [ + query, + key, + value, + past_key, + past_value, + seqlens_k_int32, + total_seq_length_int32, + cos, + sin, + ] + attributes = gqa_node.attributes + return op.GroupQueryAttention( + *updated_inputs, **attributes, _domain="com.microsoft", _outputs=3 + ) + +_basic_gqa_rule = GroupQueryAttention.rule() +_gqa_causal_mask_rule = GQACausalMask.rule() +_longrope_gqa_causal_mask_rule = LongRoPeGQACausalMask.rule() + +gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule, _gqa_causal_mask_rule, _longrope_gqa_causal_mask_rule]) + +fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules) From f12630cbbff6afcd57ee785355538a9194c565e8 Mon Sep 17 00:00:00 2001 From: Tommaso Date: Fri, 1 Aug 2025 16:00:59 +0000 Subject: [PATCH 12/26] Removed whitespace from gqa longrope fusion --- onnxscript/rewriter/ort_fusions/gqa.py | 282 +++++++++++-------------- 1 file changed, 127 insertions(+), 155 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 2b7c314b3e..7b509d4840 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -7,9 +7,8 @@ import numpy as np import onnx_ir as ir -import onnxscript.onnx_types as _onnx_types import onnxscript.rewriter._fusion_utils as _fusion_utils -from onnxscript.rewriter import _basics, _ir_utils, pattern +from onnxscript.rewriter import _ir_utils, pattern """ GroupQueryAttention: This generalizes MHA by allowing the number of heads to be different @@ -33,20 +32,7 @@ Dim = Union[int, ir.SymbolicDim] -def _is_model_input(value: ir.Value, name: str, model: ir.Model) -> bool: - return value in model.graph.inputs and value.name == name - - -def _causal_mask( - op, - input_ids, - past_kv_cache, - shape_B111, - min_val, - window_size, - dtype, -): - """Defines a pattern for a pure causal mask, with optional sliding window support.""" +def causal_mask_pattern(op, input_ids, past_kv_cache, shape_B111): seq_len = op.Shape(input_ids, end=2, start=1) seq_len_0D = op.Squeeze(seq_len) @@ -56,93 +42,28 @@ def _causal_mask( total_seq_len_0D = op.Add(past_seq_len_0D, seq_len_0D) total_seq_len = op.Reshape(total_seq_len_0D, [-1]) + # The Phi modeling code generates the following +1 as the target-length, which seems + # unnecessary in this context. But using it for pattern-matching against + # generated onnx model. + total_seq_len_plus_1_0D = op.Add(total_seq_len_0D, 1) + total_seq_len_plus_1 = op.Reshape(total_seq_len_plus_1_0D, [-1]) + current_range = op.Range(past_seq_len_0D, total_seq_len_0D, 1) - mask_shape = op.Concat(seq_len, total_seq_len, axis=0) - mask_all_min_expand = op.Expand(min_val, mask_shape) - # The following Trilu is optional: not used in Phi models, but used in LLama. - mask_all_min_trilu = op.Trilu(mask_all_min_expand, 1, upper=1) - mask_all_min = pattern.OrValue([mask_all_min_expand, mask_all_min_trilu]) - total_range_as_row = op.Range(0, total_seq_len_0D, 1) + mask_shape = op.Concat(seq_len, total_seq_len_plus_1, axis=0) + min_float32 = float(np.finfo(np.float32).min) + mask_all_min = op.Expand(min_float32, mask_shape) + total_range_as_row = op.Range(0, total_seq_len_plus_1_0D, 1) current_range_as_column = op.Reshape(current_range, [-1, 1]) - - non_causal = op.Greater(total_range_as_row, current_range_as_column) - - # sliding window support: - current_range_minus_window = op.Sub(current_range_as_column, window_size) - out_of_sliding_window = op.LessOrEqual(total_range_as_row, current_range_minus_window) - non_causal_sliding_window = op.Or(non_causal, out_of_sliding_window) - - boolean_mask = pattern.OrValue([non_causal, non_causal_sliding_window]) - - float_0_1_mask = op.Cast(boolean_mask, to=dtype) + boolean_mask = op.Greater(total_range_as_row, current_range_as_column) + float_0_1_mask = op.Cast(boolean_mask, to=1) float_0_min_mask = op.Mul(mask_all_min, float_0_1_mask) - mask_4d_11ST = op.Unsqueeze(float_0_min_mask, [0, 1]) - mask_4d_B1ST = op.Expand(mask_4d_11ST, shape_B111) - - return mask_4d_B1ST - - -class _CausalMaskPattern(pattern.PatternBase): - def pattern( - self, - op, - input_ids, - past_kv_cache, - shape_B111, - min_val, - window_size, - dtype1, - attn_mask_2d, - dtype2, - ): - causal_mask = _causal_mask( - op, - input_ids, - past_kv_cache, - shape_B111, - min_val, - window_size, - dtype1, - ) - - attn_mask_4d = op.Unsqueeze(attn_mask_2d, [1, 2]) - attn_mask_4d_cast = op.Cast(attn_mask_4d, to=dtype2) - - sum = op.Add(causal_mask, attn_mask_4d_cast) - sum_fp32 = op.Cast(sum, to=ir.DataType.FLOAT) - # The cast is optional, and may be absent if the sum is already in float32. - sum_fp32 = pattern.OrValue([sum_fp32, sum]) - is_zero = op.Equal(sum_fp32, 0.0) - result = op.Where(is_zero, min_val, causal_mask) - return result - - def check(self, context, dtype1, dtype2, min_val, attn_mask_2d, sliding_window=None, **_): - # Check that attn_mask_2d is the model input "attention_mask" - if not _is_model_input(attn_mask_2d, "attention_mask", context.model): - return pattern.MatchResult().fail("Invalid attention_mask input", attn_mask_2d) - - if dtype1.as_int() != dtype2.as_int(): - return pattern.MatchResult().fail("Dtype mismatch", [dtype1, dtype2]) - - # Check that min_val is a constant and matches the expected minimum value for the dtype. - min_value = _ir_utils.get_singleton_value(min_val) - if min_value is None: - return pattern.MatchResult().fail("Minval is not a constant.", min_val) - expected_min_value = np.finfo(min_val.dtype.numpy()).min - if min_value != expected_min_value: - return pattern.MatchResult().fail( - f"Expected min value {expected_min_value}, got {min_value}", min_val - ) - - # TODO(rama) Sliding window: not yet supported. - if sliding_window: - return pattern.MatchResult().fail( - "Sliding window not yet supported", sliding_window - ) - return True - + mask_4d = op.Unsqueeze(float_0_min_mask, [0, 1]) + mask_B1ST_plus = op.Expand(mask_4d, shape_B111) -_causal_mask_pattern = _CausalMaskPattern() + # Get rid of the extra +1 added above: total_seq_len is enough, no + # need for total_seq_len+1. + mask_B1ST = op.Slice(mask_B1ST_plus, [0], total_seq_len, [3], [1]) + return mask_B1ST class GroupQueryAttention(pattern.RewriteRuleClassBase): @@ -157,7 +78,8 @@ def pattern( value_BSDkv, past_key, past_value, - position_ids, + position_ids_q, + position_ids_k, cos, sin, mask, @@ -179,7 +101,7 @@ def pattern( query_BHSDh_rope = op.RotaryEmbedding( query_BHSDh, - position_ids, + position_ids_q, cos, sin, _domain="com.microsoft", @@ -187,7 +109,7 @@ def pattern( ) key_BHkvSDh_rope = op.RotaryEmbedding( key_BHkvSDh, - position_ids, + position_ids_k, cos, sin, _domain="com.microsoft", @@ -232,7 +154,7 @@ def pattern( def check( self, - context: _basics.MatchContext, + op, query_BSD, key_BSDkv, value_BSDkv, @@ -242,7 +164,6 @@ def check( key_BHkvSDh_rope, query_BSHDh, key_BSHkvDh, - mask, **_, ): bindings: dict[str, Dim] = {} @@ -289,20 +210,6 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool: ) self._interleaved = query_interleaved - # Check mask: - mask_node = mask.producer() - if mask_node is None: - return pattern.MatchResult().fail("Unhandled mask pattern", mask) - mask_match_result = _causal_mask_pattern.match( - context.model, - context.graph_or_function, - mask_node, - check_nodes_are_removable=False, - ) - if mask_match_result is None: - return pattern.MatchResult().fail("Mask does not match causal mask pattern", mask) - # TODO: handle sliding window support in mask - return True def rewrite( @@ -313,37 +220,24 @@ def rewrite( value_BSDkv, past_key, past_value, - position_ids, + position_ids_q, + position_ids_k, cos, sin, mask, **_, ): - # Note that the following optimization is specific to current ORT GenAI attention-mask - # usage. Specifically, it assumes that the model-input "attention_mask" is a 2D - # mask with shape [batch_size, sequence_length], and that the mask is a 0/1 mask - # that is used only to indicate the current tokens. Hence, the input attention_mask - # is redundant as long as past-sequence-length and current-sequence-length can be - # computed. - - # Construct seqlens_k and total_seq_length_int32 from position_ids - # seqlens_k : int32[batch_size] indicates total_sequence-length-1 for each batch - # position_ids: int64[batch_size, sequence_length] indicates the position of each token - one_int32_0d = op.Constant(value=ir.tensor(1, dtype=ir.DataType.INT32)) - one_int64_1d = op.Constant(value=ir.tensor([1], dtype=ir.DataType.INT64)) - zero_int64_1d = op.Constant(value=ir.tensor([0], dtype=ir.DataType.INT64)) - seqlens_k_int64 = op.ReduceMax(position_ids, one_int64_1d, keepdims=0) - seqlens_k = op.Cast(seqlens_k_int64, to=ir.DataType.INT32) - max_seq_length = op.ReduceMax(seqlens_k, zero_int64_1d, keepdims=0) - total_seq_length_int32 = op.Add(max_seq_length, one_int32_0d) - return op.GroupQueryAttention( + return op.GQA( + mask, + position_ids_k, + position_ids_q, query_BSD, key_BSDkv, value_BSDkv, past_key, past_value, - seqlens_k, - total_seq_length_int32, + None, # seqlens_k, + None, # total_seq_length_int32, cos, sin, num_heads=self.num_heads, @@ -351,23 +245,101 @@ def rewrite( do_rotary=1, rotary_interleaved=self._interleaved, # skipped optional attributes: local_window_size, scale, smooth_softmax, softcap - _domain="com.microsoft", + _domain="ai.onnxruntime._fusion", _outputs=3, ) + +class GQACausalMask(pattern.RewriteRuleClassBase): + def __init__(self): + super().__init__("GQACausalMask", remove_nodes=False) + + def pattern( + self, + op, + mask, + input_ids, + some_kv_cache, + shape_B111, + past_seq_length, + total_seq_length, + ): + mask = causal_mask_pattern(op, input_ids, some_kv_cache, shape_B111) + position_ids = op.Range(past_seq_length, total_seq_length, 1) + position_ids_q = op.Unsqueeze(position_ids, [0]) + position_ids_k = op.Unsqueeze(position_ids, [0]) + return op.GQA( + mask, + position_ids_k, + position_ids_q, + _allow_other_inputs=True, + _domain="ai.onnxruntime._fusion", + _outputs=["attn_output", "key_seq", "value_seq"], + ) + + def rewrite( + self, + op, + total_seq_length, + attn_output, + **_, + ): + # Construct total_seq_length_int32 and seqlens_k + total_seq_length_int32 = op.Cast(total_seq_length, to=ir.DataType.INT32) + one_0D = op.Constant(value_int=1) + one_0D_int32 = op.Cast(one_0D, to=ir.DataType.INT32) + seqlens_k_0D = op.Sub(total_seq_length_int32, one_0D_int32) + zero_1D = op.Constant(value_int=0, dtype=ir.DataType.INT64, shape=[1]) + seqlens_k = op.Unsqueeze(seqlens_k_0D, zero_1D) + + gqa_node = attn_output.producer() + assert len(gqa_node.inputs) == 12, ( + f"Expected 12 inputs for GQA node, got {len(gqa_node.inputs)}" + ) + query, key, value, past_key, past_value = gqa_node.inputs[3:8] + cos, sin = gqa_node.inputs[10:12] + updated_inputs = [ + query, + key, + value, + past_key, + past_value, + seqlens_k, + total_seq_length_int32, + cos, + sin, + ] + attributes = gqa_node.attributes + return op.GroupQueryAttention( + *updated_inputs, **attributes, _domain="com.microsoft", _outputs=3 + ) + + +_basic_gqa_rule = GroupQueryAttention.rule() +_gqa_causal_mask_rule = GQACausalMask.rule() + +gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule, _gqa_causal_mask_rule]) + +fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules) + + class LongRoPeGQACausalMask(pattern.RewriteRuleClassBase): def __init__(self): super().__init__("LongRoPeGQACausalMask", remove_nodes=False) self._mask_cache = {} - + def _get_mask_key(self, attention_mask): """ Generate a unique key for the mask based on input_ids and past_kv_cache. This is used to cache the mask to avoid recomputation. """ return (id(attention_mask)) - - def compute_mask(self, op, attention_mask : _onnx_types.INT64['batch', 'seq_len']): + + def compute_mask(self, op, attention_mask): + """ + Computes the total_seq_length_int32 and seqlens_k_int32 based on the attention_mask, + caching results to avoid recomputation at each layer. + """ mask_key = self._get_mask_key(attention_mask) if mask_key in self._mask_cache: @@ -377,14 +349,14 @@ def compute_mask(self, op, attention_mask : _onnx_types.INT64['batch', 'seq_len' # Construct total_seq_length_int32 and seqlens_k attention_shape = op.Shape(attention_mask, _outputs=["seq_len"]) total_seq_length = op.Gather(attention_shape, op.Constant(value=ir.tensor(1, ir.DataType.INT64)), axis=0, _outputs=["total_seq_length"]) - reduced_attention = op.ReduceSum(attention_mask, op.Constant(value=ir.tensor([1], ir.DataType.INT64)), _outputs=["reduced_attention"]) + reduced_attention = op.ReduceSum(attention_mask, op.Constant(value=ir.tensor([1], ir.DataType.INT64)), _outputs=["reduced_attention"]) sub_reduced_attention = op.Sub(reduced_attention, op.Constant(value=ir.tensor([1], ir.DataType.INT64)), _outputs=["sub_reduced_attention"]) total_seq_length_int32 = op.Cast(total_seq_length, to=ir.DataType.INT32, _outputs=["total_seq_length_int32"]) seqlens_k_int32 = op.Cast(sub_reduced_attention, to=ir.DataType.INT32, _outputs=["seqlens_k_int32"]) self._mask_cache[mask_key] = (total_seq_length_int32, seqlens_k_int32) - + return self._mask_cache[mask_key] - + def pattern( self, @@ -409,9 +381,9 @@ def pattern( current_range_B = op.Range(0, total_seq_len_0D, 1, _outputs=["current_range_B"]) total_seq_len_B = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_B"]) total_seq_len_C = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_C"]) - + total_seq_len_final = op.Reshape(total_seq_len_0D, pattern.ANY_VALUE, allowzero=0, _outputs=["total_seq_len_final"]) - + # EXPAND BRANCH A batch_size = op.Shape(past_kv_cache_2, end=1, start=0, _outputs=["batch_size"]) mask_shape_A = op.Concat(batch_size, [1], seq_len, total_seq_len_A, axis=0, _outputs=["mask_shape_A"]) @@ -424,7 +396,7 @@ def pattern( mask_shape_B_abs = op.Abs(mask_shape_B, _outputs=["mask_shape_B_abs"]) reshaped_range_B = op.Reshape(current_range_B, [1, 1, 1, -1], allowzero=1, _outputs=["reshaped_range_B"]) mask_expanded_B = op.Expand(reshaped_range_B, mask_shape_B_abs, _outputs=["mask_expanded_B"]) - + # EXPAND BRANCH C mask_shape_C = op.Concat(batch_size, [1], seq_len, total_seq_len_C, axis=0, _outputs=["mask_shape_C"]) mask_shape_C_abs = op.Abs(mask_shape_C, _outputs=["mask_shape_C_abs"]) @@ -455,12 +427,12 @@ def pattern( # Propagation to GQA mask_sliced = op.Slice(mask_A_B_C_scaled, [0], pattern.ANY_VALUE, [3], [1], _outputs=["mask_sliced"]) - #mask_where = op.Where(mask_sliced, pattern.ANY_VALUE, pattern.ANY_VALUE, _outputs=["mask_where"]) + gqa_input = pattern.OrValue([mask_sliced, mask_A_B_C_scaled]) return op.GQA( - mask_sliced, + gqa_input, pattern.ANY_VALUE, # position_ids_k - pattern.ANY_VALUE, # position_ids_q + pattern.ANY_VALUE, # position_ids_q pattern.ANY_VALUE, # query pattern.ANY_VALUE, # key pattern.ANY_VALUE, # value @@ -509,9 +481,9 @@ def rewrite( ) _basic_gqa_rule = GroupQueryAttention.rule() +_gqa_causal_mask_rule = GQACausalMask.rule() _longrope_gqa_causal_mask_rule = LongRoPeGQACausalMask.rule() -gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule]) -gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule, _longrope_gqa_causal_mask_rule]) +gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule, _gqa_causal_mask_rule, _longrope_gqa_causal_mask_rule]) -fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules) \ No newline at end of file +fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules) From 75196535a46cc51bbc25f46a39664105cfe64cfa Mon Sep 17 00:00:00 2001 From: Tommaso Date: Fri, 1 Aug 2025 16:07:21 +0000 Subject: [PATCH 13/26] Added docstrings to GQA pattern method --- onnxscript/rewriter/ort_fusions/gqa.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 7b509d4840..4a23d78a5c 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -324,6 +324,12 @@ def rewrite( class LongRoPeGQACausalMask(pattern.RewriteRuleClassBase): + """ + LongRoPeGQACausalMask is a specialized version of GQACausalMask that handles + the LongRoPe GQA fusion. It computes the causal mask for Group Query Attention + with LongRoPe (Long Range Rotary Position Embedding) and caches the mask to + avoid recomputation at each layer. + """ def __init__(self): super().__init__("LongRoPeGQACausalMask", remove_nodes=False) self._mask_cache = {} @@ -369,6 +375,12 @@ def pattern( past_seq_length, total_seq_length, ): + """ + Pattern for LongRoPe GQA Causal Mask. + This pattern computes the causal mask for Group Query Attention with LongRoPe. + It constructs the mask based on input_ids and past_kv_cache, and handles the + expansion of the mask across the batch and sequence dimensions. + """ seq_len = op.Shape(input_ids, end=2, start=1, _outputs=["seq_len"]) seq_len_0D = op.Squeeze(seq_len, _outputs=["seq_len_0D"]) past_seq_len = op.Shape(past_kv_cache_1, end=3, start=2, _outputs=["past_seq_len"]) @@ -455,6 +467,11 @@ def rewrite( attn_output, **_, ): + """ + Rewrite the GQA node with the new mask information. + This method computes the total sequence length and seqlens_k based on the + attention_mask and rewrites the GQA node to use these values. + """ # Compute total_seq_length_int32 and seqlens_k_int32 total_seq_length_int32, seqlens_k_int32 = self.compute_mask(op, attention_mask) From e59cb83787a5aa87296b904c45d933535c86f21b Mon Sep 17 00:00:00 2001 From: Tommaso Date: Fri, 1 Aug 2025 16:32:28 +0000 Subject: [PATCH 14/26] Renamed pattern branches to match kv_range, query_range, and batch_range computation --- onnxscript/rewriter/ort_fusions/gqa.py | 103 +++++++++++-------------- 1 file changed, 47 insertions(+), 56 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 4a23d78a5c..4bc05f17de 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -387,73 +387,64 @@ def pattern( past_seq_len_0D = op.Squeeze(past_seq_len, _outputs=["past_seq_len_0D"]) total_seq_len_0D = op.Add(past_seq_len_0D, seq_len_0D, _outputs=["total_seq_len_0D"]) - # All of the Add node's outputs - current_range_A = op.Range(past_seq_len_0D, total_seq_len_0D, 1, _outputs=["current_range_A"]) - total_seq_len_A = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_A"]) - current_range_B = op.Range(0, total_seq_len_0D, 1, _outputs=["current_range_B"]) - total_seq_len_B = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_B"]) - total_seq_len_C = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_C"]) + # Create ranges for different dimensions + kv_range = op.Range(past_seq_len_0D, total_seq_len_0D, 1, _outputs=["kv_range"]) + total_seq_len_for_kv = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_for_kv"]) + query_range = op.Range(0, total_seq_len_0D, 1, _outputs=["query_range"]) + total_seq_len_for_query = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_for_query"]) + total_seq_len_for_batch = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_for_batch"]) - total_seq_len_final = op.Reshape(total_seq_len_0D, pattern.ANY_VALUE, allowzero=0, _outputs=["total_seq_len_final"]) + #total_seq_len_final = op.Reshape(total_seq_len_0D, pattern.ANY_VALUE, allowzero=0, _outputs=["total_seq_len_final"]) - # EXPAND BRANCH A + # BRANCH A: KV Range - Creates tensor with KV positions [1, 1, seq_len, 1] batch_size = op.Shape(past_kv_cache_2, end=1, start=0, _outputs=["batch_size"]) - mask_shape_A = op.Concat(batch_size, [1], seq_len, total_seq_len_A, axis=0, _outputs=["mask_shape_A"]) - mask_shape_A_abs = op.Abs(mask_shape_A, _outputs=["mask_shape_A_abs"]) - reshaped_range_A = op.Reshape(current_range_A, [1, 1, -1, 1], allowzero=1, _outputs=["reshaped_range_A"]) - mask_expanded_A = op.Expand(reshaped_range_A, mask_shape_A_abs, _outputs=["mask_expanded_A"]) - - # EXPAND BRANCH B - mask_shape_B = op.Concat(batch_size, [1], seq_len, total_seq_len_B, axis=0, _outputs=["mask_shape_B"]) - mask_shape_B_abs = op.Abs(mask_shape_B, _outputs=["mask_shape_B_abs"]) - reshaped_range_B = op.Reshape(current_range_B, [1, 1, 1, -1], allowzero=1, _outputs=["reshaped_range_B"]) - mask_expanded_B = op.Expand(reshaped_range_B, mask_shape_B_abs, _outputs=["mask_expanded_B"]) - - # EXPAND BRANCH C - mask_shape_C = op.Concat(batch_size, [1], seq_len, total_seq_len_C, axis=0, _outputs=["mask_shape_C"]) - mask_shape_C_abs = op.Abs(mask_shape_C, _outputs=["mask_shape_C_abs"]) + kv_mask_shape = op.Concat(batch_size, [1], seq_len, total_seq_len_for_kv, axis=0, _outputs=["kv_mask_shape"]) + kv_mask_shape_abs = op.Abs(kv_mask_shape, _outputs=["kv_mask_shape_abs"]) + reshaped_kv_range = op.Reshape(kv_range, [1, 1, -1, 1], allowzero=1, _outputs=["reshaped_kv_range"]) + expanded_kv_range = op.Expand(reshaped_kv_range, kv_mask_shape_abs, _outputs=["expanded_kv_range"]) + + # BRANCH B: Query Range - Creates tensor with query positions [1, 1, 1, total_seq_len] + query_mask_shape = op.Concat(batch_size, [1], seq_len, total_seq_len_for_query, axis=0, _outputs=["query_mask_shape"]) + query_mask_shape_abs = op.Abs(query_mask_shape, _outputs=["query_mask_shape_abs"]) + reshaped_query_range = op.Reshape(query_range, [1, 1, 1, -1], allowzero=1, _outputs=["reshaped_query_range"]) + expanded_query_range = op.Expand(reshaped_query_range, query_mask_shape_abs, _outputs=["expanded_query_range"]) + + # BRANCH C: Batch Range - Creates tensor with batch indices [batch_size, 1, 1, 1] + batch_mask_shape = op.Concat(batch_size, [1], seq_len, total_seq_len_for_batch, axis=0, _outputs=["batch_mask_shape"]) + batch_mask_shape_abs = op.Abs(batch_mask_shape, _outputs=["batch_mask_shape_abs"]) batch_size_squeezed = op.Squeeze(batch_size, _outputs=["batch_size_squeezed"]) batch_range = op.Range(0, batch_size_squeezed, 1, _outputs=["batch_range"]) - reshaped_range_C = op.Reshape(batch_range, [-1, 1, 1, 1], allowzero=1, _outputs=["reshaped_range_C"]) - mask_expanded_C = op.Expand(reshaped_range_C, mask_shape_C_abs, _outputs=["mask_expanded_C"]) - - # EXPAND A/B TO AND - mask_expanded_A_sub = op.Sub(mask_expanded_A, 262144, _outputs=["mask_expanded_A_sub"]) - mask_A_B_greater = op.Greater(mask_expanded_B, mask_expanded_A_sub, _outputs=["mask_A_B_greater"]) - mask_A_B_greater_bitwise = op.And(True, mask_A_B_greater, _outputs=["mask_A_B_greater_bitwise"]) - mask_A_B_less = op.LessOrEqual(mask_expanded_B, mask_expanded_A, _outputs=["mask_A_B_less"]) - mask_A_B_combined = op.And(mask_A_B_greater_bitwise, mask_A_B_less, _outputs=["mask_A_B_combined"]) - mask_A_B_combined_bitwise = op.And(True, mask_A_B_combined, _outputs=["mask_A_B_combined_bitwise"]) - - # EXPAND B/C TO AND - unsqueezed_mask_expanded_B = op.Unsqueeze(mask_expanded_B, [-1], _outputs=["unsqueezed_mask_expanded_B"]) - unsqueezed_mask_expanded_C = op.Unsqueeze(mask_expanded_C, [-1], _outputs=["unsqueezed_mask_expanded_C"]) - mask_B_C_concat = op.Concat(unsqueezed_mask_expanded_C, unsqueezed_mask_expanded_B, axis=-1, _outputs=["mask_B_C_concat"]) + reshaped_batch_range = op.Reshape(batch_range, [-1, 1, 1, 1], allowzero=1, _outputs=["reshaped_batch_range"]) + expanded_batch_range = op.Expand(reshaped_batch_range, batch_mask_shape_abs, _outputs=["expanded_batch_range"]) + + # Combine KV/Query Ranges for Sliding Window Mask + kv_range_offset = op.Sub(expanded_kv_range, 262144, _outputs=["kv_range_offset"]) + query_gt_kv_offset = op.Greater(expanded_query_range, kv_range_offset, _outputs=["query_gt_kv_offset"]) + query_gt_kv_offset_mask = op.And(True, query_gt_kv_offset, _outputs=["query_gt_kv_offset_mask"]) + query_le_kv = op.LessOrEqual(expanded_query_range, expanded_kv_range, _outputs=["query_le_kv"]) + sliding_window_mask = op.And(query_gt_kv_offset_mask, query_le_kv, _outputs=["sliding_window_mask"]) + sliding_window_mask_final = op.And(True, sliding_window_mask, _outputs=["sliding_window_mask_final"]) + + # Combine Query/Batch Ranges for Attention Mask Lookup + unsqueezed_query_range = op.Unsqueeze(expanded_query_range, [-1], _outputs=["unsqueezed_query_range"]) + unsqueezed_batch_range = op.Unsqueeze(expanded_batch_range, [-1], _outputs=["unsqueezed_batch_range"]) + batch_query_indices = op.Concat(unsqueezed_batch_range, unsqueezed_query_range, axis=-1, _outputs=["batch_query_indices"]) attention_mask_bool = op.Cast(attention_mask, to=ir.DataType.BOOL, _outputs=["attention_mask_bool"]) - mask_gatherND = op.GatherND(attention_mask_bool, mask_B_C_concat, batch_dims=0, _outputs=["mask_gatherND"]) - - mask_A_B_C_combined = op.And(mask_A_B_combined_bitwise, mask_gatherND, _outputs=["mask_A_B_C_combined"]) - mask_A_B_C_negated = op.Not(mask_A_B_C_combined, _outputs=["mask_A_B_C_negated"]) - mask_A_B_C_fp32 = op.Cast(mask_A_B_C_negated, to=ir.DataType.FLOAT, _outputs=["mask_A_B_C_fp32"]) - mask_A_B_C_scaled = op.Mul(mask_A_B_C_fp32, pattern.ANY_VALUE) + attention_lookup = op.GatherND(attention_mask_bool, batch_query_indices, batch_dims=0, _outputs=["attention_lookup"]) + + # Final Mask Combination + final_attention_mask = op.And(sliding_window_mask_final, attention_lookup, _outputs=["final_attention_mask"]) + inverted_mask = op.Not(final_attention_mask, _outputs=["inverted_mask"]) + mask_fp32 = op.Cast(inverted_mask, to=ir.DataType.FLOAT, _outputs=["mask_fp32"]) + scaled_mask = op.Mul(mask_fp32, pattern.ANY_VALUE) + # Propagation to GQA - mask_sliced = op.Slice(mask_A_B_C_scaled, [0], pattern.ANY_VALUE, [3], [1], _outputs=["mask_sliced"]) + sliced_mask = op.Slice(scaled_mask, [0], pattern.ANY_VALUE, [3], [1], _outputs=["sliced_mask"]) - gqa_input = pattern.OrValue([mask_sliced, mask_A_B_C_scaled]) + gqa_input = pattern.OrValue([sliced_mask, scaled_mask]) return op.GQA( gqa_input, - pattern.ANY_VALUE, # position_ids_k - pattern.ANY_VALUE, # position_ids_q - pattern.ANY_VALUE, # query - pattern.ANY_VALUE, # key - pattern.ANY_VALUE, # value - pattern.ANY_VALUE, # past_key - pattern.ANY_VALUE, # past_value - pattern.ANY_VALUE, # seqlens_k (optional) - pattern.ANY_VALUE, # total_seq_length (optional) - pattern.ANY_VALUE, # cos - pattern.ANY_VALUE, # sin _allow_other_inputs=True, _domain="ai.onnxruntime._fusion", _outputs=["attn_output", "key_seq", "value_seq"], From bad78117223da4db8da955c77ab3074610d1c0bf Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 24 Jul 2025 12:52:38 -0700 Subject: [PATCH 15/26] Remove DORT related tests since it was removed from PyTorch (#2465) Signed-off-by: Justin Chu --- onnxscript/tools/training_helper.py | 47 ------------------- .../tools/transformers_models/llama_test.py | 29 ++---------- .../tools/transformers_models/mistral_test.py | 31 ++---------- .../tools/transformers_models/phi3_test.py | 31 ++---------- .../tools/transformers_models/phi_test.py | 29 ------------ 5 files changed, 15 insertions(+), 152 deletions(-) delete mode 100644 onnxscript/tools/training_helper.py diff --git a/onnxscript/tools/training_helper.py b/onnxscript/tools/training_helper.py deleted file mode 100644 index bd791ae8e6..0000000000 --- a/onnxscript/tools/training_helper.py +++ /dev/null @@ -1,47 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -import torch -from torch.onnx import _OrtBackend, _OrtBackendOptions - - -def make_aot_ort(): - """Implements an autograd backend for torch.compile based on onnxrt backend.""" - options = _OrtBackendOptions() - ort_backend = _OrtBackend(options=options) - return ort_backend - - -def train_loop(model, *args, loss_fn=None, optimizer=None): - """Implements a training loop to be used in tests.""" - - if loss_fn is None: - loss_fn = torch.nn.MSELoss() - if optimizer is None: - optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) - - # Set the model to training mode - important for batch normalization and dropout layers - # Unnecessary in this situation but added for best practices - model.train() - - # Compute prediction and loss - pred = model(*args) - if isinstance(pred, tuple): - v = pred[0] - elif hasattr(pred, "last_hidden_state"): - v = pred.last_hidden_state - else: - v = pred - loss = loss_fn(v, torch.ones_like(v)) - - # Backpropagation - loss.backward() - optimizer.step() - # skip that part to retrieve the gradients - # optimizer.zero_grad() - - # returns the gradients - res = tuple(p.grad for p in model.parameters() if p.grad is not None) - assert len(res) > 0, f"No gradient, loss is {loss}" - return res diff --git a/onnxscript/tools/transformers_models/llama_test.py b/onnxscript/tools/transformers_models/llama_test.py index 7f8d42050b..5cb3159600 100644 --- a/onnxscript/tools/transformers_models/llama_test.py +++ b/onnxscript/tools/transformers_models/llama_test.py @@ -9,7 +9,6 @@ import onnxruntime import torch -import onnxscript.tools.training_helper import onnxscript.tools.transformers_models import onnxscript.tools.transformers_models.llama from onnxscript._internal.version_utils import ( @@ -34,13 +33,7 @@ def test_llama_export_cpu(self): ) input_tensors = input_tensors_many[0] expected = model(*input_tensors) - try: - proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) - except torch._export.verifier.SpecViolationError as e: # pylint: disable=protected-access - # see https://github.com/pytorch/pytorch/issues/128394 - if "Node.meta _enter_autocast is missing val field." in str(e): - raise unittest.SkipTest(str(e)) - raise + proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) names = [i.name for i in proto.graph.input] np_input_tensors = [x.numpy() for x in input_tensors] feeds = dict(zip(names, np_input_tensors)) @@ -63,15 +56,9 @@ def test_llama_export_cpu_export_api(self): ) input_tensors = input_tensors_many[0] expected = model(*input_tensors) - try: - proto = onnxscript.tools.transformers_models.export_to_onnx( - model, *input_tensors, export_api=True - ) - except torch._export.verifier.SpecViolationError as e: # pylint: disable=protected-access - # see https://github.com/pytorch/pytorch/issues/128394 - if "Node.meta _enter_autocast is missing val field." in str(e): - raise unittest.SkipTest(str(e)) - raise + proto = onnxscript.tools.transformers_models.export_to_onnx( + model, *input_tensors, export_api=True + ) names = [i.name for i in proto.graph.input] np_input_tensors = [x.numpy() for x in input_tensors] feeds = dict(zip(names, np_input_tensors)) @@ -94,13 +81,7 @@ def test_llama_export_cuda(self): model = model.to("cuda") input_tensors = [i.to("cuda") for i in input_tensors_cpu] expected = model(*input_tensors) - try: - proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) - except torch._export.verifier.SpecViolationError as e: # pylint: disable=protected-access - # see https://github.com/pytorch/pytorch/issues/128394 - if "Node.meta _enter_autocast is missing val field." in str(e): - raise unittest.SkipTest(str(e)) - raise + proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) names = [i.name for i in proto.graph.input] np_input_tensors = [x.detach().cpu().numpy() for x in input_tensors] feeds = dict(zip(names, np_input_tensors)) diff --git a/onnxscript/tools/transformers_models/mistral_test.py b/onnxscript/tools/transformers_models/mistral_test.py index fb06ecbd57..2883fbd32e 100644 --- a/onnxscript/tools/transformers_models/mistral_test.py +++ b/onnxscript/tools/transformers_models/mistral_test.py @@ -9,9 +9,6 @@ import onnxruntime import torch -import onnxscript.optimizer -import onnxscript.rewriter -import onnxscript.tools.training_helper import onnxscript.tools.transformers_models import onnxscript.tools.transformers_models.mistral from onnxscript._internal.version_utils import ( @@ -36,13 +33,7 @@ def test_mistral_export_cpu(self): ) input_tensors = input_tensors_many[0] expected = model(*input_tensors) - try: - proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) - except torch._export.verifier.SpecViolationError as e: # pylint: disable=protected-access - # see https://github.com/pytorch/pytorch/issues/128394 - if "Node.meta _enter_autocast is missing val field." in str(e): - raise unittest.SkipTest(str(e)) - raise + proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) names = [i.name for i in proto.graph.input] np_input_tensors = [x.numpy() for x in input_tensors] feeds = dict(zip(names, np_input_tensors)) @@ -65,15 +56,9 @@ def test_mistral_export_cpu_export_api(self): ) input_tensors = input_tensors_many[0] expected = model(*input_tensors) - try: - proto = onnxscript.tools.transformers_models.export_to_onnx( - model, *input_tensors, export_api=True - ) - except torch._export.verifier.SpecViolationError as e: # pylint: disable=protected-access - # see https://github.com/pytorch/pytorch/issues/128394 - if "Node.meta _enter_autocast is missing val field." in str(e): - raise unittest.SkipTest(str(e)) - raise + proto = onnxscript.tools.transformers_models.export_to_onnx( + model, *input_tensors, export_api=True + ) names = [i.name for i in proto.graph.input] np_input_tensors = [x.numpy() for x in input_tensors] feeds = dict(zip(names, np_input_tensors)) @@ -95,13 +80,7 @@ def test_phi_export_cuda(self): model = model.to("cuda") input_tensors = [i.to("cuda") for i in input_tensors_cpu] expected = model(*input_tensors) - try: - proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) - except torch._export.verifier.SpecViolationError as e: # pylint: disable=protected-access - # see https://github.com/pytorch/pytorch/issues/128394 - if "Node.meta _enter_autocast is missing val field." in str(e): - raise unittest.SkipTest(str(e)) - raise + proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) names = [i.name for i in proto.graph.input] np_input_tensors = [x.detach().cpu().numpy() for x in input_tensors] feeds = dict(zip(names, np_input_tensors)) diff --git a/onnxscript/tools/transformers_models/phi3_test.py b/onnxscript/tools/transformers_models/phi3_test.py index ac03f487d5..db47b7d1f1 100644 --- a/onnxscript/tools/transformers_models/phi3_test.py +++ b/onnxscript/tools/transformers_models/phi3_test.py @@ -9,9 +9,6 @@ import onnxruntime import torch -import onnxscript.optimizer -import onnxscript.rewriter -import onnxscript.tools.training_helper import onnxscript.tools.transformers_models import onnxscript.tools.transformers_models.phi3 from onnxscript._internal.version_utils import ( @@ -35,13 +32,7 @@ def test_phi3_export_cpu(self): ) input_tensors = input_tensors_many[0] expected = model(*input_tensors) - try: - proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) - except torch._export.verifier.SpecViolationError as e: # pylint: disable=protected-access - # see https://github.com/pytorch/pytorch/issues/128394 - if "Node.meta _enter_autocast is missing val field." in str(e): - raise unittest.SkipTest(str(e)) - raise + proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) names = [i.name for i in proto.graph.input] np_input_tensors = [x.numpy() for x in input_tensors] feeds = dict(zip(names, np_input_tensors)) @@ -62,15 +53,9 @@ def test_phi3_export_cpu_export_api(self): ) input_tensors = input_tensors_many[0] expected = model(*input_tensors) - try: - proto = onnxscript.tools.transformers_models.export_to_onnx( - model, *input_tensors, export_api=True - ) - except torch._export.verifier.SpecViolationError as e: # pylint: disable=protected-access - # see https://github.com/pytorch/pytorch/issues/128394 - if "Node.meta _enter_autocast is missing val field." in str(e): - raise unittest.SkipTest(str(e)) - raise + proto = onnxscript.tools.transformers_models.export_to_onnx( + model, *input_tensors, export_api=True + ) names = [i.name for i in proto.graph.input] np_input_tensors = [x.numpy() for x in input_tensors] feeds = dict(zip(names, np_input_tensors)) @@ -93,13 +78,7 @@ def test_phi3_export_cuda(self): model = model.to("cuda") input_tensors = [i.to("cuda") for i in input_tensors_cpu] expected = model(*input_tensors) - try: - proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) - except torch._export.verifier.SpecViolationError as e: # pylint: disable=protected-access - # see https://github.com/pytorch/pytorch/issues/128394 - if "Node.meta _enter_autocast is missing val field." in str(e): - raise unittest.SkipTest(str(e)) - raise + proto = onnxscript.tools.transformers_models.export_to_onnx(model, *input_tensors) names = [i.name for i in proto.graph.input] np_input_tensors = [x.detach().cpu().numpy() for x in input_tensors] feeds = dict(zip(names, np_input_tensors)) diff --git a/onnxscript/tools/transformers_models/phi_test.py b/onnxscript/tools/transformers_models/phi_test.py index f2b5f9ff8f..9b88203084 100644 --- a/onnxscript/tools/transformers_models/phi_test.py +++ b/onnxscript/tools/transformers_models/phi_test.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. # pylint: disable=not-callable -import copy import sys import unittest @@ -10,7 +9,6 @@ import onnxruntime import torch -import onnxscript.tools.training_helper import onnxscript.tools.transformers_models import onnxscript.tools.transformers_models.phi from onnxscript._internal.version_utils import ( @@ -79,33 +77,6 @@ def test_phi_export_cuda(self): results = sess.run(None, feeds) np.testing.assert_allclose(expected[0].detach().cpu().numpy(), results[0], atol=1e-5) - @unittest.skipIf(sys.platform == "win32", reason="not supported yet on Windows") - @unittest.skipIf(not has_transformers(), reason="transformers is missing") - @unittest.skipIf( - not hasattr(onnxruntime, "training"), reason="ORT training removed since 1.22" - ) - @ignore_warnings(UserWarning) - def test_phi_dort_static(self): - model, input_tensors_many, _ = onnxscript.tools.transformers_models.phi.get_phi_model() - input_tensors = input_tensors_many[0] - expected = model(*input_tensors) - - local_aot_ort = onnxscript.tools.training_helper.make_aot_ort() - - compiled_model = torch.compile( - copy.deepcopy(model), - backend=local_aot_ort, - dynamic=False, - fullgraph=True, - ) - - results = compiled_model(*input_tensors) - torch.testing.assert_close(expected[0], results[0], atol=1e-5, rtol=1e-5) - - expected_gradients = onnxscript.tools.training_helper.train_loop(model, *input_tensors) - gradients = onnxscript.tools.training_helper.train_loop(compiled_model, *input_tensors) - torch.testing.assert_close(expected_gradients[0], gradients[0], atol=1e-5, rtol=1e-5) - if __name__ == "__main__": unittest.main(verbosity=2) From 19f5e65baf576153161ecab601f58fd1b0db1ba2 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Fri, 25 Jul 2025 16:13:34 -0700 Subject: [PATCH 16/26] Handle matching against None explicitly (#2460) Provide a way to indicate that a pattern-variable can match successfully against a None-valued input. Cleanup current handling which was inconsistent in one place. Add test cases. --------- Signed-off-by: Ganesan Ramalingam Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxscript/rewriter/_matcher.py | 24 +++++++---- onnxscript/rewriter/_pattern_ir.py | 30 ++++++++++++-- onnxscript/rewriter/ort_fusions/attention.py | 5 +-- .../rewriter/ort_fusions/fuse_mha_bias.py | 6 +-- onnxscript/rewriter/pattern.py | 2 + onnxscript/rewriter/pattern_test.py | 41 ++++++++++++++++++- 6 files changed, 89 insertions(+), 19 deletions(-) diff --git a/onnxscript/rewriter/_matcher.py b/onnxscript/rewriter/_matcher.py index ab278ef573..4993fe8232 100644 --- a/onnxscript/rewriter/_matcher.py +++ b/onnxscript/rewriter/_matcher.py @@ -149,18 +149,21 @@ def _match_node(self, pattern_node: _pattern_ir.NodePattern, node: ir.Node) -> b match.bind_node(pattern_node, node) # TODO: Revisit this to handle optional trailing inputs better. - if pattern_node.allow_other_inputs: - if len(node.inputs) < len(pattern_node.inputs): + + if len(node.inputs) > len(pattern_node.inputs): + if not pattern_node.allow_other_inputs: return self.fail( - f"Number of inputs ({len(node.inputs)}) is less than expected ({len(pattern_node.inputs)})" + f"Number of inputs ({len(node.inputs)}) is greater than expected ({len(pattern_node.inputs)})" ) + checked_inputs = zip(node.inputs, pattern_node.inputs) else: - if len(node.inputs) != len(pattern_node.inputs): - return self.fail( - f"Input nums mismatch. {len(node.inputs)} vs {len(pattern_node.inputs)}" - ) + # In ONNX, trailing Nones can be omitted in the inputs of a node. So, we extend actual + # node inputs with None values to match the pattern node inputs length when zipping. + checked_inputs = itertools.zip_longest( + node.inputs, pattern_node.inputs, fillvalue=None + ) - for arg_value, arg_pattern in zip(node.inputs, pattern_node.inputs): + for arg_value, arg_pattern in checked_inputs: # arg_pattern could be a Var, if it's the original arg. if arg_pattern is None: if arg_value is None: @@ -216,6 +219,11 @@ def _match_value( if pattern_value.tag_var is not None: self._match.bind(pattern_value.tag_var, i) return result + # Default case: a plain pattern variable (ValuePattern) + if value is None and not pattern_value.can_match_none: + return self.fail( + f"Mismatch: pattern variable {pattern_value} does not match None." + ) return True def _match_node_output( diff --git a/onnxscript/rewriter/_pattern_ir.py b/onnxscript/rewriter/_pattern_ir.py index 8fd283f0f0..1687897737 100644 --- a/onnxscript/rewriter/_pattern_ir.py +++ b/onnxscript/rewriter/_pattern_ir.py @@ -123,12 +123,16 @@ def _to_attr_pattern(value: AttrPattern | ValuePattern | SupportedAttrTypes) -> """Represents promotion of values allowed as keyword-arguments in a pattern-builder call to an AttrPattern.""" if isinstance(value, AttrPattern): return value - if type(value) is ValuePattern: - # This is a hack. Currently, when we create pattern-variables, we create them as ValuePattern, + if isinstance(value, Var): + # This is a hack. Currently, when we create pattern-variables, we create them as Var, # and change them to AttrPattern if/when used in an attribute context. We could use type # annotations to distinguish between ValuePattern and AttrPattern, but forces users to # use these type annotations. # TODO: check for misuse at rule-creation time. (Currently will be caught by matcher at match-time.) + if value.can_match_none or value.check_method is not None: + raise ValueError( + "Pattern variables used in attributes must not have can_match_none or check_method set." + ) return AttrPattern(value.name) if isinstance(value, (int, float, str)): return AttrConstantPattern(value) @@ -320,9 +324,12 @@ class ValuePattern: operations, so that we can write patterns like `x + 1` and `1 + x`. """ - def __init__(self, name: str | None, *, check: Callable | None = None) -> None: + def __init__( + self, name: str | None, *, check: Callable | None = None, can_match_none: bool = False + ) -> None: self._name = name self._check = check + self._can_match_none = can_match_none # Note: uses will be computed only when the full graph-pattern is constructed. self._uses: list[tuple[NodePattern, int]] = [] @@ -338,6 +345,11 @@ def name(self) -> str | None: def check_method(self) -> Callable | None: return self._check + @property + def can_match_none(self) -> bool: + """Indicates whether this variable can match a None input.""" + return self._can_match_none + def producer(self) -> NodePattern | None: return None @@ -547,7 +559,17 @@ def producer(self) -> NodePattern: return self._producer -Var = ValuePattern +class Var(ValuePattern): + """Represents a pattern-variable.""" + + def __init__( + self, name: str | None, *, check: Callable | None = None, can_match_none: bool = False + ) -> None: + super().__init__(name, check=check, can_match_none=can_match_none) + + def clone(self, node_map: dict[NodePattern, NodePattern]) -> Var: + """Clones the pattern-variable, preserving its name and check method.""" + return Var(self.name, check=self.check_method, can_match_none=self.can_match_none) class AnyValue(ValuePattern): diff --git a/onnxscript/rewriter/ort_fusions/attention.py b/onnxscript/rewriter/ort_fusions/attention.py index 284258bd6f..ffbe131233 100644 --- a/onnxscript/rewriter/ort_fusions/attention.py +++ b/onnxscript/rewriter/ort_fusions/attention.py @@ -34,7 +34,6 @@ def pattern( qkv_bias, # mask_index, past, - attention_bias, num_heads, # scale, start1, @@ -106,7 +105,7 @@ def pattern( value_BSD, qkv_bias, None, # key_padding_mask - attention_bias, + pattern.Var("attention_bias", can_match_none=True), past_key, past_value, num_heads=num_heads, @@ -127,7 +126,7 @@ def pattern( value_BSD, qkv_bias, None, # key_padding_mask - attention_bias, + pattern.Var("attention_bias", can_match_none=True), None, # past_key None, # past_value num_heads=num_heads, diff --git a/onnxscript/rewriter/ort_fusions/fuse_mha_bias.py b/onnxscript/rewriter/ort_fusions/fuse_mha_bias.py index fdb8f08cf8..c152cecbc1 100644 --- a/onnxscript/rewriter/ort_fusions/fuse_mha_bias.py +++ b/onnxscript/rewriter/ort_fusions/fuse_mha_bias.py @@ -52,9 +52,9 @@ def pattern( value_BSD, None, # bias None, # key padding mask - mask, # attention mask/bias - past_key, - past_value, + pattern.Var("mask", can_match_none=True), # attention mask/bias + pattern.Var("past_key", can_match_none=True), + pattern.Var("past_value", can_match_none=True), num_heads=num_heads, # scale=scale, _domain="com.microsoft", diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 29caa52aef..68c1654f5c 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -10,6 +10,7 @@ Constant, OpsetPatternBuilder, OrValue, + Var, pattern_builder, torch_module_op, ) @@ -41,4 +42,5 @@ "PatternMatcher", "SimplePatternMatcher", "torch_module_op", + "Var", ] diff --git a/onnxscript/rewriter/pattern_test.py b/onnxscript/rewriter/pattern_test.py index ec0db97d11..bf5940e97c 100644 --- a/onnxscript/rewriter/pattern_test.py +++ b/onnxscript/rewriter/pattern_test.py @@ -450,8 +450,9 @@ def test_model(x: FLOAT[1024]) -> FLOAT[1024]: self.assertEqual(model.graph.node(1).op_type, "Original") def test_match_optional_input(self): - def none_pattern(op, optional_input, x): + def none_pattern(op, x): # match against a call to Original where the first input may or may not be None + optional_input = pattern.Var("optional_input", can_match_none=True) return op.Original(optional_input, x) def replacement(op, optional_input, x): @@ -478,6 +479,44 @@ def test_model(x: FLOAT[1024]) -> FLOAT[1024]: self.assertEqual(model.graph.node(0).op_type, "ReplacedNone") self.assertEqual(model.graph.node(1).op_type, "ReplacedNotNone") + def test_mismatched_number_of_inputs(self): + def var_length_pattern(op): + # match against a call to Original where the first input may or may not be None + input1 = pattern.Var("input1", can_match_none=False) + input2 = pattern.Var("input2", can_match_none=True) + return op.Original(input1, input2) + + def replacement(op, input1, input2): + return op.Replaced(input1, input2) + + rule = pattern.RewriteRule(var_length_pattern, replacement) + + @script() + def test_model(x: FLOAT[1024], y: FLOAT[1024], z: FLOAT[1024]) -> FLOAT[1024]: + # Pattern should NOT match following 2 calls, since pattern requires first input to be non-None + t0 = op.Original() + t1 = op.Original(None, x) + + # Pattern should match following 3 calls, since second input can be None + t2 = op.Original(x) + t3 = op.Original(x, None) + t4 = op.Original(x, y) + + # Pattern should NOT match following call, since it has more than 2 inputs + t5 = op.Original(x, y, z) + return op.All(t0, t1, t2, t3, t4, t5) + + model_proto = test_model.to_model_proto() + model = ir.serde.deserialize_model(model_proto) + + count = rule.apply_to_model(model) + self.assertEqual(count, 3) + self.assertEqual(len(model.graph), 7) + self.assertEqual( + [n.op_type for n in model.graph], + ["Original", "Original", "Replaced", "Replaced", "Replaced", "Original", "All"], + ) + def test_graph_visitor(self): class ReplaceFoo(pattern.RewriteRuleClassBase): def __init__(self): From 17c117fa6a34ba7670d74078e517523d4884bd8d Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Tue, 29 Jul 2025 17:23:58 -0700 Subject: [PATCH 17/26] [docs] Document rewriter pattern options (#2406) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds comprehensive documentation for the rewriter pattern options that were previously undocumented. The rewriter pattern system supports four key options for controlling pattern matching and replacement behavior: ## New Documentation Added ### `_allow_other_inputs` option - **File**: `docs/tutorial/rewriter/allow_other_inputs.md` - **Purpose**: Controls whether patterns can match nodes with additional inputs beyond those specified - **Default**: `False` (exact input matching) - **Example**: Matching `Conv` operations that may have optional bias inputs ```python def conv_pattern(op, input, weight): # Matches Conv with 2 or 3 inputs (weight + optional bias) return op.Conv(input, weight, _allow_other_inputs=True) ``` ### `_domain` option - **File**: `docs/tutorial/rewriter/domain_option.md` - **Purpose**: Specifies operator domains for pattern matching and replacement - **Use cases**: Domain-specific rewrites, migrating between operator domains - **Example**: Targeting operations from specific domains like "com.microsoft" ```python def custom_relu_pattern(op, input): # Only matches Relu from custom domain return op.Relu(input, _domain="custom.domain") ``` ### `_outputs` option - **File**: `docs/tutorial/rewriter/outputs_option.md` - **Purpose**: Specifies number and names of operation outputs - **Formats**: Integer count (`_outputs=2`) or named list (`_outputs=["first", "second"]`) - **Example**: Handling multi-output operations like `Split` ```python def split_pattern(op, input): # Matches Split operations with exactly 2 outputs return op.Split(input, num_outputs=2, axis=0, _outputs=2) ``` ### Enhanced `_allow_other_attributes` documentation - **File**: `docs/tutorial/rewriter/attributes.md` (improved formatting) - **Already documented**: Controls whether patterns match nodes with additional attributes - **Default**: `True` (allows extra attributes) ## Documentation Structure Improvements - Added "Pattern Options" section to main rewriter documentation - Integrated all option docs into the tutorial flow - Created working code examples for each option - Followed existing documentation patterns and style - All examples compile and run successfully - Documentation builds correctly with Sphinx The documentation now provides complete coverage of all rewriter pattern options with practical examples showing real-world usage patterns. Fixes #2405. > [!WARNING] > >
> Firewall rules blocked me from connecting to one or more addresses > > #### I tried to connect to the following addresses, but was blocked by firewall rules: > > - `docs.python.org` > - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns block) > - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 ` (dns block) > - `docs.scipy.org` > - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns block) > - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 ` (dns block) > - `matplotlib.org` > - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns block) > - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 ` (dns block) > - `numpy.org` > - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns block) > - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 ` (dns block) > - `onnx.ai` > - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns block) > - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 ` (dns block) > - `onnxruntime.ai` > - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns block) > - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 ` (dns block) > - `pytorch.org` > - Triggering command: `python -m sphinx docs dist/html -W -q ` (dns block) > - Triggering command: `python -m sphinx docs dist/html -q -E -j 1 ` (dns block) > > If you need me to access, download, or install something from one of these locations, you can either: > > - Configure [Actions setup steps](https://gh.io/copilot/actions-setup-steps) to set up my environment, which run before the firewall is enabled > - Add the appropriate URLs or hosts to my [firewall allow list](https://gh.io/copilot/firewall-config) > >
--- 💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click [here](https://survey.alchemer.com/s3/8343779/Copilot-Coding-agent) to start the survey. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com> Co-authored-by: gramalingam <10075881+gramalingam@users.noreply.github.com> --- docs/tutorial/rewriter/allow_other_inputs.md | 27 ++++++ docs/tutorial/rewriter/attributes.md | 1 + docs/tutorial/rewriter/domain_option.md | 38 ++++++++ .../rewriter/examples/allow_other_inputs.py | 71 +++++++++++++++ .../rewriter/examples/domain_option.py | 86 +++++++++++++++++++ .../rewriter/examples/outputs_option.py | 76 ++++++++++++++++ docs/tutorial/rewriter/outputs_option.md | 43 ++++++++++ docs/tutorial/rewriter/rewrite_patterns.md | 20 +++++ 8 files changed, 362 insertions(+) create mode 100644 docs/tutorial/rewriter/allow_other_inputs.md create mode 100644 docs/tutorial/rewriter/domain_option.md create mode 100644 docs/tutorial/rewriter/examples/allow_other_inputs.py create mode 100644 docs/tutorial/rewriter/examples/domain_option.py create mode 100644 docs/tutorial/rewriter/examples/outputs_option.py create mode 100644 docs/tutorial/rewriter/outputs_option.md diff --git a/docs/tutorial/rewriter/allow_other_inputs.md b/docs/tutorial/rewriter/allow_other_inputs.md new file mode 100644 index 0000000000..29ccabca03 --- /dev/null +++ b/docs/tutorial/rewriter/allow_other_inputs.md @@ -0,0 +1,27 @@ +# Specifying variable inputs in the pattern + +This section demonstrates the use of the `_allow_other_inputs` option in pattern-based rewriting. +The `_allow_other_inputs` option allows the pattern to match nodes that have additional inputs +beyond those specified in the pattern. If it is set to `False` (the default), then the node must +have exactly the specified inputs for a successful match. If set to `True`, the pattern will +match nodes that have the specified inputs plus any number of additional inputs. + +This is particularly useful when matching operations like `Conv` that can have optional inputs +(such as bias), or when creating generic patterns that should work with various input configurations. + +```{literalinclude} examples/allow_other_inputs.py +:pyobject: conv_pattern +``` + +```{literalinclude} examples/allow_other_inputs.py +:pyobject: conv_replacement +``` + +```{literalinclude} examples/allow_other_inputs.py +:pyobject: apply_rewrite +``` + +In this example, the pattern matches `Conv` operations with any number of inputs. A `Conv` operation +might have 2 inputs (input and weight) or 3 inputs (input, weight, and bias). By setting +`_allow_other_inputs=True`, our pattern will match both cases even though we only specify 2 inputs +in the pattern definition. diff --git a/docs/tutorial/rewriter/attributes.md b/docs/tutorial/rewriter/attributes.md index 12f1834241..ba72cc5ade 100644 --- a/docs/tutorial/rewriter/attributes.md +++ b/docs/tutorial/rewriter/attributes.md @@ -4,6 +4,7 @@ This section demonstrates the use of attribute values in pattern-based rewriting First, write a target pattern and replacement pattern in a similar way to the previous examples. The example pattern below will match successfully only against Dropout nodes with the attribute value `training_mode` set to `False`. + The `_allow_other_attributes` option allows the pattern to match nodes that have additional attributes not specified in the pattern. If it is set to `False`, then the node must have only the specified attribute values, and no other attributes, for a successful match. The default value for this diff --git a/docs/tutorial/rewriter/domain_option.md b/docs/tutorial/rewriter/domain_option.md new file mode 100644 index 0000000000..30a7384b59 --- /dev/null +++ b/docs/tutorial/rewriter/domain_option.md @@ -0,0 +1,38 @@ +# Specifying domains in the pattern + +This section demonstrates the use of the `_domain` option in pattern-based rewriting. +The `_domain` option allows you to specify which operator domain the pattern should match against, +and also allows you to create replacement operations in specific domains. + +ONNX operators can belong to different domains: +- The default ONNX domain (empty string or "ai.onnx") +- Custom domains like "com.microsoft" for Microsoft-specific operations +- User-defined domains for custom operations + +## Matching operations from a specific domain + +```{literalinclude} examples/domain_option.py +:pyobject: custom_relu_pattern +``` + +In this pattern, `_domain="custom.domain"` ensures that only `Relu` operations from the +"custom.domain" domain will be matched, not standard ONNX `Relu` operations. + +## Creating replacement operations in a specific domain + +```{literalinclude} examples/domain_option.py +:pyobject: microsoft_relu_replacement +``` + +Here, the replacement operation is created in the "com.microsoft" domain, which might +provide optimized implementations of standard operations. + +## Complete rewrite example + +```{literalinclude} examples/domain_option.py +:pyobject: apply_rewrite +``` + +This example shows how domain-specific pattern matching can be used to migrate operations +between different operator domains, such as replacing custom domain operations with +standard ONNX operations or vice versa. diff --git a/docs/tutorial/rewriter/examples/allow_other_inputs.py b/docs/tutorial/rewriter/examples/allow_other_inputs.py new file mode 100644 index 0000000000..cc3a3d926f --- /dev/null +++ b/docs/tutorial/rewriter/examples/allow_other_inputs.py @@ -0,0 +1,71 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""ONNX Pattern Rewriting with variable number of inputs + +This script shows how to define a rewriting rule based on patterns that +can match nodes with additional inputs beyond those specified in the pattern. +""" + +import onnx + +import onnxscript +from onnxscript import FLOAT, opset18, script +from onnxscript.rewriter import pattern + + +@script() +def original_model(A: FLOAT[2, 2], B: FLOAT[2, 2], C: FLOAT[2, 2]) -> FLOAT[2, 2]: + # Conv with bias - has 3 inputs: input, weight, bias + result = opset18.Conv(A, B, C) + return result + + +_model = original_model.to_model_proto() +onnx.checker.check_model(_model) + + +#################################### +# The target pattern +# ===================== + + +def conv_pattern(op, input, weight): + # Pattern to match Conv operations, allowing additional inputs like bias + # _allow_other_inputs=True allows the pattern to match Conv with bias (3 inputs) + # even though we only specify 2 inputs in the pattern + return op.Conv(input, weight, _allow_other_inputs=True) + + +#################################### +# The replacement pattern +# ===================== + + +def conv_replacement(op, input, weight, **_): + # Replace with a custom operation in a different domain + return op.OptimizedConv(input, weight, _domain="custom.domain") + + +#################################### +# Create Rewrite Rule and Apply to Model +# ===================== + + +def apply_rewrite(model): + # Create rewrite rules + conv_rule = pattern.RewriteRule( + conv_pattern, # target pattern + conv_replacement, # replacement pattern + ) + # Create a Rewrite Rule Set + rewrite_rule_set = pattern.RewriteRuleSet([conv_rule]) + # Apply rewrite + model_with_rewrite = onnxscript.rewriter.rewrite( + model, + pattern_rewrite_rules=rewrite_rule_set, + ) + return model_with_rewrite + + +_model_with_rewrite = apply_rewrite(_model) +onnx.checker.check_model(_model_with_rewrite) diff --git a/docs/tutorial/rewriter/examples/domain_option.py b/docs/tutorial/rewriter/examples/domain_option.py new file mode 100644 index 0000000000..7018c04719 --- /dev/null +++ b/docs/tutorial/rewriter/examples/domain_option.py @@ -0,0 +1,86 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""ONNX Pattern Rewriting with domain specification + +This script shows how to define a rewriting rule that targets operations +from specific domains and replaces them with operations in other domains. +""" + +import onnx + +import onnxscript +from onnxscript import script +from onnxscript.rewriter import pattern +from onnxscript.values import Opset + +# Create an opset for the custom domain +opset = Opset("custom.domain", 1) + + +@script(opset) +def create_model_with_custom_domain(input: onnxscript.FLOAT[2, 2]) -> onnxscript.FLOAT[2, 2]: + """Create a model with a Relu operation in a custom domain.""" + return opset.Relu(input) + + +_model = create_model_with_custom_domain.to_model_proto() +_model = onnx.shape_inference.infer_shapes(_model) +onnx.checker.check_model(_model) + + +#################################### +# The target pattern +# ===================== + + +def custom_relu_pattern(op, input): + # Pattern to match Relu operations from a specific domain + # _domain="custom.domain" specifies we only want to match operations from this domain + return op.Relu(input, _domain="custom.domain") + + +#################################### +# The replacement pattern +# ===================== + + +def standard_relu_replacement(op, input, **_): + # Replace with standard ONNX Relu (default domain) + return op.Relu(input) + + +#################################### +# Alternative: Replace with operation in different domain +# ===================== + + +def microsoft_relu_replacement(op, input, **_): + # Replace with operation in Microsoft's domain + return op.OptimizedRelu(input, _domain="com.microsoft") + + +#################################### +# Create Rewrite Rule and Apply to Model +# ===================== + + +def apply_rewrite(model): + # Create rewrite rules + relu_rule = pattern.RewriteRule( + custom_relu_pattern, # target pattern - matches custom domain operations + standard_relu_replacement, # replacement pattern - uses standard domain + ) + # Create a Rewrite Rule Set + rewrite_rule_set = pattern.RewriteRuleSet([relu_rule]) + # Apply rewrite + model_with_rewrite = onnxscript.rewriter.rewrite( + model, + pattern_rewrite_rules=rewrite_rule_set, + ) + return model_with_rewrite + + +# The rewrite rule will now match the Relu operation in the custom domain +# and replace it with a standard ONNX Relu operation +_model_with_rewrite = apply_rewrite(_model) +onnx.checker.check_model(_model_with_rewrite) diff --git a/docs/tutorial/rewriter/examples/outputs_option.py b/docs/tutorial/rewriter/examples/outputs_option.py new file mode 100644 index 0000000000..88483385dc --- /dev/null +++ b/docs/tutorial/rewriter/examples/outputs_option.py @@ -0,0 +1,76 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""ONNX Pattern Rewriting with output specification + +This script shows how to define a rewriting rule that specifies +the number and names of outputs from operations. +""" + +import onnx + +import onnxscript +from onnxscript import FLOAT, opset18, script +from onnxscript.rewriter import pattern + + +@script() +def original_model(A: FLOAT[4, 4]) -> FLOAT[2, 4]: + # Split operation that produces 2 outputs + result1, _result2 = opset18.Split(A, num_outputs=2, axis=0) + # We only return the first output for simplicity + return result1 + + +_model = original_model.to_model_proto() +onnx.checker.check_model(_model) + + +#################################### +# The target pattern with multiple outputs +# ===================== + + +def split_pattern(op, input): + # Pattern to match Split operations with 2 outputs + # num_outputs=2 corresponds to the attribute of the ONNX Split op + # _outputs=2 is an option controlling the pattern constructor + return op.Split(input, num_outputs=2, axis=0, _outputs=2) + + +#################################### +# The replacement pattern with named outputs +# ===================== + + +def custom_split_replacement(op, input, **_): + # Replace with a custom split operation using named outputs + # _outputs=["first_half", "second_half"] assigns names to the outputs + # IMPORTANT: The number of outputs must match the pattern (2 outputs) + return op.CustomSplit( + input, _domain="custom.domain", _outputs=["first_half", "second_half"] + ) + + +#################################### +# Create Rewrite Rule and Apply to Model +# ===================== + + +def apply_rewrite(model): + # Create rewrite rules + split_rule = pattern.RewriteRule( + split_pattern, # target pattern - matches Split with 2 outputs + custom_split_replacement, # replacement pattern - uses named outputs + ) + # Create a Rewrite Rule Set + rewrite_rule_set = pattern.RewriteRuleSet([split_rule]) + # Apply rewrite + model_with_rewrite = onnxscript.rewriter.rewrite( + model, + pattern_rewrite_rules=rewrite_rule_set, + ) + return model_with_rewrite + + +_model_with_rewrite = apply_rewrite(_model) +onnx.checker.check_model(_model_with_rewrite) diff --git a/docs/tutorial/rewriter/outputs_option.md b/docs/tutorial/rewriter/outputs_option.md new file mode 100644 index 0000000000..cc73bcc561 --- /dev/null +++ b/docs/tutorial/rewriter/outputs_option.md @@ -0,0 +1,43 @@ +# Specifying outputs in the pattern + +This section demonstrates the use of the `_outputs` option in pattern-based rewriting. +The `_outputs` option allows you to specify the number of outputs an operation produces +and optionally assign names to those outputs for easier reference in replacement patterns. + +The `_outputs` option can be specified in two ways: +- As an integer: `_outputs=2` specifies that the operation produces 2 unnamed outputs +- As a list of strings/None: `_outputs=["first", "second"]` specifies 2 named outputs + +## Matching operations with multiple outputs + +```{literalinclude} examples/outputs_option.py +:pyobject: split_pattern +``` + +This pattern matches `Split` operations that produce exactly 2 outputs. The `_outputs=2` +specification ensures the pattern only matches operations with this specific output count. + +## Creating replacement operations with named outputs + +```{literalinclude} examples/outputs_option.py +:pyobject: custom_split_replacement +``` + +In the replacement, `_outputs=["first_half", "second_half"]` creates two outputs with +descriptive names. This can make the replacement pattern more readable and maintainable. + +**Important**: The number of outputs in the replacement pattern must match the number of +outputs in the target pattern. Since the pattern specifies `_outputs=2`, the replacement +must also produce exactly 2 outputs. + +## Complete rewrite example + +```{literalinclude} examples/outputs_option.py +:pyobject: apply_rewrite +``` + +The `_outputs` option is particularly important when: +- Working with operations that have variable numbers of outputs (like `Split`) +- Creating custom operations that need specific output configurations +- Ensuring pattern matching precision by specifying exact output counts +- Improving code readability by naming outputs in replacement patterns diff --git a/docs/tutorial/rewriter/rewrite_patterns.md b/docs/tutorial/rewriter/rewrite_patterns.md index d4556fe871..50615945d1 100644 --- a/docs/tutorial/rewriter/rewrite_patterns.md +++ b/docs/tutorial/rewriter/rewrite_patterns.md @@ -10,12 +10,32 @@ There are three main components needed when rewriting patterns in the graph: 2. `replacement_pattern` : Pattern to replace the original pattern with. This pattern is also written as a function using ONNXScript-like operators. 3. `match_condition` (optional) : Pattern rewrite will occur only if the match condition is satisfied. +## Pattern Options + +When defining patterns, you can use several special options to control how patterns match and what they produce: + +- `_allow_other_attributes`: Controls whether the pattern allows additional attributes not specified in the pattern (default: True) +- `_allow_other_inputs`: Controls whether the pattern allows additional inputs beyond those specified (default: False) +- `_domain`: Specifies the operator domain for matching or creating operations +- `_outputs`: Specifies the number and optionally names of outputs from an operation + +These options are documented in detail in the following sections. + ```{include} simple_example.md ``` ```{include} attributes.md ``` +```{include} allow_other_inputs.md +``` + +```{include} domain_option.md +``` + +```{include} outputs_option.md +``` + ```{include} conditional_rewrite.md ``` From 3fb87c07882519ce6a863ab6769586a12eb802b1 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 30 Jul 2025 17:49:48 -0700 Subject: [PATCH 18/26] Update requirements-ort-nightly.txt (#2471) --- requirements/ci/requirements-ort-nightly.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/ci/requirements-ort-nightly.txt b/requirements/ci/requirements-ort-nightly.txt index 918fd21118..4ed908b4e2 100644 --- a/requirements/ci/requirements-ort-nightly.txt +++ b/requirements/ci/requirements-ort-nightly.txt @@ -1,3 +1,3 @@ # https://aiinfra.visualstudio.com/PublicPackages/_artifacts/feed/ORT-Nightly/PyPI/onnxruntime/overview --index-url=https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ -onnxruntime==1.22.0.dev20250402004 +onnxruntime==1.23.0.dev20250517001 From 127aee81694c5d28b62bf1be71e0a154e7c81b6a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 30 Jul 2025 18:44:45 -0700 Subject: [PATCH 19/26] Fix logic for converting np array to text (#2470) In onnx2script, nan, inf etc. were converted to plain text, which causes evaluation to fail because they don't exist in the script. I updated the logic to replace them with np. values. --------- Signed-off-by: Justin Chu --- onnxscript/backend/onnx_export.py | 16 ++++++++-------- onnxscript/backend/onnx_export_test.py | 13 ++++--------- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/onnxscript/backend/onnx_export.py b/onnxscript/backend/onnx_export.py index 1b79998e12..c6b6abb56e 100644 --- a/onnxscript/backend/onnx_export.py +++ b/onnxscript/backend/onnx_export.py @@ -4,7 +4,7 @@ from typing import Any, Optional, Sequence -import numpy +import numpy as np import onnx from onnx import FunctionProto, GraphProto, ModelProto, TensorProto, ValueInfoProto @@ -384,17 +384,17 @@ def _translate_attributes(self, node): if isinstance(value, str): attributes.append((at.name, f"{value!r}")) continue - if isinstance(value, numpy.ndarray): + if isinstance(value, np.ndarray): onnx_dtype = at.t.data_type if len(value.shape) == 0: text = ( f'make_tensor("value", {onnx_dtype}, dims=[], ' - f"vals=[{value.tolist()!r}])" + f"vals=[{repr(value.tolist()).replace('nan', 'np.nan').replace('inf', 'np.inf')}])" ) else: text = ( f'make_tensor("value", {onnx_dtype}, dims={list(value.shape)!r}, ' - f"vals={value.ravel().tolist()!r})" + f"vals={repr(value.ravel().tolist()).replace('nan', 'np.nan').replace('inf', 'np.inf')})" ) attributes.append((at.name, text)) continue @@ -738,7 +738,7 @@ def generate_rand(name: str, value: TensorProto) -> str: raise NotImplementedError( f"Unable to generate random initializer for data type {value.data_type}." ) - return f"{__}{name} = numpy.random.rand({shape}).astype(numpy.float32)" + return f"{__}{name} = np.random.rand({shape}).astype(np.float32)" random_initializer_values = "\n".join( generate_rand(key, value) for key, value in self.skipped_initializers.items() @@ -793,7 +793,7 @@ def add(line: str) -> None: result.append(line) # Generic imports. - add("import numpy") + add("import numpy as np") add("from onnx import TensorProto") add("from onnx.helper import make_tensor") add("from onnxscript import script, external_tensor") @@ -873,11 +873,11 @@ def export2python( .. runpython:: :showcode: :process: - import numpy + import numpy as np from sklearn.cluster import KMeans from mlprodict.onnx_conv import to_onnx from mlprodict.onnx_tools.onnx_export import export2python - X = numpy.arange(20).reshape(10, 2).astype(numpy.float32) + X = np.arange(20).reshape(10, 2).astype(np.float32) tr = KMeans(n_clusters=2) tr.fit(X) onx = to_onnx(tr, X, target_opset=14) diff --git a/onnxscript/backend/onnx_export_test.py b/onnxscript/backend/onnx_export_test.py index 1d05428a2c..bee20b47ba 100644 --- a/onnxscript/backend/onnx_export_test.py +++ b/onnxscript/backend/onnx_export_test.py @@ -45,14 +45,8 @@ def skip(pattern: str | Pattern, reason: str, *, condition: bool = True): SKIP_TESTS = ( - skip( - r"^test_ai_onnx_ml_array_feature_extractor", - "ImportError: cannot import name 'opset' from 'onnxscript.onnx_opset'", - ), - skip( - r"^test_ai_onnx_ml_binarizer", - "ImportError: cannot import name 'opset' from 'onnxscript.onnx_opset'", - ), + skip(r"^test_ai_onnx_ml_array_feature_extractor", "ORT doesn't support this op"), + skip(r"^test_ai_onnx_ml_binarizer", "ORT doesn't support this op"), skip(r"^test_center_crop_pad_crop_negative_axes_hwc", "fixme: ORT segfaults"), skip(r"_scan_", "Operator Scan is not supported by onnxscript"), skip(r"^test_scan", "Operator Scan is not supported by onnxscript"), @@ -89,6 +83,7 @@ def skip(pattern: str | Pattern, reason: str, *, condition: bool = True): "Change when the converter supports support something like 'while i < n and cond:'", ), skip(r"^test_ai_onnx_ml_label_encoder", "ONNX Runtime does not support Opset 21 at 1.17"), + skip(r"^test_ai_onnx_ml_tree_ensemble", "Opset 23 is not supported"), ) if sys.platform == "win32": @@ -160,7 +155,7 @@ class TestOnnxBackEnd(unittest.TestCase): test_folder = root_folder / "tests" / "onnx_backend_test_code" temp_folder = root_folder / "tests" / "export" - def _proto_to_os_and_back(self, proto: onnxscript.FunctionProto, **export_options): + def _proto_to_os_and_back(self, proto: onnx.FunctionProto, **export_options): """Convert a proto to onnxscript code and convert it back to a proto.""" code = onnx_export.export2python(proto, **export_options) map = extract_functions(proto.name, code, TestOnnxBackEnd.temp_folder) From 131e4970965d01e7805999f49e55ad5ee5f270fb Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 31 Jul 2025 10:10:56 -0700 Subject: [PATCH 20/26] [torchlib] Improves aten_chunk conversion (#2469) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Simplify implementation for `aten_chunk` and allow it to work on all data types. Original author: @xadupre Updated: Conditionally use the new implementation when torch>=2.7 --------- Signed-off-by: Justin Chu Co-authored-by: Xavier Dupré --- .../function_libs/torch_lib/ops/core.py | 58 +++++++++++-------- tests/function_libs/torch_lib/ops_test.py | 1 - .../function_libs/torch_lib/ops_test_data.py | 15 +---- 3 files changed, 38 insertions(+), 36 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 92b8abb36d..595f4a758a 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -36,6 +36,7 @@ graph, ir, ) +from onnxscript._internal import version_utils from onnxscript.function_libs.torch_lib.ops import common as common_ops from onnxscript.function_libs.torch_lib.registration import torch_op from onnxscript.function_libs.torch_lib.tensor_typing import ( @@ -1647,29 +1648,40 @@ def aten_choose_qparams_optimized( raise NotImplementedError() -@torch_op("aten::chunk") -def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]: - """chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]""" - # This will create a Sequence of tensors - neg_1 = op.Constant(value_ints=[-1]) - # Get size of specified dim - self_shape = op.Shape(self) - dim_size = op.Gather(self_shape, dim, axis=0) - # Compute size/chunk to get the number of data in one chunk - num_per_chunk = op.Div(dim_size, chunks) - num_per_chunk = op.Cast(op.Mod(dim_size, chunks) > 0, to=INT64.dtype) + num_per_chunk # type: ignore[operator] - - # Compute real chunk number - num_chunk = op.Div(dim_size, num_per_chunk) - # Get something like [n, n, n, n, ...], total num_chunk - list_split = op.Expand(num_per_chunk, op.Reshape(num_chunk, neg_1)) - - remainder = op.Mod(dim_size, num_per_chunk) - if remainder > 0: # type: ignore[operator] - # Append the remainder to the [n, n, n, n, ..., r] - list_split = op.Concat(list_split, op.Reshape(remainder, neg_1), axis=0) - - return op.SplitToSequence(self, list_split, axis=dim) +if version_utils.torch_older_than("2.7.0"): + # PyTorch <2.7 does not support determining the number of outputs for the Split op + # https://github.com/pytorch/pytorch/commit/9a1eac6704671c72a2e85c9138db57eb3a80bfb6 + @torch_op("aten::chunk") + def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]: + """chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]""" + # This will create a Sequence of tensors + neg_1 = op.Constant(value_ints=[-1]) + # Get size of specified dim + self_shape = op.Shape(self) + dim_size = op.Gather(self_shape, dim, axis=0) + # Compute size/chunk to get the number of data in one chunk + num_per_chunk = op.Div(dim_size, chunks) + num_per_chunk = op.Cast(op.Mod(dim_size, chunks) > 0, to=INT64.dtype) + num_per_chunk # type: ignore[operator] + + # Compute real chunk number + num_chunk = op.Div(dim_size, num_per_chunk) + # Get something like [n, n, n, n, ...], total num_chunk + list_split = op.Expand(num_per_chunk, op.Reshape(num_chunk, neg_1)) + + remainder = op.Mod(dim_size, num_per_chunk) + if remainder > 0: # type: ignore[operator] + # Append the remainder to the [n, n, n, n, ..., r] + list_split = op.Concat(list_split, op.Reshape(remainder, neg_1), axis=0) + + return op.SplitToSequence(self, list_split, axis=dim) +else: + + @torch_op("aten::chunk", trace_only=True) + def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]: + """chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]""" + if chunks == 1: + return op.Identity(self) + return op.Split(self, axis=dim, num_outputs=chunks) @torch_op("aten::clamp", trace_only=True) diff --git a/tests/function_libs/torch_lib/ops_test.py b/tests/function_libs/torch_lib/ops_test.py index 59e6c98c9f..7ba6f9d37f 100644 --- a/tests/function_libs/torch_lib/ops_test.py +++ b/tests/function_libs/torch_lib/ops_test.py @@ -200,7 +200,6 @@ def run_test_output_match( reference_torch_outputs, _ = pytree.tree_flatten(torch_output) if ( op.name.startswith("split") - or op.name.startswith("chunk") or op.name.startswith("unbind") or op.name in {"atleast_1d_Sequence", "atleast_2d_Sequence", "atleast_3d_Sequence"} diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 73ea68116c..cd2d933309 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -694,18 +694,9 @@ def _where_input_wrangler( reason="fixme: ORT aborts with zero-dim tensors. https://github.com/microsoft/onnxruntime/issues/16619", ), TorchLibOpInfo("ceil", core_ops.aten_ceil), - TorchLibOpInfo( - "chunk", - core_ops.aten_chunk, - ) - .xfail( - dtypes=(torch.float16,), - enabled_if=version_utils.onnxruntime_older_than("1.17"), - reason="fixme: SplitToSequence op inference failed. https://github.com/microsoft/onnxruntime/issues/16006", - ) - .xfail( - dtypes=(torch.bool,), - reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", + TorchLibOpInfo("chunk", core_ops.aten_chunk).skip( + enabled_if=version_utils.torch_older_than("2.7"), + reason="Test for chunk is not configured for torch<2.7", ), TorchLibOpInfo("clamp_max", core_ops.aten_clamp_max_tensor).skip( reason="Size 0 inputs are not handled by design", From acdfd1bfe718400f5bafe4c8f075a02a1e6e023c Mon Sep 17 00:00:00 2001 From: Tommaso Date: Fri, 1 Aug 2025 16:33:04 +0000 Subject: [PATCH 21/26] Removed unecessary pattern variable --- onnxscript/rewriter/ort_fusions/gqa.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 4bc05f17de..04edb3c74e 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -394,8 +394,6 @@ def pattern( total_seq_len_for_query = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_for_query"]) total_seq_len_for_batch = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_for_batch"]) - #total_seq_len_final = op.Reshape(total_seq_len_0D, pattern.ANY_VALUE, allowzero=0, _outputs=["total_seq_len_final"]) - # BRANCH A: KV Range - Creates tensor with KV positions [1, 1, seq_len, 1] batch_size = op.Shape(past_kv_cache_2, end=1, start=0, _outputs=["batch_size"]) kv_mask_shape = op.Concat(batch_size, [1], seq_len, total_seq_len_for_kv, axis=0, _outputs=["kv_mask_shape"]) @@ -437,7 +435,7 @@ def pattern( inverted_mask = op.Not(final_attention_mask, _outputs=["inverted_mask"]) mask_fp32 = op.Cast(inverted_mask, to=ir.DataType.FLOAT, _outputs=["mask_fp32"]) scaled_mask = op.Mul(mask_fp32, pattern.ANY_VALUE) - + # Propagation to GQA sliced_mask = op.Slice(scaled_mask, [0], pattern.ANY_VALUE, [3], [1], _outputs=["sliced_mask"]) From 76624ad8925f84adf1d2d032c6d70b6dff992521 Mon Sep 17 00:00:00 2001 From: Tommaso Date: Fri, 1 Aug 2025 16:37:07 +0000 Subject: [PATCH 22/26] Added snake casing for variable names --- onnxscript/rewriter/ort_fusions/gqa.py | 32 +++++++++++++------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 04edb3c74e..2c704188ad 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -34,25 +34,25 @@ def causal_mask_pattern(op, input_ids, past_kv_cache, shape_B111): seq_len = op.Shape(input_ids, end=2, start=1) - seq_len_0D = op.Squeeze(seq_len) + seq_len_0d = op.Squeeze(seq_len) past_seq_len = op.Shape(past_kv_cache, end=3, start=2) - past_seq_len_0D = op.Squeeze(past_seq_len) + past_seq_len_0d = op.Squeeze(past_seq_len) - total_seq_len_0D = op.Add(past_seq_len_0D, seq_len_0D) - total_seq_len = op.Reshape(total_seq_len_0D, [-1]) + total_seq_len_0d = op.Add(past_seq_len_0d, seq_len_0d) + total_seq_len = op.Reshape(total_seq_len_0d, [-1]) # The Phi modeling code generates the following +1 as the target-length, which seems # unnecessary in this context. But using it for pattern-matching against # generated onnx model. - total_seq_len_plus_1_0D = op.Add(total_seq_len_0D, 1) - total_seq_len_plus_1 = op.Reshape(total_seq_len_plus_1_0D, [-1]) + total_seq_len_plus_1_0d = op.Add(total_seq_len_0d, 1) + total_seq_len_plus_1 = op.Reshape(total_seq_len_plus_1_0d, [-1]) - current_range = op.Range(past_seq_len_0D, total_seq_len_0D, 1) + current_range = op.Range(past_seq_len_0d, total_seq_len_0d, 1) mask_shape = op.Concat(seq_len, total_seq_len_plus_1, axis=0) min_float32 = float(np.finfo(np.float32).min) mask_all_min = op.Expand(min_float32, mask_shape) - total_range_as_row = op.Range(0, total_seq_len_plus_1_0D, 1) + total_range_as_row = op.Range(0, total_seq_len_plus_1_0d, 1) current_range_as_column = op.Reshape(current_range, [-1, 1]) boolean_mask = op.Greater(total_range_as_row, current_range_as_column) float_0_1_mask = op.Cast(boolean_mask, to=1) @@ -382,17 +382,17 @@ def pattern( expansion of the mask across the batch and sequence dimensions. """ seq_len = op.Shape(input_ids, end=2, start=1, _outputs=["seq_len"]) - seq_len_0D = op.Squeeze(seq_len, _outputs=["seq_len_0D"]) + seq_len_0d = op.Squeeze(seq_len, _outputs=["seq_len_0d"]) past_seq_len = op.Shape(past_kv_cache_1, end=3, start=2, _outputs=["past_seq_len"]) - past_seq_len_0D = op.Squeeze(past_seq_len, _outputs=["past_seq_len_0D"]) - total_seq_len_0D = op.Add(past_seq_len_0D, seq_len_0D, _outputs=["total_seq_len_0D"]) + past_seq_len_0d = op.Squeeze(past_seq_len, _outputs=["past_seq_len_0d"]) + total_seq_len_0d = op.Add(past_seq_len_0d, seq_len_0d, _outputs=["total_seq_len_0d"]) # Create ranges for different dimensions - kv_range = op.Range(past_seq_len_0D, total_seq_len_0D, 1, _outputs=["kv_range"]) - total_seq_len_for_kv = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_for_kv"]) - query_range = op.Range(0, total_seq_len_0D, 1, _outputs=["query_range"]) - total_seq_len_for_query = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_for_query"]) - total_seq_len_for_batch = op.Reshape(total_seq_len_0D, [-1], allowzero=0, _outputs=["total_seq_len_for_batch"]) + kv_range = op.Range(past_seq_len_0d, total_seq_len_0d, 1, _outputs=["kv_range"]) + total_seq_len_for_kv = op.Reshape(total_seq_len_0d, [-1], allowzero=0, _outputs=["total_seq_len_for_kv"]) + query_range = op.Range(0, total_seq_len_0d, 1, _outputs=["query_range"]) + total_seq_len_for_query = op.Reshape(total_seq_len_0d, [-1], allowzero=0, _outputs=["total_seq_len_for_query"]) + total_seq_len_for_batch = op.Reshape(total_seq_len_0d, [-1], allowzero=0, _outputs=["total_seq_len_for_batch"]) # BRANCH A: KV Range - Creates tensor with KV positions [1, 1, seq_len, 1] batch_size = op.Shape(past_kv_cache_2, end=1, start=0, _outputs=["batch_size"]) From fbb191a4c396803d5864ee13fa6e6ab08d80f6e1 Mon Sep 17 00:00:00 2001 From: Tommaso Date: Fri, 1 Aug 2025 16:37:36 +0000 Subject: [PATCH 23/26] Added more snake casing and removed uneeded code --- onnxscript/rewriter/ort_fusions/gqa.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 2c704188ad..6a4069984c 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -286,11 +286,11 @@ def rewrite( ): # Construct total_seq_length_int32 and seqlens_k total_seq_length_int32 = op.Cast(total_seq_length, to=ir.DataType.INT32) - one_0D = op.Constant(value_int=1) - one_0D_int32 = op.Cast(one_0D, to=ir.DataType.INT32) - seqlens_k_0D = op.Sub(total_seq_length_int32, one_0D_int32) - zero_1D = op.Constant(value_int=0, dtype=ir.DataType.INT64, shape=[1]) - seqlens_k = op.Unsqueeze(seqlens_k_0D, zero_1D) + one_0d = op.Constant(value_int=1) + one_0d_int32 = op.Cast(one_0d, to=ir.DataType.INT32) + seqlens_k_0d = op.Sub(total_seq_length_int32, one_0d_int32) + zero_1d = op.Constant(value_int=0, dtype=ir.DataType.INT64, shape=[1]) + seqlens_k = op.Unsqueeze(seqlens_k_0d, zero_1d) gqa_node = attn_output.producer() assert len(gqa_node.inputs) == 12, ( @@ -314,15 +314,6 @@ def rewrite( *updated_inputs, **attributes, _domain="com.microsoft", _outputs=3 ) - -_basic_gqa_rule = GroupQueryAttention.rule() -_gqa_causal_mask_rule = GQACausalMask.rule() - -gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule, _gqa_causal_mask_rule]) - -fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules) - - class LongRoPeGQACausalMask(pattern.RewriteRuleClassBase): """ LongRoPeGQACausalMask is a specialized version of GQACausalMask that handles From f295bc5dde501ea210c0c5060145d6c7f4b674cd Mon Sep 17 00:00:00 2001 From: Tommaso Date: Fri, 1 Aug 2025 17:06:06 +0000 Subject: [PATCH 24/26] Moved get_mask_key method to module level and used IR value directly --- onnxscript/rewriter/ort_fusions/gqa.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 6a4069984c..90db62d24b 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -31,7 +31,6 @@ Dim = Union[int, ir.SymbolicDim] - def causal_mask_pattern(op, input_ids, past_kv_cache, shape_B111): seq_len = op.Shape(input_ids, end=2, start=1) seq_len_0d = op.Squeeze(seq_len) @@ -314,6 +313,13 @@ def rewrite( *updated_inputs, **attributes, _domain="com.microsoft", _outputs=3 ) +def _get_mask_key(attention_mask): + """ + Generate a unique key for the mask based on input_ids and past_kv_cache. + This is used to cache the mask to avoid recomputation. + """ + return attention_mask + class LongRoPeGQACausalMask(pattern.RewriteRuleClassBase): """ LongRoPeGQACausalMask is a specialized version of GQACausalMask that handles @@ -325,19 +331,12 @@ def __init__(self): super().__init__("LongRoPeGQACausalMask", remove_nodes=False) self._mask_cache = {} - def _get_mask_key(self, attention_mask): - """ - Generate a unique key for the mask based on input_ids and past_kv_cache. - This is used to cache the mask to avoid recomputation. - """ - return (id(attention_mask)) - def compute_mask(self, op, attention_mask): """ Computes the total_seq_length_int32 and seqlens_k_int32 based on the attention_mask, caching results to avoid recomputation at each layer. """ - mask_key = self._get_mask_key(attention_mask) + mask_key = _get_mask_key(attention_mask) if mask_key in self._mask_cache: total_seq_length_int32, seqlens_k_int32 = self._mask_cache[mask_key] From 0334bb1de5a10ccaf752c4aad9a56b5e9ccf8b32 Mon Sep 17 00:00:00 2001 From: Tommaso Date: Fri, 1 Aug 2025 17:09:32 +0000 Subject: [PATCH 25/26] Added cleanup method for the attention mask cache --- onnxscript/rewriter/ort_fusions/gqa.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/onnxscript/rewriter/ort_fusions/gqa.py b/onnxscript/rewriter/ort_fusions/gqa.py index 90db62d24b..c25ee56799 100644 --- a/onnxscript/rewriter/ort_fusions/gqa.py +++ b/onnxscript/rewriter/ort_fusions/gqa.py @@ -331,6 +331,9 @@ def __init__(self): super().__init__("LongRoPeGQACausalMask", remove_nodes=False) self._mask_cache = {} + def cleanup(self): + self._mask_cache.clear() + def compute_mask(self, op, attention_mask): """ Computes the total_seq_length_int32 and seqlens_k_int32 based on the attention_mask, From 74e8e246bdcdb6a7497d59eccf6d5857d60a84f2 Mon Sep 17 00:00:00 2001 From: Tommaso Date: Fri, 1 Aug 2025 17:25:53 +0000 Subject: [PATCH 26/26] Added LongRoPE GQA Causal Mask Fusion Separately --- .../rewriter/ort_fusions/longrope_gqa.py | 485 ++++++++++++++++++ 1 file changed, 485 insertions(+) create mode 100644 onnxscript/rewriter/ort_fusions/longrope_gqa.py diff --git a/onnxscript/rewriter/ort_fusions/longrope_gqa.py b/onnxscript/rewriter/ort_fusions/longrope_gqa.py new file mode 100644 index 0000000000..effc023291 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/longrope_gqa.py @@ -0,0 +1,485 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from typing import Sequence, Union + +import numpy as np +import onnx_ir as ir + +import onnxscript.rewriter._fusion_utils as _fusion_utils +from onnxscript.rewriter import _ir_utils, pattern + +""" +GroupQueryAttention: This generalizes MHA by allowing the number of heads to be different +for query and key/value. + +We use the following abbreviations for the dimensions: +B: Batch size +S: Sequence length (for current query/key/value) + +Hkv: number of heads for key/value +G = number of groups +H: number of heads = G * Hkv + +Dh: head size or embedding dimension per head +D: input embedding dimension (hidden size) = H * Dh +Dkv: key/value hidden size = Hkv * Dh + +T: total sequence length (after concatenation of past and current key/value) +""" + +Dim = Union[int, ir.SymbolicDim] + +def causal_mask_pattern(op, input_ids, past_kv_cache, shape_B111): + seq_len = op.Shape(input_ids, end=2, start=1) + seq_len_0d = op.Squeeze(seq_len) + + past_seq_len = op.Shape(past_kv_cache, end=3, start=2) + past_seq_len_0d = op.Squeeze(past_seq_len) + + total_seq_len_0d = op.Add(past_seq_len_0d, seq_len_0d) + total_seq_len = op.Reshape(total_seq_len_0d, [-1]) + + # The Phi modeling code generates the following +1 as the target-length, which seems + # unnecessary in this context. But using it for pattern-matching against + # generated onnx model. + total_seq_len_plus_1_0d = op.Add(total_seq_len_0d, 1) + total_seq_len_plus_1 = op.Reshape(total_seq_len_plus_1_0d, [-1]) + + current_range = op.Range(past_seq_len_0d, total_seq_len_0d, 1) + mask_shape = op.Concat(seq_len, total_seq_len_plus_1, axis=0) + min_float32 = float(np.finfo(np.float32).min) + mask_all_min = op.Expand(min_float32, mask_shape) + total_range_as_row = op.Range(0, total_seq_len_plus_1_0d, 1) + current_range_as_column = op.Reshape(current_range, [-1, 1]) + boolean_mask = op.Greater(total_range_as_row, current_range_as_column) + float_0_1_mask = op.Cast(boolean_mask, to=1) + float_0_min_mask = op.Mul(mask_all_min, float_0_1_mask) + mask_4d = op.Unsqueeze(float_0_min_mask, [0, 1]) + mask_B1ST_plus = op.Expand(mask_4d, shape_B111) + + # Get rid of the extra +1 added above: total_seq_len is enough, no + # need for total_seq_len+1. + mask_B1ST = op.Slice(mask_B1ST_plus, [0], total_seq_len, [3], [1]) + return mask_B1ST + + +class GroupQueryAttention(pattern.RewriteRuleClassBase): + def __init__(self): + super().__init__("GQA", remove_nodes=False) + + def pattern( + self, + op, + query_BSD, + key_BSDkv, + value_BSDkv, + past_key, + past_value, + position_ids_q, + position_ids_k, + cos, + sin, + mask, + ): + # Reshape query from (B, S, D) to (B, S, H, D/H) + query_BSHDh = op.Reshape(query_BSD, pattern.ANY_VALUE, _outputs=["query_BSHDh"]) + # Transpose from (B, S, H, D/H) to (B, H, S, D/H) + query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3]) + + # Reshape key from (B, S, Dkv) to (B, S, Hkv, D/H) + key_BSHkvDh = op.Reshape(key_BSDkv, pattern.ANY_VALUE, _outputs=["key_BSHkvDh"]) + # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H) + key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3]) + + # Reshape value from (B, S, Dkv) to (B, S, Hkv, D/H) + value_BSHkvDh = op.Reshape(value_BSDkv, pattern.ANY_VALUE, _outputs=["value_BSHkvDh"]) + # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H) + value_BHkvSDh = op.Transpose(value_BSHkvDh, perm=[0, 2, 1, 3]) + + query_BHSDh_rope = op.RotaryEmbedding( + query_BHSDh, + position_ids_q, + cos, + sin, + _domain="com.microsoft", + _outputs=["query_BHSDh_rope"], + ) + key_BHkvSDh_rope = op.RotaryEmbedding( + key_BHkvSDh, + position_ids_k, + cos, + sin, + _domain="com.microsoft", + _outputs=["key_BHkvSDh_rope"], + ) + + # Concatenate past_key cache and current key, expand across heads + # that share key/value. + + key_seq_BHkvTDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2) + key_seq_BHkv1TDh = op.Unsqueeze(key_seq_BHkvTDh, 2) + key_seq_BHkvGTDh = op.Expand(key_seq_BHkv1TDh, pattern.ANY_VALUE) + key_seq_BHTDh = op.Reshape( + key_seq_BHkvGTDh, pattern.ANY_VALUE, _outputs=["key_seq_BHTDh"] + ) + + # Concatenate past_value cache and current value, expand across heads + # that share key/value. + value_seq_BHkvTDh = op.Concat(past_value, value_BHkvSDh, axis=-2) + value_seq_BHkv1TDh = op.Unsqueeze(value_seq_BHkvTDh, 2) + value_seq_BHkvGTDh = op.Expand(value_seq_BHkv1TDh, pattern.ANY_VALUE) + value_seq_BHTDh = op.Reshape( + value_seq_BHkvGTDh, pattern.ANY_VALUE, _outputs=["value_seq_BHTDh"] + ) + + attention_BHSDh = op.SDPA( + query_BHSDh_rope, + key_seq_BHTDh, + value_seq_BHTDh, + mask, + key_format="BHSd", + _domain="ai.onnxruntime._fusion", + ) + + # Transpose attention back to (B, S, H, D/H) + attention_BSHDh = op.Transpose(attention_BHSDh, perm=[0, 2, 1, 3]) + # Reshape back to (B, S, D) + attention_BSD = op.Reshape( + attention_BSHDh, pattern.ANY_VALUE, _outputs=["attention_BSD"] + ) + return attention_BSD, key_seq_BHkvTDh, value_seq_BHkvTDh + + def check( + self, + op, + query_BSD, + key_BSDkv, + value_BSDkv, + past_key, + past_value, + query_BHSDh_rope, + key_BHkvSDh_rope, + query_BSHDh, + key_BSHkvDh, + **_, + ): + bindings: dict[str, Dim] = {} + + def no_match(val: ir.Value, dims: Sequence[str]) -> bool: + return not _fusion_utils._check_shape(bindings, val, dims) + + if no_match(query_BSD, ["B", "S", "D"]): + return False + if no_match(key_BSDkv, ["B", "S", "Dkv"]): + return False + if no_match(value_BSDkv, ["B", "S", "Dkv"]): + return False + + if no_match(past_key, ["B", "Hkv", "P", "Dh"]): + return False + if no_match(past_value, ["B", "Hkv", "P", "Dv"]): + return False + + # TODO: verify Reshapes: + # eg.: verify bindings["B"] * bindings["H"] == bindings["B*H"]: + # and bindings["H"] * bindings["Dh"] == bindings["H*Dh"]: + # or check Reshape's shape-input value + + result = pattern.MatchResult() + num_heads = _ir_utils.get_dim(query_BSHDh, 2) + kv_num_heads = _ir_utils.get_dim(key_BSHkvDh, 2) + if not isinstance(num_heads, int): + return result.fail("Unable to determine num_heads value", query_BSHDh) + if not isinstance(kv_num_heads, int): + return result.fail("Unable to determine kv_num_heads value", key_BSHkvDh) + self.num_heads = num_heads + self.kv_num_heads = kv_num_heads + + # Rotary embedding attributes + query_rotary_attributes = query_BHSDh_rope.producer().attributes + key_rotary_attributes = key_BHkvSDh_rope.producer().attributes + query_interleaved = query_rotary_attributes.get_int("interleaved", 0) + key_interleaved = key_rotary_attributes.get_int("interleaved", 0) + if query_interleaved != key_interleaved: + return pattern.MatchResult().fail( + "Rotary embedding interleaved attribute mismatch", + [query_BHSDh_rope.producer(), key_BHkvSDh_rope.producer()], + ) + self._interleaved = query_interleaved + + return True + + def rewrite( + self, + op, + query_BSD, + key_BSDkv, + value_BSDkv, + past_key, + past_value, + position_ids_q, + position_ids_k, + cos, + sin, + mask, + **_, + ): + return op.GQA( + mask, + position_ids_k, + position_ids_q, + query_BSD, + key_BSDkv, + value_BSDkv, + past_key, + past_value, + None, # seqlens_k, + None, # total_seq_length_int32, + cos, + sin, + num_heads=self.num_heads, + kv_num_heads=self.kv_num_heads, + do_rotary=1, + rotary_interleaved=self._interleaved, + # skipped optional attributes: local_window_size, scale, smooth_softmax, softcap + _domain="ai.onnxruntime._fusion", + _outputs=3, + ) + + +class GQACausalMask(pattern.RewriteRuleClassBase): + def __init__(self): + super().__init__("GQACausalMask", remove_nodes=False) + + def pattern( + self, + op, + mask, + input_ids, + some_kv_cache, + shape_B111, + past_seq_length, + total_seq_length, + ): + mask = causal_mask_pattern(op, input_ids, some_kv_cache, shape_B111) + position_ids = op.Range(past_seq_length, total_seq_length, 1) + position_ids_q = op.Unsqueeze(position_ids, [0]) + position_ids_k = op.Unsqueeze(position_ids, [0]) + return op.GQA( + mask, + position_ids_k, + position_ids_q, + _allow_other_inputs=True, + _domain="ai.onnxruntime._fusion", + _outputs=["attn_output", "key_seq", "value_seq"], + ) + + def rewrite( + self, + op, + total_seq_length, + attn_output, + **_, + ): + # Construct total_seq_length_int32 and seqlens_k + total_seq_length_int32 = op.Cast(total_seq_length, to=ir.DataType.INT32) + one_0d = op.Constant(value_int=1) + one_0d_int32 = op.Cast(one_0d, to=ir.DataType.INT32) + seqlens_k_0d = op.Sub(total_seq_length_int32, one_0d_int32) + zero_1d = op.Constant(value_int=0, dtype=ir.DataType.INT64, shape=[1]) + seqlens_k = op.Unsqueeze(seqlens_k_0d, zero_1d) + + gqa_node = attn_output.producer() + assert len(gqa_node.inputs) == 12, ( + f"Expected 12 inputs for GQA node, got {len(gqa_node.inputs)}" + ) + query, key, value, past_key, past_value = gqa_node.inputs[3:8] + cos, sin = gqa_node.inputs[10:12] + updated_inputs = [ + query, + key, + value, + past_key, + past_value, + seqlens_k, + total_seq_length_int32, + cos, + sin, + ] + attributes = gqa_node.attributes + return op.GroupQueryAttention( + *updated_inputs, **attributes, _domain="com.microsoft", _outputs=3 + ) + +def _get_mask_key(attention_mask): + """ + Generate a unique key for the mask based on input_ids and past_kv_cache. + This is used to cache the mask to avoid recomputation. + """ + return attention_mask + +class LongRoPeGQACausalMask(pattern.RewriteRuleClassBase): + """ + LongRoPeGQACausalMask is a specialized version of GQACausalMask that handles + the LongRoPe GQA fusion. It computes the causal mask for Group Query Attention + with LongRoPe (Long Range Rotary Position Embedding) and caches the mask to + avoid recomputation at each layer. + """ + def __init__(self): + super().__init__("LongRoPeGQACausalMask", remove_nodes=False) + self._mask_cache = {} + + def cleanup(self): + self._mask_cache.clear() + + def compute_mask(self, op, attention_mask): + """ + Computes the total_seq_length_int32 and seqlens_k_int32 based on the attention_mask, + caching results to avoid recomputation at each layer. + """ + mask_key = _get_mask_key(attention_mask) + + if mask_key in self._mask_cache: + total_seq_length_int32, seqlens_k_int32 = self._mask_cache[mask_key] + + else: + # Construct total_seq_length_int32 and seqlens_k + attention_shape = op.Shape(attention_mask, _outputs=["seq_len"]) + total_seq_length = op.Gather(attention_shape, op.Constant(value=ir.tensor(1, ir.DataType.INT64)), axis=0, _outputs=["total_seq_length"]) + reduced_attention = op.ReduceSum(attention_mask, op.Constant(value=ir.tensor([1], ir.DataType.INT64)), _outputs=["reduced_attention"]) + sub_reduced_attention = op.Sub(reduced_attention, op.Constant(value=ir.tensor([1], ir.DataType.INT64)), _outputs=["sub_reduced_attention"]) + total_seq_length_int32 = op.Cast(total_seq_length, to=ir.DataType.INT32, _outputs=["total_seq_length_int32"]) + seqlens_k_int32 = op.Cast(sub_reduced_attention, to=ir.DataType.INT32, _outputs=["seqlens_k_int32"]) + self._mask_cache[mask_key] = (total_seq_length_int32, seqlens_k_int32) + + return self._mask_cache[mask_key] + + + def pattern( + self, + op, + input_ids, + past_kv_cache_1, + past_kv_cache_2, + attention_mask, + ): + """ + Pattern for LongRoPe GQA Causal Mask. + This pattern computes the causal mask for Group Query Attention with LongRoPe. + It constructs the mask based on input_ids and past_kv_cache, and handles the + expansion of the mask across the batch and sequence dimensions. + """ + seq_len = op.Shape(input_ids, end=2, start=1, _outputs=["seq_len"]) + seq_len_0d = op.Squeeze(seq_len, _outputs=["seq_len_0d"]) + past_seq_len = op.Shape(past_kv_cache_1, end=3, start=2, _outputs=["past_seq_len"]) + past_seq_len_0d = op.Squeeze(past_seq_len, _outputs=["past_seq_len_0d"]) + total_seq_len_0d = op.Add(past_seq_len_0d, seq_len_0d, _outputs=["total_seq_len_0d"]) + + # Create ranges for different dimensions + kv_range = op.Range(past_seq_len_0d, total_seq_len_0d, 1, _outputs=["kv_range"]) + total_seq_len_for_kv = op.Reshape(total_seq_len_0d, [-1], allowzero=0, _outputs=["total_seq_len_for_kv"]) + query_range = op.Range(0, total_seq_len_0d, 1, _outputs=["query_range"]) + total_seq_len_for_query = op.Reshape(total_seq_len_0d, [-1], allowzero=0, _outputs=["total_seq_len_for_query"]) + total_seq_len_for_batch = op.Reshape(total_seq_len_0d, [-1], allowzero=0, _outputs=["total_seq_len_for_batch"]) + + # BRANCH A: KV Range - Creates tensor with KV positions [1, 1, seq_len, 1] + batch_size = op.Shape(past_kv_cache_2, end=1, start=0, _outputs=["batch_size"]) + kv_mask_shape = op.Concat(batch_size, [1], seq_len, total_seq_len_for_kv, axis=0, _outputs=["kv_mask_shape"]) + kv_mask_shape_abs = op.Abs(kv_mask_shape, _outputs=["kv_mask_shape_abs"]) + reshaped_kv_range = op.Reshape(kv_range, [1, 1, -1, 1], allowzero=1, _outputs=["reshaped_kv_range"]) + expanded_kv_range = op.Expand(reshaped_kv_range, kv_mask_shape_abs, _outputs=["expanded_kv_range"]) + + # BRANCH B: Query Range - Creates tensor with query positions [1, 1, 1, total_seq_len] + query_mask_shape = op.Concat(batch_size, [1], seq_len, total_seq_len_for_query, axis=0, _outputs=["query_mask_shape"]) + query_mask_shape_abs = op.Abs(query_mask_shape, _outputs=["query_mask_shape_abs"]) + reshaped_query_range = op.Reshape(query_range, [1, 1, 1, -1], allowzero=1, _outputs=["reshaped_query_range"]) + expanded_query_range = op.Expand(reshaped_query_range, query_mask_shape_abs, _outputs=["expanded_query_range"]) + + # BRANCH C: Batch Range - Creates tensor with batch indices [batch_size, 1, 1, 1] + batch_mask_shape = op.Concat(batch_size, [1], seq_len, total_seq_len_for_batch, axis=0, _outputs=["batch_mask_shape"]) + batch_mask_shape_abs = op.Abs(batch_mask_shape, _outputs=["batch_mask_shape_abs"]) + batch_size_squeezed = op.Squeeze(batch_size, _outputs=["batch_size_squeezed"]) + batch_range = op.Range(0, batch_size_squeezed, 1, _outputs=["batch_range"]) + reshaped_batch_range = op.Reshape(batch_range, [-1, 1, 1, 1], allowzero=1, _outputs=["reshaped_batch_range"]) + expanded_batch_range = op.Expand(reshaped_batch_range, batch_mask_shape_abs, _outputs=["expanded_batch_range"]) + + # Combine KV/Query Ranges for Sliding Window Mask + kv_range_offset = op.Sub(expanded_kv_range, 262144, _outputs=["kv_range_offset"]) + query_gt_kv_offset = op.Greater(expanded_query_range, kv_range_offset, _outputs=["query_gt_kv_offset"]) + query_gt_kv_offset_mask = op.And(True, query_gt_kv_offset, _outputs=["query_gt_kv_offset_mask"]) + query_le_kv = op.LessOrEqual(expanded_query_range, expanded_kv_range, _outputs=["query_le_kv"]) + sliding_window_mask = op.And(query_gt_kv_offset_mask, query_le_kv, _outputs=["sliding_window_mask"]) + sliding_window_mask_final = op.And(True, sliding_window_mask, _outputs=["sliding_window_mask_final"]) + + # Combine Query/Batch Ranges for Attention Mask Lookup + unsqueezed_query_range = op.Unsqueeze(expanded_query_range, [-1], _outputs=["unsqueezed_query_range"]) + unsqueezed_batch_range = op.Unsqueeze(expanded_batch_range, [-1], _outputs=["unsqueezed_batch_range"]) + batch_query_indices = op.Concat(unsqueezed_batch_range, unsqueezed_query_range, axis=-1, _outputs=["batch_query_indices"]) + attention_mask_bool = op.Cast(attention_mask, to=ir.DataType.BOOL, _outputs=["attention_mask_bool"]) + attention_lookup = op.GatherND(attention_mask_bool, batch_query_indices, batch_dims=0, _outputs=["attention_lookup"]) + + # Final Mask Combination + final_attention_mask = op.And(sliding_window_mask_final, attention_lookup, _outputs=["final_attention_mask"]) + inverted_mask = op.Not(final_attention_mask, _outputs=["inverted_mask"]) + mask_fp32 = op.Cast(inverted_mask, to=ir.DataType.FLOAT, _outputs=["mask_fp32"]) + scaled_mask = op.Mul(mask_fp32, pattern.ANY_VALUE) + + # Propagation to GQA + sliced_mask = op.Slice(scaled_mask, [0], pattern.ANY_VALUE, [3], [1], _outputs=["sliced_mask"]) + + gqa_input = pattern.OrValue([sliced_mask, scaled_mask]) + + return op.GQA( + gqa_input, + _allow_other_inputs=True, + _domain="ai.onnxruntime._fusion", + _outputs=["attn_output", "key_seq", "value_seq"], + ) + + + def rewrite( + self, + op, + attention_mask, + attn_output, + **_, + ): + """ + Rewrite the GQA node with the new mask information. + This method computes the total sequence length and seqlens_k based on the + attention_mask and rewrites the GQA node to use these values. + """ + # Compute total_seq_length_int32 and seqlens_k_int32 + total_seq_length_int32, seqlens_k_int32 = self.compute_mask(op, attention_mask) + + gqa_node = attn_output.producer() + assert len(gqa_node.inputs) == 12, ( + f"Expected 12 inputs for GQA node, got {len(gqa_node.inputs)}" + ) + query, key, value, past_key, past_value = gqa_node.inputs[3:8] + cos, sin = gqa_node.inputs[10:12] + updated_inputs = [ + query, + key, + value, + past_key, + past_value, + seqlens_k_int32, + total_seq_length_int32, + cos, + sin, + ] + attributes = gqa_node.attributes + return op.GroupQueryAttention( + *updated_inputs, **attributes, _domain="com.microsoft", _outputs=3 + ) + +_basic_gqa_rule = GroupQueryAttention.rule() +_gqa_causal_mask_rule = GQACausalMask.rule() +_longrope_gqa_causal_mask_rule = LongRoPeGQACausalMask.rule() + +gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule, _gqa_causal_mask_rule, _longrope_gqa_causal_mask_rule]) + +fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules)