Skip to content

[JAX] Fix: Use jitted kernels for generating THD (and BSHD) segment pos#2823

Merged
KshitijLakhani merged 10 commits into
NVIDIA:mainfrom
KshitijLakhani:klakhani/fix/perf-regression-from_segment_ids_and_pos
Apr 3, 2026
Merged

[JAX] Fix: Use jitted kernels for generating THD (and BSHD) segment pos#2823
KshitijLakhani merged 10 commits into
NVIDIA:mainfrom
KshitijLakhani:klakhani/fix/perf-regression-from_segment_ids_and_pos

Conversation

@KshitijLakhani

@KshitijLakhani KshitijLakhani commented Apr 1, 2026

Copy link
Copy Markdown
Collaborator

Description

This is a follow up to PR #2523.
PR #2523 made an API change to the user facing convenience method to create a SequenceDescriptor: from_segment_ids_and_pos.
However, what we've observed is that:

As there are multiple combinations: thd/bshd, cp/reordering, with/without segment_pos which can be difficult to support in an easy manner and because users usually tend to generate their own segment_pos, there isn't much value in continuing to support the ability to generate segment_pos from within from_segment_ids_and_posand the complexities associated with it.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

This PR does two things:

Comments point to test code where

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@KshitijLakhani KshitijLakhani self-assigned this Apr 1, 2026
@KshitijLakhani

Copy link
Copy Markdown
Collaborator Author

/te-ci L0

…only segment id is passed

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
… SequenceDescriptor mandatory

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…its.

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/fix/perf-regression-from_segment_ids_and_pos branch from 98a4c72 to a915db9 Compare April 2, 2026 21:34
pre-commit-ci Bot and others added 2 commits April 2, 2026 21:35
@KshitijLakhani KshitijLakhani marked this pull request as ready for review April 2, 2026 22:03
@KshitijLakhani KshitijLakhani removed the performance Performance issues label Apr 2, 2026
@greptile-apps

greptile-apps Bot commented Apr 2, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR simplifies SequenceDescriptor.from_segment_ids_and_pos by removing the auto-generation of segment_pos and the is_thd/is_segment_ids_reordered keyword arguments, making segment_pos a required input. The test helpers are updated to always produce explicit segment_pos (now consistently int32) for both BSHD and THD paths, and the conditional CP/reorder logic that gated None vs explicit segment_pos is removed.

Confidence Score: 5/5

Safe to merge; the simplification is sound and the dtype concern from prior review has been resolved.

All previously raised concerns (dtype mismatch, breaking-change label) have been addressed. The only remaining finding is a minor typo in the error message (P2). No logic or correctness issues were found.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/jax/attention.py Removes complex auto-generation of segment_pos and the is_thd/is_segment_ids_reordered keyword args; now raises ValueError when segment_pos is None (with a TODO to make it a required arg by June 2026). Minor typo in the error message.
tests/jax/test_fused_attn.py Renames gen_valid → generate_valid_segment_ids_and_pos and generate_random_segment_ids → generate_random_segment_ids_and_pos; explicitly passes segment_pos (now int32) for both BSHD and THD paths; removes the conditional CP/reorder logic that selected None vs explicit segment_pos.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["from_segment_ids_and_pos(segment_ids, segment_pos)"] --> B{segment_pos is None?}
    B -- Yes --> C["raise ValueError\n(segment_pos now required)"]
    B -- No --> D["_expand_to_pair(segment_ids)\n→ q_seg_ids, kv_seg_ids"]
    D --> E["_expand_to_pair(segment_pos)\n→ q_seg_pos, kv_seg_pos"]
    E --> F["return SequenceDescriptor(\n  segment_ids=(q, kv),\n  segment_pos=(q, kv)\n)"]
Loading

Reviews (3): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines 854 to 858
def from_segment_ids_and_pos(
cls,
segment_ids: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]],
segment_pos: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]] = None,
*,
is_thd: bool,
is_segment_ids_reordered: bool,
segment_pos: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]],
) -> SequenceDescriptor:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Breaking API change not flagged in checklist

from_segment_ids_and_pos previously accepted segment_pos=None (optional) plus two keyword-only parameters is_thd: bool and is_segment_ids_reordered: bool. All three have been removed, making segment_pos a required positional argument. Any caller passing is_thd=True or is_segment_ids_reordered=False will receive a TypeError at runtime, and any caller relying on the segment_pos=None default path for BSHD or non-load-balanced THD will break silently (missing required argument). The PR checklist marks "Breaking change" as unchecked, which is incorrect for this function, even with its "Experimental" label.

Comment thread tests/jax/test_fused_attn.py
KshitijLakhani and others added 3 commits April 2, 2026 16:06
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…nsistent with THD when setting up test inputs

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Comment thread tests/jax/test_fused_attn.py
Comment thread transformer_engine/jax/attention.py Outdated
*,
is_thd: bool,
is_segment_ids_reordered: bool,
segment_pos: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]],

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have time, it may be nice to have the following:

def from_segment_ids_and_pos(
    cls,
    ...,
    segment_pos: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]] = None,
):
    assert segment_pos is not None, "segment_pos is now required as the previous automatic generation of segment_pos did not have the proper context to generate correct a correct segment_pos to account for all load-balancing and context-parallelism strategies the user may be using when calling TransformerEngine's DotProductAttention."

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in 86038f6

KshitijLakhani and others added 2 commits April 2, 2026 16:46

@jberchtold-nvidia jberchtold-nvidia left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@KshitijLakhani

Copy link
Copy Markdown
Collaborator Author

CI passed for all jobs (L0, L1, L2) except a few which were related to CI node issues
Safe to merge

@KshitijLakhani KshitijLakhani merged commit 9d77dcb into NVIDIA:main Apr 3, 2026
2 checks passed
@KshitijLakhani KshitijLakhani deleted the klakhani/fix/perf-regression-from_segment_ids_and_pos branch April 3, 2026 06:07
KshitijLakhani added a commit that referenced this pull request Apr 3, 2026
…os (#2823)

* Fix: Use jitted kernels for generating THD (and BSHD) segment pos if only segment id is passed

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* Make passing of segment_pos to from_segmet_ids_and_pos for creating a SequenceDescriptor mandatory

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* Make test changes for from_segmet_ids_and_pos API change. Also some nits.

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* nit: Make segment_pos arg mandatory and not Optional

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* Add comments for from_segment_ids_and_pos

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* nit: Change data types for BSHD seg pos and seg id to be int32 adn consistent with THD when setting up test inputs

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Replace a TypeError if segment_pos is not passed with a ValueError with a message

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
@KshitijLakhani

KshitijLakhani commented Apr 3, 2026

Copy link
Copy Markdown
Collaborator Author

NOTE: The title of the PR is not rightly frames (which also reflects in the PR merge commit)
"[JAX] Fix: Use jitted kernels for generating THD (and BSHD) segment pos" was the initial scope, but should have been revised to "[JAX] Fix: Force user to pass segment_pos in from_segment_ids_and_pos"
Latter is more appropriate as that was the final scope of the PR

faradawn pushed a commit to faradawn/TransformerEngine that referenced this pull request May 14, 2026
…os (NVIDIA#2823)

* Fix: Use jitted kernels for generating THD (and BSHD) segment pos if only segment id is passed

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* Make passing of segment_pos to from_segmet_ids_and_pos for creating a SequenceDescriptor mandatory

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* Make test changes for from_segmet_ids_and_pos API change. Also some nits.

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* nit: Make segment_pos arg mandatory and not Optional

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* Add comments for from_segment_ids_and_pos

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* nit: Change data types for BSHD seg pos and seg id to be int32 adn consistent with THD when setting up test inputs

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Replace a TypeError if segment_pos is not passed with a ValueError with a message

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants