Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ab542fc
Plumbing correct bias dims from TE to cudnn
KshitijLakhani Dec 20, 2025
c86328e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 22, 2025
fddf0ac
Make changes for cp bias code
KshitijLakhani Jan 9, 2026
d3aa7ec
Add dbias and dbias_ to run_dpa_with_cp test
KshitijLakhani Jan 9, 2026
4d295c4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 9, 2026
f4f9cc6
Fix: Use output_dBias instead of input_dBias to extract the shape
KshitijLakhani Jan 9, 2026
f9f4fb8
Add guards for bias/bias_/dbias/dbias_ being None
KshitijLakhani Jan 21, 2026
d9547fa
Add support for bias shape 111s in addition to the original 1hss, 11s…
KshitijLakhani Jan 22, 2026
7ede1fe
Add support for dbias calculation and variant packing for the dbias s…
KshitijLakhani Feb 6, 2026
c20d67a
Add support for 111s bias shape in DPA
KshitijLakhani Feb 6, 2026
303aee7
Allow fused attn for dbias calculation for 11ss, b1ss, bhss. Disable …
KshitijLakhani Feb 6, 2026
e9f88f0
Disable requires_grad for bias for shape 111s in tests
KshitijLakhani Feb 6, 2026
6bf73e1
Disable bias grad / training flag for 111s bias in the non-CP attn te…
KshitijLakhani Feb 6, 2026
ebee29b
Fix to correctly create the bias shape tensor instead of the hard cod…
KshitijLakhani Feb 6, 2026
126be03
Add fused attn cp test cases for all supported bias shapes
KshitijLakhani Feb 6, 2026
7b0f942
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2026
f795056
nit: switch to elif for bias grad conditional
KshitijLakhani Feb 13, 2026
0e74dcf
Add CP support for bias/dbias shape 111s
KshitijLakhani Feb 13, 2026
0acf8f8
Add support for is_training in CP attn tests
KshitijLakhani Feb 13, 2026
2133bd8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 13, 2026
89e90a5
nit: Fix incorrect comment
KshitijLakhani Feb 17, 2026
5a25d9c
nit: Fix incorrect comment and assert string
KshitijLakhani Feb 17, 2026
0e2a72f
Create the dbias graph tensor only if it is a cuDNN supported bias shape
KshitijLakhani Feb 18, 2026
f066c88
Fix the dim that is being compared for the two cp chunks in the test
KshitijLakhani Feb 18, 2026
ff174a8
nit: Reinstate the original test for right side swa
KshitijLakhani Feb 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 179 additions & 57 deletions tests/pytorch/attention/run_attention_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,25 @@ def run_dpa_with_cp(
x.requires_grad = True

if config.attn_bias_type not in ["no_bias", "alibi"]:
attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv)
bias_shape_map = {
"1hss": (1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv),
"11ss": (1, 1, config.max_seqlen_q, config.max_seqlen_kv),
"b1ss": (config.batch_size, 1, config.max_seqlen_q, config.max_seqlen_kv),
"bhss": (
config.batch_size,
config.num_heads,
config.max_seqlen_q,
config.max_seqlen_kv,
),
"111s": (1, 1, 1, config.max_seqlen_kv),
}
attn_bias_shape = bias_shape_map.get(config.bias_shape)
if attn_bias_shape is None:
assert False, f"cuDNN does not support {config.bias_shape=}"
bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda()
# cuDNN does not support dbias calculation for 111s as of cuDNN 9.18
# TODO(KshitijLakhani): Set requires_grad to True for all shapes once 111s is supported
bias.requires_grad = True if config.bias_shape != "111s" else False
else:
bias = None

Expand Down Expand Up @@ -338,7 +355,7 @@ def run_dpa_with_cp(
out.backward(dout_fp8)
else:
out.backward(dout)
dq, dk, dv = q.grad, k.grad, v.grad
dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad if bias is not None else None
d_softmax_offset = None
if config.softmax_type != "vanilla":
d_softmax_offset = core_attn.softmax_offset.grad
Expand Down Expand Up @@ -389,11 +406,27 @@ def run_dpa_with_cp(
q_, k_, v_ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer)
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
if bias_ is not None:
bias_ = bias_.view(
*bias_.shape[:-2], 2 * world_size, bias_.shape[-2] // (2 * world_size), bias_.shape[-1]
)
bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
ndim = bias_.ndim
seq_q_dim = ndim - 2
if qkv_format == "thd":
bias_seq_idx = seq_idx_q
else:
bias_seq_idx = seq_idx
shape_before_seq = bias_.shape[:seq_q_dim]
seq_q_size = bias_.shape[seq_q_dim]
seq_kv_size = bias_.shape[-1]
if seq_q_size == 1:
# TODO(KshitijLakhani): Set to True always once cuDNN supports dbias for 111s
bias_.requires_grad = False
# Bias is broadcast, no need to partition along sequence dimension
pass
else:
bias_ = bias_.view(
*shape_before_seq, 2 * world_size, seq_q_size // (2 * world_size), seq_kv_size
)
bias_ = bias_.index_select(seq_q_dim, bias_seq_idx)
bias_ = bias_.view(*shape_before_seq, -1, seq_kv_size)
bias_.requires_grad = True
# set up environment
core_attn.set_context_parallel_group(
cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group,
Expand Down Expand Up @@ -433,23 +466,27 @@ def run_dpa_with_cp(
out_.backward(dout_fp8_)
else:
out_.backward(dout_)
dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad
dq_, dk_, dv_, dbias_ = q_.grad, k_.grad, v_.grad, bias_.grad if bias_ is not None else None
d_softmax_offset_ = None
if config.softmax_type != "vanilla":
d_softmax_offset_ = core_attn.softmax_offset.grad.clone()

# get outputs
tensors = [out, dq, dk, dv, out_, dq_, dk_, dv_]
tensors = [out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_]
if fp8_mha:
tensors_to_deq = [out, out_] if not fp8_bwd else tensors
for i, tensor in enumerate(tensors_to_deq):
tensors_to_deq[i] = tensor.dequantize()
# dbias/dbias_ could be None, so skip check for it
if tensor is not None:
tensors_to_deq[i] = tensor.dequantize()
if not fp8_bwd:
tensors[0], tensors[4] = tensors_to_deq
tensors[0], tensors[5] = tensors_to_deq
for tensor in tensors:
assert torch.all(~torch.isnan(tensor))
assert torch.all(~torch.isinf(tensor))
out, dq, dk, dv, out_, dq_, dk_, dv_ = tensors
# dbias/dbias_ could be None, so skip check for it
if tensor is not None:
assert torch.all(~torch.isnan(tensor))
assert torch.all(~torch.isinf(tensor))
out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors

############ compare results between CP and no-CP ############
if qkv_format == "bshd" or qkv_format == "sbhd":
Expand All @@ -467,6 +504,21 @@ def run_dpa_with_cp(
x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :])
for x in [dq_, dk_, dv_, out_]
]
if dbias is not None and dbias_ is not None:
ndim = dbias.ndim
# Query seq is at dim -2
seq_q_dim = ndim - 2
shape_before_seq = dbias.shape[:seq_q_dim]
seq_q_size = dbias.shape[seq_q_dim]
seq_kv_size = dbias.shape[-1]
# Reshape to split seq_q dimension
dbias = dbias.view(
*shape_before_seq, 2 * world_size, seq_q_size // (2 * world_size), seq_kv_size
)
# Index select on the newly created dimension (now at position seq_q_dim)
dbias = dbias.index_select(seq_q_dim, seq_idx)
dbias_ = dbias_.view(*shape_before_seq, 2, dbias_.shape[seq_q_dim] // 2, seq_kv_size)

elif qkv_format == "thd":
dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]]
dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]]
Expand Down Expand Up @@ -509,57 +561,127 @@ def run_dpa_with_cp(
)

atol, rtol, rmse_tol = get_tols(config, dtype)
tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_, max_logit_]
tensors_no_cp = [out, dq, dk, dv, d_softmax_offset, max_logit]
names = ["out", "dq", "dk", "dv", "d_softmax_offset", "max_logit"]
tensors_cp = [out_, dq_, dk_, dv_, dbias_, d_softmax_offset_, max_logit_]
tensors_no_cp = [out, dq, dk, dv, dbias, d_softmax_offset, max_logit]
names = ["out", "dq", "dk", "dv", "dbias", "d_softmax_offset", "max_logit"]
names_cp = [x + "_cp" for x in names]
names_no_cp = [x + "_no_cp" for x in names]
is_fp8 = dtype == "fp8"
for i, t in enumerate(tensors_no_cp):
if t is not None:
if "softmax_offset" not in names[i] and "max_logit" not in names[i]:
if qkv_format == "bshd":
compare_and_assert(
t[:, 0],
tensors_cp[i][:, 0],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[:, 1],
tensors_cp[i][:, 1],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
# Compare the two sequence chunks separately
# Compare dbias
if names[i] == "dbias":
# After reshaping: (1, 1, 2, seq_q//2, seq_kv)
# Compare along dimension 2 (the split sequence dimension)
ndim_bias = t.ndim
seq_q_dim_bias = ndim_bias - 2 # Query sequence dimension
# After reshaping both have shape: [..., 2, seq_q//2, seq_kv]
# The split dimension is at seq_q_dim_bias
slice_0 = [slice(None)] * ndim_bias
slice_0[seq_q_dim_bias] = 0
slice_1 = [slice(None)] * ndim_bias
slice_1[seq_q_dim_bias] = 1
Comment on lines +624 to +628

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong split dimension used in dbias comparison

After the reshape+index_select at lines 536–543, dbias has shape [..., 2, sq//(2*world_size), skv] (e.g. [1, H, 2, sq//4, skv] for 1hss). At that point t.ndim == 5, so seq_q_dim_bias = ndim_bias - 2 = 3.

But the dimension with exactly 2 elements (the CP split that should be compared chunk-by-chunk) is at index 2, not 3. Index 3 holds sq//(2*world_size) elements. The current code therefore slices along the inner sub-sequence dimension instead of the CP-half dimension, so errors confined to one CP half (dim 2 == 1) are never independently validated.

The fix is to use the same seq_q_dim value that was used during the earlier reshape (computed from the original ndim - 2 = 2):

Suggested change
ndim_bias = t.ndim
seq_q_dim_bias = ndim_bias - 2 # Query sequence dimension
# After reshaping both have shape: [..., 2, seq_q//2, seq_kv]
# The split dimension is at seq_q_dim_bias
slice_0 = [slice(None)] * ndim_bias
slice_0[seq_q_dim_bias] = 0
slice_1 = [slice(None)] * ndim_bias
slice_1[seq_q_dim_bias] = 1
if names[i] == "dbias":
# After reshaping both tensors have shape: [..., 2, seq_q//2, seq_kv]
# The CP-split dimension is at index seq_q_dim (= original ndim - 2 = 2)
split_dim = ndim - 2 # original ndim before reshape, i.e. 2 for [B,H,sq,skv]
slice_0 = [slice(None)] * t.ndim
slice_0[split_dim] = 0
slice_1 = [slice(None)] * t.ndim
slice_1[split_dim] = 1

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch by greptile
Was comparing the dim 3/-2 of the 5d tensor instead of dim 2/-3 of the tensor
Re ran the test and all pass locally

compare_and_assert(
t[tuple(slice_0)], # First sequence chunk
tensors_cp[i][tuple(slice_0)],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[tuple(slice_1)], # First sequence chunk
tensors_cp[i][tuple(slice_1)],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
# Compare Q/K/V/out
else:
# Compare along dimension 1 (the split sequence dimension)
compare_and_assert(
t[:, 0],
tensors_cp[i][:, 0],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[:, 1],
tensors_cp[i][:, 1],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
elif qkv_format == "sbhd":
compare_and_assert(
t[0],
tensors_cp[i][0],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[1],
tensors_cp[i][1],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
# Compare the two sequence chunks separately
# Compare dbias (same as BSHD)
if names[i] == "dbias":
# After reshaping: (1, 1, 2, seq_q//2, seq_kv)
# Compare along dimension 2 (the split sequence dimension)
ndim_bias = t.ndim
seq_q_dim_bias = ndim_bias - 2
slice_0 = [slice(None)] * ndim_bias
slice_0[seq_q_dim_bias] = 0
slice_1 = [slice(None)] * ndim_bias
slice_1[seq_q_dim_bias] = 1
Comment on lines +678 to +682

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same wrong split dimension for sbhd format

This block is a copy of the bshd block above and has the same issue: seq_q_dim_bias = ndim_bias - 2 resolves to dimension 3 of the post-reshape tensor (which has sq//2 elements), while the CP-split dimension (2 elements) is at dimension 2.

The fix mirrors the one for bshd: capture the original ndim - 2 (== 2 for a 4-D bias) and use it as the split axis, rather than recomputing from the already-expanded t.ndim.

Suggested change
ndim_bias = t.ndim
seq_q_dim_bias = ndim_bias - 2
slice_0 = [slice(None)] * ndim_bias
slice_0[seq_q_dim_bias] = 0
slice_1 = [slice(None)] * ndim_bias
slice_1[seq_q_dim_bias] = 1
if names[i] == "dbias":
split_dim = ndim - 2 # original ndim before reshape, i.e. 2 for [B,H,sq,skv]
slice_0 = [slice(None)] * t.ndim
slice_0[split_dim] = 0
slice_1 = [slice(None)] * t.ndim
slice_1[split_dim] = 1

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch by greptile
Was comparing the dim 3/-2 of the 5d tensor instead of dim 2/-3 of the tensor
Re ran the test and all pass locally

compare_and_assert(
t[tuple(slice_0)], # First sequence chunk
tensors_cp[i][tuple(slice_0)],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[tuple(slice_1)], # First sequence chunk
tensors_cp[i][tuple(slice_1)],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
# Compare Q/K/V/out
else:
# Compare along dimension 0 (the split sequence dimension)
compare_and_assert(
t[0],
tensors_cp[i][0],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
compare_and_assert(
t[1],
tensors_cp[i][1],
names_no_cp[i],
names_cp[i],
atol,
rtol,
rmse_tol,
is_fp8,
)
elif qkv_format == "thd":
compare_and_assert(
t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8
Expand Down
22 changes: 19 additions & 3 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,16 @@ def test_dot_product_attention(
)

# Get backends
# For 111s, dbias calculation is not supported as of cuDNN 9.18, hence, test fwd only for 111s.
# For all other shapes test fwd+bwd
is_training = True
# TODO(KshitijLakhani): Set is_training to True for all cases once cuDNN supports dbias for 111s.
if config.bias_shape == "111s":
is_training = False
logging.info(
"Setting is_training to False as cuDNN does not support dbias for"
f" {config.bias_shape=} "
)
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
Expand Down Expand Up @@ -636,7 +645,8 @@ def test_dpa_bias(dtype, model_configs, model):
"bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"),
"bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"),
"bias_1_3": ModelConfig(2, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="bhss"),
"bias_1_4": ModelConfig(
"bias_1_4": ModelConfig(2, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="111s"),
"bias_1_5": ModelConfig(
4,
2048,
24,
Expand All @@ -646,7 +656,7 @@ def test_dpa_bias(dtype, model_configs, model):
bias_shape="1hss",
alibi_type="custom",
),
"bias_1_5": ModelConfig(
"bias_1_6": ModelConfig(
2,
2048,
24,
Expand Down Expand Up @@ -1143,10 +1153,16 @@ def _run_dot_product_attention(
bias = None
if config.attn_bias_type == "post_scale_bias":
shape = "_".join(config.bias_shape)
# For 1hss, 11ss, b1ss, bhss
shape_cache = shape
shape = shape.replace("_s_s", "_sq_skv")
# For 111s
if shape == shape_cache:
shape = shape.replace("_1_s", "_1_skv")
tensor_shape = [dim_to_num[j] for j in shape.split("_")]
bias = torch.randn(tensor_shape, dtype=dtype, device="cuda")
if config.bias_shape != "1hss":
# For 111s, dbias calculation is not supported as of cuDNN 9.18
if config.bias_shape == "111s":
bias.requires_grad = False

# Create RNG
Expand Down
Loading