Skip to content

Commit 6df6dd1

Browse files
authored
Add new falsh attention fp8 support on BMG (#613)
## Description Add FP8 support to new flash attention api examples and fix the acc issue that caused by reorder --------- Signed-off-by: Chen, Xi2 <[email protected]>
1 parent 45a33fe commit 6df6dd1

File tree

4 files changed

+40
-15
lines changed

4 files changed

+40
-15
lines changed

examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,18 @@ int main(int argc, const char **argv) {
147147
#else
148148
constexpr int PipelineStages = 2;
149149
#endif
150-
151-
return FMHAConfig<false, ShapeQK, ShapePV, ShapeOut, SubgroupLayoutQK, void, PipelineStages>::run(options);
150+
#ifdef IS_FLOAT_E5M2
151+
using ElementQ = cutlass::float_e5m2_t;
152+
using ElementK = cutlass::float_e5m2_t;
153+
using ElementV = cutlass::float_e5m2_t;
154+
#elif defined(IS_FLOAT_E4M3)
155+
using ElementQ = cutlass::float_e4m3_t;
156+
using ElementK = cutlass::float_e4m3_t;
157+
using ElementV = cutlass::float_e4m3_t;
158+
#else
159+
using ElementQ = bfloat16_t;
160+
using ElementK = bfloat16_t;
161+
using ElementV = bfloat16_t;
162+
#endif
163+
return FMHAConfig<false, ShapeQK, ShapePV, ShapeOut, SubgroupLayoutQK, void, PipelineStages, ElementQ, ElementK, ElementV>::run(options);
152164
}

examples/06_bmg_flash_attention/CMakeLists.txt

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,27 @@ set(TEST_NO_PAGED "")
3333
set(TEST_PAGED "--use_paged_kv")
3434

3535
foreach(HEAD_DIM 64 96 128 192)
36+
foreach(INPUT_TYPE bfloat16_t float_e5m2_t float_e4m3_t)
37+
cutlass_example_add_executable(
38+
06_xe_fmha_fwd_prefill_${INPUT_TYPE}_hdim${HEAD_DIM}
39+
06_xe_fmha_fwd.cpp
40+
)
3641

37-
cutlass_example_add_executable(
38-
06_xe_fmha_fwd_prefill_hdim${HEAD_DIM}
39-
06_xe_fmha_fwd.cpp
40-
)
42+
cutlass_example_add_executable(
43+
06_xe_fmha_fwd_decode_${INPUT_TYPE}_hdim${HEAD_DIM}
44+
06_xe_fmha_fwd.cpp
45+
)
46+
if(INPUT_TYPE STREQUAL "bfloat16_t")
47+
set(INPUT_MACRO "IS_BFLOAT16")
48+
elseif(INPUT_TYPE STREQUAL "float_e5m2_t")
49+
set(INPUT_MACRO "IS_FLOAT_E5M2")
50+
elseif(INPUT_TYPE STREQUAL "float_e4m3_t")
51+
set(INPUT_MACRO "IS_FLOAT_E4M3")
52+
endif()
4153

42-
cutlass_example_add_executable(
43-
06_xe_fmha_fwd_decode_hdim${HEAD_DIM}
44-
06_xe_fmha_fwd.cpp
45-
)
54+
target_compile_definitions(06_xe_fmha_fwd_prefill_${INPUT_TYPE}_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM} PREFILL SHOW_DIFF=1 INPUT_TYPE=${INPUT_TYPE} ${INPUT_MACRO})
55+
target_compile_definitions(06_xe_fmha_fwd_decode_${INPUT_TYPE}_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM} DECODE SHOW_DIFF=1 INPUT_TYPE=${INPUT_TYPE} ${INPUT_MACRO})
56+
endforeach()
4657

4758
cutlass_example_add_executable(
4859
06_bmg_prefill_attention_hdim${HEAD_DIM}
@@ -82,6 +93,4 @@ foreach(HEAD_DIM 64 96 128 192)
8293
target_compile_definitions(06_bmg_decode_attention_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM})
8394
target_compile_definitions(06_bmg_prefill_attention_fp8_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM})
8495
target_compile_definitions(06_bmg_decode_attention_fp8_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM})
85-
target_compile_definitions(06_xe_fmha_fwd_prefill_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM} PREFILL SHOW_DIFF=1)
86-
target_compile_definitions(06_xe_fmha_fwd_decode_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM} DECODE SHOW_DIFF=1)
8796
endforeach()

examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ template <class FMHAKernel> struct ExampleRunner {
359359

360360
// Check if output from CUTLASS kernel and reference kernel are equal or not
361361
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_O.get(), block_O.get(),
362-
block_O.size(), ElementO{0.005}, ElementO{0.005});
362+
block_O.size(), ElementO{0.05}, ElementO{0.05});
363363

364364
return passed;
365365
}
@@ -531,7 +531,11 @@ struct FMHAConfig {
531531

532532
static constexpr int SGTileQ = get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{})))();
533533
using MMAOperation = cute::conditional_t<is_void_v<MMAOperation_>,
534-
XE_DPAS_TT<cute::gcd(SGTileQ, 8), float, ElementQ>,
534+
typename cute::conditional_t<
535+
cute::is_same_v<ElementQ, cutlass::float_e5m2_t> || cute::is_same_v<ElementQ, cutlass::float_e4m3_t>,
536+
XE_DPAS_TT<cute::gcd(SGTileQ, 8), float, half_t>,
537+
XE_DPAS_TT<cute::gcd(SGTileQ, 8), float, ElementQ>
538+
>,
535539
MMAOperation_>;
536540
using SubgroupLayoutPV = cute::conditional_t<is_void_v<SubgroupLayoutPV_>,
537541
decltype(cutlass::fmha::collective::get_sg_layout_pv(SubgroupLayoutQK{})),

include/cute/atom/copy_traits_xe_2d.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1216,7 +1216,7 @@ make_block_2d_prefetch(const Shape&, Stride<Strides...> const& stride, const XMo
12161216
constexpr auto shape_y = get<YMode::value>(Shape{});
12171217

12181218
// Try to retrieve whole cache lines (contiguous dimension = x)
1219-
constexpr auto width = cute::min(shape_x, 512 / sizeof_bits_v<ValType>);
1219+
constexpr auto width = cute::gcd(shape_x, 512 / sizeof_bits_v<ValType>);
12201220

12211221
// Do a preliminary tiling to choose appropriate height.
12221222
constexpr int n_sg_x = cute::gcd(SGCount, ceil_div(shape_x, width));

0 commit comments

Comments
 (0)