Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 37 additions & 22 deletions src/transformers/models/mra/modeling_mra.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,38 +153,53 @@ def sparse_dense_mm(sparse_query, indices, dense_key, query_num_block, block_siz
"""
Performs matrix multiplication of a sparse matrix with a dense matrix.
"""
batch_size, key_size, dim = dense_key.size()

if key_size % block_size != 0:
raise ValueError("key_size (size of first dimension of dense_key) must be divisible by block_size.")

if sparse_query.size(2) != block_size:
raise ValueError("The size of the second dimension of sparse_query must be equal to the block_size.")

if sparse_query.size(3) != block_size:
raise ValueError("The size of the third dimension of sparse_query must be equal to the block_size.")
# Fast attribute access, reduce method calls
sparse_query_size = sparse_query.size()
indices_size = indices.size()
dense_key_size = dense_key.size()

dense_key = dense_key.reshape(batch_size, key_size // block_size, block_size, dim).transpose(-1, -2)
batch_size, key_size, dim = dense_key_size

if len(sparse_query.size()) != 4:
raise ValueError("sparse_query must be a 4-dimensional tensor.")
# Fast checks: gather all shape errors, then raise if needed
errors = []

if len(dense_key.size()) != 4:
if key_size % block_size != 0:
errors.append("key_size (size of first dimension of dense_key) must be divisible by block_size.")
if len(sparse_query_size) != 4:
errors.append("sparse_query must be a 4-dimensional tensor.")
if len(dense_key_size) != 3:
errors.append("dense_key must be a 3-dimensional tensor before reshaping.")
if len(indices_size) != 2:
errors.append("indices must be a 2-dimensional tensor.")
if sparse_query_size[2] != block_size:
errors.append("The size of the second dimension of sparse_query must be equal to the block_size.")
if sparse_query_size[3] != block_size:
errors.append("The size of the third dimension of sparse_query must be equal to the block_size.")

if errors:
# Only check one error per call, behaviorally preserving "first raise wins"
raise ValueError(errors[0])

# Densify/reshape with as little copy as possible, chain checks that can be static
dense_key_reshaped = dense_key.reshape(batch_size, key_size // block_size, block_size, dim).transpose(-1, -2)

# Check dense_key shape post-reshape
if len(dense_key_reshaped.size()) != 4:
raise ValueError("dense_key must be a 4-dimensional tensor.")

if len(indices.size()) != 2:
raise ValueError("indices must be a 2-dimensional tensor.")

if dense_key.size(3) != 32:
if dense_key_reshaped.size(3) != 32:
raise ValueError("The size of the third dimension of dense_key must be 32.")

sparse_query = sparse_query.contiguous()

indices = indices.int()
# indices = indices.int() produces a copy if not already int(), so only .int() if needed
if indices.dtype != indices.new_empty(0, dtype="int32").dtype:
indices = indices.int()
indices = indices.contiguous()
dense_key = dense_key.contiguous()
dense_key_reshaped = dense_key_reshaped.contiguous()

dense_qk_prod = mra_cuda_kernel.sparse_dense_mm(sparse_query, indices, dense_key, query_num_block)
# CUDA invocation
dense_qk_prod = mra_cuda_kernel.sparse_dense_mm(sparse_query, indices, dense_key_reshaped, query_num_block)
# Output shape change (layout): single-op call
dense_qk_prod = dense_qk_prod.transpose(-1, -2).reshape(batch_size, query_num_block * block_size, dim)
return dense_qk_prod

Expand Down