Skip to content

Newton-Schulz via cuSOLVERMp#2706

Merged
vcherepanov-nv merged 64 commits into
NVIDIA:mainfrom
vcherepanov-nv:newton-schulz
Apr 15, 2026
Merged

Newton-Schulz via cuSOLVERMp#2706
vcherepanov-nv merged 64 commits into
NVIDIA:mainfrom
vcherepanov-nv:newton-schulz

Conversation

@vcherepanov-nv
Copy link
Copy Markdown
Collaborator

Description

Adds an API to call Newton-Schulz method on a distributed tensor.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Integrate cuSOLVERMp as a new dependency
  • Add corresponding API to TE/common
  • Add PyTorch binding and tests

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Feb 25, 2026

Greptile Summary

This PR adds a distributed Newton-Schulz matrix orthogonalization API backed by cuSolverMp, with RAII-managed C++ context lifecycle, a Python CusolverMpCtx class intended to be reused across calls, and a comprehensive distributed test suite using torchrun. The implementation correctly handles the row-major ↔ column-major impedance mismatch between PyTorch tensors and cuSolverMp by transposing the matrix descriptor dimensions, and prior review threads have driven meaningful fixes (RAII handles, dtype/contiguity guards, timeout on subprocess, reference check mode).

Confidence Score: 4/5

Safe to merge after addressing a handful of previously-identified items that remain open (option() placement, num_coefficients dead parameter, devInfo nullptr). The one new finding (Literal type annotation) is minor and non-blocking at runtime.

All P0 concerns from prior threads have been resolved (RAII handles, contiguity/dtype guards, test timeout, row-major transpose trick correctly applied). Several P1/P2 items flagged in prior rounds remain open per the thread (option() placement, num_coefficients unused, devInfo nullptr convergence diagnostics, debug print in build_tools/utils.py), which keeps the score at 4 rather than 5.

transformer_engine/pytorch/newton_schulz.py (Literal annotation), transformer_engine/common/newton_schulz/newton_schulz.cpp (devInfo nullptr, unused num_coefficients — both from prior threads still open)

Important Files Changed

Filename Overview
transformer_engine/pytorch/newton_schulz.py Python binding for distributed Newton-Schulz; adds CusolverMpCtx lifecycle management and coefficient scheduling. Invalid Literal[dict.keys()] annotation on line 63 will fail static type checkers.
transformer_engine/common/newton_schulz/newton_schulz.cpp Core C++ implementation wrapping cuSolverMp; uses RAII for all CUDA/cuSolverMp handles, correctly applies row-major↔column-major transpose trick, and caches device workspace across calls.
transformer_engine/common/include/transformer_engine/newton_schulz.h Public C header exposing cuSolverMp context and Newton-Schulz APIs; unconditionally includes nccl.h (previously discussed trade-off).
transformer_engine/common/CMakeLists.txt Adds cuSolverMp as an optional dependency; option() declared after first use of the newton_schulz source (previously flagged). NCCL_LIB find_library duplicated between cuBLASMp and cuSolverMp blocks but CMake caches the result so it is harmless.
tests/pytorch/distributed/test_newton_schulz.py Pytest harness launching torchrun subprocesses; correctly includes timeout=300 and checks for explicit PASSED/FAILED markers in stdout/stderr.
tests/pytorch/distributed/run_newton_schulz.py Distributed test worker; now correctly validates orthogonality (X @ X.t() ≈ I) and includes a reference check mode. Uses column-scattering with an assert guard for divisibility.
transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp Thin PyTorch ↔ TE tensor bridge for newton_schulz; correctly builds TensorWrapper and threads through the caller CUDA stream.
setup.py Wires NVTE_WITH_CUSOLVERMP env-var and CUSOLVERMP_HOME into CMake flags; straightforward and consistent with cuBLASMp handling.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant CusolverMpCtx
    participant newton_schulz_py as newton_schulz (Python)
    participant tex as transformer_engine_torch (C++ ext)
    participant cpp as nvte_newton_schulz (C++)
    participant cuSolverMp

    Caller->>CusolverMpCtx: CusolverMpCtx(group)
    CusolverMpCtx->>tex: cusolvermp_ctx_create(nccl_comm_ptr, nranks, rank)
    tex->>cpp: nvte_cusolvermp_ctx_create(comm, nranks, rank)
    cpp->>cuSolverMp: cusolverMpCreate + cusolverMpCreateDeviceGrid
    cpp-->>tex: NVTECusolverMpCtx*
    tex-->>CusolverMpCtx: ctx_ptr (int64)

    Caller->>newton_schulz_py: newton_schulz(x_local, ctx, num_iterations, coefficients)
    newton_schulz_py->>newton_schulz_py: validate dtype, contiguity, dims
    newton_schulz_py->>tex: newton_schulz(ctx_ptr, m, n, x, iters, coeffs)
    tex->>cpp: nvte_newton_schulz(ctx, m, n, x_tensor, iters, coeffs, caller_stream)
    cpp->>cpp: record in_ready event on caller_stream
    cpp->>cpp: wait on internal stream
    cpp->>cuSolverMp: cusolverMpNewtonSchulz_bufferSize
    cpp->>cpp: allocate/grow workspace (ncclMemAlloc or cudaMalloc)
    cpp->>cuSolverMp: cusolverMpNewtonSchulz (in-place on x^T)
    cpp->>cpp: record out_ready event, caller stream waits
    cpp-->>newton_schulz_py: return
    newton_schulz_py-->>Caller: x_local modified in-place

    Caller->>CusolverMpCtx: ctx.destroy()
    CusolverMpCtx->>tex: cusolvermp_ctx_destroy(ctx_ptr)
    tex->>cpp: nvte_cusolvermp_ctx_destroy(ctx)
    cpp->>cuSolverMp: free workspace + destroy grid + destroy handle
Loading

Reviews (34): Last reviewed commit: "Couple num_iterations with coeff types i..." | Re-trigger Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

15 files reviewed, 15 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +93 to +98
# Check: if X = A^{-1/2}, then X @ A @ X should be the identity matrix
if rank == 0:
XXT = X @ X.t()
I = torch.eye(N, device=XXT.device, dtype=XXT.dtype)
max_diff = (XXT - I).abs().max().item()
print(f"Max |X @ X.t() - I|: {max_diff:.6e}", flush=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

verification doesn't match the comment - if X = A^{-1/2}, the check should be X @ A @ X ≈ I, not X @ X.t() ≈ I. The current check verifies X is orthogonal, not that X is the inverse square root of A. Note that A_orig is created on line 76 but never used.

Suggested change
# Check: if X = A^{-1/2}, then X @ A @ X should be the identity matrix
if rank == 0:
XXT = X @ X.t()
I = torch.eye(N, device=XXT.device, dtype=XXT.dtype)
max_diff = (XXT - I).abs().max().item()
print(f"Max |X @ X.t() - I|: {max_diff:.6e}", flush=True)
# Check: if X = A^{-1/2}, then X @ A @ X should be the identity matrix
XAX = X @ A_orig @ X
I = torch.eye(N, device=XAX.device, dtype=XAX.dtype)
max_diff = (XAX - I).abs().max().item()
print(f"Max |X @ A @ X - I|: {max_diff:.6e}", flush=True)
if torch.allclose(XAX, I, atol=args.atol, rtol=args.rtol):

Comment on lines +31 to +32
nccl_backend = group._get_backend(torch.device("cuda"))
return nccl_backend._comm_ptr()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uses private PyTorch APIs (_get_backend, _comm_ptr) that may change in future versions

Comment on lines +39 to +58
quintic_coefficients = [
4.0848,
-6.8946,
2.9270,
3.9505,
-6.3029,
2.6377,
3.7418,
-5.5913,
2.3037,
2.8769,
-3.1427,
1.2046,
2.8366,
-3.0525,
1.2012,
]
coefficients = (
quintic_coefficients if args.num_iterations == 5 else [1.5, -0.5, 0.0] * args.num_iterations
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

coefficients mismatch with API defaults - test uses 15 coefficients for 5 iterations, but newton_schulz.py defaults to 5 coefficients. This inconsistency means default API behavior isn't tested.

Comment on lines +8 to +12
* \brief Functions for distributed Newton-Schulz inverse square root.
*
* This API is a TE-native binding to the cuSolverMp library.
* It computes an iterative Newton-Schulz inverse square root
* approximation on a distributed matrix.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Documentation claims this computes "inverse square root" but the test validates orthogonality (X @ X.t() ≈ I), and commit dd1dd0b states "it approximates orthogonal matrix, not inverse square root". If this computes the polar decomposition (orthogonal factor), the documentation should be updated to reflect that. Inverse square root would satisfy X @ A @ X ≈ I, which is different from orthogonality.

Comment on lines +79 to +80
m = x.size(0) * nranks # rows are distributed across ranks
n = x.size(1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assumes rows are evenly distributed (m = x.size(0) * nranks) but doesn't validate this. If matrix size isn't divisible by nranks, the computed global size m will be incorrect, leading to wrong results from cuSOLVERMp. Consider adding validation:

Suggested change
m = x.size(0) * nranks # rows are distributed across ranks
n = x.size(1)
# Global matrix dimensions
# Rows must be evenly distributed across ranks
local_rows = x.size(0)
m = local_rows * nranks
n = x.size(1)

Then add a validation check that all ranks have the same local_rows via dist.all_reduce.

num_iterations: int = 5,
coefficients: Optional[List[float]] = None,
) -> None:
"""Compute Newton-Schulz inverse square root in-place on a distributed matrix.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring says "inverse square root" but test checks orthogonality. Update to match actual behavior (see comment on header file).

Comment thread transformer_engine/common/newton_schulz/newton_schulz.cpp
Comment thread transformer_engine/pytorch/__init__.py Outdated
from transformer_engine.pytorch import optimizers
from transformer_engine.pytorch.export import onnx_export
from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy
from transformer_engine.pytorch.newton_schulz import newton_schulz
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unconditional import of optional feature

newton_schulz is unconditionally imported and exported as part of the public API, even when TE is built without NVTE_WITH_CUSOLVERMP. While the function itself raises a runtime error when called, this exposes the symbol to all users and makes it appear as a supported feature in auto-complete and docs. Consider guarding this import behind a check (similar to how other optional features are handled), or at minimum adding a note in the docstring that the function requires NVTE_WITH_CUSOLVERMP=1 at build time.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment thread transformer_engine/common/newton_schulz/newton_schulz.cpp
Comment on lines +82 to +86
ctx_ptr = tex.cusolvermp_ctx_create(nccl_comm_ptr, nranks, rank)
try:
tex.newton_schulz(ctx_ptr, m, n, x, num_iterations, coefficients)
finally:
tex.cusolvermp_ctx_destroy(ctx_ptr)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Context created/destroyed per call wastes resources

A new NVTECusolverMpCtx is created and destroyed on every invocation of newton_schulz. Context creation involves cudaStreamCreate, two cudaEventCreate calls, cusolverMpCreate, and cusolverMpCreateDeviceGrid — all of which are heavyweight operations. And since the context is destroyed afterward, the grow-only workspace caching in the C++ layer (lines 170-177 of newton_schulz.cpp) is never actually reused.

Consider caching the context (e.g., in a module-level dict keyed by (nccl_comm_ptr, nranks, rank)) and reusing it across calls, or exposing the context lifecycle to callers so they can amortize the cost when calling newton_schulz repeatedly in a training loop.

Comment on lines +65 to +67
assert (
len(coefficients) == num_iterations * 3
), f"Unexpected number of coefficients: {len(coefficients)} for {num_iterations} iterations"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use ValueError instead of assert for validation - assert can be disabled with Python's -O flag

Suggested change
assert (
len(coefficients) == num_iterations * 3
), f"Unexpected number of coefficients: {len(coefficients)} for {num_iterations} iterations"
if len(coefficients) != num_iterations * 3:
raise ValueError(
f"Unexpected number of coefficients: {len(coefficients)} for {num_iterations} iterations"
)

Comment thread transformer_engine/pytorch/newton_schulz.py
Comment thread transformer_engine/pytorch/newton_schulz.py
Comment thread tests/pytorch/distributed/test_newton_schulz.py Outdated
PATHS ${CUSOLVERMP_DIR}
PATH_SUFFIXES lib
REQUIRED)
target_link_libraries(transformer_engine PUBLIC ${CUSOLVERMP_LIB})
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PUBLIC linkage exposes cuSOLVERMp to all downstream consumers of transformer_engine library. Since newton_schulz.h doesn't expose cuSOLVERMp types in the public API, PRIVATE linkage would provide better encapsulation (consumers don't need cuSOLVERMp at link time).

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment thread transformer_engine/common/newton_schulz/newton_schulz.cpp Outdated
Comment thread transformer_engine/common/newton_schulz/newton_schulz.cpp
Comment thread transformer_engine/common/newton_schulz/newton_schulz.cpp
Comment thread build_tools/utils.py Outdated
f"'nm' failed on {lib_path} (exit code {e.returncode}):\n{e.stderr}"
) from e

return symbol in result.stdout
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Substring match can produce false positives

symbol in result.stdout does a plain substring search over the entire nm output. If the library ever contains a symbol that has the target symbol as a prefix (e.g. nvte_cusolvermp_ctx_create_with_options), this check will incorrectly return True and enable cuSolverMp support in the PyTorch extension even though the real symbol is absent.

The nm -D output format is <address> <type> <name>\n per line. A safer check is to match against word boundaries:

import re
return bool(re.search(r'\b' + re.escape(symbol) + r'\b', result.stdout))

or equivalently test for whitespace/line boundaries:

return any(line.split()[-1] == symbol for line in result.stdout.splitlines() if line.strip())

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Mar 4, 2026

Additional Comments (1)

transformer_engine/common/CMakeLists.txt
option() declared after first use

The option(NVTE_WITH_CUSOLVERMP ...) declaration appears ~12 lines later in the file (at the target_link_libraries block), but this if (NVTE_WITH_CUSOLVERMP) guard relies on it. While this works when -DNVTE_WITH_CUSOLVERMP=ON is passed on the command line (the cache entry is set before CMake processes any option()), a developer who adds a second reference to NVTE_WITH_CUSOLVERMP above the option() call — or who processes the file without the cache — will get confusing behaviour.

CMake convention is to declare option() before first use. Moving the option() declaration to just above the first if (NVTE_WITH_CUSOLVERMP) block (line 281) would improve robustness and readability.

Comment thread build_tools/utils.py Outdated
Comment thread transformer_engine/common/CMakeLists.txt Outdated
#ifndef TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_
#define TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_

#include <nccl.h>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unconditional #include <nccl.h> in a public header.

newton_schulz.h is installed as a public header (under include/transformer_engine/). The unconditional #include <nccl.h> means that any downstream project that includes this header — even one with no interest in Newton-Schulz — now requires NCCL in its include path.

ncclComm_t is only used in the function signatures of nvte_cusolvermp_ctx_create and nvte_newton_schulz, which are themselves only meaningful when NVTE_WITH_CUSOLVERMP is defined. Guarding the include and the declarations together would prevent the leakage:

Suggested change
#include <nccl.h>
#ifdef NVTE_WITH_CUSOLVERMP
#include <nccl.h>
// ... struct and function declarations ...
#endif // NVTE_WITH_CUSOLVERMP

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also a little uneasy at exposing NCCL as a required dependency, but I see we already import the NCCL header elsewhere in the codebase:

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use NCCL types (ncclComm_t) in nvte_ctx_create API, so if we're taking the route of always defining the functions - we need to pull NCCL headers unconditionally.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Mar 4, 2026

Additional Comments (3)

build_tools/utils.py, line 95
FileNotFoundError propagates to the build system

When libtransformer_engine.so cannot be located, common_lib_has_symbol raises FileNotFoundError. The call-site in pytorch.py is a bare if common_lib_has_symbol(...) with no try/except, so any missing-library condition aborts the entire PyTorch extension build rather than gracefully treating it as "symbol absent → no cuSOLVERMp".

This scenario is realistic: the common library might be built in a non-standard path, or a developer may be doing an incremental build of only the PyTorch extension. The function should return False (with an optional warning) instead of raising:

    if lib_path is None:
        import warnings
        warnings.warn(
            "Could not find libtransformer_engine.so; cuSOLVERMp auto-detection disabled.",
            stacklevel=2,
        )
        return False

The subprocess.CalledProcessError path (nm failure) has the same problem and should similarly return False or at most warn.


tests/pytorch/distributed/test_newton_schulz.py, line 252
Missing skip guard for builds without cuSOLVERMp

The file correctly skips when fewer than 2 GPUs are present, but there is no skip when TE was built without NVTE_WITH_CUSOLVERMP=1. In that case, newton_schulz raises RuntimeError("newton_schulz requires Transformer Engine to be built with NVTE_WITH_CUSOLVERMP=1"), which will appear in CI as a test failure rather than a skip.

A simple guard at module level (or inside the test) would mirror the GPU-count check:

import transformer_engine_torch as tex
if not hasattr(tex, "newton_schulz"):
    pytest.skip(
        "Newton-Schulz tests require TE built with NVTE_WITH_CUSOLVERMP=1.",
        allow_module_level=True,
    )

transformer_engine/pytorch/newton_schulz.py, line 813
Undocumented fallback coefficients for num_iterations != 5

When coefficients is None and num_iterations != 5, the function silently falls back to [1.5, -0.5, 0.0] * num_iterations. These are generic cubic Newton-Schulz coefficients and are likely far less accurate than the carefully tuned QUINTIC_COEFFICIENTS. Nothing in the public docstring tells callers about this behaviour.

If 15-iteration runs use the test's custom coefficients (which are different from both defaults), the fallback path is also never exercised by the current tests with the default API — making the generic path untested in practice.

Consider either:

  • Documenting that "optimised coefficients are only bundled for num_iterations=5; for other values you must supply coefficients explicitly", and turning the fallback into a ValueError, or
  • Bundling a second set of optimised coefficients for the 15-iteration case that matches what the test actually uses.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Mar 4, 2026

Additional Comments (12)

qa/L1_pytorch_distributed_unittest/test.sh, line 35
Newton-Schulz test runs unconditionally in QA

The test is added without any guard for NVTE_WITH_CUSOLVERMP. If the CI runner is not built with cuSolverMp, this subprocess will fail with a RuntimeError, causing the entire L1 QA job to fail. Other optional features are guarded in their test paths.

Add a conditional:

if [ "${NVTE_WITH_CUSOLVERMP:-0}" == "1" ]; then
    python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_newton_schulz.xml $TE_PATH/tests/pytorch/distributed/test_newton_schulz.py || test_fail "test_newton_schulz.py"
fi

tests/pytorch/distributed/test_newton_schulz.py, line 15
Missing skip guard for cuSolverMp availability

The file skips when fewer than 2 GPUs are available, but not when Transformer Engine is built without NVTE_WITH_CUSOLVERMP=1. On such builds, every test fails inside the subprocess with RuntimeError, producing confusing stderr messages.

Add a module-level skip check:

import transformer_engine_torch as tex
if not hasattr(tex, "newton_schulz"):
    pytest.skip(
        "TE not built with NVTE_WITH_CUSOLVERMP=1; skipping Newton-Schulz tests.",
        allow_module_level=True,
    )

build_tools/pytorch.py, line 95
Silent default for CUSOLVERMP_HOME inconsistent with NVSHMEM pattern

The cuSolverMp block silently defaults to "/usr" when CUSOLVERMP_HOME is unset. The NVSHMEM block asserts that NVSHMEM_HOME is explicitly set, providing a clear error message. If the library is not installed under /usr/include and /usr/lib, the build fails with a generic linker error rather than a clear message about the missing environment variable.

Align with the NVSHMEM pattern:

if bool(int(os.getenv("NVTE_WITH_CUSOLVERMP", "0"))):
    assert (
        os.getenv("CUSOLVERMP_HOME") is not None
    ), "CUSOLVERMP_HOME must be set when compiling with NVTE_WITH_CUSOLVERMP=1"
    cusolvermp_home = Path(os.getenv("CUSOLVERMP_HOME"))
    include_dirs.append(cusolvermp_home / "include")
    library_dirs.append(cusolvermp_home / "lib")
    libraries.append("cusolverMp")
    cxx_flags.append("-DNVTE_WITH_CUSOLVERMP")

transformer_engine/pytorch/newton_schulz.py, line 21
Uses private PyTorch APIs that may change

Lines 20-21 use _get_backend() and _comm_ptr(), which are private PyTorch APIs (underscore prefix indicates internal/unstable). These can change in future PyTorch versions, breaking this code.

Consider using public APIs or documenting this dependency clearly in comments, noting that this code may need updates with new PyTorch releases.


transformer_engine/pytorch/newton_schulz.py, line 67
Use ValueError instead of assert for validation

The assertion on line 65 validates user input. Assertions can be disabled with Python's -O flag, silently allowing invalid inputs. Use ValueError for user-facing validation:

if len(coefficients) != num_iterations * 3:
    raise ValueError(
        f"Unexpected number of coefficients: {len(coefficients)} for {num_iterations} iterations"
    )

transformer_engine/pytorch/newton_schulz.py, line 72
Missing tensor memory layout validation

The C++ code calls data_ptr() which requires contiguous memory. Non-contiguous tensors will cause silent incorrect results. Add a contiguity check before validation:

if not x.is_contiguous():
    raise ValueError("Input tensor must be contiguous (C-order)")

Also add dtype validation since the docstring specifies float32 or bfloat16:

if x.dtype not in (torch.float32, torch.bfloat16):
    raise ValueError(f"Input tensor must be float32 or bfloat16, got {x.dtype}")

tests/pytorch/distributed/test_newton_schulz.py, line 37
Missing subprocess timeout

The distributed test subprocess has no timeout. If it deadlocks (e.g., NCCL communication issue), the test will block indefinitely, hanging the CI job. Add a timeout:

result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False, timeout=300)

transformer_engine/common/CMakeLists.txt, line 317
PUBLIC linkage exposes cuSOLVERMp to all downstream consumers

cuSOLVERMp is linked with PUBLIC visibility, meaning all projects depending on Transformer Engine must have it in their link path, even those that don't use Newton-Schulz. Since the public API doesn't expose cuSOLVERMp types, PRIVATE linkage would provide better encapsulation:

target_link_libraries(transformer_engine PRIVATE ${CUSOLVERMP_LIB})

transformer_engine/common/CMakeLists.txt, line 230
CMake option declared after first use

The option(NVTE_WITH_CUSOLVERMP) is declared at line 308, but used at line 230. Every other optional feature declares the option immediately before using it. While this works when -DNVTE_WITH_CUSOLVERMP=ON is passed on the command line, it breaks the established pattern and could surprise developers adding follow-on logic.

Move the option() declaration to just before line 230.


transformer_engine/common/include/transformer_engine/newton_schulz.h, line 17
Unconditional #include <nccl.h> in public header

newton_schulz.h is installed as a public header. The unconditional #include <nccl.h> forces NCCL into every downstream project's include path, even those with no interest in Newton-Schulz. Since ncclComm_t is only used in functions meaningful when NVTE_WITH_CUSOLVERMP is defined, guarding both the include and declarations would prevent this leakage:

#ifdef NVTE_WITH_CUSOLVERMP
#include <nccl.h>
...
NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int rank);
void nvte_newton_schulz(...);
#endif

transformer_engine/pytorch/__init__.py, line 62
Unconditional import of optional feature

newton_schulz is unconditionally imported and exported as public API, even when TE is built without NVTE_WITH_CUSOLVERMP. While the function raises a runtime error when called, this exposes the symbol to all users and makes it appear as a supported feature in auto-complete and documentation.

Guard this import similarly to other optional features, or add a note in the docstring that NVTE_WITH_CUSOLVERMP=1 at build time is required.


transformer_engine/common/newton_schulz/newton_schulz.cpp, line 119
Resource leak on exception during context creation

Raw CUDA handles (stream, in_ready, out_ready) are created at lines 99, 102, 104 but not moved into the struct until lines 112-114. If MakeCusolverMpHandle (line 106) or MakeCusolverMpGrid (line 107) throw (both contain NVTE_CHECK_CUSOLVERMP which can throw), these three handles are never cleaned up.

Wrap each handle in an RAII type (similar to CusolverMpHandle and CusolverMpGrid) to ensure safe cleanup on any early-exit path.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Mar 4, 2026

Additional Comments (4)

build_tools/pytorch.py, line 95
Missing assertion for CUSOLVERMP_HOME

When NVTE_WITH_CUSOLVERMP=1 is set but CUSOLVERMP_HOME is not, this silently falls back to /usr and will produce a confusing linker error at build time rather than a clear configuration message.

Compare to the NVSHMEM_HOME handling directly above (lines 81-83), which explicitly asserts the variable must be set. Add the same guard here:

if bool(int(os.getenv("NVTE_WITH_CUSOLVERMP", "0"))):
    assert (
        os.getenv("CUSOLVERMP_HOME") is not None
    ), "CUSOLVERMP_HOME must be set when compiling with NVTE_WITH_CUSOLVERMP=1"
    cusolvermp_home = Path(os.getenv("CUSOLVERMP_HOME"))
    ...

qa/L1_pytorch_distributed_unittest/test.sh, line 35
Test runs unconditionally regardless of build config

This test is added unconditionally to the QA script, so it will always execute even when TE is built without NVTE_WITH_CUSOLVERMP=1. The subprocess will fail with a runtime error about missing cuSolverMp support, breaking the CI job.

Add a build-flag guard matching the build configuration:

if [ "${NVTE_WITH_CUSOLVERMP:-0}" = "1" ]; then
    python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_newton_schulz.xml $TE_PATH/tests/pytorch/distributed/test_newton_schulz.py || test_fail "test_newton_schulz.py"
fi

tests/pytorch/distributed/test_newton_schulz.py, line 15
No guard for missing cuSolverMp build support

The test only skips when fewer than 2 GPUs are available, but does not check whether TE was built with NVTE_WITH_CUSOLVERMP=1. On a system with 2+ GPUs but a TE build without cuSolverMp, the torchrun subprocess will fail with a runtime error, causing the test to report AssertionError with confusing output rather than a clean skip.

Add an early skip guard:

import transformer_engine_torch as tex

if not hasattr(tex, "newton_schulz"):
    pytest.skip("Newton-Schulz tests require TE built with NVTE_WITH_CUSOLVERMP=1.", allow_module_level=True)

transformer_engine/pytorch/newton_schulz.py, line 64
Fallback coefficients for num_iterations != 5 are undocumented and degrade polynomial degree

When num_iterations != 5 and no custom coefficients are supplied, the fallback is [1.5, -0.5, 0.0] * num_iterations. The trailing 0.0 silently degenerates the quintic polynomial to a cubic one (a·X + b·X³ + 0·X⁵). This means users calling with, e.g., num_iterations=10 will unknowingly use different convergence behavior than the optimized 5-iteration case.

Consider either:

  1. Raising a ValueError when num_iterations != 5 and no coefficients are provided, forcing users to supply their own, or
  2. Documenting clearly in the docstring that only 5-iteration defaults are optimised and all other counts fall back to generic cubic steps

@vcherepanov-nv vcherepanov-nv changed the title [Draft] Newton-Schulz via cuSOLVERMp Newton-Schulz via cuSOLVERMp Mar 9, 2026
@cyanguwa cyanguwa requested review from cyanguwa and timmoon10 April 1, 2026 00:11
@cyanguwa cyanguwa added the 2.15.0 label Apr 1, 2026
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we agreed on putting this file in transformer_engine/pytorch/optimizers?

Also, we want to fully replicate the functionality of newton_schulz_tp, so may need to take a look at supporting the partition_dim and mode parameters? Looking at Emerging-Optimizers' implementation, I think it's mostly wrapper code that we need to add in Python.

Copy link
Copy Markdown
Member

@timmoon10 timmoon10 Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I'd suggest something liketransformer_engine/pytorch/optimizers/muon.

Copy link
Copy Markdown
Member

@timmoon10 timmoon10 Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this depends on how we're planning on exposing this. If TE owns an implementation of Muon, then I agree we can treat this as internal utilities. If the Muon implementation lives in Megatron-LM or somewhere else, then we should treat this as a general API in something like transformer_engine/pytorch/cusolvermp.

dist.all_gather(gathered, x_local)
X = torch.cat(gathered, dim=0)

# Check: the resulting matrix should be orthogonal
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking whether the result is orthogonal might not be enough - I think we said we'd compare the actual values of the matrix with the result of a non-distributed version?

@pytest.mark.parametrize("dtype", ["float32", "bfloat16"])
@pytest.mark.parametrize("matrix_size", [256])
@pytest.mark.parametrize("num_iterations", [5, 15])
def test_newton_schulz(dtype, matrix_size, num_iterations):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we expand the tests a bit to cover more sizes, shapes, etc?

https://github.com/NVIDIA-NeMo/Emerging-Optimizers/blob/main/tests/test_muon_utils.py

Comment thread transformer_engine/pytorch/newton_schulz.py
@vcherepanov-nv vcherepanov-nv force-pushed the newton-schulz branch 2 times, most recently from 16ba1fd to ea4db43 Compare April 2, 2026 19:44
vcherepanov-nv and others added 8 commits April 2, 2026 19:51
Add a new distributed Newton-Schulz inverse square root API to Transformer
Engine's common C library. This wraps the cusolverMpNewtonSchulz library
function, following the same pattern as the existing cuBLASMp integration
for comm_gemm.

New files:
- newton_schulz.h: Public C API header with context management and
  computation functions
- newton_schulz/newton_schulz.cpp: Implementation with RAII wrappers
  for cuSolverMp handles

Build integration:
- New NVTE_WITH_CUSOLVERMP CMake option and CUSOLVERMP_HOME env var
- NVTE_CHECK_CUSOLVERMP error checking macro in logging.h
- Conditional compilation guarded by NVTE_WITH_CUSOLVERMP

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Add PyTorch-level bindings for the cuSolverMp Newton-Schulz inverse
square root API introduced in the previous commit.

New files:
- pytorch/csrc/extensions/newton_schulz.cpp: C++ extension wrapping
  the C API with PyTorch tensor support
- pytorch/newton_schulz.py: Python wrapper that extracts NCCL
  communicator from torch.distributed ProcessGroup
- tests/pytorch/distributed/test_newton_schulz.py: pytest launcher
- tests/pytorch/distributed/run_newton_schulz.py: distributed test
  worker with reference implementation for numerical validation

Modified files:
- pytorch/csrc/extensions.h: Function declarations
- pytorch/csrc/extensions/pybind.cpp: pybind11 registrations
- pytorch/__init__.py: Public API export

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Fix API mismatches discovered during compilation:
- cusolverMpCreate takes (handle*, deviceId, stream), not (handle*, stream)
- cusolverMpCreateDeviceGrid takes handle as first arg with different
  parameter order
- Use cusolverMpGridMapping_t (not cusolverMpGridLayout_t) and
  CUSOLVERMP_GRID_MAPPING_COL_MAJOR
- cusolverMpCreateMatrixDesc has different parameter order: (desc*,
  grid, dtype, M, N, MB, NB, RSRC, CSRC, LLD)
- cusolverMpNewtonSchulzDescriptorCreate takes only (nsDesc*) with no
  iteration/coefficient args
- No cusolverMpStreamSet exists; create handle per-call with user stream
- cusolverMpNewtonSchulz requires computeType and info parameters
- Switch from generic template RAII to explicit deleter structs

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
…build

Add NVTE_WITH_CUSOLVERMP compiler define and cusolverMp include/library
paths to the PyTorch C++ extension build, following the same pattern as
NVTE_UB_WITH_MPI and NVTE_ENABLE_NVSHMEM.

Without this, the #ifdef NVTE_WITH_CUSOLVERMP guards in the PyTorch
extension code would never be active since the define was only set as
PRIVATE in the CMake build for the common library.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Two fixes:
- Use ProcessGroupNCCL._comm_ptr() to extract the raw NCCL communicator
  pointer instead of the non-existent get_nccl_comm() method
- Pass global matrix dimensions (m, n) from Python to C++ instead of
  using local tensor dimensions, which would produce incorrect
  ScaLAPACK block sizes in the distributed computation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
cuSolverMp handle and grid creation are expensive operations. Move them
from per-call creation in nvte_newton_schulz into the NVTECusolverMpCtx,
which is their natural home — the context exists to encapsulate the grid.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
cuSolverMp cannot work with the default CUDA stream. Create a dedicated
stream inside nvte_cusolvermp_ctx_create and remove the stream parameter
from both C API functions since the context now owns its stream.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
The internal dedicated stream was reading the input tensor before the
caller's stream had finished producing it, resulting in all-zero output.

Add event-based synchronisation: the internal stream waits for the
caller's input to be ready, and the caller's stream waits for the
output to be written. Replaces the blocking cudaStreamSynchronize.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Comment thread transformer_engine/pytorch/newton_schulz.py
Comment thread transformer_engine/common/newton_schulz/newton_schulz.cpp
Comment thread transformer_engine/common/include/transformer_engine/newton_schulz.h Outdated
Comment thread transformer_engine/common/newton_schulz/newton_schulz.cpp Outdated
Comment thread transformer_engine/pytorch/csrc/extensions/pybind.cpp Outdated
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I'd suggest something liketransformer_engine/pytorch/optimizers/muon.

Comment thread build_tools/pytorch.py Outdated
Comment thread transformer_engine/common/CMakeLists.txt Outdated
Comment thread transformer_engine/common/newton_schulz/newton_schulz.cpp
Comment thread transformer_engine/common/newton_schulz/newton_schulz.cpp Outdated
vcherepanov-nv and others added 14 commits April 3, 2026 00:23
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: vcherepanov-nv <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: vcherepanov-nv <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
…t cuSOLVERMp support

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could tweak the design of the Python wrapper for the cuSOLVERMp context, but otherwise this looks good to me.

#ifndef TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_
#define TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_

#include <nccl.h>
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also a little uneasy at exposing NCCL as a required dependency, but I see we already import the NCCL header elsewhere in the codebase:

Comment thread transformer_engine/pytorch/newton_schulz.py Outdated
Comment thread transformer_engine/pytorch/newton_schulz.py Outdated
vcherepanov-nv and others added 6 commits April 13, 2026 12:46
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: vcherepanov-nv <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
timmoon10
timmoon10 previously approved these changes Apr 14, 2026
Copy link
Copy Markdown
Member

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
@vcherepanov-nv vcherepanov-nv merged commit a073ad5 into NVIDIA:main Apr 15, 2026
10 of 12 checks passed
@ptrendx ptrendx added this to the 2.15 milestone Apr 23, 2026
faradawn pushed a commit to faradawn/TransformerEngine that referenced this pull request May 14, 2026
* [Common] Add Newton-Schulz inverse square root C API via cuSolverMp

Add a new distributed Newton-Schulz inverse square root API to Transformer
Engine's common C library. This wraps the cusolverMpNewtonSchulz library
function, following the same pattern as the existing cuBLASMp integration
for comm_gemm.

New files:
- newton_schulz.h: Public C API header with context management and
  computation functions
- newton_schulz/newton_schulz.cpp: Implementation with RAII wrappers
  for cuSolverMp handles

Build integration:
- New NVTE_WITH_CUSOLVERMP CMake option and CUSOLVERMP_HOME env var
- NVTE_CHECK_CUSOLVERMP error checking macro in logging.h
- Conditional compilation guarded by NVTE_WITH_CUSOLVERMP

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [PyTorch] Add Newton-Schulz PyTorch bindings and distributed tests

Add PyTorch-level bindings for the cuSolverMp Newton-Schulz inverse
square root API introduced in the previous commit.

New files:
- pytorch/csrc/extensions/newton_schulz.cpp: C++ extension wrapping
  the C API with PyTorch tensor support
- pytorch/newton_schulz.py: Python wrapper that extracts NCCL
  communicator from torch.distributed ProcessGroup
- tests/pytorch/distributed/test_newton_schulz.py: pytest launcher
- tests/pytorch/distributed/run_newton_schulz.py: distributed test
  worker with reference implementation for numerical validation

Modified files:
- pytorch/csrc/extensions.h: Function declarations
- pytorch/csrc/extensions/pybind.cpp: pybind11 registrations
- pytorch/__init__.py: Public API export

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [Common] Fix cuSolverMp API signatures in Newton-Schulz implementation

Fix API mismatches discovered during compilation:
- cusolverMpCreate takes (handle*, deviceId, stream), not (handle*, stream)
- cusolverMpCreateDeviceGrid takes handle as first arg with different
  parameter order
- Use cusolverMpGridMapping_t (not cusolverMpGridLayout_t) and
  CUSOLVERMP_GRID_MAPPING_COL_MAJOR
- cusolverMpCreateMatrixDesc has different parameter order: (desc*,
  grid, dtype, M, N, MB, NB, RSRC, CSRC, LLD)
- cusolverMpNewtonSchulzDescriptorCreate takes only (nsDesc*) with no
  iteration/coefficient args
- No cusolverMpStreamSet exists; create handle per-call with user stream
- cusolverMpNewtonSchulz requires computeType and info parameters
- Switch from generic template RAII to explicit deleter structs

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [PyTorch] Propagate NVTE_WITH_CUSOLVERMP define to PyTorch extension build

Add NVTE_WITH_CUSOLVERMP compiler define and cusolverMp include/library
paths to the PyTorch C++ extension build, following the same pattern as
NVTE_UB_WITH_MPI and NVTE_ENABLE_NVSHMEM.

Without this, the #ifdef NVTE_WITH_CUSOLVERMP guards in the PyTorch
extension code would never be active since the define was only set as
PRIVATE in the CMake build for the common library.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [PyTorch] Fix NCCL comm extraction and pass global dims to Newton-Schulz

Two fixes:
- Use ProcessGroupNCCL._comm_ptr() to extract the raw NCCL communicator
  pointer instead of the non-existent get_nccl_comm() method
- Pass global matrix dimensions (m, n) from Python to C++ instead of
  using local tensor dimensions, which would produce incorrect
  ScaLAPACK block sizes in the distributed computation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [Common] Cache cuSolverMp handle and grid in Newton-Schulz context

cuSolverMp handle and grid creation are expensive operations. Move them
from per-call creation in nvte_newton_schulz into the NVTECusolverMpCtx,
which is their natural home — the context exists to encapsulate the grid.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [Common] Create dedicated CUDA stream in Newton-Schulz context

cuSolverMp cannot work with the default CUDA stream. Create a dedicated
stream inside nvte_cusolvermp_ctx_create and remove the stream parameter
from both C API functions since the context now owns its stream.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [Common] Fix Newton-Schulz zero output with event-based stream sync

The internal dedicated stream was reading the input tensor before the
caller's stream had finished producing it, resulting in all-zero output.

Add event-based synchronisation: the internal stream waits for the
caller's input to be ready, and the caller's stream waits for the
output to be written. Replaces the blocking cudaStreamSynchronize.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [Common] Fix Newton-Schulz NaNs by keeping host workspace alive

cuSolverMp is asynchronous and uses the host workspace during multi-GPU
execution. The event-based output sync did not block the host, so the
local workspace_host vector was destroyed while the GPU was still
reading from it. Restore cudaStreamSynchronize to ensure the host
workspace remains valid for the full duration of the operation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [Common] Cache CUDA event in Newton-Schulz context

Avoid creating and destroying a cudaEvent_t on every
nvte_newton_schulz call by making it a persistent member of
NVTECusolverMpCtx, matching the existing pattern for the stream.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [Common] Use separate in/out events for Newton-Schulz stream sync

Replace single event with in_ready and out_ready events. After the
cuSolverMp call, record out_ready on the internal stream and make the
caller's stream wait on it, ensuring the output tensor is ready before
the caller uses it.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Correct coefficients

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* No stream synchronize

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [Test] Verify Newton-Schulz result with XAX=I identity check

Replace reference-comparison test with a direct arithmetic check:
if X is the inverse square root of A, then X @ A @ X must equal the
identity matrix. This is more robust and removes the need for a
separate reference implementation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Change test - it approximates orthogonal matrix, not inverse square root

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Generalize number of iterations in tests

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Remove extra info diag - everything should be in logs

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Add Newton-Schulz tests to the QA script

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Fix outdated comments

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Remove unused variable

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Move magic numbers from tests to impl

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Fix outdated comments

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Check num_coefficients

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Auto-detect cuSolverMp support from common library binary

Instead of requiring NVTE_WITH_CUSOLVERMP env var to be set for
both the common library and PyTorch extension builds, inspect the
already-built libtransformer_engine.so for exported symbols. This
is more robust for incremental builds and CI environments where
the env var may not be propagated to the extension build step.

The PyTorch extension only calls nvte_* C API functions, so it
does not need cusolverMp headers or libraries — only the compile
definition.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Conditionally exclude Newton-Schulz API from PyTorch extension

When NVTE_WITH_CUSOLVERMP is not defined, omit the Newton-Schulz
functions entirely from the pybind module instead of registering
stubs that throw runtime errors. The Python wrapper checks for
the attribute at call time and raises a clear error message.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Make symbol detection errors fatal in common_lib_has_symbol

Raise FileNotFoundError when no libtransformer_engine.so is found in
any candidate location, and raise RuntimeError when nm is unavailable
or exits non-zero, rather than silently returning False in both cases.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Search for libtransformer_engine.so via installed module location first

In common_lib_has_symbol, prepend a candidate derived by importing
transformer_engine via importlib.util.find_spec and using the package
directory as the root. This correctly resolves the SO path for source
and PyPI installs (where it lives inside transformer_engine/), before
falling back to the repo-root and CMake build dir candidates.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Add site packages to search paths for TE common

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Revert "Auto-detect cuSolverMp support from common library binary"

This reverts commit 8f50bd5.

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Remove unused import

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Fix incorrect 'inverse square root' references in Newton-Schulz comments

Replace misleading 'inverse square root' descriptions with accurate
'matrix orthogonalization' in the module docstring, function docstring,
and pybind11 binding docstring.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [PyTorch] Expose cuSolverMp context creation/destruction as public API

Context creation is expensive and should not happen on every
newton_schulz call. Introduce CusolverMpCtx and cusolvermp_ctx_create()
so callers can create a context once from a ProcessGroup and reuse it.
CusolverMpCtx supports explicit destroy() and use as a context manager.
newton_schulz() now takes CusolverMpCtx instead of ProcessGroup.

Export CusolverMpCtx and cusolvermp_ctx_create from the pytorch package.
Update the distributed test worker to use explicit context lifecycle.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [PyTorch] Strengthen input validation in newton_schulz

Replace assert with ValueError for the coefficients length check.
Add dtype (float32/bfloat16) and contiguity checks for the input tensor.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Use ncclMemAlloc for cuSolverMp Newton-Schulz workspace

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Add Newton-Schulz reference tests

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Fix Newton-Schulz reference test logic

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Fix column-major usage of cuSOLVERMp; add rectangular test cases

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Avoid explicit transpose

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Cleanup

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* More cleanup

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Cleanup

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Update transformer_engine/common/newton_schulz/newton_schulz.cpp

Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: vcherepanov-nv <vcherepanov@nvidia.com>

* Fix syntax

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Apply suggestions from code review

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: vcherepanov-nv <vcherepanov@nvidia.com>

* Add timeout

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Use RAII for cusolvermp CUDA resources

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Make NS API declared unconditional, with stub / runtime errors without cuSOLVERMp support

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix index in diag

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* CMake fixes

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Update transformer_engine/pytorch/newton_schulz.py

Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: vcherepanov-nv <vcherepanov@nvidia.com>

* Fix a typo

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* Cleanup context management

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Borrow more coefficient sets from Emerging Optimizers

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Couple num_iterations with coeff types in tests

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>

---------

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: vcherepanov-nv <vcherepanov@nvidia.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants