Skip to content

fix: [https://nvbugspro.nvidia.com/bug/5242406][fix] Fix fp8 kvcache support #3877

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions tensorrt_llm/_torch/attention_backend/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,11 +422,9 @@ def __init__(
def update_quant_config(self, new_quant_config: Optional[QuantConfig]):
self.quant_config = new_quant_config
self.has_fp8_kv_cache = False
if self.quant_config and self.quant_config.layer_quant_mode.has_any_quant(
):
quant_mode = self.quant_config.layer_quant_mode
if quant_mode.has_fp8_kv_cache():
self.has_fp8_kv_cache = True
if self.quant_config:
self.has_fp8_kv_cache = self.quant_config.layer_quant_mode.has_fp8_kv_cache(
)

def forward(self,
q: torch.Tensor,
Expand Down
19 changes: 5 additions & 14 deletions tensorrt_llm/_torch/attention_backend/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def __init__(
head_dim (int): The size of each attention head (hidden_size // num_heads).
num_kv_heads (int): The number of kv heads. Defaults to num_heads if None.
pos_embd_params (PositionalEmbeddingParams): Optional parameters defining how positional embedding should be applied.
quant_config (QuantConfig): Optional quantization configuration. If None, no quantization is applied.
"""
rope_params = None
if pos_embd_params is not None:
Expand Down Expand Up @@ -126,7 +125,7 @@ def __init__(
self.kwargs = {}
self.kwargs.update(kwargs)

def create_weights(self, quant_config: Optional[QuantConfig] = None):
def update_quant_config(self, quant_config: Optional[QuantConfig] = None):
quant_config = quant_config or QuantConfig()
self.quant_mode = int(quant_config.layer_quant_mode)

Expand Down Expand Up @@ -623,16 +622,17 @@ def __init__(

def update_quant_config(self, new_quant_config: Optional[QuantConfig]):
self.quant_config = new_quant_config
self.wrapper.create_weights(self.quant_config)
self.wrapper.update_quant_config(self.quant_config)

self.has_fp8_qdq = self.has_fp8_kv_cache = self.has_nvfp4 = False
if self.quant_config is not None:
self.has_fp8_kv_cache = self.quant_config.layer_quant_mode.has_fp8_kv_cache(
)

self.has_fp8_qdq = self.quant_config.layer_quant_mode.has_fp8_qdq()
self.has_fp8_block_wise = self.quant_config.layer_quant_mode.has_fp8_block_scales(
)
self.has_nvfp4 = self.quant_config.layer_quant_mode.has_nvfp4()
self.has_fp8_kv_cache = self.quant_config.layer_quant_mode.has_fp8_kv_cache(
)
self.has_nvfp4 = self.quant_config.layer_quant_mode.has_nvfp4()

def forward(
Expand Down Expand Up @@ -662,15 +662,6 @@ def forward(
or metadata.runtime_features.has_speculative_draft_tokens
) if metadata.runtime_features else False

if use_paged_context_fmha and self.has_fp8_kv_cache:
# NOTE: W4A8_AWQ can be included too, exclude for now since
# we don't use int4 in PyTorch
if not (self.has_fp8_qdq or self.has_nvfp4
or self.has_fp8_block_wise):
raise RuntimeError(
"When FP8 KV cache is being used, paged context FMHA cannot be used without "
"FP8 attention.")

num_seqs = metadata.num_seqs
self.wrapper.plan(
tokens_per_block=metadata.tokens_per_block,
Expand Down
8 changes: 8 additions & 0 deletions tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,18 @@ def from_pretrained(cls,
mixed_quant_config_file = model_dir / 'quant_cfg.json'
with open(mixed_quant_config_file) as fm:
mixed_quant_configs = json.load(fm)
# kv_cache_quant_algo is global regardless of MIXED_PRECISION
kv_cache_quant_algo = mixed_quant_configs[
'kv_cache_quant_algo']
mixed_quant_configs = mixed_quant_configs[
'quantized_layers']
if kv_cache_quant_algo is not None and quant_config.kv_cache_quant_algo is not None:
if kv_cache_quant_algo != quant_config.kv_cache_quant_algo:
raise RuntimeError(
f"The kvcache config in 'quant_cfg.json', {kv_cache_quant_algo},"
f"is different from 'hf_quant_config.json', {quant_config.kv_cache_quant_algo}!"
)
kv_cache_quant_algo = kv_cache_quant_algo or quant_config.kv_cache_quant_algo

for layer in mixed_quant_configs:
config = QuantConfig()
Expand Down
12 changes: 9 additions & 3 deletions tensorrt_llm/_torch/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from torch.utils._pytree import tree_any_only
from tqdm import tqdm

from tensorrt_llm.mapping import Mapping

from ...logger import logger
from ...mapping import Mapping
from ...models.modeling_utils import QuantConfig
from ..attention_backend import AttentionMetadata
from ..model_config import ModelConfig, TConfig
from ..modules.attention import Attention
Expand Down Expand Up @@ -432,15 +432,21 @@ def __post_init__(self):
# TODO: support MLA

# 2. skip quant for modules in QuantConfig.exclude_modules
# kv_cache_quant_algo takes precedence over exclude_modules
quant_config = self.model_config.quant_config
kv_cache_quant_algo = None
if quant_config:
kv_cache_quant_algo = quant_config.kv_cache_quant_algo
new_config = QuantConfig(kv_cache_quant_algo=kv_cache_quant_algo)

if quant_config is not None:
if quant_config.exclude_modules is not None:
for name, module in self.named_modules():
is_excluded = quant_config.is_module_excluded_from_quantization(
name)
if is_excluded and getattr(module, "quant_config",
None) is not None:
module.quant_config = None
module.quant_config = new_config

for _, module in self.named_modules():
if callable(getattr(module, "create_weights", None)):
Expand Down
6 changes: 4 additions & 2 deletions tensorrt_llm/_torch/modules/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,8 @@ def create_weights(self):
self.has_fp8_qdq = False
self.has_fp8_block_scales = False
self.has_nvfp4 = False
if self.quant_config and self.quant_config.quant_mode.has_any_quant():
if self.quant_config and self.quant_config.quant_mode.has_any_quant(
exclude_kv_cache=True):
self.has_any_quant = True
qc = self.quant_config
if qc.quant_mode.has_fp8_qdq():
Expand Down Expand Up @@ -1128,7 +1129,8 @@ def load_expert_w2_weight(w2_weight,
load_expert_w2_weight(w2_weight, self.w2_weight.data[expert_idx],
is_trtllm_nvfp4)

if self.quant_config and self.quant_config.quant_mode.has_any_quant():
if self.quant_config and self.quant_config.quant_mode.has_any_quant(
exclude_kv_cache=True):
if self.quant_config.quant_mode.has_fp8_qdq():
self._load_fp8_qdq_scales(weights)
elif self.quant_config.quant_mode.has_nvfp4():
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,9 @@ def create_weights(self):
self.has_fp8_qdq = False
self.has_fp8_block_scales = False
self.has_nvfp4 = False
# only _create_weights, and load quantized weight directly.

if self.quant_config and self.quant_config.layer_quant_mode.has_any_quant(
):
exclude_kv_cache=True):
self.has_any_quant = True
qc = self.quant_config
if qc.layer_quant_mode.has_fp8_qdq():
Expand Down
22 changes: 13 additions & 9 deletions tensorrt_llm/quantization/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,15 +175,19 @@ def has_nvfp4(self):
def has_weight_quant(self):
return self._any(self.INT4_WEIGHTS | self.INT8_WEIGHTS)

def has_any_quant(self):
return self._any(self.INT4_WEIGHTS
| self.INT8_WEIGHTS
| self.ACTIVATIONS
| self.INT8_KV_CACHE | self.FP8_KV_CACHE
| self.NVFP4_KV_CACHE
| self.FP8_QDQ | self.FP8_ROWWISE | self.W4A8_QSERVE
| self.FP8_1x128_128x128
| self.NVFP4)
def has_any_quant(self, exclude_kv_cache: bool = False):
has_quant = self._any(self.INT4_WEIGHTS
| self.INT8_WEIGHTS
| self.ACTIVATIONS
| self.FP8_QDQ | self.FP8_ROWWISE
| self.W4A8_QSERVE
| self.FP8_1x128_128x128
| self.NVFP4)
if exclude_kv_cache:
return has_quant

return has_quant | self._any(self.INT8_KV_CACHE | self.FP8_KV_CACHE
| self.NVFP4_KV_CACHE)

def set_int8_kv_cache(self):
return self | self.INT8_KV_CACHE
Expand Down