diff --git a/flashinfer/decode.py b/flashinfer/decode.py index cc865ae5f8..b2b7a930d5 100644 --- a/flashinfer/decode.py +++ b/flashinfer/decode.py @@ -63,6 +63,7 @@ _get_range_buf, _unpack_paged_kv_cache, canonicalize_torch_dtype, + determine_attention_backend, device_support_pdl, get_device_sm_count, is_float8, @@ -721,7 +722,7 @@ def __init__( self._jit_module = get_batch_prefill_jit_module( jit_args[0], gen_customize_batch_prefill_module( - "fa2", *jit_args + backend, *jit_args ).build_and_load(), ) else: @@ -834,6 +835,7 @@ def plan( logits_soft_cap: Optional[float] = None, q_data_type: Optional[Union[str, torch.dtype]] = "float16", kv_data_type: Optional[Union[str, torch.dtype]] = None, + o_data_type: Optional[Union[str, torch.dtype]] = None, data_type: Optional[Union[str, torch.dtype]] = None, sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, @@ -881,6 +883,9 @@ def plan( kv_data_type : Optional[Union[str, torch.dtype]] The data type of the key/value tensor. If None, will be set to ``q_data_type``. Defaults to ``None``. + o_data_type : Optional[Union[str, torch.dtype]] + The data type of the output tensor. If None, will be set to :attr:`q_data_type`. + For FP8 inputs, this should typically be set to torch.float16 or torch.bfloat16. data_type: Optional[Union[str, torch.dtype]] The data type of both the query and key/value tensors. Defaults to torch.float16. data_type is deprecated, please use q_data_type and kv_data_type instead. @@ -966,6 +971,10 @@ def plan( if kv_data_type is None: kv_data_type = q_data_type kv_data_type = canonicalize_torch_dtype(kv_data_type) + if o_data_type is None: + o_data_type = q_data_type + o_data_type = canonicalize_torch_dtype(o_data_type) + if fixed_split_size is not None and not self.use_tensor_cores: raise ValueError( "fixed_split_size is only supported by tensor core decode for now." @@ -975,6 +984,7 @@ def plan( self._cached_q_data_type = q_data_type self._cached_kv_data_type = kv_data_type + self._cached_o_data_type = o_data_type self._batch_size = batch_size self._num_qo_heads = num_qo_heads self._num_kv_heads = num_kv_heads @@ -1014,7 +1024,7 @@ def plan( self._cached_module = get_trtllm_gen_decode_module( q_data_type, kv_data_type, - q_data_type, + o_data_type, indptr.dtype, head_dim, head_dim, @@ -1029,11 +1039,20 @@ def plan( if self._jit_module is not None: self._cached_module = self._jit_module else: + if self._backend == "auto": + self._backend = determine_attention_backend( + self.device, + PosEncodingMode[pos_encoding_mode].value, + False, # use_fp16_qk_reduction + False, # use_custom_mask + q_data_type, + kv_data_type, + ) self._cached_module = get_batch_prefill_module( - "fa2", + self._backend, q_data_type, kv_data_type, - q_data_type, + o_data_type, indptr.dtype, head_dim, # head_dim_qk head_dim, # head_dim_vo @@ -1043,7 +1062,7 @@ def plan( False, # use_fp16_qk_reduction ) - self._plan_info = self._cached_module.plan( + args = [ self._float_workspace_buffer, self._int_workspace_buffer, self._pin_memory_int_workspace_buffer, @@ -1060,9 +1079,13 @@ def plan( head_dim, False, # causal window_left, - fixed_split_size, - disable_split_kv, - 0, # num_colocated_ctas + ] + if self._backend == "fa2": + args.append(fixed_split_size) + args.append(disable_split_kv) + args.append(0) # num_colocated_ctas + self._plan_info = self._cached_module.plan( + *args, ) else: if self._jit_module is not None: @@ -1071,7 +1094,7 @@ def plan( self._cached_module = get_batch_decode_module( q_data_type, kv_data_type, - q_data_type, + o_data_type, indptr.dtype, head_dim, # head_dim_qk head_dim, # head_dim_vo @@ -1281,9 +1304,13 @@ def run( ) if out is None: - out = torch.empty_like(q) + out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype + out = torch.empty( + q.shape[:-1] + v_cache.shape[-1:], dtype=out_dtype, device=q.device + ) else: - check_shape_dtype_device(out, q.shape, q.dtype, q.device, "out") + out_dtype = getattr(self, "_cached_o_data_type", None) or q.dtype + check_shape_dtype_device(out, q.shape, out_dtype, q.device, "out") if self._backend == "trtllm-gen": q = q.view(q.size(0) // q_len_per_req, q_len_per_req, q.size(1), q.size(2)) @@ -1311,6 +1338,14 @@ def run( if self._jit_module is not None: run_args.extend(list(args)) else: + # Extract FP8 scale tensors from *args if q is FP8 + fp8_scale_q = None + fp8_scale_k = None + fp8_scale_v = None + if is_float8(q) and len(args) >= 3: + fp8_scale_q = args[0] + fp8_scale_k = args[1] + fp8_scale_v = args[2] run_args += [ None, # packed_custom_mask None, # mask_indptr_buf @@ -1320,9 +1355,9 @@ def run( None, # maybe_max_item_len_ptr logits_soft_cap, sm_scale, - None, # scale_q, not supported yet - None, # scale_k - None, # scale_v + fp8_scale_q, + fp8_scale_k, + fp8_scale_v, rope_scale, rope_theta, 0, # token_pos_in_items_len @@ -1375,7 +1410,7 @@ def run( ] self._cached_module.run(*run_args) - if v_scale is not None: + if v_scale is not None and v_scale != 1.0: # TODO(Zihao): fused into kernel if is_float8(out): out = (out.to(torch.float32) * v_scale).to(out.dtype) @@ -2597,8 +2632,8 @@ def fast_decode_plan( kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size) try: - # Make sure we pass exactly 16 arguments for tensor core version - self._plan_info = self._cached_module.plan( + # Make sure we pass exactly 19 arguments for fa2 backend and 16 arguments for fa3 backend + args = [ self._float_workspace_buffer, self._int_workspace_buffer, self._pin_memory_int_workspace_buffer, @@ -2615,9 +2650,13 @@ def fast_decode_plan( head_dim, False, # causal window_left, - fixed_split_size, - disable_split_kv, - 0, # num_colocated_ctas + ] + if self._backend == "fa2": + args.append(fixed_split_size) + args.append(disable_split_kv) + args.append(0) # num_colocated_ctas + self._plan_info = self._cached_module.plan( + *args, ) except Exception as e: raise RuntimeError(f"Error in standard plan: {e}") from e diff --git a/flashinfer/jit/attention/modules.py b/flashinfer/jit/attention/modules.py index d596695ad1..bb6962b791 100755 --- a/flashinfer/jit/attention/modules.py +++ b/flashinfer/jit/attention/modules.py @@ -984,6 +984,13 @@ def gen_batch_prefill_module( # KV-only quant is not influenced by this flag fp8_enabled = dtype_q in [torch.float8_e4m3fn, torch.float8_e5m2] + assert backend in ["fa2", "fa3"], ( + f"backend must be fa2 or fa3 in gen_batch_prefill_module(), got: {backend}" + ) + assert dtype_o not in [torch.float8_e4m3fn, torch.float8_e5m2], ( + "FP8 output is not supported in fa2/fa3 backends yet" + ) + if backend == "fa2": assert not fp8_enabled, "fp8 tensor core is not supported in fa2 backend" additional_tensor_names = [ diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 41fac0e4e9..d5e288267d 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -1701,7 +1701,7 @@ def plan( The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`. o_data_type : Optional[Union[str, torch.dtype]] The data type of the output tensor. If None, will be set to :attr:`q_data_type`. - For FP8 inputs, this should typically be set to torch.float16. + For FP8 inputs, this should typically be set to torch.float16 or torch.bfloat16. non_blocking : bool Whether to copy the input tensors to the device asynchronously, defaults to ``True``. prefix_len_ptr :Optional[torch.Tensor] @@ -2077,6 +2077,8 @@ def run( *args Additional arguments for custom kernels. + q_scale : Optional[float] + The calibration scale of query for fp8 input, if not provided, will be set to ``1.0``. k_scale : Optional[float] The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``. v_scale : Optional[float] @@ -2105,6 +2107,11 @@ def run( _check_cached_qkv_data_type( q, k_cache, self._cached_q_data_type, self._cached_kv_data_type ) + o_dtype = self._cached_o_data_type + if out is not None and out.dtype != o_dtype: + raise ValueError( + f"The dtype of out {out.dtype} does not match the o_data_type {o_dtype} specified in plan function." + ) if self._kv_layout == "NHD": page_size = k_cache.shape[1] @@ -2258,7 +2265,7 @@ def run( assert self._cached_module is not None, "cached module is not initialized" self._cached_module.paged_run(*run_args) - if v_scale is not None: + if v_scale is not None and v_scale != 1.0: # TODO(Zihao): fused into kernel if is_float8(out): out = (out.to(torch.float32) * v_scale).to(out.dtype) @@ -2646,7 +2653,7 @@ def plan( The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`. o_data_type : Optional[Union[str, torch.dtype]] The data type of the output tensor. If None, will be set to :attr:`q_data_type`. - For FP8 inputs, this should typically be set to torch.float16. + For FP8 inputs, this should typically be set to torch.float16 or torch.bfloat16. non_blocking : bool Whether to copy the input tensors to the device asynchronously, defaults to ``True``. prefix_len_ptr :Optional[torch.Tensor]