Skip to content

[JAX] Fix norm workspace on global shapes#3085

Draft
jberchtold-nvidia wants to merge 4 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/norm-workspace-size-fix
Draft

[JAX] Fix norm workspace on global shapes#3085
jberchtold-nvidia wants to merge 4 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/norm-workspace-size-fix

Conversation

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

Description

When calling TE/JAX's norm, the cuDNN norm workspace size is queried twice, once for the outer abstract on the global shapes, once on the inner abstract on the local shard's shape.

A bug was noticed where if the global shape exceeded INT32_MAX, cuDNN would raise an error when we queried the workspace size with the global shape, despite us never actually executing the cuDNN graph on the outer shape.

This PR removes this workspace querying in the outer primitive as it was unnecessary and removing it fixes the false-positive error reporting in these case described above.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Avoid querying workspace size of cuDNN norm on global shape in outer abstract

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jun 4, 2026

Greptile Summary

This PR fixes a false-positive cuDNN error that occurred when a norm's global tensor shape exceeded INT32_MAX, by removing the workspace-size query from the outer abstract (which operates on global shapes). The fix collapses outer_abstract into abstract via an is_outer flag, returning early (without workspace) when called on the outer primitive.

  • NormFwdPrimitive.abstract and NormBwdPrimitive.abstract now accept is_outer; when True they return the 7/3 output avals immediately, skipping the cuDNN workspace query that triggered the error on large global shapes.
  • NormFwdPrimitive.impl introduces a Python syntax error (lines 332–338): an incomplete bind() call is prepended before the existing tuple-unpacking assignment, producing code that cannot be parsed. The module will fail to import.
  • sharded_impl correctly passes is_outer=False so the inner primitive always allocates 8/4 output buffers to match what the C++ FFI kernel actually writes.

Confidence Score: 1/5

Not safe to merge — the file has a Python syntax error that prevents the module from importing.

Lines 332–338 of NormFwdPrimitive.impl insert an incomplete inner_primitive.bind( call (ending with norm_type=norm_type,) directly before the existing tuple-unpacking statement. Python never sees a closing parenthesis for that first call, so the entire remainder of the function body — including a return statement — is parsed as arguments to it, resulting in a SyntaxError. The module cannot be imported, making every JAX norm path completely broken.

transformer_engine/jax/cpp_extensions/normalization.py — NormFwdPrimitive.impl (lines 332–365) must be fixed before this file can be imported.

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/normalization.py Refactors NormFwdPrimitive and NormBwdPrimitive to skip workspace allocation for the outer abstract when is_outer=True; introduces a Python SyntaxError in NormFwdPrimitive.impl (lines 332–338) that prevents the module from loading.

Reviews (2): Last reviewed commit: "Apply suggestions from code review" | Re-trigger Greptile

Comment thread transformer_engine/jax/cpp_extensions/normalization.py
Comment thread transformer_engine/jax/cpp_extensions/normalization.py
jberchtold-nvidia and others added 2 commits June 4, 2026 15:46
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
@jberchtold-nvidia jberchtold-nvidia marked this pull request as draft June 4, 2026 22:52
Comment thread transformer_engine/jax/cpp_extensions/normalization.py
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 jax

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant