@@ -514,7 +514,8 @@ def __init__(
514514 ):
515515 self .q_data_type = self .kv_cache_dtype
516516 else :
517- self .q_data_type = self .model_config .dtype
517+ # self.q_data_type = self.model_config.dtype
518+ self .q_data_type = self .kv_cache_dtype
518519
519520 # Prefer TRTLLM attention for decoding in all cases.
520521 # This allows us to use AttentionCGSupport.UNIFORM_BATCH mode.
@@ -814,7 +815,7 @@ def build(
814815
815816 # The q quantization is not supported for non-trtllm attention,
816817 # fall back to model dtype.
817- self .q_data_type = self .model_config .dtype
818+ # self.q_data_type = self.model_config.dtype
818819
819820 attn_metadata = FlashInferMetadata (
820821 num_actual_tokens = num_actual_tokens ,
@@ -928,6 +929,7 @@ def build(
928929 logits_soft_cap = self .logits_soft_cap ,
929930 q_data_type = self .q_data_type ,
930931 kv_data_type = self .kv_cache_dtype ,
932+ o_data_type = self .model_config .dtype ,
931933 fixed_split_size = self .prefill_fixed_split_size ,
932934 disable_split_kv = self .disable_split_kv ,
933935 )
@@ -972,6 +974,7 @@ def build(
972974 logits_soft_cap = self .logits_soft_cap ,
973975 q_data_type = self .q_data_type ,
974976 kv_data_type = self .kv_cache_dtype ,
977+ o_data_type = self .model_config .dtype ,
975978 fixed_split_size = self .decode_fixed_split_size ,
976979 disable_split_kv = self .disable_split_kv ,
977980 )
@@ -1045,7 +1048,7 @@ def __init__(
10451048 self .support_trtllm_attn = can_use_trtllm_attention (num_heads , num_kv_heads )
10461049 vllm_config = get_current_vllm_config ()
10471050 self .supports_quant_query_input = (
1048- self .support_trtllm_attn
1051+ self .kv_cache_dtype . startswith ( "fp8" )
10491052 and not vllm_config .attention_config .disable_flashinfer_q_quantization
10501053 )
10511054 self .bmm1_scale : float | None = None
@@ -1245,6 +1248,7 @@ def forward(
12451248 prefill_wrapper .run (
12461249 prefill_query ,
12471250 kv_cache_permute ,
1251+ q_scale = layer ._q_scale_float ,
12481252 k_scale = layer ._k_scale_float ,
12491253 v_scale = layer ._v_scale_float ,
12501254 out = output [num_decode_tokens :],
@@ -1338,6 +1342,7 @@ def forward(
13381342 decode_wrapper .run (
13391343 decode_query ,
13401344 kv_cache_permute ,
1345+ q_scale = layer ._q_scale_float ,
13411346 k_scale = layer ._k_scale_float ,
13421347 v_scale = layer ._v_scale_float ,
13431348 out = output_tmp ,
@@ -1354,6 +1359,7 @@ def forward(
13541359 decode_wrapper .run (
13551360 decode_query ,
13561361 kv_cache_permute ,
1362+ q_scale = layer ._q_scale_float ,
13571363 k_scale = layer ._k_scale_float ,
13581364 v_scale = layer ._v_scale_float ,
13591365 out = output [:num_decode_tokens ],
@@ -1427,6 +1433,7 @@ def fast_plan_decode(
14271433 logits_soft_cap : float | None = None ,
14281434 q_data_type : str | torch .dtype | None = "float16" ,
14291435 kv_data_type : str | torch .dtype | None = None ,
1436+ o_data_type : str | torch .dtype | None = None ,
14301437 data_type : str | torch .dtype | None = None ,
14311438 sm_scale : float | None = None ,
14321439 rope_scale : float | None = None ,
@@ -1465,6 +1472,7 @@ def fast_plan_decode(
14651472 logits_soft_cap ,
14661473 q_data_type ,
14671474 kv_data_type ,
1475+ o_data_type ,
14681476 data_type ,
14691477 sm_scale ,
14701478 rope_scale ,
@@ -1484,24 +1492,6 @@ def fast_plan_decode(
14841492 if logits_soft_cap is None :
14851493 logits_soft_cap = 0.0
14861494
1487- # Handle data types consistently
1488- if data_type is not None :
1489- if q_data_type is None :
1490- q_data_type = data_type
1491- if kv_data_type is None :
1492- kv_data_type = data_type
1493- elif q_data_type is None :
1494- q_data_type = "float16"
1495-
1496- if kv_data_type is None :
1497- kv_data_type = q_data_type
1498- q_data_type = (
1499- getattr (torch , q_data_type ) if isinstance (q_data_type , str ) else q_data_type
1500- )
1501- kv_data_type = (
1502- getattr (torch , kv_data_type ) if isinstance (kv_data_type , str ) else kv_data_type
1503- )
1504-
15051495 if batch_size != self ._fixed_batch_size :
15061496 raise ValueError (
15071497 "The batch size should be fixed in cudagraph mode, the runtime "
@@ -1521,8 +1511,9 @@ def fast_plan_decode(
15211511 qo_indptr_host = _get_range_buf (batch_size + 1 , "cpu" )
15221512
15231513 try :
1524- # Make sure we pass exactly 19 arguments for tensor core version
1525- self ._plan_info = self ._cached_module .plan (
1514+ # Make sure we pass exactly 19 arguments for fa2 backend and 16 arguments for
1515+ # fa3 backend
1516+ args = [
15261517 self ._float_workspace_buffer ,
15271518 self ._int_workspace_buffer ,
15281519 self ._pin_memory_int_workspace_buffer ,
@@ -1539,9 +1530,13 @@ def fast_plan_decode(
15391530 head_dim ,
15401531 False , # causal
15411532 window_left ,
1542- fixed_split_size ,
1543- disable_split_kv ,
1544- 0 ,
1533+ ]
1534+ if self ._backend == "fa2" :
1535+ args .append (fixed_split_size )
1536+ args .append (disable_split_kv )
1537+ args .append (0 ) # num_colocated_ctas
1538+ self ._plan_info = self ._cached_module .plan (
1539+ * args ,
15451540 )
15461541 except Exception as e :
15471542 raise RuntimeError (f"Error in tensor core plan: { e } " ) from e
0 commit comments