Skip to content

Commit 2f08aa3

Browse files
[None][feat] AutoDeploy: Perf improvement for mamba layers (NVIDIA#8991)
Signed-off-by: Chenghao Zhang <[email protected]> Signed-off-by: Suyog Gupta <[email protected]> Co-authored-by: Suyog Gupta <[email protected]>
1 parent 68e98ed commit 2f08aa3

File tree

2 files changed

+7
-13
lines changed

2 files changed

+7
-13
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,9 @@ def _cuda_cached_causal_conv1d(
204204

205205
if y_dec.dim() == 3:
206206
y_dec = y_dec.squeeze(-1)
207-
y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_(y_dec)
207+
y_flat[total_prefill_tokens : total_prefill_tokens + num_decode].copy_(
208+
y_dec.to(y_flat.dtype)
209+
)
208210

209211
# Custom op must not return an alias of any input; return a fresh tensor
210212
return y
@@ -296,9 +298,5 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]:
296298
stride, padding, dilation, groups, padding_mode = extract_op_args(
297299
source_attn_node, "stride", "padding", "dilation", "groups", "padding_mode"
298300
)
299-
# activation parameter may not exist in the source node (added by fusion later)
300-
try:
301-
activation = extract_op_args(source_attn_node, "activation")[0]
302-
except (RuntimeError, IndexError):
303-
activation = None
304-
return [stride, padding, dilation, groups, padding_mode, activation]
301+
# None is for activation parameter, which may not exist in the source node (added by fusion later)
302+
return [stride, padding, dilation, groups, padding_mode, None]

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -355,9 +355,5 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]:
355355
stride, padding, dilation, groups, padding_mode = extract_op_args(
356356
source_attn_node, "stride", "padding", "dilation", "groups", "padding_mode"
357357
)
358-
# activation parameter may not exist in the source node (added by fusion later)
359-
try:
360-
activation = extract_op_args(source_attn_node, "activation")[0]
361-
except (RuntimeError, IndexError):
362-
activation = None
363-
return [stride, padding, dilation, groups, padding_mode, activation]
358+
# None is for activation parameter, which may not exist in the source node (added by fusion later)
359+
return [stride, padding, dilation, groups, padding_mode, None]

0 commit comments

Comments
 (0)