6363 _get_range_buf ,
6464 _unpack_paged_kv_cache ,
6565 canonicalize_torch_dtype ,
66+ determine_attention_backend ,
6667 device_support_pdl ,
6768 get_device_sm_count ,
6869 is_float8 ,
@@ -721,7 +722,7 @@ def __init__(
721722 self ._jit_module = get_batch_prefill_jit_module (
722723 jit_args [0 ],
723724 gen_customize_batch_prefill_module (
724- "fa2" , * jit_args
725+ backend , * jit_args
725726 ).build_and_load (),
726727 )
727728 else :
@@ -834,6 +835,7 @@ def plan(
834835 logits_soft_cap : Optional [float ] = None ,
835836 q_data_type : Optional [Union [str , torch .dtype ]] = "float16" ,
836837 kv_data_type : Optional [Union [str , torch .dtype ]] = None ,
838+ o_data_type : Optional [Union [str , torch .dtype ]] = None ,
837839 data_type : Optional [Union [str , torch .dtype ]] = None ,
838840 sm_scale : Optional [float ] = None ,
839841 rope_scale : Optional [float ] = None ,
@@ -881,6 +883,9 @@ def plan(
881883 kv_data_type : Optional[Union[str, torch.dtype]]
882884 The data type of the key/value tensor. If None, will be set to
883885 ``q_data_type``. Defaults to ``None``.
886+ o_data_type : Optional[Union[str, torch.dtype]]
887+ The data type of the output tensor. If None, will be set to :attr:`q_data_type`.
888+ For FP8 inputs, this should typically be set to torch.float16 or torch.bfloat16.
884889 data_type: Optional[Union[str, torch.dtype]]
885890 The data type of both the query and key/value tensors. Defaults to torch.float16.
886891 data_type is deprecated, please use q_data_type and kv_data_type instead.
@@ -966,6 +971,10 @@ def plan(
966971 if kv_data_type is None :
967972 kv_data_type = q_data_type
968973 kv_data_type = canonicalize_torch_dtype (kv_data_type )
974+ if o_data_type is None :
975+ o_data_type = q_data_type
976+ o_data_type = canonicalize_torch_dtype (o_data_type )
977+
969978 if fixed_split_size is not None and not self .use_tensor_cores :
970979 raise ValueError (
971980 "fixed_split_size is only supported by tensor core decode for now."
@@ -975,6 +984,7 @@ def plan(
975984
976985 self ._cached_q_data_type = q_data_type
977986 self ._cached_kv_data_type = kv_data_type
987+ self ._cached_o_data_type = o_data_type
978988 self ._batch_size = batch_size
979989 self ._num_qo_heads = num_qo_heads
980990 self ._num_kv_heads = num_kv_heads
@@ -1014,7 +1024,7 @@ def plan(
10141024 self ._cached_module = get_trtllm_gen_decode_module (
10151025 q_data_type ,
10161026 kv_data_type ,
1017- q_data_type ,
1027+ o_data_type ,
10181028 indptr .dtype ,
10191029 head_dim ,
10201030 head_dim ,
@@ -1029,11 +1039,20 @@ def plan(
10291039 if self ._jit_module is not None :
10301040 self ._cached_module = self ._jit_module
10311041 else :
1042+ if self ._backend == "auto" :
1043+ self ._backend = determine_attention_backend (
1044+ self .device ,
1045+ PosEncodingMode [pos_encoding_mode ].value ,
1046+ False , # use_fp16_qk_reduction
1047+ False , # use_custom_mask
1048+ q_data_type ,
1049+ kv_data_type ,
1050+ )
10321051 self ._cached_module = get_batch_prefill_module (
1033- "fa2" ,
1052+ self . _backend ,
10341053 q_data_type ,
10351054 kv_data_type ,
1036- q_data_type ,
1055+ o_data_type ,
10371056 indptr .dtype ,
10381057 head_dim , # head_dim_qk
10391058 head_dim , # head_dim_vo
@@ -1043,7 +1062,7 @@ def plan(
10431062 False , # use_fp16_qk_reduction
10441063 )
10451064
1046- self . _plan_info = self . _cached_module . plan (
1065+ args = [
10471066 self ._float_workspace_buffer ,
10481067 self ._int_workspace_buffer ,
10491068 self ._pin_memory_int_workspace_buffer ,
@@ -1060,9 +1079,13 @@ def plan(
10601079 head_dim ,
10611080 False , # causal
10621081 window_left ,
1063- fixed_split_size ,
1064- disable_split_kv ,
1065- 0 , # num_colocated_ctas
1082+ ]
1083+ if self ._backend == "fa2" :
1084+ args .append (fixed_split_size )
1085+ args .append (disable_split_kv )
1086+ args .append (0 ) # num_colocated_ctas
1087+ self ._plan_info = self ._cached_module .plan (
1088+ * args ,
10661089 )
10671090 else :
10681091 if self ._jit_module is not None :
@@ -1071,7 +1094,7 @@ def plan(
10711094 self ._cached_module = get_batch_decode_module (
10721095 q_data_type ,
10731096 kv_data_type ,
1074- q_data_type ,
1097+ o_data_type ,
10751098 indptr .dtype ,
10761099 head_dim , # head_dim_qk
10771100 head_dim , # head_dim_vo
@@ -1281,9 +1304,13 @@ def run(
12811304 )
12821305
12831306 if out is None :
1284- out = torch .empty_like (q )
1307+ out_dtype = getattr (self , "_cached_o_data_type" , None ) or q .dtype
1308+ out = torch .empty (
1309+ q .shape [:- 1 ] + v_cache .shape [- 1 :], dtype = out_dtype , device = q .device
1310+ )
12851311 else :
1286- check_shape_dtype_device (out , q .shape , q .dtype , q .device , "out" )
1312+ out_dtype = getattr (self , "_cached_o_data_type" , None ) or q .dtype
1313+ check_shape_dtype_device (out , q .shape , out_dtype , q .device , "out" )
12871314
12881315 if self ._backend == "trtllm-gen" :
12891316 q = q .view (q .size (0 ) // q_len_per_req , q_len_per_req , q .size (1 ), q .size (2 ))
@@ -1311,6 +1338,14 @@ def run(
13111338 if self ._jit_module is not None :
13121339 run_args .extend (list (args ))
13131340 else :
1341+ # Extract FP8 scale tensors from *args if q is FP8
1342+ fp8_scale_q = None
1343+ fp8_scale_k = None
1344+ fp8_scale_v = None
1345+ if is_float8 (q ) and len (args ) >= 3 :
1346+ fp8_scale_q = args [0 ]
1347+ fp8_scale_k = args [1 ]
1348+ fp8_scale_v = args [2 ]
13141349 run_args += [
13151350 None , # packed_custom_mask
13161351 None , # mask_indptr_buf
@@ -1320,9 +1355,9 @@ def run(
13201355 None , # maybe_max_item_len_ptr
13211356 logits_soft_cap ,
13221357 sm_scale ,
1323- None , # scale_q, not supported yet
1324- None , # scale_k
1325- None , # scale_v
1358+ fp8_scale_q ,
1359+ fp8_scale_k ,
1360+ fp8_scale_v ,
13261361 rope_scale ,
13271362 rope_theta ,
13281363 0 , # token_pos_in_items_len
@@ -1375,7 +1410,7 @@ def run(
13751410 ]
13761411
13771412 self ._cached_module .run (* run_args )
1378- if v_scale is not None :
1413+ if v_scale is not None and v_scale != 1.0 :
13791414 # TODO(Zihao): fused into kernel
13801415 if is_float8 (out ):
13811416 out = (out .to (torch .float32 ) * v_scale ).to (out .dtype )
@@ -2597,8 +2632,8 @@ def fast_decode_plan(
25972632 kv_lens_arr_host = get_seq_lens (indptr_host , last_page_len_host , page_size )
25982633
25992634 try :
2600- # Make sure we pass exactly 16 arguments for tensor core version
2601- self . _plan_info = self . _cached_module . plan (
2635+ # Make sure we pass exactly 19 arguments for fa2 backend and 16 arguments for fa3 backend
2636+ args = [
26022637 self ._float_workspace_buffer ,
26032638 self ._int_workspace_buffer ,
26042639 self ._pin_memory_int_workspace_buffer ,
@@ -2615,9 +2650,13 @@ def fast_decode_plan(
26152650 head_dim ,
26162651 False , # causal
26172652 window_left ,
2618- fixed_split_size ,
2619- disable_split_kv ,
2620- 0 , # num_colocated_ctas
2653+ ]
2654+ if self ._backend == "fa2" :
2655+ args .append (fixed_split_size )
2656+ args .append (disable_split_kv )
2657+ args .append (0 ) # num_colocated_ctas
2658+ self ._plan_info = self ._cached_module .plan (
2659+ * args ,
26212660 )
26222661 except Exception as e :
26232662 raise RuntimeError (f"Error in standard plan: { e } " ) from e
0 commit comments