[JAX] Fix: Use jitted kernels for generating THD (and BSHD) segment pos#2823
Conversation
|
/te-ci L0 |
…only segment id is passed Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
… SequenceDescriptor mandatory Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…its. Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
98a4c72 to
a915db9
Compare
for more information, see https://pre-commit.ci
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Greptile SummaryThis PR simplifies Confidence Score: 5/5Safe to merge; the simplification is sound and the dtype concern from prior review has been resolved. All previously raised concerns (dtype mismatch, breaking-change label) have been addressed. The only remaining finding is a minor typo in the error message (P2). No logic or correctness issues were found. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["from_segment_ids_and_pos(segment_ids, segment_pos)"] --> B{segment_pos is None?}
B -- Yes --> C["raise ValueError\n(segment_pos now required)"]
B -- No --> D["_expand_to_pair(segment_ids)\n→ q_seg_ids, kv_seg_ids"]
D --> E["_expand_to_pair(segment_pos)\n→ q_seg_pos, kv_seg_pos"]
E --> F["return SequenceDescriptor(\n segment_ids=(q, kv),\n segment_pos=(q, kv)\n)"]
Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| def from_segment_ids_and_pos( | ||
| cls, | ||
| segment_ids: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], | ||
| segment_pos: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]] = None, | ||
| *, | ||
| is_thd: bool, | ||
| is_segment_ids_reordered: bool, | ||
| segment_pos: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], | ||
| ) -> SequenceDescriptor: |
There was a problem hiding this comment.
Breaking API change not flagged in checklist
from_segment_ids_and_pos previously accepted segment_pos=None (optional) plus two keyword-only parameters is_thd: bool and is_segment_ids_reordered: bool. All three have been removed, making segment_pos a required positional argument. Any caller passing is_thd=True or is_segment_ids_reordered=False will receive a TypeError at runtime, and any caller relying on the segment_pos=None default path for BSHD or non-load-balanced THD will break silently (missing required argument). The PR checklist marks "Breaking change" as unchecked, which is incorrect for this function, even with its "Experimental" label.
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…nsistent with THD when setting up test inputs Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
| *, | ||
| is_thd: bool, | ||
| is_segment_ids_reordered: bool, | ||
| segment_pos: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], |
There was a problem hiding this comment.
If you have time, it may be nice to have the following:
def from_segment_ids_and_pos(
cls,
...,
segment_pos: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]] = None,
):
assert segment_pos is not None, "segment_pos is now required as the previous automatic generation of segment_pos did not have the proper context to generate correct a correct segment_pos to account for all load-balancing and context-parallelism strategies the user may be using when calling TransformerEngine's DotProductAttention."
…th a message Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
LGTM, thanks!
|
CI passed for all jobs (L0, L1, L2) except a few which were related to CI node issues |
…os (#2823) * Fix: Use jitted kernels for generating THD (and BSHD) segment pos if only segment id is passed Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Make passing of segment_pos to from_segmet_ids_and_pos for creating a SequenceDescriptor mandatory Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Make test changes for from_segmet_ids_and_pos API change. Also some nits. Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * nit: Make segment_pos arg mandatory and not Optional Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Add comments for from_segment_ids_and_pos Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * nit: Change data types for BSHD seg pos and seg id to be int32 adn consistent with THD when setting up test inputs Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Replace a TypeError if segment_pos is not passed with a ValueError with a message Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
|
NOTE: The title of the PR is not rightly frames (which also reflects in the PR merge commit) |
…os (NVIDIA#2823) * Fix: Use jitted kernels for generating THD (and BSHD) segment pos if only segment id is passed Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Make passing of segment_pos to from_segmet_ids_and_pos for creating a SequenceDescriptor mandatory Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Make test changes for from_segmet_ids_and_pos API change. Also some nits. Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * nit: Make segment_pos arg mandatory and not Optional Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Add comments for from_segment_ids_and_pos Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * nit: Change data types for BSHD seg pos and seg id to be int32 adn consistent with THD when setting up test inputs Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Replace a TypeError if segment_pos is not passed with a ValueError with a message Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Description
This is a follow up to PR #2523.
PR #2523 made an API change to the user facing convenience method to create a SequenceDescriptor:
from_segment_ids_and_pos.However, what we've observed is that:
As there are multiple combinations: thd/bshd, cp/reordering, with/without segment_pos which can be difficult to support in an easy manner and because users usually tend to generate their own segment_pos, there isn't much value in continuing to support the ability to generate
segment_posfrom withinfrom_segment_ids_and_posand the complexities associated with it.Type of change
Changes
This PR does two things:
segment_posmandatory (in addition to thesegment_id) to generate theSequenceDescriptor`Comments point to test code where
Checklist: