Skip to content

Commit e26d4e9

Browse files
committed
Use torch compile to fuse weight scaling and multistream
Signed-off-by: Chang Liu (Enterprise Products) <[email protected]>
1 parent 6ff82ea commit e26d4e9

File tree

3 files changed

+31
-19
lines changed

3 files changed

+31
-19
lines changed

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
maybe_execute_in_parallel
1818
from tensorrt_llm._torch.modules.rotary_embedding import RotaryEmbedding
1919
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
20+
from tensorrt_llm._torch.utils import maybe_compile
2021
from tensorrt_llm._utils import get_size_in_bytes
2122
from tensorrt_llm.bindings import DataType
2223
from tensorrt_llm.bindings.executor import KvCacheConfig
@@ -572,6 +573,11 @@ def update_for_spec_dec(self):
572573
self.on_update_kv_lens()
573574

574575

576+
@maybe_compile(dynamic=True)
577+
def _scale(weights, q_scale, s):
578+
return weights * q_scale.squeeze(-1) * s
579+
580+
575581
class Indexer(nn.Module):
576582

577583
def __init__(self,
@@ -962,9 +968,6 @@ def sparse_attn_indexer(
962968
device=hidden_states.device)
963969
topk_indices_buffer[:hidden_states.shape[0]] = -1
964970

965-
# Store k_fp8 and k_scale into indexer k cache
966-
self._update_k_cache(k_fp8, k_scale, metadata)
967-
968971
if has_prefill:
969972
# Use chunked prefill to reduce memory footprint
970973
if metadata.indexer_prefill_chunks is not None:
@@ -1099,9 +1102,7 @@ def weight_scale(self, hidden_states: torch.Tensor,
10991102
q_scale: torch.Tensor) -> torch.Tensor:
11001103
weights = indexer_weights if indexer_weights is not None else self.weights_proj(
11011104
hidden_states)
1102-
weights = weights.unsqueeze(-1) * q_scale * self.weight_scale_factor
1103-
# output weights is guaranteed to be float32 due to type promotion from q_scale (float32)
1104-
weights = weights.squeeze(-1)
1105+
weights = _scale(weights, q_scale, self.weight_scale_factor)
11051106
return weights
11061107

11071108
@torch.inference_mode()
@@ -1170,7 +1171,15 @@ def _prep_q_or_k(qk_pe, qk_nope):
11701171
q_fp8 = q_fp8.view(-1, self.n_heads, self.head_dim)
11711172
q_scale = q_scale.view(-1, self.n_heads, 1)
11721173

1173-
weights = self.weight_scale(hidden_states, indexer_weights, q_scale)
1174+
weights, _ = maybe_execute_in_parallel(
1175+
lambda: self.weight_scale(hidden_states, indexer_weights, q_scale),
1176+
lambda: self._update_k_cache(
1177+
k_fp8, k_scale, metadata), # store k_fp8 and k_scale in k cache
1178+
self.ln_events[0],
1179+
self.ln_events[1],
1180+
self.aux_stream,
1181+
)
1182+
11741183
# Return topk indices buffer for sparse attention [num_tokens, index_topk]
11751184
return self.sparse_attn_indexer(metadata, hidden_states, q_fp8, k_fp8,
11761185
k_scale, weights)

tensorrt_llm/_torch/modules/attention.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ..model_config import ModelConfig
2424
from ..peft.lora.layer import LoraLayer, LoraModuleType
2525
from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs,
26-
is_piecewise_running, is_torch_compiling)
26+
is_torch_compiling, maybe_compile)
2727
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
2828
from .multi_stream_utils import maybe_execute_in_parallel
2929
from .rms_norm import RMSNorm
@@ -76,17 +76,6 @@ def extract_extra_attrs(layer_idx: str, attn_type: str):
7676
return metadata, attn_layer
7777

7878

79-
def maybe_compile(func):
80-
81-
def wrapper(*args, **kwargs):
82-
if is_piecewise_running():
83-
# When piecewise running, we don't need to compile the function to avoid host overhead in attention op.
84-
return func(*args, **kwargs)
85-
return torch.compile(func)(*args, **kwargs)
86-
87-
return wrapper
88-
89-
9079
@maybe_compile
9180
def maybe_compiled_copy_(dst, src):
9281
dst.copy_(src)

tensorrt_llm/_torch/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,3 +325,17 @@ def get_device_uuid(device_idx: int) -> str:
325325
property = torch.cuda.get_device_properties(device_idx)
326326
uuid = "GPU-" + str(property.uuid)
327327
return uuid
328+
329+
330+
def maybe_compile(func=None, **compile_kwargs):
331+
332+
def decorator(f):
333+
334+
def wrapper(*args, **kwargs):
335+
if is_piecewise_running():
336+
return f(*args, **kwargs)
337+
return torch.compile(f, **compile_kwargs)(*args, **kwargs)
338+
339+
return wrapper
340+
341+
return decorator(func) if func else decorator

0 commit comments

Comments
 (0)