Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions vllm_gaudi/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def __init__(
qk_head_dim: int,
v_head_dim: int,
kv_b_proj: ColumnParallelLinear,
sinks: Optional[torch.Tensor] = None,
**kwargs,
) -> None:
torch.nn.Module.__init__(self)
Expand Down Expand Up @@ -258,6 +259,11 @@ def __init__(
"encoder/decoder cross-attention "
"are not implemented for "
"TritonMLAImpl")
self.sinks = sinks
if sinks is not None:
assert sinks.shape[0] == num_heads, ("Sinks must have the same number of heads as the number of "
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}.")

def forward(
self,
Expand Down Expand Up @@ -434,6 +440,7 @@ def __init__(
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
sinks: Optional[torch.Tensor] = None,
) -> None:
super(AttentionImpl, self).__init__()
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
Expand Down Expand Up @@ -500,6 +507,11 @@ def __init__(
raise NotImplementedError("Encoder self-attention "
"is not implemented for "
"HPUAttentionImpl")
self.sinks = sinks
if sinks is not None:
assert sinks.shape[0] == num_heads, ("Sinks must have the same number of heads as the number of "
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}.")

self.is_chunked_attention = False

Expand Down Expand Up @@ -582,6 +594,12 @@ def forward(
if kv_cache is not None and isinstance(kv_cache, tuple):
key_cache, value_cache, k_scales, v_scales = \
HPUPagedAttention.split_kv_cache(kv_cache, self.num_kv_heads, self.head_size)
if key.dtype != key_cache.dtype:
key = key.to(key_cache.dtype)
if value.dtype != value_cache.dtype:
value = value.to(value_cache.dtype)
if query.dtype != key.dtype:
query = query.to(key.dtype)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
Expand Down Expand Up @@ -720,6 +738,7 @@ def common_attention_args(self,
'key_cache': key_cache,
'value_cache': value_cache,
'block_size': block_size,
"sinks": self.sinks,
'k_scales': k_scales,
'v_scales': v_scales,
}
Expand Down
140 changes: 112 additions & 28 deletions vllm_gaudi/extension/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def matmul_shape(lhs, rhs):
return result


def pipelined_pa(attn, value, block_bias, block_groups, block_mapping, batch_size, matmul_av_op, batch2block_matmul_op,
block2batch_matmul_op):
def pipelined_pa(attn, value, block_bias, block_groups, block_mapping, sink, batch_size, matmul_av_op,
batch2block_matmul_op, block2batch_matmul_op):
# When fp32_softmax is enabled attn is left in fp32 after Q@K
# We can return to native dtype after we renormalize and calculate the adjustments
if block_bias is not None and attn.dtype != block_bias.dtype:
Expand All @@ -82,11 +82,27 @@ def pipelined_pa(attn, value, block_bias, block_groups, block_mapping, batch_siz
if block_bias is not None:
attn.add_(block_bias)
block_max = attn.amax(dim=-1, keepdim=True)
if sink is not None:
block_max = torch.maximum(block_max, sink)
attn = attn.sub(block_max)
attn = attn.exp()
if attn.dtype == torch.float32:
attn = attn.to(value.dtype)
block_sums = attn.sum(dim=-1, keepdim=True)
attn_shape = attn.shape
block_sums = attn.view(-1, attn_shape[-1]).sum(dim=-1, keepdim=True)
attn_shape = list(attn_shape)
attn_shape[-1] = 1
block_sums = block_sums.view(attn_shape)
if sink is not None:
attn_sink = sink.sub(block_max)
attn_sink = attn_sink.exp()
if attn_sink.dtype == torch.float32:
attn_sink = attn_sink.to(value.dtype)
#TODO: Removing this .sum and using attn_sink directly
#results in wrong output which does not make sense.
#Looks like a Synapse issue, need to investigate further.
block_sums_sink = attn_sink.sum(dim=-1, keepdim=True)
block_sums = block_sums + block_sums_sink
attn = matmul_av_op(attn, value)
if get_config().fused_block_softmax_adjustment:
out_shape = list(attn.shape[:3]) + [1] * (attn.dim() - 3)
Expand Down Expand Up @@ -179,7 +195,7 @@ def flat_pa_mla(query, key_cache, value_cache, block_list, block_mapping, block_

def flat_pa(query, key_cache, value_cache, block_list, block_mapping, block_bias, block_groups, block_size, scale,
matmul_qk_op, position_bias, matmul_av_op, batch2block_matmul_op, block2batch_matmul_op, keys_fetch_func,
values_fetch_func, k_scales, v_scales, **ignored_args):
values_fetch_func, sinks, k_scales, v_scales, **ignored_args):
batch_size, _, hidden_size = query.shape
_, kv_heads, head_size = key_cache.shape
q_heads = hidden_size // head_size
Expand All @@ -197,6 +213,13 @@ def flat_pa(query, key_cache, value_cache, block_list, block_mapping, block_bias
value = values_fetch_func(value_cache.unflatten(0, (-1, block_size)),
**get_kv_fetch_extra_args(blocks=block_list, scales=v_scales_uf)).transpose(1, 2)
block_bias = block_bias.view(key.size(0), 1, 1, -1)
sink = None
if sinks is not None:
sinks = sinks.reshape(sinks.shape[0], 1)
sink = sinks.reshape(1, sinks.shape[0], 1, sinks.shape[1])
sink = sink.expand(query.shape[0], -1, query.shape[-2], -1)
if kv_heads != q_heads:
sink = sink.unflatten(1, (kv_heads, -1))
if kv_heads != q_heads:
query = query.unflatten(1, (kv_heads, -1))
key = key.unflatten(1, (kv_heads, 1))
Expand Down Expand Up @@ -234,6 +257,7 @@ def flat_pa(query, key_cache, value_cache, block_list, block_mapping, block_bias
block_bias,
block_groups,
block_mapping,
sink,
batch_size=batch_size,
matmul_av_op=matmul_av_op,
batch2block_matmul_op=batch2block_matmul_op,
Expand Down Expand Up @@ -292,6 +316,7 @@ def _naive_prompt_attention(query: torch.Tensor,
matmul_qk_op=torch.matmul,
softmax_op=torch.softmax,
matmul_av_op=torch.matmul,
sinks: Optional[torch.Tensor] = None,
**ignored_args) -> torch.Tensor:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
Expand Down Expand Up @@ -323,10 +348,19 @@ def _naive_prompt_attention(query: torch.Tensor,
if attn_weights.dtype != attn_bias.dtype:
attn_bias = attn_bias.to(dtype=attn_weights.dtype)
attn_weights.add_(attn_bias)
if sinks is not None:
sink = sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
if query_heads != kv_heads:
sink = sink.unflatten(1, (kv_heads, -1))
combined_logits = torch.cat([attn_weights, sink], dim=-1)
combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
attn_weights = combined_logits
if get_config().fp32_softmax:
attn_weights = torch.softmax(attn_weights, dim=-1)
else:
attn_weights = softmax_op(attn_weights, dim=-1)
if sinks is not None:
attn_weights = attn_weights[..., :-1]
attn_weights = attn_weights.to(query.dtype)
attn_weights = matmul_av_op(attn_weights, value)

Expand All @@ -345,6 +379,7 @@ def _fsdpa_prompt_attention(query: torch.Tensor,
attn_bias: Optional[torch.Tensor] = None,
valid_seq_lengths: Optional[torch.Tensor] = None,
window_size: Optional[int] = None,
sinks: Optional[torch.Tensor] = None,
**ignored_args) -> torch.Tensor:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
Expand All @@ -366,10 +401,19 @@ def _fsdpa_prompt_attention(query: torch.Tensor,
query, key, value, attn_bias, 0.0, is_causal, scale, softmax_mode, recompute_mode, valid_seq_lengths,
padding_side
]
args += [window_size] if window_size else []
if sinks is not None:
args += [window_size] if window_size else [None]
else:
args += [window_size] if window_size else []
# use sinks in fsdpa
if sinks is not None:
args += [sinks]
attn_weights = fsdpa_op(*args)

attn_weights = attn_weights.transpose(1, 2)
if sinks is not None:
# TODO - check if we can remove this
htcore.mark_step()
return attn_weights


Expand Down Expand Up @@ -494,10 +538,14 @@ class MoeMatmul(torch.nn.Module):

def __init__(self):
super().__init__()
self.bias = None

def set_weight(self, w):
self.weight = w

def set_bias(self, b):
self.bias = b

def forward(self, state, expert_id, w):
raise NotImplementedError()

Expand All @@ -509,12 +557,14 @@ def __init__(self,
num_total_experts: int,
experts_min: int = 0,
experts_max: int = 8,
dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None):
dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None,
bias=False):
super().__init__()
self.experts_min = experts_min
self.experts_max = experts_max
self.global_num_experts = global_num_experts
self.num_experts = num_total_experts
self.bias = bias

if MAX_EXPERTS_PER_SLICE > 0:
max_expert_per_slice = MAX_EXPERTS_PER_SLICE
Expand Down Expand Up @@ -566,8 +616,9 @@ def __init__(self,
num_total_experts: int,
experts_min: int = 0,
experts_max: int = 8,
dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None):
super().__init__(global_num_experts, num_total_experts, experts_min, experts_max, dispatch_fn)
dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None,
bias=False):
super().__init__(global_num_experts, num_total_experts, experts_min, experts_max, dispatch_fn, bias)
self.w13_list = torch.nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)])
self.w2_list = torch.nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)])

Expand All @@ -580,31 +631,64 @@ def forward(self, hidden_states, expert_routing_table, router_weights, permuted_
w2_list = [self.w2_list[i].weight.squeeze() for i in experts_range]

if self.moe_n_slice == 1:
return torch.ops.hpu.mixture_of_experts(hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list,
w3=w2_list,
permuted_weights=permuted_weights,
activation=activation,
experts_min=self.experts_min,
experts_max=self.experts_max,
**kwargs)
if self.bias is True and self.w13_list[i].bias is not None \
and self.w2_list[i].bias is not None:
w1_bias_list = [self.w13_list[i].bias.squeeze() for i in experts_range]
w2_bias_list = [self.w2_list[i].bias.squeeze() for i in experts_range]
return torch.ops.hpu.mixture_of_experts.bias_fused_weights(hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list,
w3=w2_list,
w12_bias=w1_bias_list,
w3_bias=w2_bias_list,
permuted_weights=permuted_weights,
experts_min=self.experts_min,
experts_max=self.experts_max)
else:
return torch.ops.hpu.mixture_of_experts(hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list,
w3=w2_list,
permuted_weights=permuted_weights,
activation=activation,
experts_min=self.experts_min,
experts_max=self.experts_max,
**kwargs)
for i in range(self.moe_n_slice):
w1_list_slice = w1_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group]
w2_list_slice = w2_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group]
min_expert = self.experts_min + i * self.num_expert_per_group
max_expert = min_expert + self.num_expert_per_group - 1
slice_final_hidden_states = torch.ops.hpu.mixture_of_experts(hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list_slice,
w3=w2_list_slice,
permuted_weights=permuted_weights,
activation=activation,
experts_min=min_expert,
experts_max=max_expert,
**kwargs)
if self.bias is not None and self.w13_list[i].bias is not None \
and self.w2_list is not None:
w1_bias_list = [self.w13_list[i].bias.squeeze() for i in experts_range]
w2_bias_list = [self.w2_list[i].bias.squeeze() for i in experts_range]
w1_bias_list_slice = w1_bias_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group]
w2_bias_list_slice = w2_bias_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group]
slice_final_hidden_states = torch.ops.hpu.mixture_of_experts.bias_fused_weights(
hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list,
w3=w2_list,
w12_bias=w1_bias_list_slice,
w3_bias=w2_bias_list_slice,
permuted_weights=permuted_weights,
experts_min=self.experts_min,
experts_max=self.experts_max)
else:
slice_final_hidden_states = torch.ops.hpu.mixture_of_experts(hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list_slice,
w3=w2_list_slice,
permuted_weights=permuted_weights,
activation=activation,
experts_min=min_expert,
experts_max=max_expert,
**kwargs)
if i == 0:
final_hidden_states = slice_final_hidden_states
else:
Expand Down
6 changes: 4 additions & 2 deletions vllm_gaudi/extension/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,16 @@ def forward(
valid_sequence_lengths,
padding_side="left",
window_size=None,
sinks=None,
):
if window_size is not None:
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode,
recompute_mode, valid_sequence_lengths, padding_side, False, False,
window_size)
window_size, sinks)
else:
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode,
recompute_mode, valid_sequence_lengths, padding_side)
recompute_mode, valid_sequence_lengths, padding_side, False, False,
(-1, -1), sinks)


def pad_list(input, target_len, val_generator):
Expand Down
Loading
Loading