[JAX] Fix norm workspace on global shapes#3085
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Greptile SummaryThis PR fixes a false-positive cuDNN error that occurred when a norm's global tensor shape exceeded INT32_MAX, by removing the workspace-size query from the outer abstract (which operates on global shapes). The fix collapses
Confidence Score: 1/5Not safe to merge — the file has a Python syntax error that prevents the module from importing. Lines 332–338 of NormFwdPrimitive.impl insert an incomplete inner_primitive.bind( call (ending with norm_type=norm_type,) directly before the existing tuple-unpacking statement. Python never sees a closing parenthesis for that first call, so the entire remainder of the function body — including a return statement — is parsed as arguments to it, resulting in a SyntaxError. The module cannot be imported, making every JAX norm path completely broken. transformer_engine/jax/cpp_extensions/normalization.py — NormFwdPrimitive.impl (lines 332–365) must be fixed before this file can be imported. Important Files Changed
Reviews (2): Last reviewed commit: "Apply suggestions from code review" | Re-trigger Greptile |
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
|
/te-ci L1 jax |
Description
When calling TE/JAX's norm, the cuDNN norm workspace size is queried twice, once for the outer abstract on the global shapes, once on the inner abstract on the local shard's shape.
A bug was noticed where if the global shape exceeded INT32_MAX, cuDNN would raise an error when we queried the workspace size with the global shape, despite us never actually executing the cuDNN graph on the outer shape.
This PR removes this workspace querying in the outer primitive as it was unnecessary and removing it fixes the false-positive error reporting in these case described above.
Type of change
Changes
Checklist: