-
Notifications
You must be signed in to change notification settings - Fork 757
[PyT] Plumbing correct bias dims from TE to cudnn, while adding support for additional bias shapes #2537
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyT] Plumbing correct bias dims from TE to cudnn, while adding support for additional bias shapes #2537
Changes from 18 commits
ab542fc
c86328e
fddf0ac
d3aa7ec
4d295c4
f4f9cc6
f9f4fb8
d9547fa
7ede1fe
c20d67a
303aee7
e9f88f0
6bf73e1
ebee29b
126be03
7b0f942
f795056
0e74dcf
0acf8f8
2133bd8
89e90a5
5a25d9c
0e2a72f
f066c88
ff174a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -305,8 +305,25 @@ def run_dpa_with_cp( | |||||||||||||||||||||||||
| x.requires_grad = True | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if config.attn_bias_type not in ["no_bias", "alibi"]: | ||||||||||||||||||||||||||
| attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv) | ||||||||||||||||||||||||||
| bias_shape_map = { | ||||||||||||||||||||||||||
| "1hss": (1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv), | ||||||||||||||||||||||||||
| "11ss": (1, 1, config.max_seqlen_q, config.max_seqlen_kv), | ||||||||||||||||||||||||||
| "b1ss": (config.batch_size, 1, config.max_seqlen_q, config.max_seqlen_kv), | ||||||||||||||||||||||||||
| "bhss": ( | ||||||||||||||||||||||||||
| config.batch_size, | ||||||||||||||||||||||||||
| config.num_heads, | ||||||||||||||||||||||||||
| config.max_seqlen_q, | ||||||||||||||||||||||||||
| config.max_seqlen_kv, | ||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||
| "111s": (1, 1, 1, config.max_seqlen_kv), | ||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||
| attn_bias_shape = bias_shape_map.get(config.bias_shape) | ||||||||||||||||||||||||||
| if attn_bias_shape is None: | ||||||||||||||||||||||||||
| assert False, f"cuDNN does not support {config.bias_shape=}" | ||||||||||||||||||||||||||
| bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda() | ||||||||||||||||||||||||||
| # cuDNN does not support dbias calculation for 111s as of cuDNN 9.18 | ||||||||||||||||||||||||||
| # TODO(KshitijLakhani): Set requires_grad to True for all shapes once 111s is supported | ||||||||||||||||||||||||||
| bias.requires_grad = True if config.bias_shape != "111s" else False | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| bias = None | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
@@ -338,7 +355,7 @@ def run_dpa_with_cp( | |||||||||||||||||||||||||
| out.backward(dout_fp8) | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| out.backward(dout) | ||||||||||||||||||||||||||
| dq, dk, dv = q.grad, k.grad, v.grad | ||||||||||||||||||||||||||
| dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad if bias is not None else None | ||||||||||||||||||||||||||
| d_softmax_offset = None | ||||||||||||||||||||||||||
| if config.softmax_type != "vanilla": | ||||||||||||||||||||||||||
| d_softmax_offset = core_attn.softmax_offset.grad | ||||||||||||||||||||||||||
|
|
@@ -389,11 +406,27 @@ def run_dpa_with_cp( | |||||||||||||||||||||||||
| q_, k_, v_ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) | ||||||||||||||||||||||||||
| q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] | ||||||||||||||||||||||||||
| if bias_ is not None: | ||||||||||||||||||||||||||
| bias_ = bias_.view( | ||||||||||||||||||||||||||
| *bias_.shape[:-2], 2 * world_size, bias_.shape[-2] // (2 * world_size), bias_.shape[-1] | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| bias_ = bias_.index_select(2, seq_idx) | ||||||||||||||||||||||||||
| bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) | ||||||||||||||||||||||||||
| ndim = bias_.ndim | ||||||||||||||||||||||||||
| seq_q_dim = ndim - 2 | ||||||||||||||||||||||||||
| if qkv_format == "thd": | ||||||||||||||||||||||||||
| bias_seq_idx = seq_idx_q | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| bias_seq_idx = seq_idx | ||||||||||||||||||||||||||
| shape_before_seq = bias_.shape[:seq_q_dim] | ||||||||||||||||||||||||||
| seq_q_size = bias_.shape[seq_q_dim] | ||||||||||||||||||||||||||
| seq_kv_size = bias_.shape[-1] | ||||||||||||||||||||||||||
| if seq_q_size == 1: | ||||||||||||||||||||||||||
| # TODO(KshitijLakhani): Set to True always once cuDNN supports dbias for 111s | ||||||||||||||||||||||||||
| bias_.requires_grad = False | ||||||||||||||||||||||||||
| # Bias is broadcast, no need to partition along sequence dimension | ||||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| bias_ = bias_.view( | ||||||||||||||||||||||||||
| *shape_before_seq, 2 * world_size, seq_q_size // (2 * world_size), seq_kv_size | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| bias_ = bias_.index_select(seq_q_dim, bias_seq_idx) | ||||||||||||||||||||||||||
| bias_ = bias_.view(*shape_before_seq, -1, seq_kv_size) | ||||||||||||||||||||||||||
| bias_.requires_grad = True | ||||||||||||||||||||||||||
| # set up environment | ||||||||||||||||||||||||||
| core_attn.set_context_parallel_group( | ||||||||||||||||||||||||||
| cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, | ||||||||||||||||||||||||||
|
|
@@ -433,23 +466,27 @@ def run_dpa_with_cp( | |||||||||||||||||||||||||
| out_.backward(dout_fp8_) | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| out_.backward(dout_) | ||||||||||||||||||||||||||
| dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad | ||||||||||||||||||||||||||
| dq_, dk_, dv_, dbias_ = q_.grad, k_.grad, v_.grad, bias_.grad if bias_ is not None else None | ||||||||||||||||||||||||||
| d_softmax_offset_ = None | ||||||||||||||||||||||||||
| if config.softmax_type != "vanilla": | ||||||||||||||||||||||||||
| d_softmax_offset_ = core_attn.softmax_offset.grad.clone() | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # get outputs | ||||||||||||||||||||||||||
| tensors = [out, dq, dk, dv, out_, dq_, dk_, dv_] | ||||||||||||||||||||||||||
| tensors = [out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_] | ||||||||||||||||||||||||||
| if fp8_mha: | ||||||||||||||||||||||||||
| tensors_to_deq = [out, out_] if not fp8_bwd else tensors | ||||||||||||||||||||||||||
| for i, tensor in enumerate(tensors_to_deq): | ||||||||||||||||||||||||||
| tensors_to_deq[i] = tensor.dequantize() | ||||||||||||||||||||||||||
| # dbias/dbias_ could be None, so skip check for it | ||||||||||||||||||||||||||
| if tensor is not None: | ||||||||||||||||||||||||||
| tensors_to_deq[i] = tensor.dequantize() | ||||||||||||||||||||||||||
| if not fp8_bwd: | ||||||||||||||||||||||||||
| tensors[0], tensors[4] = tensors_to_deq | ||||||||||||||||||||||||||
| tensors[0], tensors[5] = tensors_to_deq | ||||||||||||||||||||||||||
| for tensor in tensors: | ||||||||||||||||||||||||||
| assert torch.all(~torch.isnan(tensor)) | ||||||||||||||||||||||||||
| assert torch.all(~torch.isinf(tensor)) | ||||||||||||||||||||||||||
| out, dq, dk, dv, out_, dq_, dk_, dv_ = tensors | ||||||||||||||||||||||||||
| # dbias/dbias_ could be None, so skip check for it | ||||||||||||||||||||||||||
| if tensor is not None: | ||||||||||||||||||||||||||
| assert torch.all(~torch.isnan(tensor)) | ||||||||||||||||||||||||||
| assert torch.all(~torch.isinf(tensor)) | ||||||||||||||||||||||||||
| out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| ############ compare results between CP and no-CP ############ | ||||||||||||||||||||||||||
| if qkv_format == "bshd" or qkv_format == "sbhd": | ||||||||||||||||||||||||||
|
|
@@ -467,6 +504,21 @@ def run_dpa_with_cp( | |||||||||||||||||||||||||
| x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :]) | ||||||||||||||||||||||||||
| for x in [dq_, dk_, dv_, out_] | ||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||
| if dbias is not None and dbias_ is not None: | ||||||||||||||||||||||||||
| ndim = dbias.ndim | ||||||||||||||||||||||||||
| # Query seq is at dim -2 | ||||||||||||||||||||||||||
| seq_q_dim = ndim - 2 | ||||||||||||||||||||||||||
| shape_before_seq = dbias.shape[:seq_q_dim] | ||||||||||||||||||||||||||
| seq_q_size = dbias.shape[seq_q_dim] | ||||||||||||||||||||||||||
| seq_kv_size = dbias.shape[-1] | ||||||||||||||||||||||||||
| # Reshape to split seq_q dimension | ||||||||||||||||||||||||||
| dbias = dbias.view( | ||||||||||||||||||||||||||
| *shape_before_seq, 2 * world_size, seq_q_size // (2 * world_size), seq_kv_size | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| # Index select on the newly created dimension (now at position seq_q_dim) | ||||||||||||||||||||||||||
| dbias = dbias.index_select(seq_q_dim, seq_idx) | ||||||||||||||||||||||||||
| dbias_ = dbias_.view(*shape_before_seq, 2, dbias_.shape[seq_q_dim] // 2, seq_kv_size) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| elif qkv_format == "thd": | ||||||||||||||||||||||||||
| dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] | ||||||||||||||||||||||||||
| dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] | ||||||||||||||||||||||||||
|
|
@@ -509,57 +561,127 @@ def run_dpa_with_cp( | |||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| atol, rtol, rmse_tol = get_tols(config, dtype) | ||||||||||||||||||||||||||
| tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_, max_logit_] | ||||||||||||||||||||||||||
| tensors_no_cp = [out, dq, dk, dv, d_softmax_offset, max_logit] | ||||||||||||||||||||||||||
| names = ["out", "dq", "dk", "dv", "d_softmax_offset", "max_logit"] | ||||||||||||||||||||||||||
| tensors_cp = [out_, dq_, dk_, dv_, dbias_, d_softmax_offset_, max_logit_] | ||||||||||||||||||||||||||
| tensors_no_cp = [out, dq, dk, dv, dbias, d_softmax_offset, max_logit] | ||||||||||||||||||||||||||
| names = ["out", "dq", "dk", "dv", "dbias", "d_softmax_offset", "max_logit"] | ||||||||||||||||||||||||||
| names_cp = [x + "_cp" for x in names] | ||||||||||||||||||||||||||
| names_no_cp = [x + "_no_cp" for x in names] | ||||||||||||||||||||||||||
| is_fp8 = dtype == "fp8" | ||||||||||||||||||||||||||
| for i, t in enumerate(tensors_no_cp): | ||||||||||||||||||||||||||
| if t is not None: | ||||||||||||||||||||||||||
| if "softmax_offset" not in names[i] and "max_logit" not in names[i]: | ||||||||||||||||||||||||||
| if qkv_format == "bshd": | ||||||||||||||||||||||||||
| compare_and_assert( | ||||||||||||||||||||||||||
| t[:, 0], | ||||||||||||||||||||||||||
| tensors_cp[i][:, 0], | ||||||||||||||||||||||||||
| names_no_cp[i], | ||||||||||||||||||||||||||
| names_cp[i], | ||||||||||||||||||||||||||
| atol, | ||||||||||||||||||||||||||
| rtol, | ||||||||||||||||||||||||||
| rmse_tol, | ||||||||||||||||||||||||||
| is_fp8, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| compare_and_assert( | ||||||||||||||||||||||||||
| t[:, 1], | ||||||||||||||||||||||||||
| tensors_cp[i][:, 1], | ||||||||||||||||||||||||||
| names_no_cp[i], | ||||||||||||||||||||||||||
| names_cp[i], | ||||||||||||||||||||||||||
| atol, | ||||||||||||||||||||||||||
| rtol, | ||||||||||||||||||||||||||
| rmse_tol, | ||||||||||||||||||||||||||
| is_fp8, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| # Compare the two sequence chunks separately | ||||||||||||||||||||||||||
| # Compare dbias | ||||||||||||||||||||||||||
| if names[i] == "dbias": | ||||||||||||||||||||||||||
| # After reshaping: (1, 1, 2, seq_q//2, seq_kv) | ||||||||||||||||||||||||||
| # Compare along dimension 2 (the split sequence dimension) | ||||||||||||||||||||||||||
| ndim_bias = t.ndim | ||||||||||||||||||||||||||
| seq_q_dim_bias = ndim_bias - 2 # Query sequence dimension | ||||||||||||||||||||||||||
| # After reshaping both have shape: [..., 2, seq_q//2, seq_kv] | ||||||||||||||||||||||||||
| # The split dimension is at seq_q_dim_bias | ||||||||||||||||||||||||||
| slice_0 = [slice(None)] * ndim_bias | ||||||||||||||||||||||||||
| slice_0[seq_q_dim_bias] = 0 | ||||||||||||||||||||||||||
| slice_1 = [slice(None)] * ndim_bias | ||||||||||||||||||||||||||
| slice_1[seq_q_dim_bias] = 1 | ||||||||||||||||||||||||||
| compare_and_assert( | ||||||||||||||||||||||||||
| t[tuple(slice_0)], # First sequence chunk | ||||||||||||||||||||||||||
| tensors_cp[i][tuple(slice_0)], | ||||||||||||||||||||||||||
| names_no_cp[i], | ||||||||||||||||||||||||||
| names_cp[i], | ||||||||||||||||||||||||||
| atol, | ||||||||||||||||||||||||||
| rtol, | ||||||||||||||||||||||||||
| rmse_tol, | ||||||||||||||||||||||||||
| is_fp8, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| compare_and_assert( | ||||||||||||||||||||||||||
| t[tuple(slice_1)], # First sequence chunk | ||||||||||||||||||||||||||
| tensors_cp[i][tuple(slice_1)], | ||||||||||||||||||||||||||
| names_no_cp[i], | ||||||||||||||||||||||||||
| names_cp[i], | ||||||||||||||||||||||||||
| atol, | ||||||||||||||||||||||||||
| rtol, | ||||||||||||||||||||||||||
| rmse_tol, | ||||||||||||||||||||||||||
| is_fp8, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| # Compare Q/K/V/out | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| # Compare along dimension 1 (the split sequence dimension) | ||||||||||||||||||||||||||
| compare_and_assert( | ||||||||||||||||||||||||||
| t[:, 0], | ||||||||||||||||||||||||||
| tensors_cp[i][:, 0], | ||||||||||||||||||||||||||
| names_no_cp[i], | ||||||||||||||||||||||||||
| names_cp[i], | ||||||||||||||||||||||||||
| atol, | ||||||||||||||||||||||||||
| rtol, | ||||||||||||||||||||||||||
| rmse_tol, | ||||||||||||||||||||||||||
| is_fp8, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| compare_and_assert( | ||||||||||||||||||||||||||
| t[:, 1], | ||||||||||||||||||||||||||
| tensors_cp[i][:, 1], | ||||||||||||||||||||||||||
| names_no_cp[i], | ||||||||||||||||||||||||||
| names_cp[i], | ||||||||||||||||||||||||||
| atol, | ||||||||||||||||||||||||||
| rtol, | ||||||||||||||||||||||||||
| rmse_tol, | ||||||||||||||||||||||||||
| is_fp8, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| elif qkv_format == "sbhd": | ||||||||||||||||||||||||||
| compare_and_assert( | ||||||||||||||||||||||||||
| t[0], | ||||||||||||||||||||||||||
| tensors_cp[i][0], | ||||||||||||||||||||||||||
| names_no_cp[i], | ||||||||||||||||||||||||||
| names_cp[i], | ||||||||||||||||||||||||||
| atol, | ||||||||||||||||||||||||||
| rtol, | ||||||||||||||||||||||||||
| rmse_tol, | ||||||||||||||||||||||||||
| is_fp8, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| compare_and_assert( | ||||||||||||||||||||||||||
| t[1], | ||||||||||||||||||||||||||
| tensors_cp[i][1], | ||||||||||||||||||||||||||
| names_no_cp[i], | ||||||||||||||||||||||||||
| names_cp[i], | ||||||||||||||||||||||||||
| atol, | ||||||||||||||||||||||||||
| rtol, | ||||||||||||||||||||||||||
| rmse_tol, | ||||||||||||||||||||||||||
| is_fp8, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| # Compare the two sequence chunks separately | ||||||||||||||||||||||||||
| # Compare dbias (same as BSHD) | ||||||||||||||||||||||||||
| if names[i] == "dbias": | ||||||||||||||||||||||||||
| # After reshaping: (1, 1, 2, seq_q//2, seq_kv) | ||||||||||||||||||||||||||
| # Compare along dimension 2 (the split sequence dimension) | ||||||||||||||||||||||||||
| ndim_bias = t.ndim | ||||||||||||||||||||||||||
| seq_q_dim_bias = ndim_bias - 2 | ||||||||||||||||||||||||||
| slice_0 = [slice(None)] * ndim_bias | ||||||||||||||||||||||||||
| slice_0[seq_q_dim_bias] = 0 | ||||||||||||||||||||||||||
| slice_1 = [slice(None)] * ndim_bias | ||||||||||||||||||||||||||
| slice_1[seq_q_dim_bias] = 1 | ||||||||||||||||||||||||||
|
Comment on lines
+678
to
+682
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same wrong split dimension for This block is a copy of the The fix mirrors the one for
Suggested change
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch by greptile |
||||||||||||||||||||||||||
| compare_and_assert( | ||||||||||||||||||||||||||
| t[tuple(slice_0)], # First sequence chunk | ||||||||||||||||||||||||||
| tensors_cp[i][tuple(slice_0)], | ||||||||||||||||||||||||||
| names_no_cp[i], | ||||||||||||||||||||||||||
| names_cp[i], | ||||||||||||||||||||||||||
| atol, | ||||||||||||||||||||||||||
| rtol, | ||||||||||||||||||||||||||
| rmse_tol, | ||||||||||||||||||||||||||
| is_fp8, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| compare_and_assert( | ||||||||||||||||||||||||||
| t[tuple(slice_1)], # First sequence chunk | ||||||||||||||||||||||||||
| tensors_cp[i][tuple(slice_1)], | ||||||||||||||||||||||||||
| names_no_cp[i], | ||||||||||||||||||||||||||
| names_cp[i], | ||||||||||||||||||||||||||
| atol, | ||||||||||||||||||||||||||
| rtol, | ||||||||||||||||||||||||||
| rmse_tol, | ||||||||||||||||||||||||||
| is_fp8, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| # Compare Q/K/V/out | ||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||
| # Compare along dimension 0 (the split sequence dimension) | ||||||||||||||||||||||||||
| compare_and_assert( | ||||||||||||||||||||||||||
| t[0], | ||||||||||||||||||||||||||
| tensors_cp[i][0], | ||||||||||||||||||||||||||
| names_no_cp[i], | ||||||||||||||||||||||||||
| names_cp[i], | ||||||||||||||||||||||||||
| atol, | ||||||||||||||||||||||||||
| rtol, | ||||||||||||||||||||||||||
| rmse_tol, | ||||||||||||||||||||||||||
| is_fp8, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| compare_and_assert( | ||||||||||||||||||||||||||
| t[1], | ||||||||||||||||||||||||||
| tensors_cp[i][1], | ||||||||||||||||||||||||||
| names_no_cp[i], | ||||||||||||||||||||||||||
| names_cp[i], | ||||||||||||||||||||||||||
| atol, | ||||||||||||||||||||||||||
| rtol, | ||||||||||||||||||||||||||
| rmse_tol, | ||||||||||||||||||||||||||
| is_fp8, | ||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| elif qkv_format == "thd": | ||||||||||||||||||||||||||
| compare_and_assert( | ||||||||||||||||||||||||||
| t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrong split dimension used in dbias comparison
After the reshape+
index_selectat lines 536–543,dbiashas shape[..., 2, sq//(2*world_size), skv](e.g.[1, H, 2, sq//4, skv]for1hss). At that pointt.ndim == 5, soseq_q_dim_bias = ndim_bias - 2 = 3.But the dimension with exactly 2 elements (the CP split that should be compared chunk-by-chunk) is at index 2, not 3. Index 3 holds
sq//(2*world_size)elements. The current code therefore slices along the inner sub-sequence dimension instead of the CP-half dimension, so errors confined to one CP half (dim 2 == 1) are never independently validated.The fix is to use the same
seq_q_dimvalue that was used during the earlier reshape (computed from the originalndim - 2 = 2):There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch by greptile
Was comparing the dim 3/-2 of the 5d tensor instead of dim 2/-3 of the tensor
Re ran the test and all pass locally