Skip to content

Commit 370207b

Browse files
authored
Revert "feat: Support fp8 qkv, fp16/bf16 out MHA for trtllm-gen. (flashinfer-ai#1490) (flashinfer-ai#1496)
<!-- .github/pull_request_template.md --> ## 📌 Description This reverts commit 96d142d. (flashinfer-ai#1490) ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent a78ae1a commit 370207b

File tree

9 files changed

+45
-64
lines changed

9 files changed

+45
-64
lines changed

flashinfer/artifacts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def get_available_cubin_files(source, retries=3, delay=5, timeout=10):
4848

4949

5050
class ArtifactPath:
51-
TRTLLM_GEN_FMHA: str = "85756bb4c98571b902095e944fc27bda1fcec3e4/fmha/trtllm-gen/"
51+
TRTLLM_GEN_FMHA: str = "c8e0abb4b0438880a2b0a9b68449e3cf1513aadf/fmha/trtllm-gen/"
5252
TRTLLM_GEN_BMM: str = (
5353
"5d347c6234c9f0e7f1ab6519ea933183b48216ed/batched_gemm-32110eb-5262bae/"
5454
)
@@ -61,7 +61,7 @@ class ArtifactPath:
6161

6262
class MetaInfoHash:
6363
TRTLLM_GEN_FMHA: str = (
64-
"24db1b729f7cb86d3f4dc4fbb7de479c5a854cd03ffbf75dc356bffcde48700c"
64+
"0d124e546c8a2e9fa59499625e8a6d140a2465573d4a3944f9d29f29f73292fb"
6565
)
6666
TRTLLM_GEN_BMM: str = (
6767
"aae02e5703ee0ce696c4b3a1f2a32936fcc960dcb69fdef52b6d0f8a7b673000"

flashinfer/decode.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
_check_cached_qkv_data_type,
5050
_check_kv_layout,
5151
_check_pos_encoding_mode,
52-
check_shape_dtype_device,
52+
_check_shape_dtype_device,
5353
_get_cache_alibi_slopes_buf,
5454
_get_cache_buf,
5555
_get_range_buf,
@@ -1231,14 +1231,14 @@ def run(
12311231
(q.size(0), q.size(1)), dtype=torch.float32, device=q.device
12321232
)
12331233
else:
1234-
check_shape_dtype_device(
1234+
_check_shape_dtype_device(
12351235
lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
12361236
)
12371237

12381238
if out is None:
12391239
out = torch.empty_like(q)
12401240
else:
1241-
check_shape_dtype_device(out, q.shape, q.dtype, q.device, "out")
1241+
_check_shape_dtype_device(out, q.shape, q.dtype, q.device, "out")
12421242

12431243
if self.use_tensor_cores:
12441244
run_args = [
@@ -1749,7 +1749,7 @@ def run(
17491749
if out is None:
17501750
out = torch.empty_like(q_nope, device=device)
17511751
else:
1752-
check_shape_dtype_device(
1752+
_check_shape_dtype_device(
17531753
out, q_nope.shape, q_nope.dtype, q_nope.device, "out"
17541754
)
17551755

@@ -1761,7 +1761,7 @@ def run(
17611761
device=device,
17621762
)
17631763
else:
1764-
check_shape_dtype_device(
1764+
_check_shape_dtype_device(
17651765
lse,
17661766
(q_nope.size(0), q_nope.size(1)),
17671767
q_nope.dtype,
@@ -2113,9 +2113,9 @@ def trtllm_batch_decode_with_kv_cache(
21132113
assert isinstance(out, torch.Tensor)
21142114

21152115
# Use uint8 as the container dtype to compliant with next fp4 gemm.
2116-
check_shape_dtype_device(out, fp4_out_shape, torch.uint8, query.device, "out")
2116+
_check_shape_dtype_device(out, fp4_out_shape, torch.uint8, query.device, "out")
21172117

2118-
check_shape_dtype_device(
2118+
_check_shape_dtype_device(
21192119
out_scale_factor,
21202120
fp4_out_scale_shape,
21212121
torch.float8_e4m3fn,
@@ -2141,12 +2141,7 @@ def trtllm_batch_decode_with_kv_cache(
21412141
o_sf_start_index = 0
21422142
out_dtype = out_dtype or query.dtype
21432143
out = out if out is not None else torch.empty_like(query, dtype=out_dtype)
2144-
assert out_dtype in (
2145-
query.dtype,
2146-
torch.float16,
2147-
torch.bfloat16,
2148-
)
2149-
check_shape_dtype_device(out, query.shape, out_dtype, query.device, "out")
2144+
_check_shape_dtype_device(out, query.shape, query.dtype, query.device, "out")
21502145
else:
21512146
raise ValueError(f"Invalid out_dtype: {out_dtype}")
21522147

@@ -2299,7 +2294,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
22992294
out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device)
23002295
else:
23012296
batch_size, _, num_q_heads, _ = query.shape
2302-
check_shape_dtype_device(
2297+
_check_shape_dtype_device(
23032298
out,
23042299
[batch_size, num_q_heads, kv_lora_rank],
23052300
torch.bfloat16,

flashinfer/fused_moe/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from ..jit.cubin_loader import get_cubin
3636
from ..jit.cutlass_gemm.generate_kernels import generate_gemm_operations
3737
from ..utils import (
38-
check_shape_dtype_device,
38+
_check_shape_dtype_device,
3939
device_support_pdl,
4040
get_shuffle_matrix_a_row_indices,
4141
get_shuffle_matrix_sf_a_row_indices,
@@ -784,7 +784,7 @@ def cutlass_fused_moe(
784784
if output is None:
785785
output = torch.empty(output_shape, dtype=output_dtype, device=input.device)
786786
else:
787-
check_shape_dtype_device(
787+
_check_shape_dtype_device(
788788
output, output_shape, output_dtype, input.device, "output"
789789
)
790790

flashinfer/mla.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from .jit import JitSpec
2323
from .jit import env as jit_env
2424
from .jit import gen_batch_mla_module, gen_jit_spec, sm100a_nvcc_flags
25-
from .utils import MaskMode, check_shape_dtype_device, determine_mla_backend
25+
from .utils import MaskMode, _check_shape_dtype_device, determine_mla_backend
2626

2727

2828
def _check_cutlass_shape(q_nope_pe, ckv_kpe_cache, kv_len, page_table):
@@ -394,7 +394,7 @@ def run(
394394
if out is None:
395395
out = torch.empty_like(q_nope)
396396
else:
397-
check_shape_dtype_device(
397+
_check_shape_dtype_device(
398398
out, q_nope.shape, q_nope.dtype, q_nope.device, "out"
399399
)
400400
q_nope_pe = torch.cat([q_nope, q_pe], dim=-1)
@@ -426,15 +426,15 @@ def run(
426426
if out is None:
427427
out = torch.empty_like(q_nope)
428428
else:
429-
check_shape_dtype_device(
429+
_check_shape_dtype_device(
430430
out, q_nope.shape, q_nope.dtype, q_nope.device, "out"
431431
)
432432

433433
if return_lse:
434434
if lse is None:
435435
lse = torch.empty(q_nope.shape[:2], dtype=torch.float32, device=device)
436436
else:
437-
check_shape_dtype_device(
437+
_check_shape_dtype_device(
438438
lse, q_nope.shape[:2], torch.float32, q_nope.device, "lse"
439439
)
440440
profiler_args = (profiler_buffer,) if self._use_profiler else ()

flashinfer/prefill.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
_check_cached_qkv_data_type,
4545
_check_kv_layout,
4646
_check_pos_encoding_mode,
47-
check_shape_dtype_device,
47+
_check_shape_dtype_device,
4848
_get_cache_alibi_slopes_buf,
4949
_get_cache_buf,
5050
_unpack_paged_kv_cache,
@@ -2034,7 +2034,7 @@ def run(
20342034
(q.size(0), q.size(1)), dtype=torch.float32, device=q.device
20352035
)
20362036
else:
2037-
check_shape_dtype_device(
2037+
_check_shape_dtype_device(
20382038
lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
20392039
)
20402040

@@ -2043,7 +2043,7 @@ def run(
20432043
q.shape[:-1] + v_cache.shape[-1:], dtype=q.dtype, device=q.device
20442044
)
20452045
else:
2046-
check_shape_dtype_device(
2046+
_check_shape_dtype_device(
20472047
out, q.shape[:-1] + v_cache.shape[-1:], q.dtype, q.device, "out"
20482048
)
20492049

@@ -2833,15 +2833,15 @@ def run(
28332833
(q.size(0), q.size(1)), dtype=torch.float32, device=q.device
28342834
)
28352835
else:
2836-
check_shape_dtype_device(
2836+
_check_shape_dtype_device(
28372837
lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
28382838
)
28392839
if out is None:
28402840
out = torch.empty(
28412841
q.shape[:-1] + v.shape[-1:], dtype=q.dtype, device=q.device
28422842
)
28432843
else:
2844-
check_shape_dtype_device(
2844+
_check_shape_dtype_device(
28452845
out, q.shape[:-1] + v.shape[-1:], q.dtype, q.device, "out"
28462846
)
28472847
if self._backend == "cutlass":
@@ -3244,9 +3244,9 @@ def trtllm_batch_context_with_kv_cache(
32443244
assert isinstance(out, torch.Tensor)
32453245

32463246
# Use uint8 as the container dtype to compliant with next fp4 gemm.
3247-
check_shape_dtype_device(out, fp4_out_shape, torch.uint8, query.device, "out")
3247+
_check_shape_dtype_device(out, fp4_out_shape, torch.uint8, query.device, "out")
32483248

3249-
check_shape_dtype_device(
3249+
_check_shape_dtype_device(
32503250
out_scale_factor,
32513251
fp4_out_scale_shape,
32523252
torch.float8_e4m3fn,
@@ -3271,13 +3271,8 @@ def trtllm_batch_context_with_kv_cache(
32713271
out_scale_factor = None
32723272
o_sf_start_index = 0
32733273
out_dtype = out_dtype or query.dtype
3274-
assert out_dtype in (
3275-
query.dtype,
3276-
torch.float16,
3277-
torch.bfloat16,
3278-
)
32793274
out = out if out is not None else torch.empty_like(query, dtype=out_dtype)
3280-
check_shape_dtype_device(out, query.shape, out_dtype, query.device, "out")
3275+
_check_shape_dtype_device(out, query.shape, query.dtype, query.device, "out")
32813276
else:
32823277
raise ValueError(f"Invalid out_dtype: {out_dtype}")
32833278

flashinfer/sparse.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
PosEncodingMode,
2929
TensorLayout,
3030
_check_pos_encoding_mode,
31-
check_shape_dtype_device,
31+
_check_shape_dtype_device,
3232
_get_cache_alibi_slopes_buf,
3333
canonicalize_torch_dtype,
3434
determine_attention_backend,
@@ -577,14 +577,14 @@ def run(
577577
(q.size(0), q.size(1)), dtype=torch.float32, device=q.device
578578
)
579579
else:
580-
check_shape_dtype_device(
580+
_check_shape_dtype_device(
581581
lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
582582
)
583583

584584
if out is None:
585585
out = torch.empty_like(q, dtype=self._o_dtype)
586586
else:
587-
check_shape_dtype_device(out, q.shape, self._o_dtype, q.device, "out")
587+
_check_shape_dtype_device(out, q.shape, self._o_dtype, q.device, "out")
588588

589589
if is_float8(q):
590590
assert q.dtype == k.dtype == v.dtype
@@ -1157,14 +1157,14 @@ def run(
11571157
(q.size(0), q.size(1)), dtype=torch.float32, device=q.device
11581158
)
11591159
else:
1160-
check_shape_dtype_device(
1160+
_check_shape_dtype_device(
11611161
lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse"
11621162
)
11631163

11641164
if out is None:
11651165
out = torch.empty_like(q, dtype=self._o_dtype)
11661166
else:
1167-
check_shape_dtype_device(out, q.shape, self._o_dtype, q.device, "out")
1167+
_check_shape_dtype_device(out, q.shape, self._o_dtype, q.device, "out")
11681168

11691169
if self._backend == "fa3":
11701170
if (

flashinfer/utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -431,22 +431,22 @@ def determine_mla_backend(device: torch.device) -> str:
431431
return "fa3" if is_sm90a_supported(device) else "fa2"
432432

433433

434-
def check_shape_dtype_device(
434+
def _check_shape_dtype_device(
435435
x: torch.Tensor,
436-
expected_shape: Optional[Sequence[int]],
437-
expected_dtype: Optional[torch.dtype],
438-
expected_device: Optional[torch.device],
436+
expected_shape: Sequence[int],
437+
expected_dtype: torch.dtype,
438+
expected_device: torch.device,
439439
name: str,
440440
) -> None:
441-
if expected_shape and x.shape != torch.Size(expected_shape):
441+
if x.shape != torch.Size(expected_shape):
442442
raise ValueError(
443443
f"Invalid shape of {name}: expected {expected_shape}, got {x.shape}"
444444
)
445-
if expected_dtype and x.dtype != expected_dtype:
445+
if x.dtype != expected_dtype:
446446
raise ValueError(
447447
f"Invalid dtype of {name}: expected {expected_dtype}, got {x.dtype}"
448448
)
449-
if expected_device and x.device != expected_device:
449+
if x.device != expected_device:
450450
raise ValueError(
451451
f"Invalid device of {name}: expected {expected_device}, got {x.device}"
452452
)

include/flashinfer/trtllm/fmha/fmhaKernels.cuh

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -597,9 +597,8 @@ class TllmFmhaKernelFactory {
597597
std::lock_guard<std::mutex> lg(s_mutex);
598598

599599
if (!metainfo_loaded) {
600-
std::string metainfo_raw =
601-
getMetaInfo(tllm_gen_fmha_cubin_path + "/include/flashInferMetaInfo",
602-
tllm_gen_fmha_metainfo_hash, ".h");
600+
std::string metainfo_raw = getMetaInfo(tllm_gen_fmha_cubin_path + "flashInferMetaInfo",
601+
tllm_gen_fmha_metainfo_hash, ".h");
603602
metainfo = KernelType::KernelMeta::loadFromMetaInfoRaw(metainfo_raw);
604603
metainfo_loaded = true;
605604
}

tests/test_trtllm_gen_attention.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from flashinfer.utils import FP4Tensor, ceil_div, round_up
99

1010
DTYPE_MAP = {
11-
"fp16": torch.float16,
11+
"half": torch.float16,
1212
"bf16": torch.bfloat16,
1313
"fp8": torch.float8_e4m3fn,
1414
"nvfp4": "nvfp4",
@@ -237,10 +237,8 @@ def unpack_compare_nvfp4(
237237
@pytest.mark.parametrize(
238238
"q_dtype,kv_dtype,o_dtype",
239239
[
240+
("half", "half", "half"),
240241
("bf16", "bf16", "bf16"),
241-
("fp16", "fp16", "fp16"),
242-
("fp8", "fp8", "bf16"),
243-
("fp8", "fp8", "fp16"),
244242
("fp8", "fp8", "fp8"),
245243
("fp8", "fp8", "nvfp4"),
246244
],
@@ -357,10 +355,8 @@ def test_trtllm_batch_prefill(
357355
)
358356
assert o_scale == 1.0
359357
rtol, atol = 4e-1, 1e0
360-
elif q_dtype == "fp8" and o_dtype == "fp8":
358+
elif o_dtype == "fp8":
361359
rtol, atol = 5e-2, 7e-2
362-
elif q_dtype == "fp8" and o_dtype in ["bf16", "fp16"]:
363-
rtol, atol = 4e-2, 6e-2
364360
else:
365361
rtol, atol = 1e-2, 1e-2
366362

@@ -403,12 +399,10 @@ def test_trtllm_batch_prefill(
403399
@pytest.mark.parametrize(
404400
"q_dtype,kv_dtype,o_dtype",
405401
[
402+
("half", "half", "half"),
403+
("half", "fp8", "half"),
406404
("bf16", "bf16", "bf16"),
407-
("fp16", "fp16", "fp16"),
408405
("bf16", "fp8", "bf16"),
409-
("fp16", "fp8", "fp16"),
410-
("fp8", "fp8", "bf16"),
411-
("fp8", "fp8", "fp16"),
412406
("fp8", "fp8", "fp8"),
413407
("fp8", "fp8", "nvfp4"),
414408
],
@@ -518,10 +512,8 @@ def test_trtllm_batch_decode(
518512
)
519513
assert o_scale == 1.0
520514
rtol, atol = 3e-1, 1e0
521-
elif q_dtype == "fp8" and o_dtype == "fp8":
515+
elif o_dtype == "fp8":
522516
rtol, atol = 5e-2, 7e-2
523-
elif q_dtype == "fp8" and o_dtype in ["bf16", "fp16"]:
524-
rtol, atol = 4e-2, 6e-2
525517
else:
526518
rtol, atol = 1e-2, 1e-2
527519

0 commit comments

Comments
 (0)