Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/native/kernels/flash_attention.cu
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
// FlashAttention-2 forward + backward CUDA kernels
// Tiled attention with online softmax, O(S) memory instead of O(S^2).

#define BLOCK_SIZE 32

extern "C" __global__
void flash_attention_forward_f32(
float* __restrict__ out, // [B*H, S, D]
Expand All @@ -16,7 +14,7 @@ void flash_attention_forward_f32(
int causal // 1 for causal masking
) {
int bh = blockIdx.x;
int row = blockIdx.y * BLOCK_SIZE + threadIdx.y;
int row = blockIdx.y * blockDim.y + threadIdx.y;
int d = threadIdx.x;

if (row >= S || d >= D) return;
Expand Down Expand Up @@ -70,7 +68,7 @@ void flash_attention_backward_f32(
int causal
) {
int bh = blockIdx.x;
int row = blockIdx.y * BLOCK_SIZE + threadIdx.y;
int row = blockIdx.y * blockDim.y + threadIdx.y;
int d = threadIdx.x;

if (row >= S || d >= D) return;
Expand Down
14 changes: 8 additions & 6 deletions src/native/src/ops/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,10 @@ pub fn flash_attention(
let lse_ptr = store.dev_ptr(lse_id);

let causal_i = if causal { 1i32 } else { 0i32 };
let block_size = 32u32;
let grid = (bh as u32, (s as u32 + block_size - 1) / block_size, 1);
let block = (d.min(32) as u32, block_size, 1);
let block_x = d as u32;
let block_y = (1024u32 / block_x).min(32);
let grid = (bh as u32, (s as u32 + block_y - 1) / block_y, 1);
let block = (block_x, block_y, 1);

let func = dev.get_func("flash_attention_forward_f32");
unsafe {
Expand Down Expand Up @@ -222,9 +223,10 @@ pub fn flash_attention_backward(
let lse_ptr = store.dev_ptr(*lse);

let causal_i = if *causal { 1i32 } else { 0i32 };
let block_size = 32u32;
let grid = (bh as u32, (*s as u32 + block_size - 1) / block_size, 1);
let block = ((*d).min(32) as u32, block_size, 1);
let block_x = *d as u32;
let block_y = (1024u32 / block_x).min(32);
let grid = (bh as u32, (*s as u32 + block_y - 1) / block_y, 1);
let block = (block_x, block_y, 1);

let func = dev.get_func("flash_attention_backward_f32");
unsafe {
Expand Down
Loading