Skip to content

Commit 9b829cd

Browse files
authoredFeb 15, 2025
Move attention block assertion for shape (#960)
This assumes the `qkv` matrices are square which is not a requirement for the attention block. Moving / updating the assertion enables architectures that do not impose this restriction.
1 parent 1aa89af commit 9b829cd

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed
 

‎sharktank/sharktank/layers/paged_llama_attention_block.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,16 @@ def forward(
101101
):
102102
assert bool(start_index is not None) ^ bool(embedding_batch_mask is not None)
103103
x = self.attn_norm(h)
104-
bs, batch_seq_len, feature_dim = x.shape
105-
assert feature_dim == self.head_count * self.head_dim
104+
bs, batch_seq_len, _ = x.shape
106105

107106
xq = self.attn_q(x)
108107
xk = self.attn_k(x)
109108
xv = self.attn_v(x)
110109

110+
assert xq.shape[-1] == self.head_count * self.head_dim
111+
assert xk.shape[-1] == self.head_count_kv * self.head_dim
112+
assert xv.shape[-1] == self.head_count_kv * self.head_dim
113+
111114
xq = xq.view(bs, batch_seq_len, self.head_count, self.head_dim)
112115
xk = xk.view(bs, batch_seq_len, self.head_count_kv, self.head_dim)
113116
xv = xv.view(bs, batch_seq_len, self.head_count_kv, self.head_dim)

0 commit comments

Comments
 (0)