Flash-DMA is a high-performance attention implementation that integrates Flash Attention's memory efficiency with Dynamic Mask Attention's computational efficiency for processing extremely long sequences in transformer models.
-
Sparse Attention Computation: Dynamically selects the most important keys for each query, reducing computation from
$O(N^2)$ to$O(N \cdot k)$ where$k \ll N$ . -
Memory Efficiency: Maintains Flash Attention's
$O(N)$ memory complexity without materializing the full attention matrix. - CUDA-Accelerated: Deep integration at the CUDA kernel level for maximum performance.
- Long Sequence Support: Efficiently handles sequences of 128K+ tokens that would be impractical with standard attention.
- Backward Compatible: API compatible with existing Flash Attention implementations.
- Python: 3.7 or later
- PyTorch: 1.10.0 or later
- CUDA: 11.0 or later (for GPU acceleration)
- NVIDIA GPU: Compute Capability 6.0 or higher
- C++ Compiler: GCC 7+ or compatible
Ensure your CUDA environment is properly configured:
# Check CUDA installation
nvcc --version
# Set CUDA_HOME if needed
export CUDA_HOME=/usr/local/cuda
git clone https://github.com/SmallDoges/flash-dmattn.git
cd flash-dmattn
git submodule update --init --recursive
pip install .
Flash-DMA combines two complementary techniques:
- Dynamic Mask Attention: Computes relevance scores for keys and selects only the most important ones for attention computation
- Flash Attention: Processes attention in blocks to reduce memory usage and HBM access
The integration happens at the CUDA kernel level with several key components:
- ZOH States: Pre-computed importance scores for key selection
- Active Masks: Binary masks indicating which keys should be considered for each query
- Sparse Matrix Multiplication: Custom CUDA kernels for efficient sparse attention computation
- Block-Based Processing: Maintains Flash Attention's block-based approach for memory efficiency
This creates a hybrid attention mechanism that achieves both memory and computational efficiency.
For detailed technical documentation, see:
- Integration Guide - Comprehensive technical details
- API Reference - Function signatures and parameters
Important
TODO
# Clone with submodules
git clone --recursive https://github.com/SmallDoges/flash-dmattn.git
cd flash-dmattn
# Build in development mode
pip install -e .
- CUDA Toolkit 11.0+
- CUTLASS library (included as submodule)
- CUB library (included as submodule)
- SM 6.0+ (Pascal, Volta, Turing, Ampere, Ada Lovelace)
- Optimized for SM 8.0+ (Ampere and newer)
# Gradient equivalent benchmarks
python benchmarks/benchmark_grad.py
Component | Supported Versions |
---|---|
PyTorch | 1.10.0+ |
CUDA | 11.0+ |
Python | 3.7+ |
GPU Arch | SM 6.0+ |
Compilation Errors
# Ensure CUDA_HOME is set
export CUDA_HOME=/usr/local/cuda
# Update NVCC if needed
which nvcc
Performance Issues
- Ensure GPU has sufficient compute capability (6.0+)
- Use appropriate data types (float16 recommended)
- Verify CUDA kernels are being used (not CPU fallback)
This project is licensed under the BSD 3-Clause License. See LICENSE for details.
If you use Flash-DMA in your research, please cite:
@misc{flash_dma_2025,
title={Trainable Dynamic Mask Sparse Attention},
author={Jingze Shi and Yifan Wu and Bingheng Wu and Yiran Peng and Yuyu Luo},
year={2025},
url={https://github.com/SmallDoges/flash-dmattn}
}
This project builds upon the excellent work of:
- Flash-Attention by Tri Dao et al.
- NVIDIA CUTLASS library for efficient matrix operations