Skip to content

Commit f3dbbb8

Browse files
committed
Add docstring and typehint
Signed-off-by: Chang Liu (Enterprise Products) <[email protected]>
1 parent 7d6a551 commit f3dbbb8

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,8 @@ def update_for_spec_dec(self):
574574

575575

576576
@maybe_compile(dynamic=True)
577-
def _scale(weights, q_scale, s):
577+
def _scale(weights: torch.Tensor, q_scale: torch.Tensor,
578+
s: float) -> torch.Tensor:
578579
return weights * q_scale.squeeze(-1) * s
579580

580581

tensorrt_llm/_torch/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,15 @@ def get_device_uuid(device_idx: int) -> str:
328328

329329

330330
def maybe_compile(func=None, **compile_kwargs):
331+
"""
332+
Conditionally compile a function with torch.compile.
333+
If is_piecewise_running() is True, the function will be compiled with torch.compile.
334+
Args:
335+
func: The function to decorate (optional, for direct decoration).
336+
**compile_kwargs: Keyword arguments for torch.compile.
337+
Returns:
338+
The conditionally compiled function..
339+
"""
331340

332341
def decorator(f):
333342

0 commit comments

Comments
 (0)