Skip to content

SmallDoges/flash-dmattn

Repository files navigation

Flash Attention and Dynamic Mask Attention Integration

Table of Contents

  1. Introduction
  2. Flash Attention Algorithm
  3. Dynamic Mask Attention (DMA) Algorithm
  4. Comparative Analysis
  5. Flash-DMA Integration
  6. Technical Implementation Challenges
  7. Optimization Strategies
  8. Implementation Roadmap
  9. API Design
  10. Conclusion

Introduction

This document provides a comprehensive explanation of Flash Attention and Dynamic Mask Attention (DMA) algorithms, along with a detailed proposal for integrating these approaches. The goal is to combine the memory efficiency of Flash Attention with the computational efficiency of DMA to create a high-performance attention mechanism for large sequence processing.

As Transformer models continue to scale to longer sequences and larger batch sizes, the attention mechanism becomes a significant bottleneck in terms of both memory usage and computational efficiency. Flash Attention addresses the memory bottleneck by using a block-based approach that avoids materializing the full attention matrix, while Dynamic Mask Attention reduces computational complexity by selectively focusing on the most important keys for each query.

By integrating these complementary approaches, we aim to create an attention mechanism that can efficiently handle extremely long sequences while maintaining high computational throughput and numerical accuracy.

Flash Attention Algorithm

Key Concepts

Flash Attention is built on several key innovations that distinguish it from standard attention implementations:

  1. Block-based Processing: Instead of computing the entire attention matrix at once, Flash Attention divides it into blocks and processes them iteratively, substantially reducing memory requirements.

  2. Online Softmax Algorithm: Flash Attention uses an online algorithm to compute softmax progressively as blocks are processed, maintaining numerical stability without storing the full attention matrix.

  3. Tiling for Shared Memory: The algorithm uses carefully designed tiling strategies to maximize data reuse in GPU shared memory, minimizing global memory accesses.

  4. Mixed Precision Computation: Flash Attention performs accumulation in higher precision (e.g., FP32) while storing intermediate results in lower precision (e.g., FP16/BF16).

  5. Log-Sum-Exp (LSE) Tracking: For numerical stability and to enable the online softmax, the algorithm maintains rolling LSE values.

Core Algorithm Overview

At a high level, Flash Attention works by processing the attention computation in blocks:

  1. Initialization: Set up data structures for the output and the LSE values.

  2. Block-wise Computation: For each block of queries:

    • For each block of keys:
      • Load query and key blocks into shared memory
      • Compute the attention scores (Q·K^T) for this block
      • Apply masking (causal, padding, etc.) if needed
      • Update running max values and exponential sums for softmax
      • Load value block into shared memory
      • Compute weighted values and update output
  3. Normalization: Apply final normalization using the accumulated LSE values.

The key insight is that by processing blocks in a specific order and maintaining sufficient statistics (max values and sums for softmax), Flash Attention can produce exactly the same result as standard attention without materializing the full attention matrix.

Implementation Details

The Flash Attention implementation in flash_attention_fwd_kernel.h employs several sophisticated techniques:

Memory Management and Tensor Layout

Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
                        typename Kernel_traits::SmemLayoutQ{});
Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)),
                        typename Kernel_traits::SmemLayoutKV{});
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});

The code meticulously manages shared memory to store query, key, and value blocks. It often uses the same memory for different tensors at different phases of the computation to minimize memory usage.

Block Processing Loop

The core computation happens in two phases:

  1. Processing blocks with masking (for causal or local attention):
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
    // Load V for current block
    // Compute Q*K^T for current block
    // Apply masking
    // Update softmax stats
    // Compute attention output with V
}
  1. Processing remaining blocks without masking:
for (; n_block >= n_block_min; --n_block) {
    // Similar process but without causal masking
}

This separation optimizes performance by avoiding unnecessary masking operations where possible.

Online Softmax Implementation

The online softmax algorithm is a critical component that enables processing in blocks:

masking_step == 0
    ? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2)
    : softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2);

For each block, it:

  1. Updates the running maximum value
  2. Scales previous accumulated values if the maximum changed
  3. Computes normalized values for the current block
  4. Updates the running sum of exponentials

This allows stable softmax computation without materializing the full attention matrix.

Performance Characteristics

Flash Attention achieves significant performance improvements over standard attention:

  1. Memory Complexity: Reduces memory usage from O(N²) to O(N), where N is the sequence length.

  2. Memory Bandwidth Optimization: Carefully designed to minimize HBM (high-bandwidth memory) accesses through shared memory reuse.

  3. Throughput: Achieves up to 3-5x speedup over standard attention implementations for long sequences.

  4. Scaling Efficiency: Performance gains increase with sequence length, making it particularly effective for long-sequence tasks.

  5. Numerical Accuracy: Produces exactly the same results as standard attention (within floating-point error margins) despite the block-based approach.

Dynamic Mask Attention Algorithm

DMA Key Concepts

Dynamic Mask Attention (DMA) introduces a different approach to optimizing attention by focusing on reducing the computational complexity:

  1. Selective Key Processing: DMA processes only a subset of keys for each query, determined by a learned importance criterion.

  2. Importance-based Selection: A learned projection matrix transforms values to generate importance scores that determine which keys to keep.

  3. Top-K Filtering: For each query, only the top-k keys with the highest importance scores are used for attention computation.

  4. Sparse Attention Computation: By computing attention only with selected keys, DMA substantially reduces the computational complexity.

  5. Dynamic Per-Query Selection: Unlike static sparse patterns, the selection is dynamic and specific to each query.

DMA Core Algorithm

The DMA algorithm consists of the following steps:

  1. Value Transformation: Project value states using a learned matrix to create importance scores:

    dt_result = matmul(value_states.transpose(-2, -3).reshape(batch_size, key_len, -1), dt_proj.T)
    
  2. Importance Score Generation: Apply activation function and scaling to generate scores:

    zero_hold_states = exp(softplus(dt_result) * A)
    
  3. Masking: Apply causal or other masking to the importance scores if needed:

    zero_hold_state = zero_hold_states[b_idx, kv_idx, q_idx, :].masked_fill(causal_mask[b_idx, 0, q_idx, :] != 0, 0)
    
  4. Top-K Selection: Select the most important keys based on scores:

    topk_values, topk_indices = torch.topk(zero_hold_state, keep_window_size, dim=-1)
    dynamic_mask = torch.zeros_like(zero_hold_state)
    dynamic_mask.scatter_(-1, topk_indices, topk_values)
    
  5. Sparse Attention Computation: Compute attention only for the selected keys:

    mask_indices = non_zero_mask_indices(dynamic_mask)
    k_vecs = key_states[b_idx, kv_idx, mask_indices, :]
    v_vecs = value_states[b_idx, kv_idx, mask_indices, :]
    
  6. Weighted Sum Computation: Calculate the final attention output:

    attn_weight = torch.sum(q_vec.unsqueeze(0) * k_vecs, dim=-1)
    attn_weight = attn_weight + dynamic_mask[mask_indices]
    attn_weight = F.softmax(attn_weight, dim=-1)
    attn_output = torch.sum(attn_weight.unsqueeze(1) * v_vecs, dim=0)
    

DMA Implementation Details

The Dynamic Mask Attention implementation from dma.py has several notable features:

Value Transformation and Importance Scoring

dt_result = torch.matmul(value_states.transpose(-2, -3).reshape(batch_size, key_len, -1), dt_proj.T)
zero_hold_states = torch.exp(F.softplus(dt_result) * A).transpose(-1, -2).unsqueeze(-2).expand(-1, -1, query_len, -1)

This transformation is crucial as it determines which keys are important for each query. The dt_proj matrix and coefficient A are learned parameters that control the selection process.

Key Selection via Top-K

if key_len > keep_window_size:
    topk_values, topk_indices = torch.topk(zero_hold_state, keep_window_size, dim=-1)
    dynamic_mask = torch.zeros_like(zero_hold_state)
    dynamic_mask.scatter_(-1, topk_indices, topk_values)
else:
    dynamic_mask = zero_hold_state

This selective process is what gives DMA its computational advantage. By processing only keep_window_size keys instead of the full key_len, it reduces the computational complexity significantly.

Sparse Attention Computation

mask_indices = non_zero_mask_indices(dynamic_mask)
if len(mask_indices) == 0:
    continue

k_vecs = key_states[b_idx, kv_idx, mask_indices, :] # [keep_window_size, head_dim]
v_vecs = value_states[b_idx, kv_idx, mask_indices, :] # [keep_window_size, head_dim]

This sparse computation is fundamentally different from both standard attention and Flash Attention. Instead of processing all keys, it only processes the selected ones, which can be a small fraction of the total.

Multi-Query Attention Support

for q_group_idx in range(num_queries_per_kv):
    h_idx = kv_idx * num_queries_per_kv + q_group_idx
    q_vec = query_states[b_idx, h_idx, q_idx, :] # [head_dim]
    
    # Compute attention and output for this query
    # ...

DMA naturally supports multi-query attention (MQA) and grouped-query attention (GQA) where multiple query heads can share the same key-value pairs.

DMA Performance Characteristics

Dynamic Mask Attention offers distinct performance advantages:

  1. Computational Complexity: Reduces computation from O(N²) to O(N*k) where k is the number of selected keys (typically k << N).

  2. Memory Usage: The Python implementation still requires O(N²) memory for initialization, but a CUDA implementation could achieve O(N) memory usage.

  3. Adaptability: The key selection adapts to the content, making it more effective for diverse attention patterns compared to fixed sparse patterns.

  4. Scalability: Performance improvements increase with sequence length, similar to Flash Attention but through a different mechanism.

  5. Training Dynamics: The key selection mechanism is learned during training, allowing the model to adaptively focus on relevant information.

Comparative Analysis

Standard Attention vs. Flash Attention vs. DMA

Here's a comparison of the three attention mechanisms:

Feature Standard Attention Flash Attention Dynamic Mask Attention
Computational Complexity O(N²) O(N²) O(N*k) where k << N
Memory Complexity O(N²) O(N) O(N²) in Python, O(N) possible in CUDA
Key Processing Strategy All keys All keys Selected top-k keys
Implementation Approach Dense matmul Block-based tiling Sparse selection and computation
Masking Support Fixed masks Fixed masks Learned, dynamic masks
MQA/GQA Support Requires adaptation Specialized variants Native support

Memory and Computational Complexity

Standard Attention:

  • Computes and stores the entire N×N attention matrix
  • Memory usage: O(N²)
  • Computation: O(N²D) where D is head dimension

Flash Attention:

  • Never materializes the full attention matrix
  • Memory usage: O(N) + O(B²) where B is block size
  • Computation: Still O(N²D) but with better constants and memory locality

Dynamic Mask Attention:

  • Only computes attention for selected keys
  • Memory usage: O(N²) in naive implementation, O(N) in optimized version
  • Computation: O(NkD) where k is the number of selected keys

Theoretical Performance Model

The integrated Flash-DMA approach offers significant performance benefits that can be quantified:

Memory Complexity:

  • Standard Attention: O(B×H×N²)
  • Flash Attention: O(B×H×N)
  • Flash-DMA: O(B×H×N)

Where B = batch size, H = number of heads, N = sequence length

Computational Complexity:

  • Standard Attention: O(B×H×N²×D)
  • Flash Attention: O(B×H×N²×D)
  • Flash-DMA: O(B×H×N×k×D)

Where D = head dimension, k = average number of selected keys per query

Expected Speedup Model:

  • For sequence length N and selection ratio r = k/N:
    • Theoretical speedup vs. Flash Attention: ~1/r
    • Practical speedup accounting for overhead: ~1/(r + c)

Where c is an implementation-dependent constant representing overhead (estimated 0.05-0.1)

Projected Performance:

Sequence Length Selection Ratio Theoretical Speedup Estimated Practical Speedup
1,024 0.2 5.0× 3.3-4.0×
4,096 0.1 10.0× 5.0-6.7×
16,384 0.05 20.0× 6.7-10.0×
65,536 0.025 40.0× 8.0-13.3×

Note: These estimates assume efficient sparse operations implementation and may vary based on hardware and specific workloads.

Use Cases and Tradeoffs

Standard Attention:

  • Best for short sequences where memory is not a constraint
  • Simplest to implement and debug
  • Compatible with all existing optimization techniques

Flash Attention:

  • Ideal for medium to long sequences
  • When memory bandwidth is the bottleneck
  • When exact attention computation is required

Dynamic Mask Attention:

  • Best for very long sequences where computational cost is prohibitive
  • When the attention pattern is naturally sparse
  • When approximate attention is acceptable

Combined Flash-DMA:

  • Optimal for extremely long sequences (tens of thousands of tokens)
  • When both memory and computation are constraints
  • For applications requiring selective attention with efficient memory usage

Flash-DMA Integration

Motivation and Benefits

The integration of Flash Attention and Dynamic Mask Attention creates a powerful combination:

  1. Complementary Strengths: Flash Attention optimizes memory usage, while DMA reduces computation through selective key processing.

  2. Extended Sequence Length Support: The combined approach could efficiently handle sequences of 100K tokens or more.

  3. Memory and Computation Optimization: Achieves both O(N) memory complexity and O(N*k) computational complexity.

  4. Hardware Efficiency: Maintains Flash Attention's optimized memory access patterns while reducing the number of operations.

  5. Adaptive Processing: The dynamic selection mechanism allows the model to focus computational resources on the most relevant parts of the input.

Architectural Overview

The integrated Flash-DMA approach modifies the Flash Attention algorithm in three key ways:

  1. Key Selection Phase: Adds a preprocessing step that determines important keys using the DMA selection mechanism.

  2. Sparse Block Processing: Modifies the block-based processing to only compute attention for selected keys within each block.

  3. Memory Management for Selected Indices: Adds efficient handling of the selected key indices.

The high-level architecture looks like this:

┌─────────────────┐      ┌───────────────────┐      ┌──────────────────┐
│                 │      │                   │      │                  │
│  Key Selection  │─────▶│  Sparse Block     │─────▶│  Output          │
│  Phase          │      │  Processing       │      │  Normalization   │
│                 │      │                   │      │                  │
└─────────────────┘      └───────────────────┘      └──────────────────┘

Key Selection Integration

The key selection phase can be integrated as a preprocessing step:

template<typename Kernel_traits, bool Is_causal, typename Params>
inline __device__ void compute_key_importance(
    const Params &params,
    Tensor& value_states,
    Tensor& dt_proj,
    Tensor& importance_scores,
    float scale_factor
) {
    // Get thread and block indices
    const int tidx = threadIdx.x;
    const int bidb = blockIdx.y;
    const int bidh = blockIdx.z;
    
    // Transform values using the projection matrix
    // This is equivalent to: dt_result = matmul(value_states, dt_proj.T)
    // But implemented as a block-based matrix multiplication
    
    // Apply softplus activation and scaling
    // This is equivalent to: importance_scores = exp(softplus(dt_result) * scale_factor)
    // But implemented in a numerically stable way
    
    // Handle causal masking if needed
    if (Is_causal) {
        apply_causal_mask(importance_scores, params);
    }
}

This would be followed by a parallel top-k selection kernel that identifies the most important keys:

template<typename Kernel_traits, typename Params>
inline __device__ void select_top_k_keys(
    const Params &params,
    Tensor& importance_scores,
    Tensor& selected_indices,
    int keep_window_size
) {
    // Use parallel reduction to find the top-k values and their indices
    // Store the result in selected_indices
}

Block Processing Modifications

The core block processing loop would be modified to only process selected keys:

for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
    // Determine which keys in this block were selected
    int block_start = n_block * kBlockN;
    int block_end = min((n_block + 1) * kBlockN, params.seqlen_k);
    
    // Get indices of selected keys in this block
    Tensor block_selected_indices = get_block_selected_indices(
        selected_indices, block_start, block_end
    );
    
    if (block_selected_indices.size() == 0) {
        // Skip this block if no keys were selected
        continue;
    }
    
    // Load only selected keys and values
    load_selected_kv(block_selected_indices, tKgK, tKsK, tVgV, tVsV);
    
    // Compute Q*K^T only for selected keys
    compute_sparse_qk(acc_s, tSrQ, tSrK, block_selected_indices);
    
    // Apply masking, softmax, etc. similar to Flash Attention
    
    // Compute attention output with selected V
    compute_sparse_attention_output(acc_o, acc_s, tOrVt, block_selected_indices);
}

Memory Management Strategies

Efficient management of selected key indices is critical for performance:

  1. Global Storage: Store all selected indices in global memory, compressed to minimize space.

  2. Block-Level Filtering: For each block, filter the global indices to identify which ones fall within the current block.

  3. Shared Memory Caching: Load relevant indices for the current block into shared memory for fast access.

__shared__ int smem_indices[MAX_SELECTED_PER_BLOCK];
__shared__ int smem_num_selected;

if (threadIdx.x == 0) {
    // Initialize counter
    smem_num_selected = 0;
}
__syncthreads();

// Each thread processes some subset of the global indices
for (int i = threadIdx.x; i < num_global_selected; i += blockDim.x) {
    int global_idx = global_selected_indices[i];
    if (global_idx >= block_start && global_idx < block_end) {
        // Use atomic add to get a unique position in the shared memory array
        int pos = atomicAdd(&smem_num_selected, 1);
        if (pos < MAX_SELECTED_PER_BLOCK) {
            smem_indices[pos] = global_idx - block_start;  // Convert to block-local index
        }
    }
}
__syncthreads();

Technical Implementation Challenges

Sparse Operations in CUDA

Implementing sparse operations efficiently in CUDA presents several challenges:

  1. Irregular Memory Access: Accessing only selected elements leads to non-coalesced memory access patterns, which can significantly degrade performance.

  2. Sparse Matrix Multiplication: Efficiently computing QK^T and attentionV when only a subset of K is used requires specialized sparse matrix multiplication routines.

  3. Dynamic Sparsity Pattern: Unlike static sparse matrices, the sparsity pattern in DMA is determined at runtime and differs for each query.

Potential solutions include:

  1. Specialized Sparse Kernels: Implementing optimized CUDA kernels for the specific sparsity patterns encountered in DMA.

  2. Coalescing Through Reordering: Reordering selected keys to improve memory access patterns.

  3. Batched Processing: Grouping queries with similar selected keys to reduce divergence.

Concrete CUDA Implementation Details

The core Flash-DMA integration requires specific CUDA implementation techniques:

Key Selection Kernel

template<typename Kernel_traits, bool Is_causal>
__global__ void compute_and_select_keys_kernel(
    const typename Kernel_traits::Params params,
    int* selected_indices,       // Output: indices of selected keys [batch, heads, query, top_k]
    float* importance_scores     // Optional: store importance scores for debugging
) {
    // Block/thread indices
    const int tid = threadIdx.x;
    const int bid = blockIdx.x;
    const int batch_id = blockIdx.y;
    const int head_id = blockIdx.z;
    
    // Shared memory for collaborative filtering
    __shared__ float sm_scores[BLOCK_SIZE];
    __shared__ int sm_indices[BLOCK_SIZE];
    
    // Each thread computes importance scores for a subset of keys
    for (int key_idx = tid; key_idx < params.seqlen_k; key_idx += BLOCK_SIZE) {
        // Project value to get importance score using dt_proj
        float score = 0.0f;
        for (int d = 0; d < params.d; d++) {
            int v_idx = batch_id * params.v_batch_stride + key_idx * params.v_row_stride + 
                        head_id * params.v_head_stride + d;
            int proj_idx = head_id * params.dt_head_stride + d;
            score += params.v_ptr[v_idx] * params.dt_proj_ptr[proj_idx];
        }
        
        // Apply softplus and scaling
        score = log1pf(expf(score)) * params.a_coef_ptr[head_id];
        
        // Apply causal masking if needed
        if (Is_causal && key_idx >= params.query_positions[bid]) {
            score = -INFINITY;
        }
        
        // Store in shared memory
        sm_scores[tid] = score;
        sm_indices[tid] = key_idx;
        __syncthreads();
        
        // Parallel reduction for top-k
        for (int k = 0; k < params.keep_window_size && k < BLOCK_SIZE; k++) {
            // Find max score and its position
            float max_score = -INFINITY;
            int max_pos = -1;
            for (int i = 0; i < BLOCK_SIZE; i++) {
                if (sm_scores[i] > max_score) {
                    max_score = sm_scores[i];
                    max_pos = i;
                }
            }
            
            // Only thread 0 writes the result
            if (tid == 0 && max_pos >= 0) {
                int out_idx = batch_id * params.batch_stride + head_id * params.head_stride + 
                             bid * params.query_stride + k;
                selected_indices[out_idx] = sm_indices[max_pos];
                if (importance_scores != nullptr) {
                    importance_scores[out_idx] = sm_scores[max_pos];
                }
            }
            __syncthreads();
            
            // Mark the max element as processed
            if (tid == max_pos) {
                sm_scores[tid] = -INFINITY;
            }
            __syncthreads();
        }
    }
}

Sparse Block Processing

template<typename Kernel_traits, bool Is_causal>
__device__ void process_sparse_block(
    const typename Kernel_traits::Params params,
    const int* selected_indices,    // [batch, heads, query, top_k]
    int block_start,                // Starting key index of current block
    int block_end,                  // Ending key index of current block
    Tensor& tSrQ,                   // Query in registers
    Tensor& acc_s,                  // Accumulator for scores
    Tensor& acc_o                   // Accumulator for output
) {
    const int tid = threadIdx.x;
    const int batch_id = blockIdx.y;
    const int head_id = blockIdx.z;
    const int query_block = blockIdx.x;
    const int query_idx = query_block * BLOCK_SIZE + tid / 32;  // Query index for this thread
    
    // Find which selected keys fall into this block
    __shared__ int sm_block_indices[MAX_SELECTED_PER_BLOCK];
    __shared__ int sm_block_count;
    
    if (tid == 0) {
        sm_block_count = 0;
    }
    __syncthreads();
    
    // Each thread checks some of the selected indices
    const int idx_offset = batch_id * params.batch_stride + head_id * params.head_stride + 
                          query_idx * params.query_stride;
    
    for (int k = tid % 32; k < params.keep_window_size; k += 32) {
        int key_idx = selected_indices[idx_offset + k];
        if (key_idx >= block_start && key_idx < block_end) {
            int pos = atomicAdd(&sm_block_count, 1);
            if (pos < MAX_SELECTED_PER_BLOCK) {
                sm_block_indices[pos] = key_idx - block_start;  // Convert to block-local index
            }
        }
    }
    __syncthreads();
    
    // Process only selected keys in this block
    for (int i = 0; i < sm_block_count; i++) {
        // Load key and compute attention score for selected key
        int local_key_idx = sm_block_indices[i];
        float key_val = load_key_value(params, batch_id, head_id, block_start + local_key_idx);
        
        // Compute score and update accumulators similar to Flash Attention
        // But only for the selected keys
        // ...
    }
}

Load Balancing and Warp Efficiency

CUDA threads within a warp execute in lockstep, making load balancing critical:

  1. Thread Divergence: When different threads process different numbers of keys, warp divergence can severely impact performance.

  2. Workload Distribution: Efficiently distributing the selected keys across threads and warps to maximize utilization.

  3. Idle Threads: Managing threads that have no keys to process in their assigned range.

Strategies to address these challenges:

  1. Work Stealing: Implementing work-stealing algorithms to redistribute work among threads.

  2. Warp-Level Primitives: Using warp-level voting and shuffle operations for efficient coordination.

  3. Persistent Threads: Keeping threads active and continuously assigning new work as it becomes available.

Cooperative Block Mapping

Instead of assigning fixed thread responsibilities, we implement cooperative mapping where threads dynamically process available work:

__device__ void cooperative_sparse_processing(
    const Params& params,
    int* selected_indices,
    int num_selected
) {
    __shared__ int work_counter;
    if (threadIdx.x == 0) work_counter = 0;
    __syncthreads();
    
    while (true) {
        // Atomically grab the next chunk of work
        int work_idx = -1;
        if (threadIdx.x % 32 == 0) {
            work_idx = atomicAdd(&work_counter, WORK_CHUNK_SIZE);
        }
        // Broadcast result to all threads in warp
        work_idx = __shfl_sync(0xffffffff, work_idx, 0);
        
        if (work_idx >= num_selected) break;
        
        // Process this chunk of work
        int end_idx = min(work_idx + WORK_CHUNK_SIZE, num_selected);
        for (int i = work_idx + threadIdx.x % 32; i < end_idx; i += 32) {
            // Process selected_indices[i]
        }
    }
}

Density-Based Processing Strategy

We dynamically choose between sparse and dense processing based on key selection density:

  1. Query Binning: Group queries based on the number of selected keys in each block

    __shared__ int sparse_queries[MAX_QUERIES_PER_BLOCK];
    __shared__ int dense_queries[MAX_QUERIES_PER_BLOCK];
    __shared__ int sparse_count, dense_count;
    
    // Determine processing mode for each query
    if (threadIdx.x < num_queries_in_block) {
        int query_idx = block_query_base + threadIdx.x;
        int selected_in_block = count_selected_keys_in_block(query_idx, block_idx);
        float density = (float)selected_in_block / BLOCK_SIZE;
        
        if (density > DENSITY_THRESHOLD) {
            int idx = atomicAdd(&dense_count, 1);
            dense_queries[idx] = query_idx;
        } else {
            int idx = atomicAdd(&sparse_count, 1);
            sparse_queries[idx] = query_idx;
        }
    }
  2. Two-Phase Processing: Process dense queries first, then sparse queries

    // Process dense queries (standard Flash Attention)
    for (int i = 0; i < dense_count; i++) {
        process_query_dense(dense_queries[i]);
    }
    
    // Process sparse queries (DMA approach)
    for (int i = 0; i < sparse_count; i++) {
        process_query_sparse(sparse_queries[i]);
    }

Dynamic Workload Distribution

For highly variable workloads, implement dynamic redistribution:

  1. Work Queue System: Maintain a queue of pending work
  2. Persistent Threads: Keep threads active and pulling from queue
  3. Work Stealing: Allow idle blocks to steal work from busy ones
// Global work queue in device memory
struct WorkQueue {
    int queue[MAX_WORK_ITEMS];
    int head;
    int tail;
};

__device__ void process_with_persistent_threads(WorkQueue* queue) {
    while (true) {
        // Atomically get next work item
        int work_idx = -1;
        if (threadIdx.x == 0) {
            if (queue->head < queue->tail) {
                work_idx = atomicAdd(&queue->head, 1);
            }
        }
        work_idx = __shfl_sync(0xffffffff, work_idx, 0);
        
        if (work_idx < 0 || work_idx >= queue->tail) return;
        
        // Process this work item
        process_work_item(queue->queue[work_idx]);
    }
}

These strategies ensure high GPU utilization even with irregular sparsity patterns, minimizing the impact of thread divergence and load imbalance.

Validation Strategy

Ensuring correctness and performance:

  1. Correctness Validation:

    • Compare outputs against standard attention for small examples
    • Validate intermediate results at each stage
    • Test with various mask configurations and sequence lengths
  2. Performance Validation:

    • Benchmark against Flash Attention and DMA separately
    • Test with varying sequence lengths and batch sizes
    • Measure memory usage and computational throughput
  3. Integration Testing:

    • Verify behavior when integrated into transformer models
    • Test impact on model convergence and accuracy
    • Validate across different hardware platforms

Comprehensive Benchmarking Framework

To rigorously evaluate the Flash-DMA integration, we propose a structured benchmarking framework:

Performance Metrics

Metric Description Measurement Method
Throughput Tokens/second processed Time entire forward pass and divide by total tokens
Memory Usage Peak memory consumption Track GPU memory allocation high-water mark
Computational Efficiency FLOPS utilization Compare achieved vs. theoretical FLOPS
Sparsity Efficiency Speedup relative to density Measure performance across varying selection ratios
Scaling Efficiency Performance vs. sequence length Benchmark with exponentially increasing lengths

Benchmark Scenarios

  1. Synthetic Benchmarks

    • Uniform random data
    • Controlled sparsity patterns
    • Variable sequence lengths (128 to 128K)
    • Different batch sizes and head configurations
  2. Real-world Workloads

    • Language modeling (long document processing)
    • Vision transformers (high-resolution images)
    • Multi-modal transformers
    • Time series analysis
  3. Ablation Studies

    • Effect of selection ratio (k/N)
    • Impact of block sizes
    • Influence of key selection algorithms
    • Dense vs. sparse block processing thresholds

Comparison Methodology

def run_benchmark_suite(models, datasets, configs):
    results = {}
    
    for model_name, model_fn in models.items():
        for dataset_name, dataset in datasets.items():
            for config_name, config in configs.items():
                # Initialize model with configuration
                model = model_fn(**config)
                
                # Warm-up runs
                for _ in range(WARMUP_STEPS):
                    run_forward(model, dataset.get_batch(batch_size))
                
                # Timed runs
                times = []
                mem_usage = []
                for i in range(BENCHMARK_STEPS):
                    batch = dataset.get_batch(batch_size)
                    
                    # Record memory before
                    mem_before = torch.cuda.max_memory_allocated()
                    torch.cuda.reset_peak_memory_stats()
                    
                    # Time the forward pass
                    start = torch.cuda.Event(enable_timing=True)
                    end = torch.cuda.Event(enable_timing=True)
                    start.record()
                    output = run_forward(model, batch)
                    end.record()
                    torch.cuda.synchronize()
                    
                    # Record metrics
                    elapsed_time = start.elapsed_time(end) / 1000  # seconds
                    times.append(elapsed_time)
                    mem_used = torch.cuda.max_memory_allocated() - mem_before
                    mem_usage.append(mem_used)
                
                # Save results
                results[f"{model_name}_{dataset_name}_{config_name}"] = {
                    "mean_time": np.mean(times),
                    "std_time": np.std(times),
                    "tokens_per_second": batch_size * seq_len / np.mean(times),
                    "mean_memory": np.mean(mem_usage),
                    "peak_memory": np.max(mem_usage)
                }
    
    return results

Visualization and Analysis Tools

  1. Performance Curves

    • Speedup vs. sequence length
    • Memory usage vs. sequence length
    • Throughput vs. sparsity ratio
  2. Profiling Integration

    • NVIDIA Nsight integration
    • Kernel execution timelines
    • Memory access pattern analysis
  3. Automated Regression Testing

    • CI/CD integration
    • Comparison against baseline implementations
    • Performance regression alerts

This comprehensive benchmarking framework will provide actionable insights for optimizing the Flash-DMA implementation and quantify its benefits across diverse workloads and configurations.

Validation Testing

Ensuring correctness and performance:

  1. Correctness Validation:

    • Compare outputs against standard attention for small examples
    • Validate intermediate results at each stage
    • Test with various mask configurations and sequence lengths
  2. Performance Validation:

    • Benchmark against Flash Attention and DMA separately
    • Test with varying sequence lengths and batch sizes
    • Measure memory usage and computational throughput
  3. Integration Testing:

    • Verify behavior when integrated into transformer models
    • Test impact on model convergence and accuracy
    • Validate across different hardware platforms

API Design

Python Interface

A user-friendly Python API for Flash-DMA:

def flash_dma_attention(
    query: torch.Tensor,               # [batch_size, seq_len_q, num_heads, head_dim]
    key: torch.Tensor,                 # [batch_size, seq_len_k, num_kv_heads, head_dim]
    value: torch.Tensor,               # [batch_size, seq_len_k, num_kv_heads, head_dim]
    dt_proj: torch.Tensor,             # [num_kv_heads, num_kv_heads * head_dim]
    a_coef: torch.Tensor,              # [num_kv_heads]
    keep_window_size: int = 1024,      # Number of keys to keep per query
    dropout_p: float = 0.0,
    causal: bool = False,
    softmax_scale: Optional[float] = None,
    return_attn_probs: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    """
    Compute attention using the integrated Flash-DMA approach.
    
    Args:
        query: Query tensor
        key: Key tensor
        value: Value tensor
        dt_proj: Projection matrix for value transformation
        a_coef: Scaling coefficient for importance scores
        keep_window_size: Number of keys to keep per query
        dropout_p: Dropout probability
        causal: Whether to apply causal masking
        softmax_scale: Scale factor for softmax (default: 1/sqrt(head_dim))
        return_attn_probs: Whether to return attention probabilities
    
    Returns:
        attention_output: Output tensor
        attention_probs: Attention probabilities (optional)
    """
    # Implementation

Configuration Options

Configuration options to control behavior:

  1. Selection Parameters:

    • keep_window_size: Number of keys to keep per query
    • selection_method: Algorithm for top-k selection ("exact", "approximate")
    • min_density_threshold: Minimum density for dense processing
  2. Processing Options:

    • block_size: Size of blocks for processing
    • mixed_processing: Whether to use mixed dense/sparse processing
    • use_reordering: Whether to reorder keys for better memory access
  3. Memory Management:

    • max_sequence_length: Maximum supported sequence length
    • max_batch_size: Maximum supported batch size
    • max_selection_ratio: Maximum ratio of keys to select (for memory allocation)

Integration with Existing Frameworks

Seamless integration with existing frameworks:

  1. PyTorch Integration:

    • Drop-in replacement for torch.nn.MultiheadAttention
    • Compatible with PyTorch's autograd system
    • Support for distributed training
  2. Hugging Face Transformers:

    • Compatible with Hugging Face attention implementations
    • Integration with popular transformer architectures
    • Support for flash-attention configuration options
  3. NVIDIA Optimizations:

    • Compatibility with NVIDIA's Deep Learning Examples
    • Support for TensorRT integration
    • Optimizations for different GPU architectures

Conclusion

Summary of Benefits

The integrated Flash-DMA approach offers significant advantages:

  1. Memory Efficiency: Maintains Flash Attention's O(N) memory complexity.

  2. Computational Efficiency: Achieves DMA's O(N*k) computational complexity.

  3. Scalability: Enables efficient processing of extremely long sequences (100K tokens and beyond).

  4. Adaptive Processing: Focuses computational resources on the most important keys.

  5. Hardware Optimization: Maximizes GPU utilization through careful memory management and access patterns.

Future Directions

Potential areas for future research and development:

  1. Automatic Parameter Tuning: Dynamically adjust the number of keys to select based on sequence content and hardware capabilities.

  2. Multi-GPU Scaling: Extend the algorithm for efficient multi-GPU implementation to handle even longer sequences.

  3. Alternative Selection Criteria: Explore different mechanisms for determining key importance.

  4. Architecture-Specific Optimizations: Develop specialized versions for different GPU architectures.

  5. Integration with Other Attention Variants: Combine with other attention optimizations like linear attention or gated attention.

This integrated approach represents a significant step forward in making transformer models more efficient and capable of handling longer contexts, potentially enabling new applications in long-document processing, genomics, and other domains requiring analysis of very long sequences.

About

Flash Dynamic Mask Attention

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 4

  •  
  •  
  •  
  •