diff --git a/csrc/apis/attention.hpp b/csrc/apis/attention.hpp index eb037261..5eb765c3 100644 --- a/csrc/apis/attention.hpp +++ b/csrc/apis/attention.hpp @@ -154,16 +154,22 @@ static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q, const auto& [batch_size__, max_block_len] = get_shape<2>(block_table); const auto& [schedule_meta_size, meta_info_size] = get_shape<2>(schedule_meta); const auto& num_sms = device_runtime->get_num_sms(); + const auto& num_kv_multicast = next_n == 4 ? 2 : 1; + const auto& num_clusters = num_sms / num_kv_multicast; const auto& kv_cache_stride_bytes = fused_kv_cache.stride(0); const auto& block_table_stride = block_table.stride(0); + const auto& arch_major = device_runtime->get_arch_major(); DG_HOST_ASSERT(batch_size == batch_size_ and batch_size == batch_size__); DG_HOST_ASSERT(batch_size_next_n == batch_size * next_n); DG_HOST_ASSERT(num_heads == num_heads_ and num_heads_kv == 1); DG_HOST_ASSERT(head_dim_with_sf == head_dim + static_cast(sizeof(float))); - DG_HOST_ASSERT(schedule_meta_size == num_sms + 1 and meta_info_size == 2); + DG_HOST_ASSERT(num_sms % num_kv_multicast == 0); + DG_HOST_ASSERT(schedule_meta_size == num_clusters + 1 and meta_info_size == 2); - DG_HOST_ASSERT(next_n == 1 or next_n == 2); + DG_HOST_ASSERT(next_n == 1 or next_n == 2 or next_n == 4); + // SM90 does not support next_n == 4 for now + DG_HOST_ASSERT(!(arch_major == 9 and next_n == 4)); DG_HOST_ASSERT(block_kv == 64); DG_HOST_ASSERT(q.is_contiguous()); @@ -204,7 +210,6 @@ static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q, logits = logits.slice(-1, 0, max_context_len); // Dispatch implementation - const auto& arch_major = device_runtime->get_arch_major(); if (arch_major == 9 or arch_major == 10) { smxx_fp8_paged_mqa_logits(q, kv_cache, kv_cache_scales, weights, context_lens, logits, block_table, schedule_meta, batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, diff --git a/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp b/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp index ffd6f439..3a431651 100644 --- a/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp +++ b/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp @@ -87,7 +87,7 @@ static void smxx_fp8_mqa_logits(const torch::Tensor& q, constexpr int block_qh = 128; constexpr int block_kv = 256; constexpr int num_specialized_threads = 128; - constexpr int num_math_threads = 256; + const int num_math_threads = (device_runtime->get_arch_major() == 10 ? 256 : 512); constexpr int num_q_stages = 3, num_kv_stages = 3; const int block_q = block_qh / num_heads; DG_HOST_ASSERT(block_qh % num_heads == 0); diff --git a/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp b/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp index 38bbfb9d..76edb039 100644 --- a/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp +++ b/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp @@ -102,6 +102,7 @@ class SMXXFP8PagedMQALogitsRuntime final: public LaunchRuntime); }}; )", arch, arch, @@ -131,7 +133,8 @@ static void __instantiate_kernel() {{ args.head_dim, args.block_kv, args.num_q_stages, args.num_kv_stages, args.split_kv, - args.num_specialized_threads, args.num_math_threads); + args.num_specialized_threads, args.num_math_threads, + args.num_kv_multicast); } static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { @@ -172,8 +175,12 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q, // Construct TMAs DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128); + + const int next_n_per_cta = next_n == 4 ? 2 : next_n; + const int num_kv_multicast = next_n == 4 ? 2 : 1; + const auto& tensor_map_q = make_tma_2d_desc(q, head_dim, batch_size * next_n * num_heads, - head_dim, next_n * num_heads, head_dim, head_dim); + head_dim, next_n_per_cta * num_heads, head_dim, head_dim); const auto& tensor_map_kv = make_tma_3d_desc(kv_cache, head_dim, block_kv, num_kv_blocks, head_dim, block_kv, 1, head_dim, kv_cache_stride_bytes, head_dim); @@ -181,13 +188,13 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q, const auto& tensor_map_kv_scales = make_tma_2d_desc(kv_cache_scales, block_kv, num_kv_blocks, block_kv, 1, kv_cache_stride_bytes / static_cast(sizeof(float)), 0); const auto& tensor_map_weights = make_tma_2d_desc(weights, next_n * num_heads, batch_size, - next_n * num_heads, 1, next_n * num_heads, 0); + next_n_per_cta * num_heads, 1, next_n * num_heads, 0); // Calculate shared memory size const int swizzle_alignment = head_dim * 8; - const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast(q.element_size()); - const int aligned_smem_weight_size_per_stage = align(next_n * num_heads * static_cast(weights.element_size()), swizzle_alignment); + const int smem_q_size_per_stage = next_n_per_cta * num_heads * head_dim * static_cast(q.element_size()); + const int aligned_smem_weight_size_per_stage = align(next_n_per_cta * num_heads * static_cast(weights.element_size()), swizzle_alignment); const int smem_q_pipe_size = num_q_stages * (smem_q_size_per_stage + aligned_smem_weight_size_per_stage) + align(num_q_stages * 8 * 2, swizzle_alignment); const int smem_kv_size_per_stage = block_kv * head_dim * static_cast(kv_cache.element_size()); @@ -224,9 +231,10 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q, .tensor_map_weights = tensor_map_weights, .num_specialized_threads = num_specialized_threads, .num_math_threads = num_math_threads, + .num_kv_multicast = num_kv_multicast, .launch_args = LaunchArgs(num_sms, num_specialized_threads + num_math_threads + num_extra_threads, - smem_size) + smem_size, num_kv_multicast) }; const auto& code = SMXXFP8PagedMQALogitsRuntime::generate(args); const auto& runtime = compiler->build("sm90_fp8_paged_mqa_logits", code); diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh index 5252ddbb..74742957 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh @@ -333,7 +333,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, con auto sum = make_float2(0, 0); - for (int j = 0; j < kNumWeightsInReg; j += 2) { + for (int j = 0; j < kNumWeightsInReg; j += 2) { auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); auto b = make_float2(weights[i][j], weights[i][j + 1]); diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh index 4a53421f..b46798df 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh @@ -21,7 +21,8 @@ template + uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads, + uint32_t kNumKVMulticast> __global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads + 128, 1) void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, const uint64_t logits_stride, const uint64_t block_table_stride, @@ -50,10 +51,13 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, } __syncwarp(); + constexpr uint32_t kNextNPerCTA = kNextN / kNumKVMulticast; + DG_STATIC_ASSERT(kNextN % kNumKVMulticast == 0, "Invalid `kNextN` or `kNumKVMulticast`"); + // Shared memory configs static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; - static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); - static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float); + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextNPerCTA * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextNPerCTA * kNumHeads * sizeof(float); static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment); static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) + constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment); @@ -104,7 +108,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; }); auto tmem_ptr_in_smem = reinterpret_cast(umma_barrier_ptr + kNumMathWarpGroups * 2); - constexpr uint32_t kNumTmemCols = kNextN * kNumHeads * kNumMathWarpGroups; + constexpr uint32_t kNumTmemCols = kNextNPerCTA * kNumHeads * kNumMathWarpGroups; DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); const bool& is_math_warp = (warp_idx < (kNumMathThreads / 32)); // 0 ~ 16 const bool& is_tma_load_warp = (warp_idx >= (kNumMathThreads / 32) and warp_idx < (kNumMathThreads / 32 + 4)); // 16 ~ 20 @@ -147,7 +151,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, constexpr uint32_t kNumMathRegisters = 104; // Scheduler - auto scheduler = PagedMQALogitsScheduler(batch_size, blockIdx.x, context_lens, schedule_meta); + auto scheduler = PagedMQALogitsScheduler(batch_size, cute::cluster_id_in_grid().x, context_lens, schedule_meta); DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV"); // Q and KV pipeline @@ -163,7 +167,8 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, // Construct instruction with layout F constexpr uint32_t UMMA_M = 64; constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); - constexpr uint32_t UMMA_N = kNextN * kNumHeads; + constexpr uint32_t UMMA_N = kNextNPerCTA * kNumHeads; + const uint32_t cta_rank_in_cluster = cute::block_rank_in_cluster(); if (is_tma_load_warp) { // TMA warp-group for loading data @@ -173,8 +178,8 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) { if (kv_group_idx == 0 and cute::elect_one_sync()) { - tma_copy(&tensor_map_q, reinterpret_cast(full_q_barriers[stage_idx]), smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); - tma_copy(&tensor_map_weights, reinterpret_cast(full_q_barriers[stage_idx]), smem_weights[stage_idx], 0, q_idx); + tma_copy(&tensor_map_q, reinterpret_cast(full_q_barriers[stage_idx]), smem_q[stage_idx], 0, (q_idx * kNextN + cta_rank_in_cluster * kNextNPerCTA) * kNumHeads); + tma_copy(&tensor_map_weights, reinterpret_cast(full_q_barriers[stage_idx]), smem_weights[stage_idx], cta_rank_in_cluster * kNextNPerCTA * kNumHeads, q_idx); full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); } }; @@ -284,7 +289,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, // Offsets const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0); - float weights[kNextN][kNumHeads / 4]; + float weights[kNextNPerCTA][kNumHeads / 4]; const auto& sub_warp_offset = (warp_idx % 4) * 16; const auto& v_0_offset = lane_idx / 4 + 0; const auto& v_1_offset = lane_idx / 4 + 8; @@ -308,10 +313,11 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, // Read weights #pragma unroll - for (uint32_t i = 0; i < kNextN; ++ i) { + for (uint32_t i = 0; i < kNextNPerCTA; ++ i) { #pragma unroll - for (uint32_t j = 0; j < kNumHeads / 4; ++ j) + for (uint32_t j = 0; j < kNumHeads / 4; ++ j) { weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); + } } } @@ -320,9 +326,9 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, kv_idx = next_kv_idx; // Calculate KV offset in advance - auto kv_offset = q_idx * kNextN * logits_stride + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset); + auto kv_offset = (q_idx * kNextN + cta_rank_in_cluster * kNextNPerCTA) * logits_stride + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset); - // Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]` + // Compute `[kNextNPerCTA * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextNPerCTA, BLOCK_KV]` // Wait TMA KV arrival CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); full_kv_barriers[kv_stage_idx]->wait(kv_phase); @@ -343,7 +349,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2; DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); #pragma unroll - for (uint32_t i = 0; i < kNextN; ++ i) { + for (uint32_t i = 0; i < kNextNPerCTA; ++ i) { // Load from the tensor memory constexpr uint32_t kNumLDTMElems = UMMA_M * kNumHeads / 128; uint32_t shifted_accum[kNumLDTMElems]; diff --git a/tests/test_attention.py b/tests/test_attention.py index 1baa80f1..cca109d7 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -179,7 +179,7 @@ def ref_fp8_paged_mqa_logits(q: torch.Tensor, kv_cache: torch.Tensor, def test_paged_mqa_logits(): print('Testing FP8 Paged MQA Logits:') max_model_len = 111 * 1000 - for batch_size, next_n in [(64, 1), (64, 2), (128, 1)]: + for batch_size, next_n in [(64, 1), (64, 2), (64, 4), (128, 1)]: for heads, index_dim in [(64, 128)]: for avg_kv in (8192, 32768): num_blocks, blocksize = max_model_len * 3, 64 @@ -204,7 +204,10 @@ def test_paged_mqa_logits(): q_fp8 = q.to(torch.float8_e4m3fn) kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) - schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(context_lens, blocksize, deep_gemm.get_num_sms()) + num_kv_multicast = 2 if next_n == 4 else 1 + num_clusters = deep_gemm.get_num_sms() // num_kv_multicast + + schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(context_lens, blocksize, num_clusters) logits = deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens, block_tables, schedule_metadata, max_model_len, clean_logits=True) ref_logits = ref_fp8_paged_mqa_logits(q, kv_cache, weights, context_lens, block_tables, max_model_len)