From a84ee5a6117acf8edd40ad359525e75faca5af5d Mon Sep 17 00:00:00 2001 From: James Xu Date: Mon, 11 Nov 2024 17:59:46 -0500 Subject: [PATCH 1/4] Add RoPE comparison across flashinfer and vLLM --- .../compare_flashinfer_vllm_rope.py | 88 +++++++++++++++++++ 1 file changed, 88 insertions(+) create mode 100644 scripts/playground/compare_flashinfer_vllm_rope.py diff --git a/scripts/playground/compare_flashinfer_vllm_rope.py b/scripts/playground/compare_flashinfer_vllm_rope.py new file mode 100644 index 00000000000..4a003ad8ddb --- /dev/null +++ b/scripts/playground/compare_flashinfer_vllm_rope.py @@ -0,0 +1,88 @@ +import torch +from einops import rearrange + + +def flashinfer_rope( + q: torch.tensor, + k: torch.tensor, + positions: torch.tensor, + rotary_dim: int, + rope_theta: int, +): + from flashinfer.rope import apply_rope_pos_ids + + q_rope, k_rope = apply_rope_pos_ids( + q, + k, + pos_ids=positions, + rotary_dim=rotary_dim, + rope_theta=rope_theta, + interleave=False, + ) + return q_rope, k_rope + + +def vllm_rope( + q: torch.tensor, + k: torch.tensor, + positions: torch.tensor, + head_size: int, + rotary_dim: int, + rope_theta: int, + max_position: int, +): + from vllm.model_executor.layers.rotary_embedding import get_rope + + rotary_emb = get_rope( + head_size=head_size, + rotary_dim=rotary_dim, + max_position=max_position, + base=rope_theta, + is_neox_style=True, + ) + + q_rope, k_rope = rotary_emb(positions, q, k) + return q_rope, k_rope + + +def main(): + batch_size, seq_len, head_size = 2, 10, 64 + rotary_dim = head_size // 2 + rope_theta = 1e4 + + torch.cuda.manual_seed_all(42) + q = torch.rand((batch_size, seq_len, head_size), dtype=torch.float16, device="cuda") + k = torch.rand((batch_size, seq_len, head_size), dtype=torch.float16, device="cuda") + + max_position = seq_len + positions = torch.randint(0, seq_len, (batch_size + 1,), device="cuda") + + # (batch_size, seq_len, head_size) -> flashinfer input shape (nnz, num_heads, head_dim) + flashinfer_q_rope, flashinfer_k_rope = flashinfer_rope( + rearrange(q, "b s h -> (b s) 1 h"), + rearrange(k, "b s h -> (b s) 1 h"), + positions, + rotary_dim, + rope_theta, + ) + + # flashinfer output shape (nnz, num_heads, head_dim) -> (batch_size, seq_len, head_size) + flashinfer_q_rope, flashinfer_k_rope = rearrange( + flashinfer_q_rope, "(b s) 1 h -> b s h", b=batch_size, s=seq_len + ), rearrange(flashinfer_k_rope, "(b s) 1 h -> b s h", b=batch_size, s=seq_len) + + # looks like this is doing something in-place? + vllm_q_rope, vllm_k_rope = vllm_rope( + q, k, positions, head_size, rotary_dim, rope_theta, max_position + ) + + # Mismatched elements: 2 / 1280 (0.2%) + # Greatest absolute difference: 0.0001220703125 at index (0, 1, 4) (up to 1e-05 allowed) + # Greatest relative difference: 0.017852783203125 at index (0, 2, 6) (up to 0.001 allowed) + + torch.testing.assert_close(flashinfer_q_rope, vllm_q_rope, atol=2e-4, rtol=2e-1) + torch.testing.assert_close(flashinfer_k_rope, vllm_k_rope, atol=2e-4, rtol=2e-1) + + +if __name__ == "__main__": + main() From 0d70c4367bcd74aaf2b3ac374a54f4a84436c309 Mon Sep 17 00:00:00 2001 From: James Xu Date: Thu, 14 Nov 2024 15:38:34 -0500 Subject: [PATCH 2/4] feat: use FlashInfer RoPE (llama) --- python/sglang/srt/layers/rotary_embedding.py | 327 +++++++++++++++++++ python/sglang/srt/models/llama.py | 2 +- 2 files changed, 328 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 80158573bd6..2cfb5978935 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -15,6 +15,39 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch +from einops import rearrange +from vllm.model_executor.custom_op import CustomOp + +from sglang.srt.layers.custom_op_util import register_custom_op + + +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) class MRotaryEmbedding: @@ -110,3 +143,297 @@ def get_next_input_positions( ) for _ in range(3) ] + + +@register_custom_op("sglang_rope") +class RotaryEmbedding(CustomOp): + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + cache = cache.to(dtype) + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + if offsets is not None: + positions = positions + offsets + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + from flashinfer.rope import apply_rope_pos_ids_inplace + + if offsets is not None: + positions = positions + offsets + seq_len, num_q_heads, num_k_heads = ( + positions.shape[0], + query.shape[1] // self.head_size, + key.shape[1] // self.head_size, + ) + + # (seq_len, num_heads * head_dim) -> flashinfer input shape (nnz=seq_len, num_heads, head_dim) + flashinfer_query, flashinfer_key = rearrange( + query.type(torch.float16), + "s (n_h h_d) -> s n_h h_d", + n_h=num_q_heads, + h_d=self.head_size, + ), rearrange( + key.type(torch.float16), + "s (n_h h_d) -> s n_h h_d", + n_h=num_k_heads, + h_d=self.head_size, + ) + apply_rope_pos_ids_inplace( + flashinfer_query, + flashinfer_key, + pos_ids=positions, + rotary_dim=self.rotary_dim, + rope_theta=self.base, + interleave=(not self.is_neox_style), + ) + + # flashinfer output shape (nnz=seq_len, num_heads, head_dim) -> (seq_len, num_heads * head_dim) + return rearrange( + flashinfer_query.type(self.dtype), "s n_h h_d -> s (n_h h_d)" + ), rearrange(flashinfer_key.type(self.dtype), "s n_h h_d -> s (n_h h_d)") + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + return s + + +_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} + + +class Llama3RotaryEmbedding(RotaryEmbedding): + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + scaling_factor: float, + low_freq_factor: float, + high_freq_factor: float, + orig_max_position: int, + ) -> None: + self.scaling_factor = scaling_factor + self.low_freq_factor = low_freq_factor + self.high_freq_factor = high_freq_factor + self.orig_max_position = orig_max_position + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + + +def get_rope( + head_size: int, + rotary_dim: int, + max_position: int, + base: int, + is_neox_style: bool = True, + rope_scaling: Optional[Dict[str, Any]] = None, + dtype: Optional[torch.dtype] = None, + partial_rotary_factor: float = 1.0, +) -> RotaryEmbedding: + if dtype is None: + dtype = torch.get_default_dtype() + if rope_scaling is not None: + # Transforms every value that is a list into a tuple for caching calls + rope_scaling_tuple = { + k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items() + } + rope_scaling_args = tuple(rope_scaling_tuple.items()) + else: + rope_scaling_args = None + if partial_rotary_factor < 1.0: + rotary_dim = int(rotary_dim * partial_rotary_factor) + key = ( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + rope_scaling_args, + dtype, + ) + if key in _ROPE_DICT: + return _ROPE_DICT[key] + + if rope_scaling is None: + rotary_emb = RotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, dtype + ) + else: + scaling_type = rope_scaling["rope_type"] + + if scaling_type == "llama3": + scaling_factor = rope_scaling["factor"] + low_freq_factor = rope_scaling["low_freq_factor"] + high_freq_factor = rope_scaling["high_freq_factor"] + original_max_position = rope_scaling["original_max_position_embeddings"] + rotary_emb = Llama3RotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + scaling_factor, + low_freq_factor, + high_freq_factor, + original_max_position, + ) + elif scaling_type == "default": + if "mrope_section" in rope_scaling: + rotary_emb = MRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + mrope_section=rope_scaling["mrope_section"], + ) + # else: + # rotary_emb = RotaryEmbedding( + # head_size, + # rotary_dim, + # max_position, + # base, + # is_neox_style, + # dtype, + # ) + # elif scaling_type == "linear": + # scaling_factor = rope_scaling["factor"] + # rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, + # max_position, base, + # is_neox_style, + # scaling_factor, dtype) + # elif scaling_type == "dynamic": + # scaling_factor = rope_scaling["factor"] + # rotary_emb = DynamicNTKScalingRotaryEmbedding( + # head_size, rotary_dim, max_position, base, is_neox_style, + # scaling_factor, dtype) + # elif scaling_type == "yarn": + # scaling_factor = rope_scaling["factor"] + # original_max_position = rope_scaling[ + # "original_max_position_embeddings"] + # extra_kwargs = { + # k: v + # for k, v in rope_scaling.items() + # if k in ("extrapolation_factor", "attn_factor", "beta_fast", + # "beta_slow") + # } + # rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim, + # original_max_position, + # base, is_neox_style, + # scaling_factor, dtype, + # **extra_kwargs) + # elif scaling_type == "deepseek_yarn": + # scaling_factor = rope_scaling["factor"] + # original_max_position = rope_scaling[ + # "original_max_position_embeddings"] + # # assert max_position == original_max_position * scaling_factor + # extra_kwargs = { + # k: v + # for k, v in rope_scaling.items() + # if k in ("extrapolation_factor", "attn_factor", "beta_fast", + # "beta_slow", "mscale", "mscale_all_dim") + # } + # rotary_emb = DeepseekScalingRotaryEmbedding( + # head_size, rotary_dim, original_max_position, base, + # is_neox_style, scaling_factor, dtype, **extra_kwargs) + # elif scaling_type == "longrope": + # short_factor = rope_scaling["short_factor"] + # long_factor = rope_scaling["long_factor"] + # original_max_position = rope_scaling[ + # "original_max_position_embeddings"] + # extra_kwargs = { + # k: v + # for k, v in rope_scaling.items() + # if k in ("short_mscale", "long_mscale") + # } + # rotary_emb = Phi3LongRoPEScaledRotaryEmbedding( + # head_size, rotary_dim, max_position, original_max_position, + # base, is_neox_style, dtype, short_factor, long_factor, + # **extra_kwargs) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + _ROPE_DICT[key] = rotary_emb + return rotary_emb diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index e3e44ea6ffc..857256230d9 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -23,7 +23,6 @@ from torch import nn from transformers import LlamaConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.rotary_embedding import get_rope from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm @@ -36,6 +35,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, From 82af8ae3bf7a28415f5dee136f88ab0af9e22a3e Mon Sep 17 00:00:00 2001 From: James Xu Date: Tue, 24 Dec 2024 06:57:32 +0000 Subject: [PATCH 3/4] Use int32 positions with flashinfer --- python/sglang/srt/layers/rotary_embedding.py | 3 ++- scripts/playground/compare_flashinfer_vllm_rope.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 2cfb5978935..3d90b959a04 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -253,10 +253,11 @@ def forward_cuda( n_h=num_k_heads, h_d=self.head_size, ) + apply_rope_pos_ids_inplace( flashinfer_query, flashinfer_key, - pos_ids=positions, + pos_ids=positions.int(), rotary_dim=self.rotary_dim, rope_theta=self.base, interleave=(not self.is_neox_style), diff --git a/scripts/playground/compare_flashinfer_vllm_rope.py b/scripts/playground/compare_flashinfer_vllm_rope.py index 4a003ad8ddb..2aaf683035b 100644 --- a/scripts/playground/compare_flashinfer_vllm_rope.py +++ b/scripts/playground/compare_flashinfer_vllm_rope.py @@ -47,7 +47,7 @@ def vllm_rope( def main(): batch_size, seq_len, head_size = 2, 10, 64 - rotary_dim = head_size // 2 + rotary_dim = head_size rope_theta = 1e4 torch.cuda.manual_seed_all(42) @@ -61,7 +61,7 @@ def main(): flashinfer_q_rope, flashinfer_k_rope = flashinfer_rope( rearrange(q, "b s h -> (b s) 1 h"), rearrange(k, "b s h -> (b s) 1 h"), - positions, + positions.int(), rotary_dim, rope_theta, ) From d0cfd4a96e404e34a89a72a331b26f4953eb9a1e Mon Sep 17 00:00:00 2001 From: James Xu Date: Sat, 28 Dec 2024 00:07:27 +0000 Subject: [PATCH 4/4] Fix reference script positions after flashinfer bump --- scripts/playground/compare_flashinfer_vllm_rope.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/scripts/playground/compare_flashinfer_vllm_rope.py b/scripts/playground/compare_flashinfer_vllm_rope.py index 2aaf683035b..9a40d5d5468 100644 --- a/scripts/playground/compare_flashinfer_vllm_rope.py +++ b/scripts/playground/compare_flashinfer_vllm_rope.py @@ -55,13 +55,13 @@ def main(): k = torch.rand((batch_size, seq_len, head_size), dtype=torch.float16, device="cuda") max_position = seq_len - positions = torch.randint(0, seq_len, (batch_size + 1,), device="cuda") + positions = torch.randint(0, seq_len, (batch_size * seq_len,), device="cuda") # (batch_size, seq_len, head_size) -> flashinfer input shape (nnz, num_heads, head_dim) flashinfer_q_rope, flashinfer_k_rope = flashinfer_rope( rearrange(q, "b s h -> (b s) 1 h"), rearrange(k, "b s h -> (b s) 1 h"), - positions.int(), + positions.int(), # flashinfer posiitons should have i32 data type rotary_dim, rope_theta, ) @@ -76,9 +76,9 @@ def main(): q, k, positions, head_size, rotary_dim, rope_theta, max_position ) - # Mismatched elements: 2 / 1280 (0.2%) - # Greatest absolute difference: 0.0001220703125 at index (0, 1, 4) (up to 1e-05 allowed) - # Greatest relative difference: 0.017852783203125 at index (0, 2, 6) (up to 0.001 allowed) + # Mismatched elements: 35 / 1280 (2.7%) + # Greatest absolute difference: 0.000732421875 at index (1, 0, 7) (up to 1e-05 allowed) + # Greatest relative difference: 0.08514404296875 at index (0, 1, 35) (up to 0.001 allowed) torch.testing.assert_close(flashinfer_q_rope, vllm_q_rope, atol=2e-4, rtol=2e-1) torch.testing.assert_close(flashinfer_k_rope, vllm_k_rope, atol=2e-4, rtol=2e-1)