Skip to content

Commit d41b0c5

Browse files
committed
Fixed and cleaned up
1 parent 87b0e67 commit d41b0c5

File tree

3 files changed

+27
-15
lines changed

3 files changed

+27
-15
lines changed

flashinfer/cudnn/prefill.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def _sdpa_prefill_key_fn(
9191
batch_offsets_stats: Optional[torch.Tensor] = None,
9292
out: Optional[torch.Tensor] = None,
9393
lse: Optional[torch.Tensor] = None,
94+
o_data_type: Optional[torch.dtype] = None,
9495
):
9596
graph_b = actual_seq_lens_q.shape[0]
9697

@@ -149,6 +150,7 @@ def _build_prefill_graph(
149150
batch_offsets_stats: Optional[torch.Tensor] = None,
150151
out: Optional[torch.Tensor] = None,
151152
lse: Optional[torch.Tensor] = None,
153+
o_data_type: Optional[torch.dtype] = None,
152154
):
153155
handle = _create_cudnn_handle(torch.cuda.current_stream(q.device))
154156

@@ -163,6 +165,16 @@ def _build_prefill_graph(
163165
cudnn_k_data_type = cudnn.datatypes._torch_to_cudnn_data_type(k_cache.dtype)
164166
cudnn_v_data_type = cudnn.datatypes._torch_to_cudnn_data_type(v_cache.dtype)
165167

168+
cudnn_o_data_type = cudnn.datatypes._torch_to_cudnn_data_type(o_data_type)
169+
170+
if (
171+
cudnn_q_data_type == cudnn.data_type.FP8_E4M3
172+
or cudnn_q_data_type == cudnn.data_type.FP8_E5M2
173+
) and cudnn.backend_version() < 91800:
174+
raise RuntimeError(
175+
f"FP8 is not supported in cuDNN backend version < 9.18.0, current version is {cudnn.backend_version()}"
176+
)
177+
166178
with cudnn.graph(handle) as (g, _):
167179
# Create tensors from the input tensors
168180
if q.dim() == 3:
@@ -318,7 +330,10 @@ def _build_prefill_graph(
318330
actual_seq_lens_q is not None and actual_seq_lens_kv is not None
319331
)
320332

321-
if cudnn_q_data_type == cudnn.data_type.BFLOAT16:
333+
if (
334+
cudnn_q_data_type == cudnn.data_type.BFLOAT16
335+
or cudnn_q_data_type == cudnn.data_type.HALF
336+
):
322337
O, Stats = g.sdpa(
323338
name="sdpa",
324339
q=cudnn_q,
@@ -410,7 +425,7 @@ def _build_prefill_graph(
410425
[graph_b, h_qo, graph_s_qo, d_vo]
411426
).set_stride(
412427
[graph_s_qo * d_vo * h_qo, d_vo, d_vo * h_qo, 1]
413-
).set_data_type(cudnn.data_type.BFLOAT16)
428+
).set_data_type(cudnn_o_data_type)
414429

415430
if return_lse:
416431
Stats.set_uid(UIDs.STATS_UID.value).set_output(
@@ -455,6 +470,7 @@ def _batch_prefill_with_kv_cache(
455470
batch_offsets_stats: Optional[torch.Tensor] = None,
456471
out: Optional[torch.Tensor] = None,
457472
lse: Optional[torch.Tensor] = None,
473+
o_data_type: Optional[torch.dtype] = None,
458474
) -> tuple[torch.Tensor, torch.Tensor]:
459475
graph, tensors = _build_prefill_graph(
460476
q=q,
@@ -475,6 +491,7 @@ def _batch_prefill_with_kv_cache(
475491
batch_offsets_stats=batch_offsets_stats,
476492
out=out,
477493
lse=lse,
494+
o_data_type=o_data_type,
478495
)
479496

480497
var_map = {
@@ -555,6 +572,7 @@ def cudnn_batch_prefill_with_kv_cache(
555572
lse: Optional[torch.Tensor] = None,
556573
is_cuda_graph_compatible: bool = False,
557574
backend: Optional[str] = None,
575+
o_data_type: Optional[torch.dtype] = None,
558576
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
559577
"""Performs batched prefill attention with paged KV cache using cuDNN.
560578
@@ -581,7 +599,7 @@ def cudnn_batch_prefill_with_kv_cache(
581599
batch_offsets_o: Optional batch offsets for output tensor of shape (batch_size,) on GPU
582600
batch_offsets_k: Optional batch offsets for key tensor of shape (batch_size,) on GPU
583601
batch_offsets_v: Optional batch offsets for value tensor of shape (batch_size,) on GPU
584-
602+
o_data_type: Optional data type for output tensor
585603
Returns:
586604
Output tensor of shape (batch_size * seq_len_q, num_heads_qo, head_dim)
587605
If return_lse is True, also returns log-sum-exp tensor of shape (batch_size, seq_len_q, num_heads_qo)
@@ -624,8 +642,7 @@ def cudnn_batch_prefill_with_kv_cache(
624642

625643
if out is None:
626644
out_shape = (num_tokens, h_qo, d_vo)
627-
out = torch.empty(out_shape, device=q.device, dtype=torch.float16)
628-
# out = torch.empty(out_shape, device=q.device, dtype=q.dtype)
645+
out = torch.empty(out_shape, device=q.device, dtype=o_data_type)
629646

630647
if CUDNN_AVAILABLE and backend != "cubin":
631648
return _batch_prefill_with_kv_cache(
@@ -651,6 +668,7 @@ def cudnn_batch_prefill_with_kv_cache(
651668
batch_offsets_stats=batch_offsets_stats,
652669
out=out,
653670
lse=lse,
671+
o_data_type=o_data_type,
654672
)
655673
else:
656674
assert return_lse, "Currently only supports return_lse = True"

flashinfer/prefill.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2200,6 +2200,7 @@ def run(
22002200
batch_offsets_o=self._qo_indptr_buf,
22012201
out=out,
22022202
lse=lse,
2203+
o_data_type=out_dtype,
22032204
)
22042205
else:
22052206
if self._backend != "trtllm-gen":

tests/attention/test_cudnn_prefill.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -221,11 +221,8 @@ def test_cudnn_prefill_fp8(
221221
s_qo, s_kv + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device
222222
)
223223

224-
print("actual_seq_lens_q ", actual_seq_lens_q)
225-
print("actual_seq_lens_kv ", actual_seq_lens_kv)
226-
227224
cumsum_s_qo = torch.sum(actual_seq_lens_q)
228-
q = torch.ones(
225+
q = torch.randn(
229226
cumsum_s_qo, num_qo_heads, head_dim, device=device, dtype=torch.bfloat16
230227
)
231228

@@ -246,7 +243,7 @@ def test_cudnn_prefill_fp8(
246243
total_num_pages = num_pages_per_seq * batch_size
247244

248245
kv_cache_shape = (total_num_pages, 2, num_kv_heads, page_size, head_dim)
249-
kv_cache = torch.ones(size=kv_cache_shape, dtype=torch.bfloat16).to(device)
246+
kv_cache = torch.randn(size=kv_cache_shape, dtype=torch.bfloat16).to(device) * 0.05
250247
kv_cache = kv_cache.as_strided(
251248
kv_cache.shape,
252249
(
@@ -383,8 +380,4 @@ def test_cudnn_prefill_fp8(
383380

384381
output_ref = wrapper.run(q, kv_cache)
385382

386-
print("output ", output)
387-
print("output_ref ", output_ref)
388-
print("block_tables ", block_tables)
389-
390-
torch.testing.assert_close(output, output_ref, atol=1e-1, rtol=1e-1)
383+
torch.testing.assert_close(output, output_ref, atol=1e-2, rtol=1e-2)

0 commit comments

Comments
 (0)