@@ -33,16 +33,27 @@ set(TEST_NO_PAGED "")
3333set (TEST_PAGED "--use_paged_kv" )
3434
3535foreach (HEAD_DIM 64 96 128 192)
36+ foreach (INPUT_TYPE bfloat16_t float_e5m2_t float_e4m3_t)
37+ cutlass_example_add_executable(
38+ 06_xe_fmha_fwd_prefill_${INPUT_TYPE} _hdim${HEAD_DIM}
39+ 06_xe_fmha_fwd.cpp
40+ )
3641
37- cutlass_example_add_executable(
38- 06_xe_fmha_fwd_prefill_hdim${HEAD_DIM}
39- 06_xe_fmha_fwd.cpp
40- )
42+ cutlass_example_add_executable(
43+ 06_xe_fmha_fwd_decode_${INPUT_TYPE} _hdim${HEAD_DIM}
44+ 06_xe_fmha_fwd.cpp
45+ )
46+ if (INPUT_TYPE STREQUAL "bfloat16_t" )
47+ set (INPUT_MACRO "IS_BFLOAT16" )
48+ elseif (INPUT_TYPE STREQUAL "float_e5m2_t" )
49+ set (INPUT_MACRO "IS_FLOAT_E5M2" )
50+ elseif (INPUT_TYPE STREQUAL "float_e4m3_t" )
51+ set (INPUT_MACRO "IS_FLOAT_E4M3" )
52+ endif ()
4153
42- cutlass_example_add_executable(
43- 06_xe_fmha_fwd_decode_hdim${HEAD_DIM}
44- 06_xe_fmha_fwd.cpp
45- )
54+ target_compile_definitions (06_xe_fmha_fwd_prefill_${INPUT_TYPE} _hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM} PREFILL SHOW_DIFF=1 INPUT_TYPE=${INPUT_TYPE} ${INPUT_MACRO} )
55+ target_compile_definitions (06_xe_fmha_fwd_decode_${INPUT_TYPE} _hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM} DECODE SHOW_DIFF=1 INPUT_TYPE=${INPUT_TYPE} ${INPUT_MACRO} )
56+ endforeach ()
4657
4758 cutlass_example_add_executable(
4859 06_bmg_prefill_attention_hdim${HEAD_DIM}
@@ -82,6 +93,4 @@ foreach(HEAD_DIM 64 96 128 192)
8293 target_compile_definitions (06_bmg_decode_attention_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM} )
8394 target_compile_definitions (06_bmg_prefill_attention_fp8_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM} )
8495 target_compile_definitions (06_bmg_decode_attention_fp8_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM} )
85- target_compile_definitions (06_xe_fmha_fwd_prefill_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM} PREFILL SHOW_DIFF=1)
86- target_compile_definitions (06_xe_fmha_fwd_decode_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM} DECODE SHOW_DIFF=1)
8796endforeach ()
0 commit comments