@@ -504,7 +504,8 @@ def __init__(
504504 if can_use_trtllm and not flashinfer_disable_q_quantization ():
505505 self .q_data_type = self .kv_cache_dtype
506506 else :
507- self .q_data_type = self .model_config .dtype
507+ # self.q_data_type = self.model_config.dtype
508+ self .q_data_type = self .kv_cache_dtype
508509
509510 # Prefer TRTLLM attention for decoding in all cases.
510511 # This allows us to use AttentionCGSupport.UNIFORM_BATCH mode.
@@ -803,7 +804,7 @@ def build(
803804
804805 # The q quantization is not supported for non-trtllm attention,
805806 # fall back to model dtype.
806- self .q_data_type = self .model_config .dtype
807+ # self.q_data_type = self.model_config.dtype
807808
808809 attn_metadata = FlashInferMetadata (
809810 num_actual_tokens = num_actual_tokens ,
@@ -917,6 +918,7 @@ def build(
917918 logits_soft_cap = self .logits_soft_cap ,
918919 q_data_type = self .q_data_type ,
919920 kv_data_type = self .kv_cache_dtype ,
921+ o_data_type = self .model_config .dtype ,
920922 fixed_split_size = self .prefill_fixed_split_size ,
921923 disable_split_kv = self .disable_split_kv ,
922924 )
@@ -961,6 +963,7 @@ def build(
961963 logits_soft_cap = self .logits_soft_cap ,
962964 q_data_type = self .q_data_type ,
963965 kv_data_type = self .kv_cache_dtype ,
966+ o_data_type = self .model_config .dtype ,
964967 fixed_split_size = self .decode_fixed_split_size ,
965968 disable_split_kv = self .disable_split_kv ,
966969 )
@@ -1047,7 +1050,7 @@ def supports_quant_query_input(self) -> bool:
10471050 if flashinfer_disable_q_quantization ():
10481051 return False
10491052
1050- return self .support_trtllm_attn
1053+ return self .kv_cache_dtype . startswith ( "fp8" )
10511054
10521055 # FlashInfer requires attention sinks to be float32
10531056 def process_weights_after_loading (self , act_dtype : torch .dtype ):
@@ -1235,6 +1238,7 @@ def forward(
12351238 prefill_wrapper .run (
12361239 prefill_query ,
12371240 kv_cache_permute ,
1241+ q_scale = layer ._q_scale_float ,
12381242 k_scale = layer ._k_scale_float ,
12391243 v_scale = layer ._v_scale_float ,
12401244 out = output [num_decode_tokens :],
@@ -1328,6 +1332,7 @@ def forward(
13281332 decode_wrapper .run (
13291333 decode_query ,
13301334 kv_cache_permute ,
1335+ q_scale = layer ._q_scale_float ,
13311336 k_scale = layer ._k_scale_float ,
13321337 v_scale = layer ._v_scale_float ,
13331338 out = output_tmp ,
@@ -1341,6 +1346,7 @@ def forward(
13411346 decode_wrapper .run (
13421347 decode_query ,
13431348 kv_cache_permute ,
1349+ q_scale = layer ._q_scale_float ,
13441350 k_scale = layer ._k_scale_float ,
13451351 v_scale = layer ._v_scale_float ,
13461352 out = output [:num_decode_tokens ],
@@ -1414,6 +1420,7 @@ def fast_plan_decode(
14141420 logits_soft_cap : float | None = None ,
14151421 q_data_type : str | torch .dtype | None = "float16" ,
14161422 kv_data_type : str | torch .dtype | None = None ,
1423+ o_data_type : str | torch .dtype | None = None ,
14171424 data_type : str | torch .dtype | None = None ,
14181425 sm_scale : float | None = None ,
14191426 rope_scale : float | None = None ,
@@ -1452,6 +1459,7 @@ def fast_plan_decode(
14521459 logits_soft_cap ,
14531460 q_data_type ,
14541461 kv_data_type ,
1462+ o_data_type ,
14551463 data_type ,
14561464 sm_scale ,
14571465 rope_scale ,
@@ -1471,24 +1479,6 @@ def fast_plan_decode(
14711479 if logits_soft_cap is None :
14721480 logits_soft_cap = 0.0
14731481
1474- # Handle data types consistently
1475- if data_type is not None :
1476- if q_data_type is None :
1477- q_data_type = data_type
1478- if kv_data_type is None :
1479- kv_data_type = data_type
1480- elif q_data_type is None :
1481- q_data_type = "float16"
1482-
1483- if kv_data_type is None :
1484- kv_data_type = q_data_type
1485- q_data_type = (
1486- getattr (torch , q_data_type ) if isinstance (q_data_type , str ) else q_data_type
1487- )
1488- kv_data_type = (
1489- getattr (torch , kv_data_type ) if isinstance (kv_data_type , str ) else kv_data_type
1490- )
1491-
14921482 if batch_size != self ._fixed_batch_size :
14931483 raise ValueError (
14941484 "The batch size should be fixed in cudagraph mode, the runtime "
@@ -1508,8 +1498,8 @@ def fast_plan_decode(
15081498 qo_indptr_host = _get_range_buf (batch_size + 1 , "cpu" )
15091499
15101500 try :
1511- # Make sure we pass exactly 19 arguments for tensor core version
1512- self . _plan_info = self . _cached_module . plan (
1501+ # Make sure we pass exactly 19 arguments for fa2 backend and 16 arguments for fa3 backend
1502+ args = [
15131503 self ._float_workspace_buffer ,
15141504 self ._int_workspace_buffer ,
15151505 self ._pin_memory_int_workspace_buffer ,
@@ -1526,9 +1516,13 @@ def fast_plan_decode(
15261516 head_dim ,
15271517 False , # causal
15281518 window_left ,
1529- fixed_split_size ,
1530- disable_split_kv ,
1531- 0 ,
1519+ ]
1520+ if self ._backend == "fa2" :
1521+ args .append (fixed_split_size )
1522+ args .append (disable_split_kv )
1523+ args .append (0 ) # num_colocated_ctas
1524+ self ._plan_info = self ._cached_module .plan (
1525+ * args ,
15321526 )
15331527 except Exception as e :
15341528 raise RuntimeError (f"Error in tensor core plan: { e } " ) from e
0 commit comments