Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
2d2e33a
Fix incorrect calculation of segment pos from segment ids for thd cas…
KshitijLakhani Dec 16, 2025
65e6b4b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 16, 2025
9857577
Correct the assert condition
KshitijLakhani Dec 17, 2025
b20ac22
Modify fused attn tests to pass new args to from_segment_ids_and_pos()
KshitijLakhani Dec 17, 2025
03398a4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 17, 2025
0a47eb6
Calculate seg ids before pos
KshitijLakhani Dec 17, 2025
217ea58
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 17, 2025
ca9d3bc
1. Change the signature for from_segment_ids_and_pos()
Dec 23, 2025
0ee40a5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 23, 2025
ceec1ea
Pass keyword-only args by name
Dec 23, 2025
ab00bb0
nit: Fix typo to use seg_ids instead of segment_ids
Dec 23, 2025
059c48d
nit: Fix comments
Dec 23, 2025
d524ad6
Modify the function call to differentiate between load balancing and …
Dec 23, 2025
d419f98
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 23, 2025
3efa504
Fix the is_segment_ids_reordered to be set only when CP and load bala…
KshitijLakhani Dec 24, 2025
e65062c
Fix comments for from_segment_ids_and_pos()
KshitijLakhani Dec 24, 2025
74a352e
Code clean up
pre-commit-ci[bot] Dec 24, 2025
e5381fb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 24, 2025
879afb4
Merge branch 'main' into klakhani/fix/incorrect-sequence-descr-from-s…
ksivaman Dec 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,8 @@ def generate_random_segment_ids(
self.cp_reorder_fn(self.segment_pos_q),
self.cp_reorder_fn(self.segment_pos_kv),
),
is_thd=self.qkv_layout.is_thd(),
is_load_balanced=self.cp_size > 1 and self.cp_load_balanced,
)
case _:
raise ValueError(f"Unknown {self.seq_desc_format=}")
Expand Down Expand Up @@ -704,6 +706,8 @@ def generate_random_segment_ids(
self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
(self.segment_ids_q, self.segment_ids_kv),
None,
is_thd=self.qkv_layout.is_thd(),
is_load_balanced=self.cp_size > 1 and self.cp_load_balanced,
)
case _:
raise ValueError(f"Unknown {self.seq_desc_format=}")
Expand Down
72 changes: 65 additions & 7 deletions transformer_engine/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,14 +791,19 @@ def from_seqlens_and_offsets(
q_offsets, kv_offsets = cls._expand_to_pair(seq_offsets)
return cls(seqlens=(q_seqlens, kv_seqlens), seq_offsets=(q_offsets, kv_offsets))

# TODO(KshitijLakhani), TODO(mgoldfarb-nvidia): Consider adding support for THD layout (non load balanced).
@classmethod
def from_segment_ids_and_pos(
cls,
segment_ids: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]],
segment_pos: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]] = None,
*,
is_thd: bool,
is_load_balanced: bool,
) -> SequenceDescriptor:
"""
Experimental factory method for inputs with segment IDs and optional positions. (THD)
Experimental factory method for inputs with segment IDs and optional positions.
segment_pos = None to be used only for : BSHD without load balancing
Args:
segment_ids(Tuple(jnp.ndarray, jnp.ndarray)) = (q_segment_ids, kv_segment_ids):
- q_segment_ids (jnp.ndarray):
Expand All @@ -812,22 +817,75 @@ def from_segment_ids_and_pos(
The position inside each segment for query, with shape [batch, max_seqlen].
- kv_segment_pos (jnp.ndarray):
The position inside each segment for key, value, with shape [batch, max_seqlen].
is_thd(bool): If True, QKVLayout is of type THD, else it is BSHD
is_load_balanced(bool): If True, CP is being used and the inputs have been load balanced.
Return:
A SequenceDescriptor with segment_ids/segment_pos initialized.
"""
q_seg_ids, kv_seg_ids = cls._expand_to_pair(segment_ids)

if segment_pos is not None:
segment_pos = cls._expand_to_pair(segment_pos)
else:
# Using defaults. segment pos has to be generated.
if segment_pos is None:
# Segment pos is not calculated implicitly Load balancing cases
assert not is_load_balanced, (
f"{segment_pos=} default arg is not supported for load balanced inputs. Please pass"
" the load balanced segment_pos and segment_ids using helper function"
" reorder_causal_load_balancing()"
)

def generate_default_pos(segment_ids):
seqlen = segment_ids.shape[-1]
return jnp.broadcast_to(jnp.arange(seqlen), segment_ids.shape)
def generate_default_pos(seg_ids):
if is_thd:
batch_size, seq_size = seg_ids.shape
# Assume that the first token belongs to a segment and is not a padded token
first_is_segment = jnp.full((batch_size, 1), True, dtype=bool)
# Get segment start positions
segment_start = jnp.concatenate(
[
first_is_segment, # First valid element starts a segment
(seg_ids[..., 1:] != seg_ids[..., :-1]) & (seg_ids[..., 1:] != 0),
],
axis=-1,
)
# Get offset for location where new segment starts
segment_start_idx = jax.vmap(lambda row: jnp.arange(row.size) * row)(
segment_start
)
segment_start_offsets = jax.vmap(lambda row: jnp.maximum.accumulate(row))(
segment_start_idx
)

# Get the last non-zero index - after this everything is padding
# (B,)
last_nonzero_idx = jax.vmap(
lambda segids_row: jnp.max(
jnp.where(segids_row != 0, jnp.arange(seq_size), -1)
)
)(seg_ids)
seg_pos_no_thd = jnp.arange(seq_size)
# Get a mask which can be used to zero out all the padding at the end (after the non-zero index)
mask = seg_pos_no_thd <= last_nonzero_idx[:, None]

# Get the unmasked seg_pos for the THD sequence
seg_pos = (
jnp.broadcast_to(jnp.arange(seq_size), seg_ids.shape)
- segment_start_offsets
)

# Use the mask to zero out the padding at the end (after the non-zero index)
segment_pos = jax.vmap(
lambda pos_row, mask_row: jnp.where(mask_row, pos_row, 0)
)(seg_pos, mask)
return segment_pos
else:
seqlen = segment_ids.shape[-1]
return jnp.broadcast_to(jnp.arange(seqlen), segment_ids.shape)

q_seg_pos = generate_default_pos(q_seg_ids)
kv_seg_pos = generate_default_pos(kv_seg_ids)
segment_pos = (q_seg_pos, kv_seg_pos)
# Explicitly passed segment_pos
else:
segment_pos = cls._expand_to_pair(segment_pos)

return cls(
segment_ids=(q_seg_ids, kv_seg_ids),
Expand Down
Loading