Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,10 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
# Optimization: use .reshape over .expand for one-step repeat (avoids extra copy with repeat_interleave, and keeps original behavior)
expanded = hidden_states.unsqueeze(2).expand(batch, num_key_value_heads, n_rep, slen, head_dim)
# view is slightly faster than reshape, but keep reshape for safety (non-contiguous), and only call once
return expanded.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def eager_attention_forward(
Expand All @@ -274,16 +276,28 @@ def eager_attention_forward(
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)

attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
# Pre-split num_key_value_groups for faster access
num_key_value_groups = module.num_key_value_groups
key_states = repeat_kv(key, num_key_value_groups)
value_states = repeat_kv(value, num_key_value_groups)

# Use local variable for shape for a minor speedup (single allocation)
key_len = key_states.shape[-2]
# Matmul: [batch, num_heads, seqlen, head_dim] x [batch, num_heads, head_dim, seqlen] -> [b, h, s, s]
attn_weights = torch.matmul(query, key_states.transpose(2, 3)).mul(scaling)
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
# Only slice once for the mask and perform inplace addition for less memory use (if possible)
causal_mask = attention_mask[:, :, :, :key_len]
attn_weights = attn_weights + causal_mask

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
# softmax (dtype=torch.float32 for stability), then convert dtype at end for speed
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
# Dropout: If dropout==0, skip allocation (minor savings)
if dropout > 0.0:
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
# Convert to query dtype only after dropout
attn_weights = attn_weights.to(query.dtype)
# Matmul: [batch, num_heads, seqlen, seqlen] x [batch, num_heads, seqlen, head_dim] -> [batch, num_heads, seqlen, head_dim]
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()

Expand Down