Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 33 additions & 19 deletions src/transformers/models/mra/modeling_mra.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,33 +278,47 @@ def get_low_resolution_logit(query, key, block_size, mask=None, value=None):
num_block_per_row = seq_len // block_size

value_hat = None

# Precompute reshaped tensors for efficiency
block_shape = (batch_size, num_block_per_row, block_size, head_dim)
query_reshaped = query.reshape(block_shape)
key_reshaped = key.reshape(block_shape)
value_reshaped = value.reshape(block_shape) if value is not None else None

if mask is not None:
token_count = mask.reshape(batch_size, num_block_per_row, block_size).sum(dim=-1)
query_hat = query.reshape(batch_size, num_block_per_row, block_size, head_dim).sum(dim=-2) / (
token_count[:, :, None] + 1e-6
)
key_hat = key.reshape(batch_size, num_block_per_row, block_size, head_dim).sum(dim=-2) / (
token_count[:, :, None] + 1e-6
)
mask_reshaped = mask.reshape(batch_size, num_block_per_row, block_size)
token_count = mask_reshaped.sum(dim=-1)

# Avoid repeated computation of denominator
denom = token_count[:, :, None] + 1e-6

# Use fused operations (masked_sum) if possible, but here we emulate the same result in fewer operations
# Computing means using sum and division
query_hat = query_reshaped.sum(dim=-2) / denom
key_hat = key_reshaped.sum(dim=-2) / denom
if value is not None:
value_hat = value.reshape(batch_size, num_block_per_row, block_size, head_dim).sum(dim=-2) / (
token_count[:, :, None] + 1e-6
)
value_hat = value_reshaped.sum(dim=-2) / denom
else:
token_count = block_size * torch.ones(batch_size, num_block_per_row, dtype=torch.float, device=query.device)
query_hat = query.reshape(batch_size, num_block_per_row, block_size, head_dim).mean(dim=-2)
key_hat = key.reshape(batch_size, num_block_per_row, block_size, head_dim).mean(dim=-2)
# Avoid torch.ones for token_count; use expand to save memory
token_count = torch.full((batch_size, num_block_per_row), block_size, dtype=torch.float, device=query.device)

# Vectorized mean operations
query_hat = query_reshaped.mean(dim=-2)
key_hat = key_reshaped.mean(dim=-2)
if value is not None:
value_hat = value.reshape(batch_size, num_block_per_row, block_size, head_dim).mean(dim=-2)
value_hat = value_reshaped.mean(dim=-2)

low_resolution_logit = torch.matmul(query_hat, key_hat.transpose(-1, -2)) / math.sqrt(head_dim)
# Matmul and scaling in one step
low_resolution_logit = torch.matmul(query_hat, key_hat.transpose(-1, -2)).div_(math.sqrt(head_dim))

low_resolution_logit_row_max = low_resolution_logit.max(dim=-1, keepdims=True).values
# Use keepdim keyword spelling for compatibility and reduce allocation
low_resolution_logit_row_max = low_resolution_logit.max(dim=-1, keepdim=True).values

if mask is not None:
low_resolution_logit = (
low_resolution_logit - 1e4 * ((token_count[:, None, :] * token_count[:, :, None]) < 0.5).float()
)
# Use torch.mul and direct broadcasting for speed, and avoid unnecessary creation of temporaries
mask_matrix = (token_count[:, None, :] * token_count[:, :, None]) < 0.5
# Only cast once and multiply directly
low_resolution_logit = low_resolution_logit - mask_matrix.float().mul_(1e4)

return low_resolution_logit, token_count, low_resolution_logit_row_max, value_hat

Expand Down