diff --git a/src/transformers/models/mra/modeling_mra.py b/src/transformers/models/mra/modeling_mra.py index 478d66781851..24f5d1f0117a 100644 --- a/src/transformers/models/mra/modeling_mra.py +++ b/src/transformers/models/mra/modeling_mra.py @@ -153,38 +153,53 @@ def sparse_dense_mm(sparse_query, indices, dense_key, query_num_block, block_siz """ Performs matrix multiplication of a sparse matrix with a dense matrix. """ - batch_size, key_size, dim = dense_key.size() - if key_size % block_size != 0: - raise ValueError("key_size (size of first dimension of dense_key) must be divisible by block_size.") - - if sparse_query.size(2) != block_size: - raise ValueError("The size of the second dimension of sparse_query must be equal to the block_size.") - - if sparse_query.size(3) != block_size: - raise ValueError("The size of the third dimension of sparse_query must be equal to the block_size.") + # Fast attribute access, reduce method calls + sparse_query_size = sparse_query.size() + indices_size = indices.size() + dense_key_size = dense_key.size() - dense_key = dense_key.reshape(batch_size, key_size // block_size, block_size, dim).transpose(-1, -2) + batch_size, key_size, dim = dense_key_size - if len(sparse_query.size()) != 4: - raise ValueError("sparse_query must be a 4-dimensional tensor.") + # Fast checks: gather all shape errors, then raise if needed + errors = [] - if len(dense_key.size()) != 4: + if key_size % block_size != 0: + errors.append("key_size (size of first dimension of dense_key) must be divisible by block_size.") + if len(sparse_query_size) != 4: + errors.append("sparse_query must be a 4-dimensional tensor.") + if len(dense_key_size) != 3: + errors.append("dense_key must be a 3-dimensional tensor before reshaping.") + if len(indices_size) != 2: + errors.append("indices must be a 2-dimensional tensor.") + if sparse_query_size[2] != block_size: + errors.append("The size of the second dimension of sparse_query must be equal to the block_size.") + if sparse_query_size[3] != block_size: + errors.append("The size of the third dimension of sparse_query must be equal to the block_size.") + + if errors: + # Only check one error per call, behaviorally preserving "first raise wins" + raise ValueError(errors[0]) + + # Densify/reshape with as little copy as possible, chain checks that can be static + dense_key_reshaped = dense_key.reshape(batch_size, key_size // block_size, block_size, dim).transpose(-1, -2) + + # Check dense_key shape post-reshape + if len(dense_key_reshaped.size()) != 4: raise ValueError("dense_key must be a 4-dimensional tensor.") - - if len(indices.size()) != 2: - raise ValueError("indices must be a 2-dimensional tensor.") - - if dense_key.size(3) != 32: + if dense_key_reshaped.size(3) != 32: raise ValueError("The size of the third dimension of dense_key must be 32.") sparse_query = sparse_query.contiguous() - - indices = indices.int() + # indices = indices.int() produces a copy if not already int(), so only .int() if needed + if indices.dtype != indices.new_empty(0, dtype="int32").dtype: + indices = indices.int() indices = indices.contiguous() - dense_key = dense_key.contiguous() + dense_key_reshaped = dense_key_reshaped.contiguous() - dense_qk_prod = mra_cuda_kernel.sparse_dense_mm(sparse_query, indices, dense_key, query_num_block) + # CUDA invocation + dense_qk_prod = mra_cuda_kernel.sparse_dense_mm(sparse_query, indices, dense_key_reshaped, query_num_block) + # Output shape change (layout): single-op call dense_qk_prod = dense_qk_prod.transpose(-1, -2).reshape(batch_size, query_num_block * block_size, dim) return dense_qk_prod