|
17 | 17 | maybe_execute_in_parallel |
18 | 18 | from tensorrt_llm._torch.modules.rotary_embedding import RotaryEmbedding |
19 | 19 | from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager |
| 20 | +from tensorrt_llm._torch.utils import maybe_compile |
20 | 21 | from tensorrt_llm._utils import get_size_in_bytes |
21 | 22 | from tensorrt_llm.bindings import DataType |
22 | 23 | from tensorrt_llm.bindings.executor import KvCacheConfig |
@@ -572,6 +573,11 @@ def update_for_spec_dec(self): |
572 | 573 | self.on_update_kv_lens() |
573 | 574 |
|
574 | 575 |
|
| 576 | +@maybe_compile(dynamic=True) |
| 577 | +def _scale(weights, q_scale, s): |
| 578 | + return weights * q_scale.squeeze(-1) * s |
| 579 | + |
| 580 | + |
575 | 581 | class Indexer(nn.Module): |
576 | 582 |
|
577 | 583 | def __init__(self, |
@@ -962,9 +968,6 @@ def sparse_attn_indexer( |
962 | 968 | device=hidden_states.device) |
963 | 969 | topk_indices_buffer[:hidden_states.shape[0]] = -1 |
964 | 970 |
|
965 | | - # Store k_fp8 and k_scale into indexer k cache |
966 | | - self._update_k_cache(k_fp8, k_scale, metadata) |
967 | | - |
968 | 971 | if has_prefill: |
969 | 972 | # Use chunked prefill to reduce memory footprint |
970 | 973 | if metadata.indexer_prefill_chunks is not None: |
@@ -1099,9 +1102,7 @@ def weight_scale(self, hidden_states: torch.Tensor, |
1099 | 1102 | q_scale: torch.Tensor) -> torch.Tensor: |
1100 | 1103 | weights = indexer_weights if indexer_weights is not None else self.weights_proj( |
1101 | 1104 | hidden_states) |
1102 | | - weights = weights.unsqueeze(-1) * q_scale * self.weight_scale_factor |
1103 | | - # output weights is guaranteed to be float32 due to type promotion from q_scale (float32) |
1104 | | - weights = weights.squeeze(-1) |
| 1105 | + weights = _scale(weights, q_scale, self.weight_scale_factor) |
1105 | 1106 | return weights |
1106 | 1107 |
|
1107 | 1108 | @torch.inference_mode() |
@@ -1170,7 +1171,15 @@ def _prep_q_or_k(qk_pe, qk_nope): |
1170 | 1171 | q_fp8 = q_fp8.view(-1, self.n_heads, self.head_dim) |
1171 | 1172 | q_scale = q_scale.view(-1, self.n_heads, 1) |
1172 | 1173 |
|
1173 | | - weights = self.weight_scale(hidden_states, indexer_weights, q_scale) |
| 1174 | + weights, _ = maybe_execute_in_parallel( |
| 1175 | + lambda: self.weight_scale(hidden_states, indexer_weights, q_scale), |
| 1176 | + lambda: self._update_k_cache( |
| 1177 | + k_fp8, k_scale, metadata), # store k_fp8 and k_scale in k cache |
| 1178 | + self.ln_events[0], |
| 1179 | + self.ln_events[1], |
| 1180 | + self.aux_stream, |
| 1181 | + ) |
| 1182 | + |
1174 | 1183 | # Return topk indices buffer for sparse attention [num_tokens, index_topk] |
1175 | 1184 | return self.sparse_attn_indexer(metadata, hidden_states, q_fp8, k_fp8, |
1176 | 1185 | k_scale, weights) |
|
0 commit comments