diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 9d868d99cf..db13e9f1e0 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -32,6 +32,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_use python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_newton_schulz.xml $TE_PATH/tests/pytorch/distributed/test_newton_schulz.py || test_fail "test_newton_schulz.py" # debug tests diff --git a/setup.py b/setup.py index 3a66e624e3..ec277b6349 100644 --- a/setup.py +++ b/setup.py @@ -78,6 +78,11 @@ def setup_common_extension() -> CMakeExtension: ).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}") cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}") + if bool(int(os.getenv("NVTE_WITH_CUSOLVERMP", "0"))): + cmake_flags.append("-DNVTE_WITH_CUSOLVERMP=ON") + cusolvermp_dir = os.getenv("CUSOLVERMP_HOME", "/usr") + cmake_flags.append(f"-DCUSOLVERMP_DIR={cusolvermp_dir}") + # Add custom CMake arguments from environment variable nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS") if nvte_cmake_extra_args: diff --git a/tests/pytorch/distributed/run_newton_schulz.py b/tests/pytorch/distributed/run_newton_schulz.py new file mode 100644 index 0000000000..bbd0733447 --- /dev/null +++ b/tests/pytorch/distributed/run_newton_schulz.py @@ -0,0 +1,127 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Distributed Newton-Schulz test worker. + +Launched via torchrun from test_newton_schulz.py. +""" + +import argparse +import sys + +import torch +import torch.distributed as dist +from torch.distributed.elastic.multiprocessing.errors import record + +from transformer_engine.pytorch.newton_schulz import ( + CusolverMpCtx, + get_coefficients, + newton_schulz, +) + + +def newton_schulz_reference(in_x: torch.Tensor, coefficients: list[float]) -> torch.Tensor: + """Local Newton-Schulz reference mirroring the provided Octave update.""" + x = in_x.clone() + for i in range(len(coefficients) // 3): + a, b, c = coefficients[3 * i : 3 * (i + 1)] + xxt = x @ x.mT + x = a * x + b * xxt @ x + c * xxt @ xxt @ x + return x + + +@record +def main(): + parser = argparse.ArgumentParser(description="Newton-Schulz distributed test") + parser.add_argument( + "--check", type=str, default="orthogonality", choices=["orthogonality", "reference"] + ) + parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "bfloat16"]) + parser.add_argument("--matrix-rows", type=int, default=256) + parser.add_argument("--matrix-cols", type=int, default=None) + parser.add_argument("--num-iterations", type=int, default=5) + parser.add_argument("--coeff-type", type=str, default="quintic") + parser.add_argument("--atol", type=float, default=1e-2) + parser.add_argument("--rtol", type=float, default=1e-2) + args = parser.parse_args() + + dist.init_process_group(backend="nccl") + rank = dist.get_rank() + world_size = dist.get_world_size() + torch.cuda.set_device(rank) + + dtype = torch.float32 if args.dtype == "float32" else torch.bfloat16 + m = args.matrix_rows + n = args.matrix_cols if args.matrix_cols is not None else args.matrix_rows + coefficients = get_coefficients(args.num_iterations, args.coeff_type) + + # Ensure the distributed column dimension is divisible by world_size. + assert n % world_size == 0, f"Matrix columns {n} must be divisible by world_size {world_size}" + + # Create a random matrix on rank 0 with singular values in (0, 1), + # which keeps the Newton-Schulz iterations in the convergence regime. + if rank == 0: + torch.manual_seed(42) + k = min(m, n) + U, _ = torch.linalg.qr( + torch.randn(m, k, device="cuda", dtype=torch.float32), mode="reduced" + ) + V, _ = torch.linalg.qr( + torch.randn(n, k, device="cuda", dtype=torch.float32), mode="reduced" + ) + singular_values = torch.rand(k, device="cuda", dtype=torch.float32) * 0.8 + 0.1 + A = U @ torch.diag(singular_values) @ V.T + A = A.to(dtype) + else: + A = torch.empty(m, n, device="cuda", dtype=dtype) + + # Broadcast the full matrix to all ranks + dist.broadcast(A, src=0) + + # Scatter columns to each rank + local_cols = n // world_size + x_local = A[:, rank * local_cols : (rank + 1) * local_cols].contiguous() + + ctx = CusolverMpCtx(dist.group.WORLD) + try: + newton_schulz(x_local, ctx, args.num_iterations, coefficients=coefficients) + finally: + ctx.destroy() + + # Gather results + gathered = [torch.empty_like(x_local) for _ in range(world_size)] + dist.all_gather(gathered, x_local) + X = torch.cat(gathered, dim=1) + + # Check: the resulting matrix should be orthogonal, or match a local reference. + if rank == 0: + if args.check == "orthogonality": + if m <= n: + gram = X @ X.t() + expected = torch.eye(m, device=gram.device, dtype=gram.dtype) + max_diff = (gram - expected).abs().max().item() + print(f"Max |X @ X.t() - I|: {max_diff:.6e}", flush=True) + else: + gram = X.t() @ X + expected = torch.eye(n, device=gram.device, dtype=gram.dtype) + max_diff = (gram - expected).abs().max().item() + print(f"Max |X.t() @ X - I|: {max_diff:.6e}", flush=True) + passed = torch.allclose(gram, expected, atol=args.atol, rtol=args.rtol) + else: + reference = newton_schulz_reference(A.float(), coefficients).to(dtype) + max_diff = (X - reference).abs().max().item() + print(f"Max |distributed - reference|: {max_diff:.6e}", flush=True) + passed = torch.allclose(X, reference, atol=args.atol, rtol=args.rtol) + + if passed: + print("NUMERICAL CHECK PASSED", flush=True) + else: + print("NUMERICAL CHECK FAILED", flush=True, file=sys.stderr) + sys.exit(1) + + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/tests/pytorch/distributed/test_newton_schulz.py b/tests/pytorch/distributed/test_newton_schulz.py new file mode 100644 index 0000000000..0bf4182518 --- /dev/null +++ b/tests/pytorch/distributed/test_newton_schulz.py @@ -0,0 +1,69 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tests for distributed Newton-Schulz matrix orthogonalization.""" + +import os +import subprocess +from pathlib import Path + +import pytest +import torch + +if torch.cuda.device_count() < 2: + pytest.skip("Newton-Schulz tests require at least 2 GPUs.", allow_module_level=True) + +TEST_ROOT = Path(__file__).parent.resolve() +NUM_PROCS = torch.cuda.device_count() +LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] +ORTHOGONALITY_SHAPES = [ + (NUM_PROCS * 64, NUM_PROCS * 64), + (NUM_PROCS * 64, NUM_PROCS * 96), + (NUM_PROCS * 96, NUM_PROCS * 64), +] +REFERENCE_SHAPES = [(NUM_PROCS * 64, NUM_PROCS * 64)] + + +def _run_test(dtype, matrix_shape, num_iterations, coeff_type, check): + rows, cols = matrix_shape + test_path = TEST_ROOT / "run_newton_schulz.py" + test_cmd = LAUNCH_CMD + [ + str(test_path), + f"--check={check}", + f"--dtype={dtype}", + f"--matrix-rows={rows}", + f"--matrix-cols={cols}", + f"--num-iterations={num_iterations}", + f"--coeff-type={coeff_type}", + ] + if dtype == "bfloat16": + test_cmd += ["--atol=5e-2", "--rtol=5e-2"] + + result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False, timeout=300) + if ( + result.returncode != 0 + or "NUMERICAL CHECK FAILED" in result.stderr.decode() + or "NUMERICAL CHECK PASSED" not in result.stdout.decode() + ): + raise AssertionError( + "Newton-Schulz test failed.\n" + f"stdout: {result.stdout.decode()}\n" + f"stderr: {result.stderr.decode()}" + ) + + +@pytest.mark.parametrize("dtype", ["float32", "bfloat16"]) +@pytest.mark.parametrize("matrix_shape", ORTHOGONALITY_SHAPES) +@pytest.mark.parametrize("num_iterations,coeff_type", [(5, "quintic"), (8, "polar_express")]) +def test_orthogonality(dtype, matrix_shape, num_iterations, coeff_type): + """Test distributed Newton-Schulz orthogonality.""" + _run_test(dtype, matrix_shape, num_iterations, coeff_type, "orthogonality") + + +@pytest.mark.parametrize("dtype", ["float32", "bfloat16"]) +@pytest.mark.parametrize("matrix_shape", REFERENCE_SHAPES) +@pytest.mark.parametrize("num_iterations,coeff_type", [(5, "quintic"), (8, "polar_express")]) +def test_against_reference(dtype, matrix_shape, num_iterations, coeff_type): + """Test distributed Newton-Schulz against a local reference implementation.""" + _run_test(dtype, matrix_shape, num_iterations, coeff_type, "reference") diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 7c223e6917..6f8e66f099 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -120,7 +120,9 @@ list(APPEND transformer_engine_cpp_sources util/rtc.cpp comm_gemm_overlap/userbuffers/ipcsocket.cc comm_gemm_overlap/userbuffers/userbuffers-host.cpp - comm_gemm_overlap/comm_gemm_overlap.cpp) + comm_gemm_overlap/comm_gemm_overlap.cpp + newton_schulz/newton_schulz.cpp + ) list(APPEND transformer_engine_cuda_sources common.cu @@ -303,6 +305,23 @@ if (NVTE_WITH_CUBLASMP) message(STATUS "Using cuBLASMp at: ${CUBLASMP_DIR}") endif() +option(NVTE_WITH_CUSOLVERMP "Use cuSolverMp for distributed Newton-Schulz" OFF) +if (NVTE_WITH_CUSOLVERMP) + target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUSOLVERMP) + target_include_directories(transformer_engine PRIVATE ${CUSOLVERMP_DIR}/include) + find_library(CUSOLVERMP_LIB + NAMES cusolverMp libcusolverMp + PATHS ${CUSOLVERMP_DIR} + PATH_SUFFIXES lib + REQUIRED) + find_library(NCCL_LIB + NAMES nccl libnccl + PATH_SUFFIXES lib + REQUIRED) + target_link_libraries(transformer_engine PRIVATE ${NCCL_LIB} ${CUSOLVERMP_LIB}) + message(STATUS "Using cuSolverMp at: ${CUSOLVERMP_DIR}") +endif() + # Number of philox4x32 rounds for stochastic rounding (build-time constant). set(NVTE_BUILD_NUM_PHILOX_ROUNDS_STR $ENV{NVTE_BUILD_NUM_PHILOX_ROUNDS}) if (NOT NVTE_BUILD_NUM_PHILOX_ROUNDS_STR) diff --git a/transformer_engine/common/include/transformer_engine/newton_schulz.h b/transformer_engine/common/include/transformer_engine/newton_schulz.h new file mode 100644 index 0000000000..bea8e32b1e --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/newton_schulz.h @@ -0,0 +1,66 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file newton_schulz.h + * \brief Functions for distributed Newton-Schulz matrix orthogonalization. + * + * This API is a TE-native binding to the cuSolverMp library. + * It computes an iterative Newton-Schulz matrix orthogonalization on a distributed matrix. + */ + +#ifndef TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_ +#define TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_ + +#include +#include + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct NVTECusolverMpCtx NVTECusolverMpCtx; + +/*! \brief Create a cuSolverMp context for Newton-Schulz operations. + * + * Creates a dedicated CUDA stream internally (cuSolverMp requires a + * non-default stream). + * + * \param[in] comm NCCL communicator. + * \param[in] nranks Number of ranks. + * \param[in] rank Local rank. + */ +NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int rank); + +/*! \brief Destroy a cuSolverMp context. + * + * \param[in] ctx Context to destroy. + */ +void nvte_cusolvermp_ctx_destroy(NVTECusolverMpCtx* ctx); + +/*! \brief Compute Newton-Schulz matrix orthogonalization in-place. + * + * \param[in] ctx cuSolverMp context. + * \param[in] m Global number of rows. + * \param[in] n Global number of columns. + * \param[in,out] x Local part of the matrix (modified in-place). + * \param[in] num_iterations Number of Newton-Schulz iterations. + * \param[in] coefficients Array of polynomial coefficients (length depends on polynomial + * degree used internally by cuSolverMp). + * \param[in] num_coefficients Number of elements in the coefficients array. + * \param[in] caller_stream CUDA stream on which the caller produced the input tensor. + * Used for event-based synchronisation with the internal stream. + */ +void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor x, + int64_t num_iterations, const float* coefficients, int64_t num_coefficients, + cudaStream_t caller_stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_ diff --git a/transformer_engine/common/newton_schulz/newton_schulz.cpp b/transformer_engine/common/newton_schulz/newton_schulz.cpp new file mode 100644 index 0000000000..0d6426a156 --- /dev/null +++ b/transformer_engine/common/newton_schulz/newton_schulz.cpp @@ -0,0 +1,267 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/newton_schulz.h" + +#include + +#include +#include + +#include "../common.h" +#include "../util/logging.h" + +#ifdef NVTE_WITH_CUSOLVERMP + +#include + +using namespace transformer_engine; + +namespace { + +struct CudaStreamDeleter { + void operator()(std::remove_pointer_t* stream) const { cudaStreamDestroy(stream); } +}; +using CudaStream = std::unique_ptr, CudaStreamDeleter>; + +struct CudaEventDeleter { + void operator()(std::remove_pointer_t* event) const { cudaEventDestroy(event); } +}; +using CudaEvent = std::unique_ptr, CudaEventDeleter>; + +struct CusolverMpHandleDeleter { + void operator()(cusolverMpHandle_t handle) const { cusolverMpDestroy(handle); } +}; +using CusolverMpHandle = + std::unique_ptr, CusolverMpHandleDeleter>; + +struct CusolverMpGridDeleter { + void operator()(cusolverMpGrid_t grid) const { cusolverMpDestroyGrid(grid); } +}; +using CusolverMpGrid = + std::unique_ptr, CusolverMpGridDeleter>; + +struct CusolverMpMatrixDescDeleter { + void operator()(cusolverMpMatrixDescriptor_t desc) const { cusolverMpDestroyMatrixDesc(desc); } +}; +using CusolverMpMatrixDesc = std::unique_ptr, + CusolverMpMatrixDescDeleter>; + +struct CusolverMpNSDescDeleter { + void operator()(cusolverMpNewtonSchulzDescriptor_t desc) const { + cusolverMpNewtonSchulzDescriptorDestroy(desc); + } +}; +using CusolverMpNSDesc = std::unique_ptr, + CusolverMpNSDescDeleter>; + +CusolverMpHandle MakeCusolverMpHandle(int device_id, cudaStream_t stream) { + cusolverMpHandle_t raw{}; + NVTE_CHECK_CUSOLVERMP(cusolverMpCreate(&raw, device_id, stream)); + return CusolverMpHandle(raw); +} + +CusolverMpGrid MakeCusolverMpGrid(cusolverMpHandle_t handle, ncclComm_t comm, int32_t nprow, + int32_t npcol, cusolverMpGridMapping_t mapping) { + cusolverMpGrid_t raw{}; + NVTE_CHECK_CUSOLVERMP(cusolverMpCreateDeviceGrid(handle, &raw, comm, nprow, npcol, mapping)); + return CusolverMpGrid(raw); +} + +CusolverMpMatrixDesc MakeCusolverMpMatrixDesc(cusolverMpGrid_t grid, cudaDataType_t dtype, + int64_t m, int64_t n, int64_t mb, int64_t nb, + uint32_t rsrc, uint32_t csrc, int64_t lld) { + cusolverMpMatrixDescriptor_t raw{}; + NVTE_CHECK_CUSOLVERMP( + cusolverMpCreateMatrixDesc(&raw, grid, dtype, m, n, mb, nb, rsrc, csrc, lld)); + return CusolverMpMatrixDesc(raw); +} + +CusolverMpNSDesc MakeCusolverMpNSDesc() { + cusolverMpNewtonSchulzDescriptor_t raw{}; + NVTE_CHECK_CUSOLVERMP(cusolverMpNewtonSchulzDescriptorCreate(&raw)); + return CusolverMpNSDesc(raw); +} + +CudaStream MakeCudaStream() { + cudaStream_t raw{}; + NVTE_CHECK_CUDA(cudaStreamCreate(&raw)); + return CudaStream(raw); +} + +CudaEvent MakeCudaEvent() { + cudaEvent_t raw{}; + NVTE_CHECK_CUDA(cudaEventCreate(&raw)); + return CudaEvent(raw); +} + +} // namespace + +struct NVTECusolverMpCtx { + int64_t nranks; + int64_t rank; + CudaStream stream; + CudaEvent in_ready; + CudaEvent out_ready; + CusolverMpHandle handle; + CusolverMpGrid grid; + void* workspace; + size_t workspace_size; + bool workspace_registered; +}; + +namespace { + +void FreeWorkspace(NVTECusolverMpCtx* ctx) { + if (ctx->workspace == nullptr) { + return; + } + if (ctx->workspace_registered) { + NVTE_CHECK_CUSOLVERMP(cusolverMpBufferDeregister(ctx->grid.get(), ctx->workspace)); + NVTE_CHECK_NCCL(ncclMemFree(ctx->workspace)); + } else { + NVTE_CHECK_CUDA(cudaFree(ctx->workspace)); + } + ctx->workspace = nullptr; + ctx->workspace_size = 0; + ctx->workspace_registered = false; +} + +} // namespace + +NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int rank) { + NVTE_API_CALL(nvte_cusolvermp_ctx_create); + int device_id{}; + NVTE_CHECK_CUDA(cudaGetDevice(&device_id)); + + auto stream = MakeCudaStream(); + auto in_ready = MakeCudaEvent(); + auto out_ready = MakeCudaEvent(); + + auto handle = MakeCusolverMpHandle(device_id, stream.get()); + auto grid = MakeCusolverMpGrid(handle.get(), comm, nranks, 1, CUSOLVERMP_GRID_MAPPING_COL_MAJOR); + + return new NVTECusolverMpCtx{ + nranks, + rank, + std::move(stream), + std::move(in_ready), + std::move(out_ready), + std::move(handle), + std::move(grid), + nullptr, + 0, + false, + }; +} + +void nvte_cusolvermp_ctx_destroy(NVTECusolverMpCtx* ctx) { + NVTE_API_CALL(nvte_cusolvermp_ctx_destroy); + FreeWorkspace(ctx); + // Destroy handle and grid before the stream they depend on + ctx->grid.reset(); + ctx->handle.reset(); + delete ctx; +} + +void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor x, + int64_t num_iterations, const float* coefficients, int64_t num_coefficients, + cudaStream_t caller_stream) { + NVTE_API_CALL(nvte_newton_schulz); + NVTE_CHECK(num_coefficients == num_iterations * 3, num_iterations, " iterations require ", + num_iterations * 3, " coefficients, but ", num_coefficients, " are passed"); + const auto* t = convertNVTETensorCheck(x); + + // Make the internal stream wait for the caller's stream so that + // the input tensor is ready before cuSolverMp reads it. + NVTE_CHECK_CUDA(cudaEventRecord(ctx->in_ready.get(), caller_stream)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(ctx->stream.get(), ctx->in_ready.get())); + + // Block size for ScaLAPACK-style distribution + const int64_t mb = m; + const int64_t nb = (n + ctx->nranks - 1) / ctx->nranks; + + // Compute local leading dimension + const int64_t local_cols = cusolverMpNUMROC(n, nb, ctx->rank, 0, ctx->nranks); + NVTE_CHECK(t->shape().size() == 2, "Shape size:", t->shape().size()); + NVTE_CHECK(t->shape()[1] == local_cols, "Tensor cols:", t->shape()[1], "Local cols:", local_cols); + const int64_t lld = std::max(local_cols, static_cast(1)); + + const cudaDataType_t cuda_dtype = get_cuda_dtype(t->dtype()); + + // Create matrix descriptor + auto mat_desc = MakeCusolverMpMatrixDesc(ctx->grid.get(), cuda_dtype, n, m, nb, mb, 0, 0, lld); + + // Create Newton-Schulz descriptor + auto ns_desc = MakeCusolverMpNSDesc(); + + // Query workspace sizes + size_t wrksp_size_device = 0; + size_t wrksp_size_host = 0; + NVTE_CHECK_CUSOLVERMP(cusolverMpNewtonSchulz_bufferSize( + ctx->handle.get(), ns_desc.get(), n, m, t->data.dptr, 1, 1, mat_desc.get(), num_iterations, + coefficients, CUDA_R_32F, &wrksp_size_device, &wrksp_size_host)); + + // Allocate/grow device workspace + if (ctx->workspace_size < wrksp_size_device) { + FreeWorkspace(ctx); + + void* workspace = nullptr; + bool workspace_registered = false; + + if (ncclMemAlloc(&workspace, wrksp_size_device) == ncclSuccess) { + if (cusolverMpBufferRegister(ctx->grid.get(), workspace, wrksp_size_device) == + CUSOLVER_STATUS_SUCCESS) { + workspace_registered = true; + } else { + NVTE_CHECK_NCCL(ncclMemFree(workspace)); + workspace = nullptr; + } + } + + if (workspace == nullptr) { + NVTE_CHECK_CUDA(cudaMalloc(&workspace, wrksp_size_device)); + } + + ctx->workspace = workspace; + ctx->workspace_size = wrksp_size_device; + ctx->workspace_registered = workspace_registered; + } + + // Allocate host workspace + std::vector workspace_host(wrksp_size_host); + + // Execute Newton-Schulz + NVTE_CHECK_CUSOLVERMP(cusolverMpNewtonSchulz( + ctx->handle.get(), ns_desc.get(), n, m, t->data.dptr, 1, 1, mat_desc.get(), num_iterations, + coefficients, CUDA_R_32F, ctx->workspace, ctx->workspace_size, workspace_host.data(), + workspace_host.size(), nullptr)); + + // Make the caller's stream wait for the internal stream so that + // the output tensor is ready before the caller uses it. + NVTE_CHECK_CUDA(cudaEventRecord(ctx->out_ready.get(), ctx->stream.get())); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(caller_stream, ctx->out_ready.get())); +} + +#else // NVTE_WITH_CUSOLVERMP + +struct NVTECusolverMpCtx {}; + +NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int rank) { + NVTE_ERROR("Transformer Engine has not been built with cuSolverMp support."); +} + +void nvte_cusolvermp_ctx_destroy(NVTECusolverMpCtx* ctx) { + NVTE_ERROR("Transformer Engine has not been built with cuSolverMp support."); +} + +void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor x, + int64_t num_iterations, const float* coefficients, int64_t num_coefficients, + cudaStream_t caller_stream) { + NVTE_ERROR("Transformer Engine has not been built with cuSolverMp support."); +} + +#endif // NVTE_WITH_CUSOLVERMP diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index 8031e342e2..da8b9b377d 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -18,6 +18,10 @@ #include #endif // NVTE_WITH_CUBLASMP +#ifdef NVTE_WITH_CUSOLVERMP +#include +#endif // NVTE_WITH_CUSOLVERMP + #include #include #include @@ -106,6 +110,18 @@ #endif // NVTE_WITH_CUBLASMP +#ifdef NVTE_WITH_CUSOLVERMP + +#define NVTE_CHECK_CUSOLVERMP(expr) \ + do { \ + const cusolverStatus_t status_NVTE_CHECK_CUSOLVERMP = (expr); \ + if (status_NVTE_CHECK_CUSOLVERMP != CUSOLVER_STATUS_SUCCESS) { \ + NVTE_ERROR("cuSolverMp Error: ", std::to_string(status_NVTE_CHECK_CUSOLVERMP)); \ + } \ + } while (false) + +#endif // NVTE_WITH_CUSOLVERMP + #define NVTE_CHECK_NCCL(expr) \ do { \ const ncclResult_t status_NVTE_CHECK_NCCL = (expr); \ diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index bbc1d7fab6..d145cf0a21 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -63,6 +63,10 @@ from transformer_engine.pytorch import optimizers from transformer_engine.pytorch.export import onnx_export from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy +from transformer_engine.pytorch.newton_schulz import ( + CusolverMpCtx, + newton_schulz, +) from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage from transformer_engine.pytorch.quantized_tensor import QuantizedTensor from transformer_engine.pytorch.quantized_tensor import Quantizer diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index e4bc744e7e..9890f6742a 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -593,6 +593,17 @@ void nvshmem_finalize(); void bulk_overlap_ag_with_external_gemm(CommOverlap &allgather_communicator, at::Stream send_stream, at::Stream recv_stream); +/*************************************************************************************************** + * Newton-Schulz (cuSolverMp) + **************************************************************************************************/ + +int64_t cusolvermp_ctx_create(int64_t nccl_comm_ptr, int nranks, int rank); + +void cusolvermp_ctx_destroy(int64_t ctx_ptr); + +void newton_schulz(int64_t ctx_ptr, int64_t m, int64_t n, at::Tensor x, int64_t num_iterations, + std::vector coefficients); + } // namespace transformer_engine::pytorch /*************************************************************************************************** diff --git a/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp b/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp new file mode 100644 index 0000000000..8b24e8fdb9 --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/newton_schulz.cpp @@ -0,0 +1,40 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/newton_schulz.h" + +#include "../extensions.h" + +namespace transformer_engine::pytorch { + +int64_t cusolvermp_ctx_create(int64_t nccl_comm_ptr, int nranks, int rank) { + auto comm = reinterpret_cast(nccl_comm_ptr); + auto* ctx = nvte_cusolvermp_ctx_create(comm, nranks, rank); + return reinterpret_cast(ctx); +} + +void cusolvermp_ctx_destroy(int64_t ctx_ptr) { + auto* ctx = reinterpret_cast(ctx_ptr); + nvte_cusolvermp_ctx_destroy(ctx); +} + +void newton_schulz(int64_t ctx_ptr, int64_t m, int64_t n, at::Tensor x, int64_t num_iterations, + std::vector coefficients) { + auto* ctx = reinterpret_cast(ctx_ptr); + + // Build NVTETensor from PyTorch tensor + auto x_sizes = x.sizes().vec(); + std::vector shape(x_sizes.begin(), x_sizes.end()); + + auto te_dtype = GetTransformerEngineDType(x.scalar_type()); + TensorWrapper x_tensor(x.data_ptr(), shape, te_dtype); + + auto caller_stream = at::cuda::getCurrentCUDAStream().stream(); + nvte_newton_schulz(ctx, m, n, x_tensor.data(), num_iterations, coefficients.data(), + static_cast(coefficients.size()), caller_stream); +} + +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 18da5d0e9f..4a20be6361 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -559,6 +559,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &transformer_engine::pytorch::multi_tensor_compute_scale_inv_e8m0_cuda, "Fused compute E8M0 scale_inv from amax", py::call_guard()); + // Newton-Schulz (cuSolverMp) + m.def("cusolvermp_ctx_create", &transformer_engine::pytorch::cusolvermp_ctx_create, + "Create cuSolverMp context for Newton-Schulz", py::arg("nccl_comm_ptr"), py::arg("nranks"), + py::arg("rank"), py::call_guard()); + m.def("cusolvermp_ctx_destroy", &transformer_engine::pytorch::cusolvermp_ctx_destroy, + "Destroy cuSolverMp context", py::arg("ctx_ptr"), py::call_guard()); + m.def("newton_schulz", &transformer_engine::pytorch::newton_schulz, + "Newton-Schulz matrix orthogonalization", py::arg("ctx_ptr"), py::arg("m"), py::arg("n"), + py::arg("x"), py::arg("num_iterations"), py::arg("coefficients"), + py::call_guard()); + // Comm+GEMM Overlap m.def("bulk_overlap_ag_with_external_gemm", &transformer_engine::pytorch::bulk_overlap_ag_with_external_gemm, diff --git a/transformer_engine/pytorch/newton_schulz.py b/transformer_engine/pytorch/newton_schulz.py new file mode 100644 index 0000000000..2367897565 --- /dev/null +++ b/transformer_engine/pytorch/newton_schulz.py @@ -0,0 +1,200 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Distributed Newton-Schulz matrix orthogonalization via cuSolverMp.""" + +from itertools import chain, cycle, islice, repeat +from typing import Iterator, List, Literal, Optional, Sequence + +import torch +import torch.distributed as dist + +import transformer_engine_torch as tex + + +_COEFFICIENT_SETS = { + # Values are rounded to closest representable in single precision. + "simple": [ + (3.4445, -4.7750, 2.0315), + ], + "quintic": [ + # optimized for a quintic iteration. + # Source: https://leloykun.github.io/ponder/muon-opt-coeffs/#how-do-we-optimize-the-coefficients + # Numbers from: https://github.com/KellerJordan/modded-nanogpt/blob/0674386070ceb4dcd207e1aca747ffcea6c15250/train_gpt_medium.py#L45 + (4.0848, -6.8946, 2.9270), + (3.9505, -6.3029, 2.6377), + (3.7418, -5.5913, 2.3037), + (2.8769, -3.1427, 1.2046), + (2.8366, -3.0525, 1.2012), + ], + "polar_express": [ + # Polar Express iteration from: https://arxiv.org/abs/2505.16932 + # We include PolarExpress' division by 1.01^polynomial_degree (as stated in their Algorithm 1) in the coefficient list. + # This is a safety factor for numerical stability. + (8.2051, -22.9019, 16.4607), + (4.0664, -2.8612, 0.5184), + (3.9096, -2.8234, 0.5250), + (3.2856, -2.4153, 0.4853), + (2.2779, -1.6198, 0.3985), + (1.8726, -1.2307, 0.3585), + (1.8564, -1.2132, 0.3568), + (1.8750, -1.2500, 0.3750), + ], + "cans": [ + # CANS from: http://arxiv.org/abs/2506.10935 + # CANS iteration (Remez + adaptive interval) based coefficients. + # Source (for generating CANS coefficients): https://github.com/GrishKate/accelerating_orthogonalization/blob/main/polynomials.py + (8.4703, -25.1081, 18.6293), + (4.1828, -3.1087, 0.5806), + (3.9619, -2.9541, 0.5630), + (3.2866, -2.4647, 0.5074), + (2.2737, -1.6447, 0.4162), + ], + "aol": [ + # from https://github.com/thib-s/flash-newton-schulz/blob/main/newton_schulz_triton.py#L511 + (4.0098, -7.0585, 2.4635), + (3.4585, -5.5479, 2.5959), + (2.7573, -3.2939, 1.4254), + (2.7215, -3.0494, 1.3169), + ], +} + +NSCoeffT = Literal[_COEFFICIENT_SETS.keys()] + +CoeffIterMode = Literal["cycle", "repeat_last"] + + +def get_coefficient_iterator( + steps: int, + coefficient_sets: Sequence[tuple[float, float, float]], + mode: CoeffIterMode = "cycle", +) -> Iterator[tuple[float, float, float]]: + """Iterate through coefficient sets with configurable end behavior using itertools. + + Args: + steps: The number of tuples to yield. + coefficient_sets: A sequence of (a, b, c) coefficient tuples. + mode: Iteration mode: + - "cycle": After the last element, restart from the beginning. + - "repeat_last": After the last element, keep yielding the last tuple. + + Yields: + Tuples (a, b, c) from coefficient_sets according to the specified mode. + + Raises: + ValueError: If coefficient_sets is empty. + ValueError: If an invalid mode is provided. + """ + if not coefficient_sets: + raise ValueError("coefficient_sets must be non-empty.") + + base: Iterator[tuple[float, float, float]] + if mode == "cycle": + base = cycle(coefficient_sets) + elif mode == "repeat_last": + # Chain the original list with an infinite repeat of the last item + base = chain(coefficient_sets, repeat(coefficient_sets[-1])) + else: + raise ValueError(f"Invalid mode: {mode}. Expected 'cycle' or 'repeat_last'.") + + return islice(base, steps) + + +def get_coefficients(steps: int, coefficient_type: NSCoeffT = "quintic") -> List[float]: + """Return the coefficient schedule for Newton-Schulz. + + Parameter ``coefficient_type`` can be one of the following + - "simple": Default coefficient set. + - "quintic": Quintic iteration with optimized coefficients. + - "polar_express": Polar Express iteration with optimized coefficients. + - "cans": CANS iteration with Remez + adaptive interval coefficients. + - "aol": AOL coefficient set. + """ + if coefficient_type not in _COEFFICIENT_SETS: + raise ValueError("Invalid coefficient type: " + coefficient_type) + iter_mode: CoeffIterMode = ( + "repeat_last" if coefficient_type in ("polar_express", "cans") else "cycle" + ) + coeff_iter = get_coefficient_iterator( + steps, _COEFFICIENT_SETS[coefficient_type], mode=iter_mode + ) + return list(chain.from_iterable(coeff_iter)) + + +class CusolverMpCtx: + """cuSolverMp context for Newton-Schulz matrix orthogonalization. + + Context creation is expensive; create once and reuse across multiple + :func:`newton_schulz` calls. Call :meth:`destroy` when done. + """ + + def __init__(self, group: dist.ProcessGroup) -> None: + self.nranks = dist.get_world_size(group) + self._ptr = tex.cusolvermp_ctx_create( + _get_nccl_comm_ptr(group), dist.get_world_size(group), dist.get_rank(group) + ) + + def destroy(self) -> None: + """Destroy the underlying cuSolverMp context.""" + if self._ptr is not None: + tex.cusolvermp_ctx_destroy(self._ptr) + self._ptr = None + + def __del__(self) -> None: + # Called when the context is manually destroyed or during Python teardown + self.destroy() + + +def _get_nccl_comm_ptr(group: dist.ProcessGroup) -> int: + """Extract the raw NCCL communicator pointer from a PyTorch process group.""" + backend = dist.get_backend(group) + if backend != "nccl": + raise RuntimeError(f"Newton-Schulz requires NCCL backend, got '{backend}'") + nccl_backend = group._get_backend(torch.device("cuda")) + return nccl_backend._comm_ptr() + + +def newton_schulz( + x: torch.Tensor, + ctx: CusolverMpCtx, + num_iterations: int = 5, + coefficients: Optional[List[float]] = None, +) -> None: + """Compute Newton-Schulz matrix orthogonalization in-place on a distributed matrix. + + Parameters + ---------- + x : torch.Tensor + Local part of the distributed matrix (modified in-place). + Must be a 2D CUDA tensor of type float32 or bfloat16. + Columns are distributed across ranks. + ctx : CusolverMpCtx + cuSolverMp context created by :func:`cusolvermp_ctx_create`. + num_iterations : int, optional + Number of Newton-Schulz iterations. Default: 5. + coefficients : list of float, optional + Polynomial coefficients for the Newton-Schulz iteration. + """ + if coefficients is None: + coefficients = get_coefficients(num_iterations) + if len(coefficients) != num_iterations * 3: + raise ValueError( + f"Unexpected number of coefficients: {len(coefficients)} for" + f" {num_iterations} iterations" + ) + + if x.dim() != 2: + raise ValueError(f"Expected 2D tensor, got {x.dim()}D") + if x.dtype not in (torch.float32, torch.bfloat16): + raise ValueError(f"Expected float32 or bfloat16 tensor, got {x.dtype}") + if not x.is_contiguous(): + raise ValueError("Input tensor must be contiguous") + if not x.is_cuda: + raise ValueError("Input tensor must be on CUDA device") + + # Global matrix dimensions; columns are distributed across ranks. + m = x.size(0) + n = x.size(1) * ctx.nranks + + tex.newton_schulz(ctx._ptr, m, n, x, num_iterations, coefficients)