-
Notifications
You must be signed in to change notification settings - Fork 738
Newton-Schulz via cuSOLVERMp #2706
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 43 commits
a536ff4
e8fea44
2e5d826
48f549b
c154d98
2f62321
e4a9999
0cf4327
f24dd8f
3badc16
8a11b4e
5c5d206
b0b1367
7cfc57c
f64d8f6
c634d61
de423aa
6d3a4dc
e424057
e5ca4b3
f86f8bb
9d503e0
89c5594
879fd38
9a7386b
ff78aa3
823a2f5
257cc43
274c06d
a1026fb
295504e
b9c6bc8
8de97d5
9811950
8cadb0d
4913a9d
a9411e1
c825455
842ed71
dd3c2e4
5015e58
3acab29
a389e14
61cff6f
960dd0f
e2576a7
33bb8fd
7e53e11
da9dea3
70d2ea8
fc47fc0
eca8616
d4d3c93
2b8d56a
bfe7484
ce0c44b
739fd08
c99e42c
1ee7dd8
ae4f539
72335af
8e14fa7
f48bbfc
5d6cc7b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,136 @@ | ||
| # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
|
|
||
| """Distributed Newton-Schulz test worker. | ||
|
|
||
| Launched via torchrun from test_newton_schulz.py. | ||
| """ | ||
|
|
||
| import argparse | ||
| import sys | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
| from torch.distributed.elastic.multiprocessing.errors import record | ||
|
|
||
| from transformer_engine.pytorch.newton_schulz import ( | ||
| cusolvermp_ctx_create, | ||
| get_coefficients, | ||
| ) | ||
| import transformer_engine_torch as tex | ||
|
|
||
|
|
||
| def newton_schulz_reference(in_x: torch.Tensor, coefficients: list[float]) -> torch.Tensor: | ||
| """Local Newton-Schulz reference mirroring the provided Octave update.""" | ||
| x = in_x.clone() | ||
| for i in range(len(coefficients) // 3): | ||
| a, b, c = coefficients[3 * i : 3 * (i + 1)] | ||
| xxt = x @ x.mT | ||
| x = a * x + b * xxt @ x + c * xxt @ xxt @ x | ||
| return x | ||
|
|
||
|
|
||
| def to_column_major_local(x: torch.Tensor) -> torch.Tensor: | ||
| """Copy a logical 2D tensor into a column-major local buffer.""" | ||
| x_col_major = torch.empty_strided( | ||
| size=x.shape, | ||
| stride=(1, x.shape[0]), | ||
| dtype=x.dtype, | ||
| device=x.device, | ||
| ) | ||
| x_col_major.copy_(x) | ||
| return x_col_major | ||
|
|
||
|
|
||
| @record | ||
| def main(): | ||
| parser = argparse.ArgumentParser(description="Newton-Schulz distributed test") | ||
| parser.add_argument("--check", type=str, default="orthogonality", choices=["orthogonality", "reference"]) | ||
| parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "bfloat16"]) | ||
| parser.add_argument("--matrix-rows", type=int, default=256) | ||
| parser.add_argument("--matrix-cols", type=int, default=None) | ||
| parser.add_argument("--num-iterations", type=int, default=5) | ||
| parser.add_argument("--atol", type=float, default=1e-2) | ||
| parser.add_argument("--rtol", type=float, default=1e-2) | ||
| args = parser.parse_args() | ||
|
|
||
| dist.init_process_group(backend="nccl") | ||
| rank = dist.get_rank() | ||
| world_size = dist.get_world_size() | ||
| torch.cuda.set_device(rank) | ||
|
|
||
| dtype = torch.float32 if args.dtype == "float32" else torch.bfloat16 | ||
| m = args.matrix_rows | ||
| n = args.matrix_cols if args.matrix_cols is not None else args.matrix_rows | ||
| coefficients = get_coefficients(args.num_iterations) | ||
|
|
||
| # Ensure the distributed row dimension is divisible by world_size. | ||
| assert m % world_size == 0, f"Matrix rows {m} must be divisible by world_size {world_size}" | ||
|
|
||
| # Create a random matrix on rank 0 with singular values in (0, 1), | ||
| # which keeps the Newton-Schulz iterations in the convergence regime. | ||
| if rank == 0: | ||
| torch.manual_seed(42) | ||
| k = min(m, n) | ||
| U, _ = torch.linalg.qr(torch.randn(m, k, device="cuda", dtype=torch.float32), mode="reduced") | ||
| V, _ = torch.linalg.qr(torch.randn(n, k, device="cuda", dtype=torch.float32), mode="reduced") | ||
| singular_values = torch.rand(k, device="cuda", dtype=torch.float32) * 0.8 + 0.1 | ||
| A = U @ torch.diag(singular_values) @ V.T | ||
| A = A.to(dtype) | ||
| else: | ||
| A = torch.empty(m, n, device="cuda", dtype=dtype) | ||
|
|
||
| # Broadcast the full matrix to all ranks | ||
| dist.broadcast(A, src=0) | ||
|
|
||
| # Scatter rows to each rank | ||
| local_rows = m // world_size | ||
| x_local = A[rank * local_rows : (rank + 1) * local_rows, :].contiguous() | ||
|
|
||
| group = dist.group.WORLD | ||
| ctx = cusolvermp_ctx_create(group) | ||
| try: | ||
| # cuSOLVERMp expects the local shard to use column-major storage. | ||
| x_local_col_major = to_column_major_local(x_local) | ||
| tex.newton_schulz(ctx._ptr, m, n, x_local_col_major, args.num_iterations, coefficients) | ||
| x_local = x_local_col_major.contiguous() | ||
| finally: | ||
| ctx.destroy() | ||
|
|
||
| # Gather results | ||
| gathered = [torch.empty_like(x_local) for _ in range(world_size)] | ||
| dist.all_gather(gathered, x_local) | ||
| X = torch.cat(gathered, dim=0) | ||
|
|
||
| # Check: the resulting matrix should be orthogonal, or match a local reference. | ||
| if rank == 0: | ||
| if args.check == "orthogonality": | ||
| if m <= n: | ||
| gram = X @ X.t() | ||
| expected = torch.eye(m, device=gram.device, dtype=gram.dtype) | ||
| max_diff = (gram - expected).abs().max().item() | ||
| print(f"Max |X @ X.t() - I|: {max_diff:.6e}", flush=True) | ||
| else: | ||
| gram = X.t() @ X | ||
| expected = torch.eye(n, device=gram.device, dtype=gram.dtype) | ||
| max_diff = (gram - expected).abs().max().item() | ||
| print(f"Max |X.t() @ X - I|: {max_diff:.6e}", flush=True) | ||
| passed = torch.allclose(gram, expected, atol=args.atol, rtol=args.rtol) | ||
| else: | ||
| reference = newton_schulz_reference(A.float(), coefficients).to(dtype) | ||
| max_diff = (X - reference).abs().max().item() | ||
| print(f"Max |distributed - reference|: {max_diff:.6e}", flush=True) | ||
| passed = torch.allclose(X, reference, atol=args.atol, rtol=args.rtol) | ||
|
|
||
| if passed: | ||
| print("NUMERICAL CHECK PASSED", flush=True) | ||
| else: | ||
| print("NUMERICAL CHECK FAILED", flush=True, file=sys.stderr) | ||
| sys.exit(1) | ||
|
|
||
| dist.destroy_process_group() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
|
|
||
| """Tests for distributed Newton-Schulz matrix orthogonalization.""" | ||
|
|
||
| import os | ||
| import subprocess | ||
| from pathlib import Path | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| if torch.cuda.device_count() < 2: | ||
| pytest.skip("Newton-Schulz tests require at least 2 GPUs.", allow_module_level=True) | ||
|
|
||
| TEST_ROOT = Path(__file__).parent.resolve() | ||
| NUM_PROCS = torch.cuda.device_count() | ||
| LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] | ||
| ORTHOGONALITY_SHAPES = [ | ||
| (NUM_PROCS * 64, NUM_PROCS * 64), | ||
| (NUM_PROCS * 64, NUM_PROCS * 96), | ||
| (NUM_PROCS * 96, NUM_PROCS * 64), | ||
| ] | ||
| REFERENCE_SHAPES = [(NUM_PROCS * 64, NUM_PROCS * 64)] | ||
|
|
||
|
|
||
| def _run_test(dtype, matrix_shape, num_iterations, check): | ||
| rows, cols = matrix_shape | ||
| test_path = TEST_ROOT / "run_newton_schulz.py" | ||
| test_cmd = LAUNCH_CMD + [ | ||
| str(test_path), | ||
| f"--check={check}", | ||
| f"--dtype={dtype}", | ||
| f"--matrix-rows={rows}", | ||
| f"--matrix-cols={cols}", | ||
| f"--num-iterations={num_iterations}", | ||
| ] | ||
| if dtype == "bfloat16": | ||
| test_cmd += ["--atol=5e-2", "--rtol=5e-2"] | ||
|
|
||
| result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False) | ||
|
vcherepanov-nv marked this conversation as resolved.
Outdated
|
||
| if ( | ||
| result.returncode != 0 | ||
| or "NUMERICAL CHECK FAILED" in result.stderr.decode() | ||
| or "NUMERICAL CHECK PASSED" not in result.stdout.decode() | ||
| ): | ||
| raise AssertionError( | ||
| "Newton-Schulz test failed.\n" | ||
| f"stdout: {result.stdout.decode()}\n" | ||
| f"stderr: {result.stderr.decode()}" | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("dtype", ["float32", "bfloat16"]) | ||
| @pytest.mark.parametrize("matrix_shape", ORTHOGONALITY_SHAPES) | ||
| @pytest.mark.parametrize("num_iterations", [5, 15]) | ||
| def test_orthogonality(dtype, matrix_shape, num_iterations): | ||
| """Test distributed Newton-Schulz orthogonality.""" | ||
| _run_test(dtype, matrix_shape, num_iterations, "orthogonality") | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("dtype", ["float32", "bfloat16"]) | ||
| @pytest.mark.parametrize("matrix_shape", REFERENCE_SHAPES) | ||
| @pytest.mark.parametrize("num_iterations", [5, 15]) | ||
| def test_against_reference(dtype, matrix_shape, num_iterations): | ||
| """Test distributed Newton-Schulz against a local reference implementation.""" | ||
| _run_test(dtype, matrix_shape, num_iterations, "reference") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -227,6 +227,11 @@ list(APPEND transformer_engine_SOURCES | |
| comm_gemm/comm_gemm.cpp) | ||
| endif() | ||
|
|
||
| if (NVTE_WITH_CUSOLVERMP) | ||
| list(APPEND transformer_engine_SOURCES | ||
| newton_schulz/newton_schulz.cpp) | ||
| endif() | ||
|
vcherepanov-nv marked this conversation as resolved.
Outdated
|
||
|
|
||
| add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) | ||
| target_include_directories(transformer_engine PUBLIC | ||
| "${CMAKE_CURRENT_SOURCE_DIR}/include") | ||
|
|
@@ -300,6 +305,19 @@ if (NVTE_WITH_CUBLASMP) | |
| message(STATUS "Using nvshmem at: ${NVSHMEM_DIR}") | ||
| endif() | ||
|
|
||
| option(NVTE_WITH_CUSOLVERMP "Use cuSolverMp for distributed Newton-Schulz" OFF) | ||
| if (NVTE_WITH_CUSOLVERMP) | ||
| target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUSOLVERMP) | ||
| target_include_directories(transformer_engine PRIVATE ${CUSOLVERMP_DIR}/include) | ||
| find_library(CUSOLVERMP_LIB | ||
| NAMES cusolverMp libcusolverMp | ||
| PATHS ${CUSOLVERMP_DIR} | ||
| PATH_SUFFIXES lib | ||
| REQUIRED) | ||
| target_link_libraries(transformer_engine PUBLIC ${CUSOLVERMP_LIB}) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
||
| message(STATUS "Using cuSolverMp at: ${CUSOLVERMP_DIR}") | ||
| endif() | ||
|
|
||
| # Hack to enable dynamic loading in cuDNN frontend | ||
| target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,68 @@ | ||||||||||||||
| /************************************************************************* | ||||||||||||||
| * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||||||||||
| * | ||||||||||||||
| * See LICENSE for license information. | ||||||||||||||
| ************************************************************************/ | ||||||||||||||
|
|
||||||||||||||
| /*! \file newton_schulz.h | ||||||||||||||
| * \brief Functions for distributed Newton-Schulz matrix orthogonalization. | ||||||||||||||
| * | ||||||||||||||
| * This API is a TE-native binding to the cuSolverMp library. | ||||||||||||||
| * It computes an iterative Newton-Schulz matrix orthogonalization on a distributed matrix. | ||||||||||||||
| */ | ||||||||||||||
|
|
||||||||||||||
| #ifndef TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_ | ||||||||||||||
| #define TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_ | ||||||||||||||
|
|
||||||||||||||
| #include <nccl.h> | ||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unconditional
Suggested change
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm also a little uneasy at exposing NCCL as a required dependency, but I see we already import the NCCL header elsewhere in the codebase:
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We use NCCL types (ncclComm_t) in nvte_ctx_create API, so if we're taking the route of always defining the functions - we need to pull NCCL headers unconditionally. |
||||||||||||||
| #include <stdint.h> | ||||||||||||||
|
|
||||||||||||||
| #include "transformer_engine.h" | ||||||||||||||
|
|
||||||||||||||
| #ifdef __cplusplus | ||||||||||||||
| extern "C" { | ||||||||||||||
| #else | ||||||||||||||
| #include <stdbool.h> | ||||||||||||||
|
timmoon10 marked this conversation as resolved.
Outdated
|
||||||||||||||
| #endif | ||||||||||||||
|
|
||||||||||||||
| typedef struct NVTECusolverMpCtx NVTECusolverMpCtx; | ||||||||||||||
|
|
||||||||||||||
| /*! \brief Create a cuSolverMp context for Newton-Schulz operations. | ||||||||||||||
| * | ||||||||||||||
| * Creates a dedicated CUDA stream internally (cuSolverMp requires a | ||||||||||||||
| * non-default stream). | ||||||||||||||
| * | ||||||||||||||
| * \param[in] comm NCCL communicator. | ||||||||||||||
| * \param[in] nranks Number of ranks. | ||||||||||||||
| * \param[in] rank Local rank. | ||||||||||||||
| */ | ||||||||||||||
| NVTECusolverMpCtx* nvte_cusolvermp_ctx_create(ncclComm_t comm, int nranks, int rank); | ||||||||||||||
|
|
||||||||||||||
| /*! \brief Destroy a cuSolverMp context. | ||||||||||||||
| * | ||||||||||||||
| * \param[in] ctx Context to destroy. | ||||||||||||||
| */ | ||||||||||||||
| void nvte_cusolvermp_ctx_destroy(NVTECusolverMpCtx* ctx); | ||||||||||||||
|
|
||||||||||||||
| /*! \brief Compute Newton-Schulz matrix orthogonalization in-place. | ||||||||||||||
| * | ||||||||||||||
| * \param[in] ctx cuSolverMp context. | ||||||||||||||
| * \param[in] m Global number of rows. | ||||||||||||||
| * \param[in] n Global number of columns. | ||||||||||||||
| * \param[in,out] x Local part of the matrix (modified in-place). | ||||||||||||||
| * \param[in] num_iterations Number of Newton-Schulz iterations. | ||||||||||||||
| * \param[in] coefficients Array of polynomial coefficients (length depends on polynomial | ||||||||||||||
| * degree used internally by cuSolverMp). | ||||||||||||||
| * \param[in] num_coefficients Number of elements in the coefficients array. | ||||||||||||||
| * \param[in] caller_stream CUDA stream on which the caller produced the input tensor. | ||||||||||||||
| * Used for event-based synchronisation with the internal stream. | ||||||||||||||
| */ | ||||||||||||||
| void nvte_newton_schulz(NVTECusolverMpCtx* ctx, int64_t m, int64_t n, NVTETensor x, | ||||||||||||||
| int64_t num_iterations, const float* coefficients, int64_t num_coefficients, | ||||||||||||||
| cudaStream_t caller_stream); | ||||||||||||||
|
|
||||||||||||||
| #ifdef __cplusplus | ||||||||||||||
| } // extern "C" | ||||||||||||||
| #endif | ||||||||||||||
|
|
||||||||||||||
| #endif // TRANSFORMER_ENGINE_COMMON_NEWTON_SCHULZ_H_ | ||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.