Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions atom/model_ops/base_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,8 @@ def fake_(
k: torch.Tensor,
v: torch.Tensor,
positions: torch.Tensor,
kv_cache: torch.Tensor,
kv_scale: torch.Tensor,
layer_name: str,
use_mla: bool,
qkv: torch.Tensor,
Expand All @@ -342,13 +344,19 @@ def fake_(
# Dynamo will not try to inspect any of the internal operations for prefill or decode
# This way, although attention operation is complicated,
# we can still capture the model's computation graph as a full-graph
@mark_spliting_op(is_custom=True, gen_fake=fake_, mutates_args=[])
@mark_spliting_op(
is_custom=True,
gen_fake=fake_,
mutates_args=["kv_cache", "kv_scale"],
)
def unified_attention_with_output_base(
q: torch.Tensor,
q_scale: Optional[torch.Tensor],
k: torch.Tensor,
v: torch.Tensor,
positions: torch.Tensor,
kv_cache: torch.Tensor,
kv_scale: torch.Tensor,
layer_name: str,
use_mla: bool,
qkv: torch.Tensor,
Expand All @@ -368,9 +376,11 @@ def unified_attention_with_output_base(
query=q,
key=k,
value=v,
position=positions,
positions=positions,
q_scale=q_scale,
qkv=qkv,
kv_cache=kv_cache,
kv_scale=kv_scale,
)


Expand All @@ -379,6 +389,7 @@ def linear_attention_with_output_base_fake(
b: torch.Tensor,
a: torch.Tensor,
core_attn_out: torch.Tensor,
kv_cache: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(core_attn_out)
Expand All @@ -387,19 +398,20 @@ def linear_attention_with_output_base_fake(
@mark_spliting_op(
is_custom=True,
gen_fake=linear_attention_with_output_base_fake,
mutates_args=[],
mutates_args=["kv_cache"],
)
def linear_attention_with_output_base(
mixed_qkv: torch.Tensor,
b: torch.Tensor,
a: torch.Tensor,
core_attn_out: torch.Tensor,
kv_cache: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
atom_config = get_current_atom_config()
self = atom_config.compilation_config.static_forward_context[layer_name]
ret = torch.empty_like(core_attn_out)
ret = self.impl.forward(mixed_qkv, b, a, ret, layer_name)
ret = self.impl.forward(mixed_qkv, b, a, ret, kv_cache, layer_name)
return ret


Expand Down Expand Up @@ -503,6 +515,7 @@ def __init__(
compilation_config = atom_config.compilation_config
default_name = f"Linear_{layer_num}"
self.layer_name = prefix if prefix is not None else default_name
self.kv_cache = torch.tensor([])
if self.layer_name in compilation_config.static_forward_context:
raise ValueError("Duplicate layer: {}".format(self.layer_name))
compilation_config.static_forward_context[self.layer_name] = self
Expand All @@ -515,6 +528,6 @@ def forward(
core_attn_out: torch.Tensor,
):
output = torch.ops.aiter.linear_attention_with_output_base(
mixed_qkv, b, a, core_attn_out, self.layer_name
mixed_qkv, b, a, core_attn_out, self.kv_cache, self.layer_name
)
return output
18 changes: 17 additions & 1 deletion atom/model_ops/paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,23 @@ def forward(
qkv: torch.Tensor = None,
**kwargs,
):
kv_cache = getattr(self.impl, "kv_cache", query.new_empty(0))
kv_scale_fn = getattr(self.impl, "_kv_scale_arg_for_forward", None)
kv_scale = (
kv_scale_fn(query, kv_cache)
if kv_scale_fn is not None
else query.new_empty(0)
)
output = torch.ops.aiter.unified_attention_with_output_base(
query, q_scale, key, value, positions, self.layer_name, self.use_mla, qkv
query,
q_scale,
key,
value,
positions,
kv_cache,
kv_scale,
self.layer_name,
self.use_mla,
qkv,
)
return output
49 changes: 37 additions & 12 deletions atom/model_ops/triton_fused_qkv_norm_rope_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,13 @@ def _fused_qkv_norm_rope_cache_kernel(
ROTARY_DIM: tl.constexpr,
ROTARY_DIM_HALF: tl.constexpr,
IS_FP8: tl.constexpr,
FP8_MAX: tl.constexpr,
# M-RoPE section boundaries (cumulative)
MROPE_S0: tl.constexpr = 0,
MROPE_S1: tl.constexpr = 0,
MROPE_SECTION_H: tl.constexpr = 0,
MROPE_SECTION_W: tl.constexpr = 0,
MROPE_INTERLEAVED: tl.constexpr = False,
IS_MROPE: tl.constexpr = False,
):
# Grid: (num_tokens * (num_heads + num_kv_heads),)
Expand Down Expand Up @@ -122,11 +126,16 @@ def _fused_qkv_norm_rope_cache_kernel(
pos_t = tl.load(pos_ptr + 0 * pos_stride_row + token_id)
pos_h = tl.load(pos_ptr + 1 * pos_stride_row + token_id)
pos_w = tl.load(pos_ptr + 2 * pos_stride_row + token_id)
pos_per_dim = tl.where(
d_cos_idx < MROPE_S0,
pos_t,
tl.where(d_cos_idx < MROPE_S1, pos_h, pos_w),
)
if MROPE_INTERLEAVED:
use_h = ((d_cos_idx % 3) == 1) & (d_cos_idx < MROPE_SECTION_H * 3)
use_w = ((d_cos_idx % 3) == 2) & (d_cos_idx < MROPE_SECTION_W * 3)
pos_per_dim = tl.where(use_h, pos_h, tl.where(use_w, pos_w, pos_t))
else:
pos_per_dim = tl.where(
d_cos_idx < MROPE_S0,
pos_t,
tl.where(d_cos_idx < MROPE_S1, pos_h, pos_w),
)
cos_base = pos_per_dim * cos_sin_stride_pos
else:
pos = tl.load(pos_ptr + token_id)
Expand Down Expand Up @@ -191,11 +200,16 @@ def _fused_qkv_norm_rope_cache_kernel(
pos_t = tl.load(pos_ptr + 0 * pos_stride_row + token_id)
pos_h = tl.load(pos_ptr + 1 * pos_stride_row + token_id)
pos_w = tl.load(pos_ptr + 2 * pos_stride_row + token_id)
pos_per_dim = tl.where(
d_cos_idx < MROPE_S0,
pos_t,
tl.where(d_cos_idx < MROPE_S1, pos_h, pos_w),
)
if MROPE_INTERLEAVED:
use_h = ((d_cos_idx % 3) == 1) & (d_cos_idx < MROPE_SECTION_H * 3)
use_w = ((d_cos_idx % 3) == 2) & (d_cos_idx < MROPE_SECTION_W * 3)
pos_per_dim = tl.where(use_h, pos_h, tl.where(use_w, pos_w, pos_t))
else:
pos_per_dim = tl.where(
d_cos_idx < MROPE_S0,
pos_t,
tl.where(d_cos_idx < MROPE_S1, pos_h, pos_w),
)
cos_base = pos_per_dim * cos_sin_stride_pos
else:
pos = tl.load(pos_ptr + token_id)
Expand Down Expand Up @@ -244,7 +258,7 @@ def _fused_qkv_norm_rope_cache_kernel(
if IS_FP8:
# FP8 per-token quantization for k
k_abs_max = tl.max(tl.abs(k_roped), axis=0)
k_scale = k_abs_max / 240.0
k_scale = k_abs_max / FP8_MAX
k_scale = tl.where(k_scale == 0.0, 1.0, k_scale)
k_quant = (k_roped / k_scale).to(k_cache_ptr.dtype.element_ty)

Expand Down Expand Up @@ -276,7 +290,7 @@ def _fused_qkv_norm_rope_cache_kernel(
# FP8 per-token quantization for v
v_f32 = v.to(tl.float32)
v_abs_max = tl.max(tl.abs(v_f32), axis=0)
v_scale = v_abs_max / 240.0
v_scale = v_abs_max / FP8_MAX
v_scale = tl.where(v_scale == 0.0, 1.0, v_scale)
v_quant = (v_f32 / v_scale).to(v_cache_ptr.dtype.element_ty)

Expand Down Expand Up @@ -342,6 +356,7 @@ def triton_fused_norm_rope_cache(
sin_cache = rotary_emb.sin_cache.squeeze(-2).squeeze(-2)

is_fp8 = kv_cache_dtype == "fp8"
fp8_max = torch.finfo(k_cache.dtype).max if is_fp8 else 1.0

block_size = k_cache.shape[3] # k_cache: [B, H, D//X, block_size, X]
x_size = k_cache.shape[4]
Expand All @@ -353,10 +368,16 @@ def triton_fused_norm_rope_cache(
assert mrope_section is not None, "M-RoPE requires rotary_emb.mrope_section"
s0 = mrope_section[0]
s1 = s0 + mrope_section[1]
section_h = mrope_section[1]
section_w = mrope_section[2]
mrope_interleaved = getattr(rotary_emb, "mrope_interleaved", False)
pos_stride_row = positions.stride(0)
else:
s0 = 0
s1 = 0
section_h = 0
section_w = 0
mrope_interleaved = False
pos_stride_row = 0

# Allocate contiguous output tensors
Expand Down Expand Up @@ -410,8 +431,12 @@ def triton_fused_norm_rope_cache(
ROTARY_DIM=rotary_dim,
ROTARY_DIM_HALF=rotary_dim // 2,
IS_FP8=is_fp8,
FP8_MAX=fp8_max,
MROPE_S0=s0,
MROPE_S1=s1,
MROPE_SECTION_H=section_h,
MROPE_SECTION_W=section_w,
MROPE_INTERLEAVED=mrope_interleaved,
IS_MROPE=is_mrope,
)

Expand Down
7 changes: 5 additions & 2 deletions atom/plugin/vllm/attention/layer_gdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ def forward(
b: torch.Tensor,
a: torch.Tensor,
core_attn_out: torch.Tensor,
kv_cache: torch.Tensor,
layer_name: str,
):
"""
Expand Down Expand Up @@ -448,8 +449,10 @@ def forward(
non_spec_state_indices_tensor = (
attn_metadata.non_spec_state_indices_tensor
) # noqa: E501
compilation_config = forward_context.no_compile_layers
self_kv_cache = compilation_config[layer_name].kv_cache
if kv_cache is None or kv_cache.numel() == 0:
compilation_config = forward_context.no_compile_layers
kv_cache = compilation_config[layer_name].kv_cache
self_kv_cache = kv_cache
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens
Expand Down
Loading
Loading