diff --git a/pyproject.toml b/pyproject.toml index 6aac64426..584466741 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,13 @@ rl = [ video = [ "decord", ] +all = [ + "decord", + "ray[default]", + "httpx", + "fastapi", + "uvicorn", +] [tool.mypy] ignore_missing_imports = true diff --git a/tests/engine/test_dense_train_engine.py b/tests/engine/test_dense_train_engine.py index 9f1b621e1..3b523eaeb 100644 --- a/tests/engine/test_dense_train_engine.py +++ b/tests/engine/test_dense_train_engine.py @@ -5,7 +5,7 @@ import parametrize import torch import torch.distributed as dist -from torch.testing._internal.common_distributed import DistributedTestBase +from xtuner._testing import DeterministicDDPTestCase from transformers import AutoTokenizer from xtuner.v1.model.moe.moe import SequenceContext @@ -25,7 +25,7 @@ DEVICE = get_device() -class TestDenseEngine(DistributedTestBase): +class TestDenseEngine(DeterministicDDPTestCase): @parametrize.parametrize( "device,tp_size,sp_size", [ diff --git a/tests/engine/test_moe_train_engine.py b/tests/engine/test_moe_train_engine.py index ee588db21..f353341e2 100644 --- a/tests/engine/test_moe_train_engine.py +++ b/tests/engine/test_moe_train_engine.py @@ -6,7 +6,7 @@ import parametrize import torch import torch.distributed as dist -from torch.testing._internal.common_distributed import DistributedTestBase +from xtuner._testing import DeterministicDDPTestCase from transformers import AutoTokenizer from xtuner.v1.model.moe.moe import SequenceContext @@ -21,12 +21,13 @@ from xtuner.v1.utils.device import get_device from xtuner.v1.utils.test_utils import init_data_mesh + # Qwen3 30B A3 QWEN3_MOE_PATH = os.environ["QWEN3_MOE_PATH"] DEVICE = get_device() -class TestMoEEngine(DistributedTestBase): +class TestMoEEngine(DeterministicDDPTestCase): @parametrize.parametrize( "device,ep_size,sp_size", [ @@ -101,9 +102,9 @@ def warmup_fn(x): lr_scheduler.step() losses.append(loss_log["reduced_llm_loss"]) - losses_ref = [2.44, 2.44, 2.42, 2.41, 2.34, 2.33, 2.16, 2.13, 1.71, 1.55] - for loss, loss_ref in zip(losses, losses_ref): - self.assertTrue(abs(loss - loss_ref) / loss_ref < 0.02) + losses_ref = torch.tensor([2.44, 2.44, 2.42, 2.41, 2.34, 2.33, 2.16, 2.13, 1.71, 1.55]) + losses = torch.tensor(losses) + self._check_loss_curve(losses, losses_ref) torch.cuda.empty_cache() try: diff --git a/tests/engine/test_moe_train_engine_float8.py b/tests/engine/test_moe_train_engine_float8.py index 980630af7..6258f7b5f 100644 --- a/tests/engine/test_moe_train_engine_float8.py +++ b/tests/engine/test_moe_train_engine_float8.py @@ -6,7 +6,7 @@ import parametrize import torch import torch.distributed as dist -from torch.testing._internal.common_distributed import DistributedTestBase +from xtuner._testing import DeterministicDDPTestCase from transformers import AutoTokenizer from xtuner.v1.model.moe.moe import SequenceContext @@ -22,12 +22,13 @@ from xtuner.v1.model.moe.moe import BalancingLossConfig + # Qwen3 30B A3 QWEN3_MOE_PATH = os.environ["QWEN3_MOE_PATH"] DEVICE = get_device() -class TestMoEEngineFloat8(DistributedTestBase): +class TestMoEEngineFloat8(DeterministicDDPTestCase): @parametrize.parametrize( "device,ep_size,hsdp_sharding_size", @@ -101,17 +102,16 @@ def warmup_fn(x): engine.step_optimizer(grad_norm) lr_scheduler.step() losses.append(loss_log["reduced_llm_loss"]) - losses_ref = [2.41, 2.41, 1.79, 1.39, 1.02, 0.68, 0.52, 0.31, 0.18, 0.12] + losses = torch.tensor(losses) + losses_ref = torch.tensor([2.41, 2.41, 1.79, 1.39, 1.02, 0.68, 0.52, 0.31, 0.18, 0.12]) - for loss, loss_ref in zip(losses, losses_ref): - self.assertTrue(abs(loss - loss_ref) < 0.2) - + self._check_loss_curve(losses, losses_ref, sim_tol=0.02, rtol=0.2) torch.cuda.empty_cache() try: dist.destroy_process_group(pg) except: pass - + @parametrize.parametrize( "device,ep_size,hsdp_sharding_size", [ @@ -184,17 +184,18 @@ def warmup_fn(x): engine.step_optimizer(grad_norm) lr_scheduler.step() losses.append(loss_log["reduced_llm_loss"]) - losses_ref = [2.45, 2.45, 1.78, 1.31, 0.95, 0.67, 0.45, 0.31, 0.18, 0.12] - for loss, loss_ref in zip(losses, losses_ref): - self.assertTrue(abs(loss - loss_ref) < 0.2) - + losses_ref = torch.tensor([2.45, 2.45, 1.78, 1.31, 0.95, 0.67, 0.45, 0.31, 0.18, 0.12]) + losses = torch.tensor(losses) + + self._check_loss_curve(losses, losses_ref, sim_tol=0.02, rtol=0.1) + torch.cuda.empty_cache() try: dist.destroy_process_group(pg) except: pass - + @parametrize.parametrize( "device,ep_size,hsdp_sharding_size", [ @@ -286,20 +287,19 @@ def warmup_fn(x): engine.step_optimizer(grad_norm) lr_scheduler.step() losses.append(loss_log["reduced_llm_loss"]) - losses_ref = [2.41, 2.41, 2.47, 2.42, 2.44, 2.44, 2.42, 2.38, 2.31, 2.30] - - for loss, loss_ref in zip(losses, losses_ref): - self.assertTrue(abs(loss - loss_ref) < 0.2) + losses_ref = torch.tensor([2.41, 2.41, 2.47, 2.42, 2.44, 2.44, 2.42, 2.38, 2.31, 2.30]) + losses = torch.tensor(losses) + self._check_loss_curve(losses, losses_ref) if dist.get_rank() == 0: shutil.rmtree(temp_dir) - + torch.cuda.empty_cache() try: dist.destroy_process_group(pg) except: pass - + @property def world_size(self) -> int: return int(os.getenv("XTUNER_TEST_WORLD_SIZE", "8")) @@ -309,7 +309,7 @@ def destroy_pg_upon_exit(self) -> bool: return False -class TestMoEEngineFloat8Case2(DistributedTestBase): +class TestMoEEngineFloat8Case2(DeterministicDDPTestCase): @parametrize.parametrize( "device,ep_size,hsdp_sharding_size", diff --git a/tests/model/test_gpt_oss_moe.py b/tests/model/test_gpt_oss_moe.py index 2c84a2b29..f43c4dbf2 100644 --- a/tests/model/test_gpt_oss_moe.py +++ b/tests/model/test_gpt_oss_moe.py @@ -6,7 +6,7 @@ import parametrize import torch -from torch.testing._internal.common_distributed import DistributedTestBase +from xtuner._testing import DeterministicDDPTestCase, patch_hf_rms_norm from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig import tempfile from pathlib import Path @@ -30,7 +30,7 @@ def wrapper(self, *args, **kwargs): return wrapper -class TestGptOss(DistributedTestBase): +class TestGptOss(DeterministicDDPTestCase): @parametrize.parametrize( "device,dispatcher,ep_size,compile,tol,loss_class", [ @@ -56,6 +56,7 @@ def test_gpt_oss_run(self, device, dispatcher, ep_size, compile, tol, loss_class device_map="cuda" ) hf_model.train() + patch_hf_rms_norm((hf_model)) tokenizer = AutoTokenizer.from_pretrained(GPT_OSS_MINI_PATH) input_ids = tokenizer("吃葡萄不吐葡萄皮", return_tensors="pt").input_ids.to("cuda") # assert input_ids.size(1) > 128 @@ -117,6 +118,7 @@ def test_fsdp_accuracy(self, device, dispatcher, ep_size): config=hf_config, device_map="cuda" ) + patch_hf_rms_norm((hf_model)) hf_model.train() tokenizer = AutoTokenizer.from_pretrained(GPT_OSS_MINI_PATH) input_ids = tokenizer("吃葡萄不吐葡萄皮", return_tensors="pt").input_ids.to("cuda") diff --git a/tests/model/test_intern_s1.py b/tests/model/test_intern_s1.py index 969a89bf9..719f2a8ea 100644 --- a/tests/model/test_intern_s1.py +++ b/tests/model/test_intern_s1.py @@ -2,7 +2,7 @@ import parametrize import torch -from torch.testing._internal.common_distributed import DistributedTestBase +from xtuner._testing import patch_hf_rms_norm, DeterministicDDPTestCase from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig import torch.distributed as dist import tempfile @@ -24,7 +24,7 @@ INTERNS1_DENSE_PATH = os.environ["INTERNS1_DENSE_PATH"] -class TestInternS1(DistributedTestBase): +class TestInternS1(DeterministicDDPTestCase): @parametrize.parametrize( "device,tol", [ @@ -48,6 +48,7 @@ def test_interns1_text_run(self, device, tol): trust_remote_code=True, device_map="cuda" ).eval() # avoid open drop_path + patch_hf_rms_norm(hf_model) tokenizer = AutoTokenizer.from_pretrained(INTERNS1_DENSE_PATH, trust_remote_code=True) input_ids = tokenizer("吃葡萄不吐葡萄皮", return_tensors="pt").input_ids.to(device) @@ -118,6 +119,7 @@ def test_interns1_image_run(self, device, sp_size, tol): trust_remote_code=True, device_map=device ).eval() # avoid open drop_path + patch_hf_rms_norm(hf_model) tokenizer = AutoTokenizer.from_pretrained(INTERNS1_DENSE_PATH, trust_remote_code=True) @@ -233,6 +235,7 @@ def test_fsdp_text_accuracy(self, device, tol): trust_remote_code=True, device_map="cuda" ).eval() # avoid open drop_path + patch_hf_rms_norm(hf_model) tokenizer = AutoTokenizer.from_pretrained(INTERNS1_DENSE_PATH, trust_remote_code=True) input_ids = tokenizer("吃葡萄不吐葡萄皮", return_tensors="pt").input_ids.to("cuda") @@ -317,6 +320,7 @@ def test_fsdp_image_accuracy(self, device, sp_size, compile, tol): trust_remote_code=True, device_map="cuda" ).eval() # avoid open drop_path + patch_hf_rms_norm(hf_model) tokenizer = AutoTokenizer.from_pretrained(INTERNS1_DENSE_PATH, trust_remote_code=True) conversations = [{"from": "human", "value": '\nPlease describe the image shortly.'}] diff --git a/tests/model/test_moe.py b/tests/model/test_moe.py index 364171af0..bc5b48be4 100644 --- a/tests/model/test_moe.py +++ b/tests/model/test_moe.py @@ -7,13 +7,15 @@ from copy import deepcopy from xtuner.v1.loss.ce_loss import CELossContext, CELossConfig, CELossContextInputItem -from torch.testing._internal.common_distributed import DistributedTestBase +from xtuner._testing import DeterministicDDPTestCase +from xtuner.v1.utils.compile import maybe_compile import parametrize class TestMoE: @parametrize.parametrize("dtype,device", [(torch.bfloat16, "cuda")]) def test_moe_config(self, dtype, device): + maybe_compile.clear_compile_targets() router_config = NoAuxRouterConfig( scoring_func="sigmoid", router_scaling_factor=1.0, @@ -46,7 +48,7 @@ def test_moe_config(self, dtype, device): num_experts_per_tok=2, first_k_dense_replace=1, hidden_factor=1.0, - moe_intermediate_size=256, # grouped linear kernel need this to be multiple of 256 + moe_intermediate_size=512, # TODO: Restriction of triton grouped gemm, should be optimizer router=router_config, ) model = MoE(config=config).to(dtype).to(device) @@ -73,7 +75,7 @@ def test_moe_config(self, dtype, device): model(seq_ctx=seq_ctx, loss_ctx=loss_ctx) -class TestDistributedMoE(DistributedTestBase): +class TestDistributedMoE(DeterministicDDPTestCase): @parametrize.parametrize( "dtype,device,dispatcher,n_shared_experts,first_k_dense_replace", [ @@ -116,7 +118,7 @@ def test_parralel_accuracy(self, dtype, device, dispatcher, n_shared_experts, fi num_experts_per_tok=2, first_k_dense_replace=first_k_dense_replace, hidden_factor=1.0, - moe_intermediate_size=256, # grouped linear kernel need this to be multiple of 256 + moe_intermediate_size=512, # TODO: Restriction of triton grouped gemm, should be optimizer router=router_config, ) loss_cfg = CELossConfig() diff --git a/tests/model/test_qwen3_dense.py b/tests/model/test_qwen3_dense.py index 50b654da4..949435648 100644 --- a/tests/model/test_qwen3_dense.py +++ b/tests/model/test_qwen3_dense.py @@ -16,12 +16,13 @@ from xtuner.v1.config import FSDPConfig from xtuner.v1.utils.compile import maybe_compile from xtuner.v1.loss.ce_loss import CELossConfig, CELossContextInputItem +from xtuner._testing import patch_hf_rms_norm, DeterministicDDPTestCase # Qwen3 8B QWEN3_PATH = os.environ["QWEN3_PATH"] -class TestQwen3Dense(DistributedTestBase): +class TestQwen3Dense(DeterministicDDPTestCase): @parametrize.parametrize( "device,tp_size,compile,tol,loss_class", [ @@ -39,6 +40,7 @@ def test_qwen3_dense_run(self, device, tp_size, compile, tol, loss_class): torch_dtype=torch.bfloat16, device_map="cuda" ) + patch_hf_rms_norm(hf_model) tokenizer = AutoTokenizer.from_pretrained(QWEN3_PATH) input_ids = tokenizer("吃葡萄不吐葡萄皮", return_tensors="pt").input_ids.to("cuda") with torch.no_grad(): diff --git a/tests/model/test_qwen3_moe.py b/tests/model/test_qwen3_moe.py index c6b364ff0..1238b9e42 100644 --- a/tests/model/test_qwen3_moe.py +++ b/tests/model/test_qwen3_moe.py @@ -4,7 +4,6 @@ import parametrize import torch -from torch.testing._internal.common_distributed import DistributedTestBase import torch.distributed as dist from transformers import AutoModelForCausalLM, AutoTokenizer import tempfile @@ -17,23 +16,18 @@ from xtuner.v1.config import FSDPConfig from xtuner.v1.utils.compile import maybe_compile from xtuner.v1.loss.ce_loss import CELossConfig, CELossContextInputItem +from xtuner._testing import patch_hf_rms_norm, DeterministicDDPTestCase + # Qwen3 30B A3 QWEN3_MOE_PATH = os.environ["QWEN3_MOE_PATH"] -def prepare(fn): - @wraps(fn) - def wrapper(self, *args, **kwargs): +class TestQwen3MoE(DeterministicDDPTestCase): + def prepare(self): self.temp_dir = tempfile.TemporaryDirectory() - ret = fn(self, *args, **kwargs) self.temp_dir.cleanup() - return ret - - return wrapper - -class TestQwen3MoE(DistributedTestBase): @parametrize.parametrize( "device,dispatcher,ep_size,compile,tol,loss_class", [ @@ -44,7 +38,6 @@ class TestQwen3MoE(DistributedTestBase): ("cuda", None, 1, False, 1e-2, "chunk_cross_entropy"), ], ) - @prepare def test_qwen3_moe_run(self, device, dispatcher, ep_size, compile, tol, loss_class): os.environ["TRITON_CACHE_DIR"] = str(Path(self.temp_dir.name) / "triton_cache") self.create_pg(device) @@ -57,6 +50,7 @@ def test_qwen3_moe_run(self, device, dispatcher, ep_size, compile, tol, loss_cla trust_remote_code=True, device_map="cuda" ) + patch_hf_rms_norm(hf_model) tokenizer = AutoTokenizer.from_pretrained(QWEN3_MOE_PATH, trust_remote_code=True) input_ids = tokenizer("吃葡萄不吐葡萄皮", return_tensors="pt").input_ids.to("cuda") with torch.no_grad(): diff --git a/tests/ops/foreach_allgather.py b/tests/ops/test_foreach_allgather.py similarity index 100% rename from tests/ops/foreach_allgather.py rename to tests/ops/test_foreach_allgather.py diff --git a/tests/ops/grouped_gemm_triton.py b/tests/ops/test_grouped_gemm_triton.py similarity index 95% rename from tests/ops/grouped_gemm_triton.py rename to tests/ops/test_grouped_gemm_triton.py index 61aa909f8..de04a8be1 100644 --- a/tests/ops/grouped_gemm_triton.py +++ b/tests/ops/test_grouped_gemm_triton.py @@ -1,6 +1,6 @@ import torch import random -from xtuner.v1.ops import grouped_gemm_triton +from xtuner.v1.ops import group_gemm def grouped_gemm_torch(x, w, tokens_per_expert): @@ -56,7 +56,7 @@ def test_grouped_gemm_triton(): x_ref = x.clone().detach().requires_grad_(True) w_ref = w.clone().detach().requires_grad_(True) out_ref = grouped_gemm_torch(x_ref, w_ref, tokens_per_expert) - out = grouped_gemm_triton(x, w, tokens_per_expert) + out = group_gemm(x, w, tokens_per_expert) out.mean().backward() out_ref.mean().backward() assert torch.allclose(out, out_ref, rtol=1e-2, atol=1e-2), "Output mismatch between Triton and PyTorch implementations" diff --git a/xtuner/_testing/__init__.py b/xtuner/_testing/__init__.py new file mode 100644 index 000000000..aba2abe6e --- /dev/null +++ b/xtuner/_testing/__init__.py @@ -0,0 +1,3 @@ +from .patch_hf import patch_hf_rms_norm +from .utils import enable_full_determinism +from .testcase import DeterministicDDPTestCase diff --git a/xtuner/_testing/patch_hf.py b/xtuner/_testing/patch_hf.py new file mode 100644 index 000000000..c11fec270 --- /dev/null +++ b/xtuner/_testing/patch_hf.py @@ -0,0 +1,9 @@ +from xtuner.v1.module import RMSNorm +import torch.nn as nn + + +def patch_hf_rms_norm(module: nn.Module) -> None: + for submodule in module.modules(): + if "RMSNorm" in submodule.__class__.__name__ and isinstance(submodule, nn.Module): + submodule.__class__.forward = RMSNorm.forward + diff --git a/xtuner/_testing/testcase.py b/xtuner/_testing/testcase.py new file mode 100644 index 000000000..100cf468a --- /dev/null +++ b/xtuner/_testing/testcase.py @@ -0,0 +1,93 @@ +from torch.testing._internal.common_distributed import DistributedTestBase, MultiProcessTestCase, logger, TEST_SKIPS, c10d +import torch +import threading +import sys +import os +import unittest +import traceback +from .utils import enable_full_determinism +import torch.nn.functional as F + + + +class DeterministicDDPTestCase(DistributedTestBase): + def prepare(self): + return + + def run_func(self, test_name): + enable_full_determinism() + self.prepare() + return getattr(self, test_name)() + + def run_test(self, test_name: str, parent_pipe) -> None: + # Start event listener thread. + signal_recv_pipe, signal_send_pipe = torch.multiprocessing.Pipe(duplex=False) + event_listener_thread = threading.Thread( + target=MultiProcessTestCase._event_listener, + args=(parent_pipe, signal_recv_pipe, self.rank), + daemon=True, + ) + event_listener_thread.start() + if sys.platform != "win32" and sys.platform != "darwin": + # Register signal handler to dump stack traces on FATALs. + # Windows and MacOS do not support the signal handlers. + torch._C._set_print_stack_traces_on_fatal_signal(True) + # Show full C++ stacktraces when a Python error originating from C++ is raised. + os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1" + + # self.id() == e.g. '__main__.TestDistributed.test_get_rank' + # We're retrieving a corresponding test and executing it. + try: + self.run_func(test_name) + except unittest.SkipTest as se: + logger.info( + "Process %s skipping test %s for following reason: %s", self.rank, test_name, str(se) + ) + sys.exit(TEST_SKIPS["generic"].exit_code) + except Exception: + logger.error( + "Caught exception: \n%s exiting " + "process %s with exit code: %s", + traceback.format_exc(), self.rank, MultiProcessTestCase.TEST_ERROR_EXIT_CODE + ) + # Send error to parent process. + parent_pipe.send(traceback.format_exc()) + sys.exit(MultiProcessTestCase.TEST_ERROR_EXIT_CODE) + finally: + if signal_send_pipe is not None: + signal_send_pipe.send(None) + + assert event_listener_thread is not None + event_listener_thread.join() + # Close pipe after done with test. + parent_pipe.close() + + if self.destroy_pg_upon_exit: + try: + # Some tests do destroy the pgs, and destroy can't be called twice. + # This avoids spewing warnings about improperly shutting down. + c10d.destroy_process_group() + except (AssertionError, ValueError): + pass + + def _check_loss_curve( + self, + losses: torch.Tensor, + losses_ref: torch.Tensor, + sim_tol: float = 0.01, + rtol: float=0.01, + ): + + loss1_norm = F.normalize(losses, dim=0) + loss2_norm = F.normalize(losses_ref, dim=0) + + similarity = torch.cosine_similarity(loss1_norm, loss2_norm, dim=0) + if similarity <= 1 - sim_tol: + raise AssertionError( + f"Failed to check the similarity of loss! expected: {losses_ref}, got {losses}, Similarity: {similarity}") + + avg_relative_diff = ((losses - losses_ref) / losses_ref).abs().mean() + if avg_relative_diff >= rtol: + raise AssertionError( + f"Failed to check relative error of loss, expected: {losses_ref}, got {losses}, Mean diff: {avg_relative_diff}") + diff --git a/xtuner/_testing/utils.py b/xtuner/_testing/utils.py new file mode 100644 index 000000000..7180dded1 --- /dev/null +++ b/xtuner/_testing/utils.py @@ -0,0 +1,17 @@ +import os +import torch + + +def enable_full_determinism(): + """ + Helper function for reproducible behavior during distributed training. See + - https://pytorch.org/docs/stable/notes/randomness.html for pytorch + """ + # Enable PyTorch deterministic mode. This potentially requires either the environment + # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set, + # depending on the CUDA version, so we set them both here + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" + + # torch.use_deterministic_algorithms(True, warn_only=True) + torch.set_deterministic_debug_mode(0) diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 08d7059b9..9dd9f6723 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -1,4 +1,5 @@ import json +import os import random import sys import time @@ -429,6 +430,7 @@ def _init_logger(self, work_dir: Path): def _set_deterministic(self): if XTUNER_DETERMINISTIC: + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" torch.use_deterministic_algorithms(True, warn_only=True) def _set_random_seed(self, seed: int): diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index 4944d62ed..d39698163 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -837,6 +837,7 @@ def _get_checkpoint_path(self, epoch: int, step: int) -> Path: def _set_deterministic(self): if XTUNER_DETERMINISTIC: + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" torch.use_deterministic_algorithms(True, warn_only=True) def _set_random_seed(self, seed: int):