Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 30 additions & 31 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,13 +547,20 @@ def _setup_inputs(self):
else:
self.softmax_offset = None

def gen_valid(bs, max_seqlen, pad_ratio):
def generate_valid_segment_ids_and_pos(bs, max_seqlen, pad_ratio):
pad_len = int(max_seqlen * pad_ratio)
valid_len = max_seqlen - pad_len
tokens = jnp.concatenate([jnp.ones((bs, valid_len)), jnp.zeros((bs, pad_len))], axis=-1)
return tokens, jnp.logical_not(tokens)
tokens = jnp.concatenate(
[
jnp.ones((bs, valid_len), dtype=jnp.int32),
jnp.zeros((bs, pad_len), dtype=jnp.int32),
],
axis=-1,
)
segment_pos = jnp.broadcast_to(jnp.arange(max_seqlen, dtype=jnp.int32), tokens.shape)
return tokens, segment_pos, jnp.logical_not(tokens)
Comment thread
jberchtold-nvidia marked this conversation as resolved.

def generate_random_segment_ids(
def generate_random_segment_ids_and_pos(
batch_size,
sequence_length,
num_segments,
Expand Down Expand Up @@ -601,8 +608,10 @@ def generate_random_segment_ids(
return segment_ids, segment_pos, segment_pad

if self.qkv_layout.is_thd():
self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids(
self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
self.segment_ids_q, self.segment_pos_q, self.pad_q = (
generate_random_segment_ids_and_pos(
self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42
)
)
self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q)
# TODO(rewang): record only self attention and find the reason of cross attention
Expand All @@ -617,22 +626,23 @@ def generate_random_segment_ids(
self.window_size is not None or self.attn_mask_type.is_bottom_right()
): # SWA or BRCM requires kv_len >= q_len
min_segment_len = self.seqlens_q
self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids(
self.batch_size,
self.max_seqlen_kv,
self.num_segments_per_seq,
seed=2024,
min_segment_len=min_segment_len,
self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = (
generate_random_segment_ids_and_pos(
self.batch_size,
self.max_seqlen_kv,
self.num_segments_per_seq,
seed=2024,
min_segment_len=min_segment_len,
)
)
self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv)
else:
self.segment_ids_q, self.pad_q = gen_valid(
self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_valid_segment_ids_and_pos(
self.batch_size, self.max_seqlen_q, pad_ratio
)
self.segment_ids_kv, self.pad_kv = gen_valid(
self.batch_size, self.max_seqlen_kv, pad_ratio
self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = (
generate_valid_segment_ids_and_pos(self.batch_size, self.max_seqlen_kv, pad_ratio)
)
self.segment_pos_q = self.segment_pos_kv = None
self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None

# For reference code
Expand Down Expand Up @@ -682,24 +692,15 @@ def generate_random_segment_ids(
(self.offsets_q, self.offsets_kv),
)
case SeqDescFormat.SegmentIDs:
# Exercise the path to generate the segment_pos in from_segment_ids_and_pos()
# if no CP and load balancing, else explicitly pass the segment_pos
# from_segment_ids_and_pos requires explicit segment_pos.
self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
(
self.cp_reorder_fn(self.segment_ids_q),
self.cp_reorder_fn(self.segment_ids_kv),
),
(
(
self.cp_reorder_fn(self.segment_pos_q),
self.cp_reorder_fn(self.segment_pos_kv),
)
if self.cp_size > 1 and self.cp_load_balanced
else None
),
is_thd=self.qkv_layout.is_thd(),
is_segment_ids_reordered=(
True if self.cp_size > 1 and self.cp_load_balanced else False
self.cp_reorder_fn(self.segment_pos_q),
self.cp_reorder_fn(self.segment_pos_kv),
),
)
case _:
Expand Down Expand Up @@ -727,9 +728,7 @@ def generate_random_segment_ids(
case SeqDescFormat.SegmentIDs:
self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos(
(self.segment_ids_q, self.segment_ids_kv),
None,
is_thd=self.qkv_layout.is_thd(),
is_segment_ids_reordered=False,
(self.segment_pos_q, self.segment_pos_kv),
)
case _:
raise ValueError(f"Unknown {self.seq_desc_format=}")
Expand Down
98 changes: 16 additions & 82 deletions transformer_engine/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,15 +854,10 @@ def from_seqlens_and_offsets(
def from_segment_ids_and_pos(
cls,
segment_ids: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]],
segment_pos: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]] = None,
*,
is_thd: bool,
is_segment_ids_reordered: bool,
segment_pos: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]],

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have time, it may be nice to have the following:

def from_segment_ids_and_pos(
    cls,
    ...,
    segment_pos: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]] = None,
):
    assert segment_pos is not None, "segment_pos is now required as the previous automatic generation of segment_pos did not have the proper context to generate correct a correct segment_pos to account for all load-balancing and context-parallelism strategies the user may be using when calling TransformerEngine's DotProductAttention."

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in 86038f6

) -> SequenceDescriptor:
Comment on lines 854 to 858

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Breaking API change not flagged in checklist

from_segment_ids_and_pos previously accepted segment_pos=None (optional) plus two keyword-only parameters is_thd: bool and is_segment_ids_reordered: bool. All three have been removed, making segment_pos a required positional argument. Any caller passing is_thd=True or is_segment_ids_reordered=False will receive a TypeError at runtime, and any caller relying on the segment_pos=None default path for BSHD or non-load-balanced THD will break silently (missing required argument). The PR checklist marks "Breaking change" as unchecked, which is incorrect for this function, even with its "Experimental" label.

"""
Experimental factory method for inputs with segment IDs and optional positions.
segment_pos = None to be used only for: BSHD with or without load balancing and,
THD without load balancing
Experimental factory method for inputs with segment IDs and positions.
Args:
segment_ids(Tuple(jnp.ndarray, jnp.ndarray)) = (q_segment_ids, kv_segment_ids):
- q_segment_ids (jnp.ndarray):
Expand All @@ -876,88 +871,27 @@ def from_segment_ids_and_pos(
The position inside each segment for query, with shape [batch, max_seqlen].
- kv_segment_pos (jnp.ndarray):
The position inside each segment for key, value, with shape [batch, max_seqlen].
is_thd(bool): If True, QKVLayout is of type THD, else it is BSHD
is_segment_ids_reordered(bool): If True, the segment ids have been reordered for load balancing.
Only THD with load balancing is expected to have this flag set to True
Return:
A SequenceDescriptor with segment_ids/segment_pos initialized.
"""
# Examples (0 in segment_ids means padding):
# THD (three segments packed together in a sequence of length 16 with no intra-segment padding):
# segment_ids = [1, 1, 1, 2, 2, 3, 3, 3, 3, 3, 0, 0, 0, 0, 0, 0]
# segment_pos = [0, 1, 2, 0, 1, 0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0]
# THD (three segments packed together in a sequence of length 16 with intra-segment padding):
# segment_ids = [1, 1, 1, 2, 2, 3, 3, 3, 0, 0, 4, 4, 0, 0, 0, 0]
# segment_pos = [0, 1, 2, 0, 1, 0, 1, 2, 3, 4, 0, 1, 0, 0, 0, 0]
# BSHD (only one segment per sequence):
# segment_ids = [1, 1, 1, 1, 1, 1, 1, 0, 0]
# segment_pos = [0, 1, 2, 3, 4, 5, 6, 7, 8]
# For an example of how to generate the segment_ids and segment_pos,
# see tests/jax/test_fused_attn.py `generate_random_segment_ids_and_pos() and `generate_valid_segment_ids_and_pos()`
q_seg_ids, kv_seg_ids = cls._expand_to_pair(segment_ids)

# Using defaults : segment pos has to be generated.
if segment_pos is None:
# THD + load balanced segment_ids are not supported in this function
# BSHD + load balanced segment_ids are incorrect as BSHD handles reordering within the primitive itself
if is_segment_ids_reordered:
assert not is_thd, (
f"{segment_pos=} default arg is not supported for load balanced reordered"
" (Striped) THD inputs. Please pass the load balanced reordered segment_pos"
" and segment_ids explicitly to {from_segment_ids_and_pos.__qualname__}"
" using convenience function reorder_causal_load_balancing()"
)
assert is_thd, (
f"{segment_pos=} default arg is not supported for load balanced reordered (Dual"
" Chunk) BSHD inputs. BSHD segment_pos and segment_ids do not need to be load"
" balanced reordered. The reordering for these is performed within the"
" primitive"
)

# Generate the default pos for THD and BSHD non-reordered segment_ids
def generate_default_pos(seg_ids):
if is_thd:
batch_size, seq_size = seg_ids.shape
# Assume that the first token belongs to a segment and is not a padded token
first_is_segment = jnp.full((batch_size, 1), True, dtype=bool)
# Get segment start positions
segment_start = jnp.concatenate(
[
first_is_segment,
(seg_ids[..., 1:] != seg_ids[..., :-1]) & (seg_ids[..., 1:] != 0),
],
axis=-1,
)
# Get offset for location where new segment starts
segment_start_idx = jax.vmap(lambda row: jnp.arange(row.size) * row)(
segment_start
)
segment_start_offsets = jax.vmap(jnp.maximum.accumulate)(segment_start_idx)

# Get the last non-zero index - after this everything is padding
# (B,)
last_nonzero_idx = jax.vmap(
lambda segids_row: jnp.max(
jnp.where(segids_row != 0, jnp.arange(seq_size), -1)
)
)(seg_ids)
seg_pos_no_thd = jnp.arange(seq_size)
# Get a mask which can be used to zero out all the padding at the end (after the non-zero index)
mask = seg_pos_no_thd <= last_nonzero_idx[:, None]

# Get the unmasked seg_pos for the THD sequence
seg_pos = (
jnp.broadcast_to(jnp.arange(seq_size), seg_ids.shape)
- segment_start_offsets
)

# Use the mask to zero out the padding at the end (after the non-zero index)
segment_pos = jax.vmap(
lambda pos_row, mask_row: jnp.where(mask_row, pos_row, 0)
)(seg_pos, mask)
return segment_pos

seqlen = seg_ids.shape[-1]
return jnp.broadcast_to(jnp.arange(seqlen), seg_ids.shape)

q_seg_pos = generate_default_pos(q_seg_ids)
kv_seg_pos = generate_default_pos(kv_seg_ids)
segment_pos = (q_seg_pos, kv_seg_pos)
# Explicitly passed segment_pos
else:
segment_pos = cls._expand_to_pair(segment_pos)
q_seg_pos, kv_seg_pos = cls._expand_to_pair(segment_pos)

return cls(
segment_ids=(q_seg_ids, kv_seg_ids),
segment_pos=segment_pos,
segment_pos=(q_seg_pos, kv_seg_pos),
)


Expand Down