@@ -91,6 +91,7 @@ def _sdpa_prefill_key_fn(
9191 batch_offsets_stats : Optional [torch .Tensor ] = None ,
9292 out : Optional [torch .Tensor ] = None ,
9393 lse : Optional [torch .Tensor ] = None ,
94+ o_data_type : Optional [torch .dtype ] = None ,
9495):
9596 graph_b = actual_seq_lens_q .shape [0 ]
9697
@@ -149,6 +150,7 @@ def _build_prefill_graph(
149150 batch_offsets_stats : Optional [torch .Tensor ] = None ,
150151 out : Optional [torch .Tensor ] = None ,
151152 lse : Optional [torch .Tensor ] = None ,
153+ o_data_type : Optional [torch .dtype ] = None ,
152154 ):
153155 handle = _create_cudnn_handle (torch .cuda .current_stream (q .device ))
154156
@@ -163,6 +165,16 @@ def _build_prefill_graph(
163165 cudnn_k_data_type = cudnn .datatypes ._torch_to_cudnn_data_type (k_cache .dtype )
164166 cudnn_v_data_type = cudnn .datatypes ._torch_to_cudnn_data_type (v_cache .dtype )
165167
168+ cudnn_o_data_type = cudnn .datatypes ._torch_to_cudnn_data_type (o_data_type )
169+
170+ if (
171+ cudnn_q_data_type == cudnn .data_type .FP8_E4M3
172+ or cudnn_q_data_type == cudnn .data_type .FP8_E5M2
173+ ) and cudnn .backend_version () < 91800 :
174+ raise RuntimeError (
175+ f"FP8 is not supported in cuDNN backend version < 9.18.0, current version is { cudnn .backend_version ()} "
176+ )
177+
166178 with cudnn .graph (handle ) as (g , _ ):
167179 # Create tensors from the input tensors
168180 if q .dim () == 3 :
@@ -318,7 +330,10 @@ def _build_prefill_graph(
318330 actual_seq_lens_q is not None and actual_seq_lens_kv is not None
319331 )
320332
321- if cudnn_q_data_type == cudnn .data_type .BFLOAT16 :
333+ if (
334+ cudnn_q_data_type == cudnn .data_type .BFLOAT16
335+ or cudnn_q_data_type == cudnn .data_type .HALF
336+ ):
322337 O , Stats = g .sdpa (
323338 name = "sdpa" ,
324339 q = cudnn_q ,
@@ -410,7 +425,7 @@ def _build_prefill_graph(
410425 [graph_b , h_qo , graph_s_qo , d_vo ]
411426 ).set_stride (
412427 [graph_s_qo * d_vo * h_qo , d_vo , d_vo * h_qo , 1 ]
413- ).set_data_type (cudnn . data_type . BFLOAT16 )
428+ ).set_data_type (cudnn_o_data_type )
414429
415430 if return_lse :
416431 Stats .set_uid (UIDs .STATS_UID .value ).set_output (
@@ -455,6 +470,7 @@ def _batch_prefill_with_kv_cache(
455470 batch_offsets_stats : Optional [torch .Tensor ] = None ,
456471 out : Optional [torch .Tensor ] = None ,
457472 lse : Optional [torch .Tensor ] = None ,
473+ o_data_type : Optional [torch .dtype ] = None ,
458474) -> tuple [torch .Tensor , torch .Tensor ]:
459475 graph , tensors = _build_prefill_graph (
460476 q = q ,
@@ -475,6 +491,7 @@ def _batch_prefill_with_kv_cache(
475491 batch_offsets_stats = batch_offsets_stats ,
476492 out = out ,
477493 lse = lse ,
494+ o_data_type = o_data_type ,
478495 )
479496
480497 var_map = {
@@ -555,6 +572,7 @@ def cudnn_batch_prefill_with_kv_cache(
555572 lse : Optional [torch .Tensor ] = None ,
556573 is_cuda_graph_compatible : bool = False ,
557574 backend : Optional [str ] = None ,
575+ o_data_type : Optional [torch .dtype ] = None ,
558576) -> tuple [torch .Tensor , Optional [torch .Tensor ]]:
559577 """Performs batched prefill attention with paged KV cache using cuDNN.
560578
@@ -581,7 +599,7 @@ def cudnn_batch_prefill_with_kv_cache(
581599 batch_offsets_o: Optional batch offsets for output tensor of shape (batch_size,) on GPU
582600 batch_offsets_k: Optional batch offsets for key tensor of shape (batch_size,) on GPU
583601 batch_offsets_v: Optional batch offsets for value tensor of shape (batch_size,) on GPU
584-
602+ o_data_type: Optional data type for output tensor
585603 Returns:
586604 Output tensor of shape (batch_size * seq_len_q, num_heads_qo, head_dim)
587605 If return_lse is True, also returns log-sum-exp tensor of shape (batch_size, seq_len_q, num_heads_qo)
@@ -624,8 +642,7 @@ def cudnn_batch_prefill_with_kv_cache(
624642
625643 if out is None :
626644 out_shape = (num_tokens , h_qo , d_vo )
627- out = torch .empty (out_shape , device = q .device , dtype = torch .float16 )
628- # out = torch.empty(out_shape, device=q.device, dtype=q.dtype)
645+ out = torch .empty (out_shape , device = q .device , dtype = o_data_type )
629646
630647 if CUDNN_AVAILABLE and backend != "cubin" :
631648 return _batch_prefill_with_kv_cache (
@@ -651,6 +668,7 @@ def cudnn_batch_prefill_with_kv_cache(
651668 batch_offsets_stats = batch_offsets_stats ,
652669 out = out ,
653670 lse = lse ,
671+ o_data_type = o_data_type ,
654672 )
655673 else :
656674 assert return_lse , "Currently only supports return_lse = True"
0 commit comments