[JAX] Fix incorrect calculation of segment pos from segment ids in user-facing API#2523
Conversation
|
/te-ci jax L0 L1 |
|
/te-ci jax L0 L1 |
Greptile SummaryFixed incorrect segment position calculation in Key changes:
Breaking change: Existing users must now provide Confidence Score: 1/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant SequenceDescriptor
participant generate_default_pos
User->>SequenceDescriptor: from_segment_ids_and_pos(segment_ids, segment_pos=None, is_thd, is_segment_ids_reordered)
alt segment_pos is None
SequenceDescriptor->>SequenceDescriptor: Check if is_segment_ids_reordered
alt is_segment_ids_reordered = True
SequenceDescriptor->>SequenceDescriptor: assert not is_thd (line 833)
SequenceDescriptor->>SequenceDescriptor: assert is_thd (line 839)
Note right of SequenceDescriptor: BUG: Contradictory assertions!<br/>Will always fail
end
SequenceDescriptor->>generate_default_pos: Call for q_seg_ids
alt is_thd = True
generate_default_pos->>generate_default_pos: Find segment boundaries
generate_default_pos->>generate_default_pos: Calculate segment_start_offsets
generate_default_pos->>generate_default_pos: Compute positions relative to segment start
generate_default_pos->>generate_default_pos: Apply padding mask
generate_default_pos-->>SequenceDescriptor: Return THD segment_pos
else is_thd = False (BSHD)
generate_default_pos->>generate_default_pos: Use simple arange
generate_default_pos-->>SequenceDescriptor: Return BSHD segment_pos
end
SequenceDescriptor->>generate_default_pos: Call for kv_seg_ids
generate_default_pos-->>SequenceDescriptor: Return kv segment_pos
else segment_pos provided
SequenceDescriptor->>SequenceDescriptor: Expand to pair
end
SequenceDescriptor-->>User: Return SequenceDescriptor with segment_ids and segment_pos
|
There was a problem hiding this comment.
Additional Comments (1)
-
transformer_engine/jax/attention.py, line 841-842 (link)logic:
q_seg_idsandkv_seg_idsare used here but not defined until line 847. This will cause aNameErrorat runtime whensegment_posisNone.
2 files reviewed, 1 comment
|
/te-ci jax L0 L1 |
There was a problem hiding this comment.
Additional Comments (1)
-
transformer_engine/jax/attention.py, line 829-832 (link)style: f-strings in assert messages won't be evaluated until assertion fails
2 files reviewed, 1 comment
…es and load balanced cases in from_segment_ids_and_pos. Enforce passing of segment_pos for THD cases and lod balanced cases Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
499f4de to
217ea58
Compare
There was a problem hiding this comment.
Additional Comments (1)
-
transformer_engine/jax/attention.py, line 833-837 (link)logic: warning is shown even for the valid case (BSHD without load balancing)
this will cause users to see warnings every time they correctly use
segment_pos=Nonewith BSHD layout
2 files reviewed, 1 comment
2. Add support for THD in from_segment_ids_and_pos() 3. Assert if load balanced segment_ids is passed to generate a segment_pos Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Additional Comments (3)
-
transformer_engine/jax/attention.py, line 880-881 (link)syntax:
segment_idsis undefined, should beseg_ids(the parameter name on line 836) -
tests/jax/test_fused_attn.py, line 680-681 (link)syntax:
is_thdandis_load_balancedare keyword-only parameters (note the*on line 800 inattention.py), so they must be passed as keyword arguments -
tests/jax/test_fused_attn.py, line 709-710 (link)syntax:
is_thdandis_load_balancedare keyword-only parameters (note the*on line 800 inattention.py), so they must be passed as keyword arguments
2 files reviewed, 3 comments
Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
There was a problem hiding this comment.
Additional Comments (3)
-
tests/jax/test_fused_attn.py, line 706-711 (link)logic: passing
segment_pos=Nonewithis_load_balanced=Truewill trigger assertion error onattention.py:830the BSHD + load balanced + SegmentIDs combination should either:
- pass explicit segment_pos (like THD does on line 677-678), or
- skip this format when load balanced
currently only works because default
mesh_shape=(1,1,1)meanscp_size=1 -
transformer_engine/jax/attention.py, line 794 (link)style: misleading TODO - THD layout (non load balanced) IS now supported by the implementation below (lines 837-878)
-
transformer_engine/jax/attention.py, line 839-840 (link)style: assumption not enforced - if sequence starts with padding (
seg_ids[0]=0),first_is_segment=Truewill incorrectly mark it as segment startconsider asserting
seg_ids[:, 0] != 0or handle leading padding explicitly
2 files reviewed, 3 comments
Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
…actually reordered segment_ids and segment_pos Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
for more information, see https://pre-commit.ci
…ncing Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
There was a problem hiding this comment.
Additional Comments (2)
-
transformer_engine/jax/attention.py, line 833-844 (link)logic: contradictory assertions will always fail when
is_segment_ids_reordered=Trueline 833 asserts
not is_thd(meaning is_thd must be False)
line 839 assertsis_thd(meaning is_thd must be True)both assertions cannot be satisfied simultaneously, causing the function to always fail when
segment_pos=Noneandis_segment_ids_reordered=True -
transformer_engine/jax/attention.py, line 850-851 (link)style: assumption that first token is always a valid segment may be incorrect if THD format allows leading padding
if
seg_ids[0, 0] == 0(first token is padding), this assumption causes incorrect segment_pos calculation. verify THD format guarantees no leading padding or add validation
2 files reviewed, 2 comments
for more information, see https://pre-commit.ci Fix lint errors Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
4101df5 to
74a352e
Compare
for more information, see https://pre-commit.ci
|
/te-ci jax L0 L1 |
There was a problem hiding this comment.
Additional Comments (3)
-
transformer_engine/jax/attention.py, line 833-844 (link)logic: contradictory assertions on lines 833 and 839 - both will execute when
is_segment_ids_reordered=True, but line 833 requiresis_thd=Falsewhile line 839 requiresis_thd=Truethis means when
is_segment_ids_reordered=True:- if
is_thd=True: line 833 fails with "not supported for load balanced reordered (Striped) THD" - if
is_thd=False: line 839 fails with "not supported for load balanced reordered (Dual Chunk) BSHD"
based on line 822 comments and PR description, only THD with load balancing should set
is_segment_ids_reordered=True, so line 839's assertion should be removed (it contradicts the intended behavior) - if
-
transformer_engine/jax/attention.py, line 856 (link)style: potential edge case:
(seg_ids[..., 1:] != 0)condition may not handle all segment transitions correctlyconsider sequence:
[1, 1, 0, 2, 2]where 0 is padding in the middle- transition from 1→0 at index 2:
(1 != 0) & (0 != 0)= False (correctly not marked as segment start) - transition from 0→2 at index 3:
(0 != 2) & (2 != 0)= True (marked as segment start)
however, the assumption on line 850 is that "the first token belongs to a segment and is not a padded token", which conflicts with the possibility of having padding in the middle. verify whether THD format allows mid-sequence padding, and if so, this logic needs adjustment. does THD format allow padding (segment_id=0) in the middle of a sequence, or is padding only at the end?
- transition from 1→0 at index 2:
-
tests/jax/test_fused_attn.py, line 687-689 (link)style: simplify boolean expression
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
2 files reviewed, 3 comments
|
/te-ci jax L0 L1 |
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
LGTM pending CI, thanks!
|
/te-ci jax L0 L1 |
There was a problem hiding this comment.
Additional Comments (2)
-
transformer_engine/jax/attention.py, line 833-844 (link)logic: contradictory assertions will always fail when
is_segment_ids_reordered=TrueLine 833 asserts
not is_thd(THD must be False), but line 839 assertsis_thd(THD must be True). Whenis_segment_ids_reordered=True, the first assertion will pass only ifis_thd=False, but then the second assertion will always fail because it requiresis_thd=True.Based on the comments and documentation (line 822: "Only THD with load balancing is expected to have this flag set to True"), the correct logic should be:
-
tests/jax/test_fused_attn.py, line 671-672 (link)style: comment doesn't match implementation logic
Comment says "if no CP and load balancing", but the code on line 683 does the opposite - it explicitly passes
segment_poswhencp_size > 1 and self.cp_load_balanced(i.e., WITH CP and load balancing), and passesNoneotherwise.
2 files reviewed, 2 comments
|
@ksivaman this PR has a breaking change (FYI for TE release notes) |
|
@mgoldfarb-nvidia @huanghua1994 @mingxu1067 - even though this PR has been merged, it would be great to get your feedback on this, in case you spot something that needs to be reverted. Thanks ! |
TransformerEngine v2.12 (NVIDIA/TransformerEngine#2523) made `is_thd` and `is_segment_ids_reordered` required keyword arguments on `SequenceDescriptor.from_segment_ids_and_pos()` to fix incorrect segment position calculation for THD layouts. Since the packing branch in `cudnn_flash_attention` uses `qkv_layout="THD_THD_THD"` with standard (non-reordered) segment IDs, the correct values are `is_thd=True, is_segment_ids_reordered=False`. Without this fix, any configuration using `attention="cudnn_flash_te"` with `packing=True` and real data (`dataset_type != "synthetic"`) fails with: TypeError: SequenceDescriptor.from_segment_ids_and_pos() missing 2 required keyword-only arguments: 'is_thd' and 'is_segment_ids_reordered'
The existing test_gpu_packed_attention uses dataset_type=synthetic, which bypasses SequenceDescriptor.from_segment_ids_and_pos() entirely (takes the elif branch in cudnn_flash_attention). This meant the TE v2.12 API breakage went undetected. Add test_gpu_packed_attention_hf that uses HF parquet data with packing=True + attention=cudnn_flash_te, exercising the actual SequenceDescriptor codepath. This serves as a regression test for NVIDIA/TransformerEngine#2523.
Description
SequenceDescriptor's
from_segment_ids_and_pos()accepts thesegment_idsand an optionalsegment_posas input. This class is supposed to serve as a convenience method to do two things:segment_idsandsegment_posin a SequenceDescriptor object for TE to use downstreamsegment_posis not passed, then calculate/extrapolate itIn it's current form, the second functionality gives incorrect results for THD + non-reordered and THD + reordered cases as it merely uses an
arangeto calculate thesegment_posnaively. This could result in incorrect masking for these cases.Type of change
Changes
This PR makes few changes:
from_segment_ids_and_pos():is_thdandis_segment_ids_reordered- the only cases that this function can currently guarantee to support is BSHD with and without load balancing and, THD without load balancing.from_segment_ids_and_pos(). However, if thesegment_posare reordered and passed tofrom_segment_ids_and_pos()it will assertfrom_segment_ids_and_pos(), it will assertfrom_segment_ids_and_pos()Impact on user of the API:
is_thdandis_segment_ids_reorderedare not Optional and hence they will cause aTypeErrorfor current users of this API - a breaking change. However, this is needed to ensure correct usage of this APIsegment_idsare reordered or not. It is expected that thesegment_idspassed will be reordered only for THD load balancing. For all other cases thesegment_idsshould not be reorderedChecklist: