-
Notifications
You must be signed in to change notification settings - Fork 757
[JAX] Fix: Use jitted kernels for generating THD (and BSHD) segment pos #2823
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
076f831
1e45c66
a915db9
09a2f43
73a8818
25cae4d
56eda2a
12b240c
86038f6
759028f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -854,15 +854,10 @@ def from_seqlens_and_offsets( | |
| def from_segment_ids_and_pos( | ||
| cls, | ||
| segment_ids: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], | ||
| segment_pos: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]] = None, | ||
| *, | ||
| is_thd: bool, | ||
| is_segment_ids_reordered: bool, | ||
| segment_pos: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you have time, it may be nice to have the following:
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added in 86038f6 |
||
| ) -> SequenceDescriptor: | ||
|
Comment on lines
854
to
858
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| """ | ||
| Experimental factory method for inputs with segment IDs and optional positions. | ||
| segment_pos = None to be used only for: BSHD with or without load balancing and, | ||
| THD without load balancing | ||
| Experimental factory method for inputs with segment IDs and positions. | ||
| Args: | ||
| segment_ids(Tuple(jnp.ndarray, jnp.ndarray)) = (q_segment_ids, kv_segment_ids): | ||
| - q_segment_ids (jnp.ndarray): | ||
|
|
@@ -876,88 +871,27 @@ def from_segment_ids_and_pos( | |
| The position inside each segment for query, with shape [batch, max_seqlen]. | ||
| - kv_segment_pos (jnp.ndarray): | ||
| The position inside each segment for key, value, with shape [batch, max_seqlen]. | ||
| is_thd(bool): If True, QKVLayout is of type THD, else it is BSHD | ||
| is_segment_ids_reordered(bool): If True, the segment ids have been reordered for load balancing. | ||
| Only THD with load balancing is expected to have this flag set to True | ||
| Return: | ||
| A SequenceDescriptor with segment_ids/segment_pos initialized. | ||
| """ | ||
| # Examples (0 in segment_ids means padding): | ||
| # THD (three segments packed together in a sequence of length 16 with no intra-segment padding): | ||
| # segment_ids = [1, 1, 1, 2, 2, 3, 3, 3, 3, 3, 0, 0, 0, 0, 0, 0] | ||
| # segment_pos = [0, 1, 2, 0, 1, 0, 1, 2, 3, 4, 0, 0, 0, 0, 0, 0] | ||
| # THD (three segments packed together in a sequence of length 16 with intra-segment padding): | ||
| # segment_ids = [1, 1, 1, 2, 2, 3, 3, 3, 0, 0, 4, 4, 0, 0, 0, 0] | ||
| # segment_pos = [0, 1, 2, 0, 1, 0, 1, 2, 3, 4, 0, 1, 0, 0, 0, 0] | ||
| # BSHD (only one segment per sequence): | ||
| # segment_ids = [1, 1, 1, 1, 1, 1, 1, 0, 0] | ||
| # segment_pos = [0, 1, 2, 3, 4, 5, 6, 7, 8] | ||
| # For an example of how to generate the segment_ids and segment_pos, | ||
| # see tests/jax/test_fused_attn.py `generate_random_segment_ids_and_pos() and `generate_valid_segment_ids_and_pos()` | ||
| q_seg_ids, kv_seg_ids = cls._expand_to_pair(segment_ids) | ||
|
|
||
| # Using defaults : segment pos has to be generated. | ||
| if segment_pos is None: | ||
| # THD + load balanced segment_ids are not supported in this function | ||
| # BSHD + load balanced segment_ids are incorrect as BSHD handles reordering within the primitive itself | ||
| if is_segment_ids_reordered: | ||
| assert not is_thd, ( | ||
| f"{segment_pos=} default arg is not supported for load balanced reordered" | ||
| " (Striped) THD inputs. Please pass the load balanced reordered segment_pos" | ||
| " and segment_ids explicitly to {from_segment_ids_and_pos.__qualname__}" | ||
| " using convenience function reorder_causal_load_balancing()" | ||
| ) | ||
| assert is_thd, ( | ||
| f"{segment_pos=} default arg is not supported for load balanced reordered (Dual" | ||
| " Chunk) BSHD inputs. BSHD segment_pos and segment_ids do not need to be load" | ||
| " balanced reordered. The reordering for these is performed within the" | ||
| " primitive" | ||
| ) | ||
|
|
||
| # Generate the default pos for THD and BSHD non-reordered segment_ids | ||
| def generate_default_pos(seg_ids): | ||
| if is_thd: | ||
| batch_size, seq_size = seg_ids.shape | ||
| # Assume that the first token belongs to a segment and is not a padded token | ||
| first_is_segment = jnp.full((batch_size, 1), True, dtype=bool) | ||
| # Get segment start positions | ||
| segment_start = jnp.concatenate( | ||
| [ | ||
| first_is_segment, | ||
| (seg_ids[..., 1:] != seg_ids[..., :-1]) & (seg_ids[..., 1:] != 0), | ||
| ], | ||
| axis=-1, | ||
| ) | ||
| # Get offset for location where new segment starts | ||
| segment_start_idx = jax.vmap(lambda row: jnp.arange(row.size) * row)( | ||
| segment_start | ||
| ) | ||
| segment_start_offsets = jax.vmap(jnp.maximum.accumulate)(segment_start_idx) | ||
|
|
||
| # Get the last non-zero index - after this everything is padding | ||
| # (B,) | ||
| last_nonzero_idx = jax.vmap( | ||
| lambda segids_row: jnp.max( | ||
| jnp.where(segids_row != 0, jnp.arange(seq_size), -1) | ||
| ) | ||
| )(seg_ids) | ||
| seg_pos_no_thd = jnp.arange(seq_size) | ||
| # Get a mask which can be used to zero out all the padding at the end (after the non-zero index) | ||
| mask = seg_pos_no_thd <= last_nonzero_idx[:, None] | ||
|
|
||
| # Get the unmasked seg_pos for the THD sequence | ||
| seg_pos = ( | ||
| jnp.broadcast_to(jnp.arange(seq_size), seg_ids.shape) | ||
| - segment_start_offsets | ||
| ) | ||
|
|
||
| # Use the mask to zero out the padding at the end (after the non-zero index) | ||
| segment_pos = jax.vmap( | ||
| lambda pos_row, mask_row: jnp.where(mask_row, pos_row, 0) | ||
| )(seg_pos, mask) | ||
| return segment_pos | ||
|
|
||
| seqlen = seg_ids.shape[-1] | ||
| return jnp.broadcast_to(jnp.arange(seqlen), seg_ids.shape) | ||
|
|
||
| q_seg_pos = generate_default_pos(q_seg_ids) | ||
| kv_seg_pos = generate_default_pos(kv_seg_ids) | ||
| segment_pos = (q_seg_pos, kv_seg_pos) | ||
| # Explicitly passed segment_pos | ||
| else: | ||
| segment_pos = cls._expand_to_pair(segment_pos) | ||
| q_seg_pos, kv_seg_pos = cls._expand_to_pair(segment_pos) | ||
|
|
||
| return cls( | ||
| segment_ids=(q_seg_ids, kv_seg_ids), | ||
| segment_pos=segment_pos, | ||
| segment_pos=(q_seg_pos, kv_seg_pos), | ||
| ) | ||
|
|
||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.