Skip to content

Commit b6aa597

Browse files
committed
modify comment
Signed-off-by: HaochenYuan <haocheny@nvidia.com>
1 parent 2c58118 commit b6aa597

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

  • transformer_engine/pytorch/attention/dot_product_attention

transformer_engine/pytorch/attention/dot_product_attention/backends.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1679,9 +1679,9 @@ def backward(ctx, d_out, *_args):
16791679
rest = [None]
16801680
if ctx.use_FAv2_bwd:
16811681
softmax_lse, rng_state = aux_ctx_tensors
1682-
# During CUDA graph capture, allocate with zeros so the memset is baked into
1683-
# the captured graph and replay buffers start clean. Outside capture, allocate
1684-
# with empty for perf and rely on the explicit tail zero-fill below.
1682+
# Keep capture replay buffers zero-initialized; outside capture, use
1683+
# empty_like to avoid the extra memset. The THD tail zero-fill below
1684+
# clears tail positions in both modes.
16851685
if torch.cuda.is_current_stream_capturing():
16861686
dq = torch.zeros_like(q)
16871687
dk = torch.zeros_like(k)

0 commit comments

Comments
 (0)