Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions csrc/apis/attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,22 @@ static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q,
const auto& [batch_size__, max_block_len] = get_shape<2>(block_table);
const auto& [schedule_meta_size, meta_info_size] = get_shape<2>(schedule_meta);
const auto& num_sms = device_runtime->get_num_sms();
const auto& num_kv_multicast = next_n == 4 ? 2 : 1;
const auto& num_clusters = num_sms / num_kv_multicast;
const auto& kv_cache_stride_bytes = fused_kv_cache.stride(0);
const auto& block_table_stride = block_table.stride(0);
const auto& arch_major = device_runtime->get_arch_major();

DG_HOST_ASSERT(batch_size == batch_size_ and batch_size == batch_size__);
DG_HOST_ASSERT(batch_size_next_n == batch_size * next_n);
DG_HOST_ASSERT(num_heads == num_heads_ and num_heads_kv == 1);
DG_HOST_ASSERT(head_dim_with_sf == head_dim + static_cast<int>(sizeof(float)));
DG_HOST_ASSERT(schedule_meta_size == num_sms + 1 and meta_info_size == 2);
DG_HOST_ASSERT(num_sms % num_kv_multicast == 0);
DG_HOST_ASSERT(schedule_meta_size == num_clusters + 1 and meta_info_size == 2);

DG_HOST_ASSERT(next_n == 1 or next_n == 2);
DG_HOST_ASSERT(next_n == 1 or next_n == 2 or next_n == 4);
// SM90 does not support next_n == 4 for now
DG_HOST_ASSERT(!(arch_major == 9 and next_n == 4));
DG_HOST_ASSERT(block_kv == 64);

DG_HOST_ASSERT(q.is_contiguous());
Expand Down Expand Up @@ -204,7 +210,6 @@ static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q,
logits = logits.slice(-1, 0, max_context_len);

// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
if (arch_major == 9 or arch_major == 10) {
smxx_fp8_paged_mqa_logits(q, kv_cache, kv_cache_scales, weights, context_lens, logits, block_table, schedule_meta,
batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv,
Expand Down
2 changes: 1 addition & 1 deletion csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ static void smxx_fp8_mqa_logits(const torch::Tensor& q,
constexpr int block_qh = 128;
constexpr int block_kv = 256;
constexpr int num_specialized_threads = 128;
constexpr int num_math_threads = 256;
const int num_math_threads = (device_runtime->get_arch_major() == 10 ? 256 : 512);
constexpr int num_q_stages = 3, num_kv_stages = 3;
const int block_q = block_qh / num_heads;
DG_HOST_ASSERT(block_qh % num_heads == 0);
Expand Down
22 changes: 15 additions & 7 deletions csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class SMXXFP8PagedMQALogitsRuntime final: public LaunchRuntime<SMXXFP8PagedMQALo

int num_specialized_threads;
int num_math_threads;
int num_kv_multicast;

LaunchArgs launch_args;
};
Expand All @@ -123,15 +124,17 @@ static void __instantiate_kernel() {{
{}, {},
{}, {},
{},
{}, {}
{}, {},
{}
>);
}};
)", arch, arch,
args.next_n, args.num_heads,
args.head_dim, args.block_kv,
args.num_q_stages, args.num_kv_stages,
args.split_kv,
args.num_specialized_threads, args.num_math_threads);
args.num_specialized_threads, args.num_math_threads,
args.num_kv_multicast);
}

static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
Expand Down Expand Up @@ -172,22 +175,26 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,

// Construct TMAs
DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128);

const int next_n_per_cta = next_n == 4 ? 2 : next_n;
const int num_kv_multicast = next_n == 4 ? 2 : 1;

const auto& tensor_map_q = make_tma_2d_desc(q, head_dim, batch_size * next_n * num_heads,
head_dim, next_n * num_heads, head_dim, head_dim);
head_dim, next_n_per_cta * num_heads, head_dim, head_dim);
const auto& tensor_map_kv = make_tma_3d_desc(kv_cache, head_dim, block_kv, num_kv_blocks,
head_dim, block_kv, 1,
head_dim, kv_cache_stride_bytes, head_dim);
// TODO: use 1D TMA
const auto& tensor_map_kv_scales = make_tma_2d_desc(kv_cache_scales, block_kv, num_kv_blocks,
block_kv, 1, kv_cache_stride_bytes / static_cast<int>(sizeof(float)), 0);
const auto& tensor_map_weights = make_tma_2d_desc(weights, next_n * num_heads, batch_size,
next_n * num_heads, 1, next_n * num_heads, 0);
next_n_per_cta * num_heads, 1, next_n * num_heads, 0);

// Calculate shared memory size
const int swizzle_alignment = head_dim * 8;

const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast<int>(q.element_size());
const int aligned_smem_weight_size_per_stage = align(next_n * num_heads * static_cast<int>(weights.element_size()), swizzle_alignment);
const int smem_q_size_per_stage = next_n_per_cta * num_heads * head_dim * static_cast<int>(q.element_size());
const int aligned_smem_weight_size_per_stage = align(next_n_per_cta * num_heads * static_cast<int>(weights.element_size()), swizzle_alignment);
const int smem_q_pipe_size = num_q_stages * (smem_q_size_per_stage + aligned_smem_weight_size_per_stage) + align(num_q_stages * 8 * 2, swizzle_alignment);

const int smem_kv_size_per_stage = block_kv * head_dim * static_cast<int>(kv_cache.element_size());
Expand Down Expand Up @@ -224,9 +231,10 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,
.tensor_map_weights = tensor_map_weights,
.num_specialized_threads = num_specialized_threads,
.num_math_threads = num_math_threads,
.num_kv_multicast = num_kv_multicast,
.launch_args = LaunchArgs(num_sms,
num_specialized_threads + num_math_threads + num_extra_threads,
smem_size)
smem_size, num_kv_multicast)
};
const auto& code = SMXXFP8PagedMQALogitsRuntime::generate(args);
const auto& runtime = compiler->build("sm90_fp8_paged_mqa_logits", code);
Expand Down
2 changes: 1 addition & 1 deletion deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, con

auto sum = make_float2(0, 0);

for (int j = 0; j < kNumWeightsInReg; j += 2) {
for (int j = 0; j < kNumWeightsInReg; j += 2) {
auto a = make_float2(fmaxf(accum[j], 0),
fmaxf(accum[j + 1], 0));
auto b = make_float2(weights[i][j], weights[i][j + 1]);
Expand Down
34 changes: 20 additions & 14 deletions deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ template <uint32_t kNextN, uint32_t kNumHeads,
uint32_t kHeadDim, uint32_t BLOCK_KV,
uint32_t kNumQStages, uint32_t kNumKVStages,
uint32_t SPLIT_KV,
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads>
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
uint32_t kNumKVMulticast>
__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads + 128, 1)
void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
const uint64_t logits_stride, const uint64_t block_table_stride,
Expand Down Expand Up @@ -50,10 +51,13 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
}
__syncwarp();

constexpr uint32_t kNextNPerCTA = kNextN / kNumKVMulticast;
DG_STATIC_ASSERT(kNextN % kNumKVMulticast == 0, "Invalid `kNextN` or `kNumKVMulticast`");

// Shared memory configs
static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float);
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextNPerCTA * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextNPerCTA * kNumHeads * sizeof(float);
static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment);
static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) +
constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment);
Expand Down Expand Up @@ -104,7 +108,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; });
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(umma_barrier_ptr + kNumMathWarpGroups * 2);

constexpr uint32_t kNumTmemCols = kNextN * kNumHeads * kNumMathWarpGroups;
constexpr uint32_t kNumTmemCols = kNextNPerCTA * kNumHeads * kNumMathWarpGroups;
DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
const bool& is_math_warp = (warp_idx < (kNumMathThreads / 32)); // 0 ~ 16
const bool& is_tma_load_warp = (warp_idx >= (kNumMathThreads / 32) and warp_idx < (kNumMathThreads / 32 + 4)); // 16 ~ 20
Expand Down Expand Up @@ -147,7 +151,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
constexpr uint32_t kNumMathRegisters = 104;

// Scheduler
auto scheduler = PagedMQALogitsScheduler<BLOCK_KV, kNumMathWarpGroups>(batch_size, blockIdx.x, context_lens, schedule_meta);
auto scheduler = PagedMQALogitsScheduler<BLOCK_KV, kNumMathWarpGroups>(batch_size, cute::cluster_id_in_grid().x, context_lens, schedule_meta);
DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV");

// Q and KV pipeline
Expand All @@ -163,7 +167,8 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
// Construct instruction with layout F
constexpr uint32_t UMMA_M = 64;
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
constexpr uint32_t UMMA_N = kNextN * kNumHeads;
constexpr uint32_t UMMA_N = kNextNPerCTA * kNumHeads;
const uint32_t cta_rank_in_cluster = cute::block_rank_in_cluster();

if (is_tma_load_warp) {
// TMA warp-group for loading data
Expand All @@ -173,8 +178,8 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,

const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) {
if (kv_group_idx == 0 and cute::elect_one_sync()) {
tma_copy(&tensor_map_q, reinterpret_cast<uint64_t*>(full_q_barriers[stage_idx]), smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads);
tma_copy(&tensor_map_weights, reinterpret_cast<uint64_t*>(full_q_barriers[stage_idx]), smem_weights[stage_idx], 0, q_idx);
tma_copy(&tensor_map_q, reinterpret_cast<uint64_t*>(full_q_barriers[stage_idx]), smem_q[stage_idx], 0, (q_idx * kNextN + cta_rank_in_cluster * kNextNPerCTA) * kNumHeads);
tma_copy(&tensor_map_weights, reinterpret_cast<uint64_t*>(full_q_barriers[stage_idx]), smem_weights[stage_idx], cta_rank_in_cluster * kNextNPerCTA * kNumHeads, q_idx);
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
}
};
Expand Down Expand Up @@ -284,7 +289,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,

// Offsets
const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0);
float weights[kNextN][kNumHeads / 4];
float weights[kNextNPerCTA][kNumHeads / 4];
const auto& sub_warp_offset = (warp_idx % 4) * 16;
const auto& v_0_offset = lane_idx / 4 + 0;
const auto& v_1_offset = lane_idx / 4 + 8;
Expand All @@ -308,10 +313,11 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,

// Read weights
#pragma unroll
for (uint32_t i = 0; i < kNextN; ++ i) {
for (uint32_t i = 0; i < kNextNPerCTA; ++ i) {
#pragma unroll
for (uint32_t j = 0; j < kNumHeads / 4; ++ j)
for (uint32_t j = 0; j < kNumHeads / 4; ++ j) {
weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
}
}
}

Expand All @@ -320,9 +326,9 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
kv_idx = next_kv_idx;

// Calculate KV offset in advance
auto kv_offset = q_idx * kNextN * logits_stride + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset);
auto kv_offset = (q_idx * kNextN + cta_rank_in_cluster * kNextNPerCTA) * logits_stride + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset);

// Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]`
// Compute `[kNextNPerCTA * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextNPerCTA, BLOCK_KV]`
// Wait TMA KV arrival
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
Expand All @@ -343,7 +349,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2;
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
#pragma unroll
for (uint32_t i = 0; i < kNextN; ++ i) {
for (uint32_t i = 0; i < kNextNPerCTA; ++ i) {
// Load from the tensor memory
constexpr uint32_t kNumLDTMElems = UMMA_M * kNumHeads / 128;
uint32_t shifted_accum[kNumLDTMElems];
Expand Down
7 changes: 5 additions & 2 deletions tests/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def ref_fp8_paged_mqa_logits(q: torch.Tensor, kv_cache: torch.Tensor,
def test_paged_mqa_logits():
print('Testing FP8 Paged MQA Logits:')
max_model_len = 111 * 1000
for batch_size, next_n in [(64, 1), (64, 2), (128, 1)]:
for batch_size, next_n in [(64, 1), (64, 2), (64, 4), (128, 1)]:
for heads, index_dim in [(64, 128)]:
for avg_kv in (8192, 32768):
num_blocks, blocksize = max_model_len * 3, 64
Expand All @@ -204,7 +204,10 @@ def test_paged_mqa_logits():
q_fp8 = q.to(torch.float8_e4m3fn)
kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache)

schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(context_lens, blocksize, deep_gemm.get_num_sms())
num_kv_multicast = 2 if next_n == 4 else 1
num_clusters = deep_gemm.get_num_sms() // num_kv_multicast

schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(context_lens, blocksize, num_clusters)
logits = deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens, block_tables, schedule_metadata, max_model_len, clean_logits=True)

ref_logits = ref_fp8_paged_mqa_logits(q, kv_cache, weights, context_lens, block_tables, max_model_len)
Expand Down