From cdf6bac527ac50e202670d7615ba21f525a56878 Mon Sep 17 00:00:00 2001 From: riccardofelluga <11768013+riccardofelluga@users.noreply.github.com> Date: Thu, 20 Nov 2025 16:16:16 +0200 Subject: [PATCH 01/11] remove te v1 tests from CI --- .lightning/workflows/transformer-engine.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.lightning/workflows/transformer-engine.yaml b/.lightning/workflows/transformer-engine.yaml index c0c3e8acaf..6034d20e69 100644 --- a/.lightning/workflows/transformer-engine.yaml +++ b/.lightning/workflows/transformer-engine.yaml @@ -12,7 +12,6 @@ parametrize: matrix: test_file: - test_transformer_engine_executor.py - - test_transformer_engine_v1_executor.py run: | whereis nvidia From 6fc46fe4c00bf0861b7b31e9f9ad07968181aa02 Mon Sep 17 00:00:00 2001 From: riccardofelluga <11768013+riccardofelluga@users.noreply.github.com> Date: Thu, 20 Nov 2025 16:18:17 +0200 Subject: [PATCH 02/11] remove te v1 tests and from benchmarks --- thunder/benchmarks/benchmark_litgpt.py | 5 - thunder/tests/distributed/test_ddp.py | 271 +--------------- thunder/tests/distributed/test_fsdp.py | 276 +--------------- thunder/tests/test_extend.py | 1 - thunder/tests/test_recipes.py | 8 +- .../test_transformer_engine_v1_executor.py | 304 ------------------ 6 files changed, 8 insertions(+), 857 deletions(-) delete mode 100644 thunder/tests/test_transformer_engine_v1_executor.py diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index 32264ebdc2..14bef9d946 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -686,11 +686,6 @@ def setup_compile(self, model): executors.insert(0, torch_compile_ex) - if "transformerengine_v1" in self.compile: - from thunder.executors.transformer_engine_v1ex import transformer_engine_v1_ex - - executors.insert(0, transformer_engine_v1_ex) - elif "transformerengine" in self.compile: from thunder.executors.transformer_engineex import ( transformer_engine_ex, diff --git a/thunder/tests/distributed/test_ddp.py b/thunder/tests/distributed/test_ddp.py index 99617366de..cfb2bf8111 100644 --- a/thunder/tests/distributed/test_ddp.py +++ b/thunder/tests/distributed/test_ddp.py @@ -11,6 +11,7 @@ import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close +from lightning_utilities.core.imports import package_available import thunder import thunder.executors @@ -19,11 +20,6 @@ from thunder.distributed import ddp from thunder.tests.framework import instantiate, TorchExecutor -from thunder.executors.transformer_engine_v1ex import ( - transformer_engine_v1_ex, - TE_AVAILABLE, - te_sync_fp8_meta_bwd, -) from thunder.executors.transformer_engineex import transformer_engine_ex, TransformerEngineTransform @@ -32,7 +28,7 @@ # This will be correctly updated below when TE Engine is installed # and if the current environment doesn't support FP8. fp8_support_reason: str = "" -if TE_AVAILABLE: +if package_available("transformer_engine"): from transformer_engine.pytorch import fp8_autocast from transformer_engine.pytorch import Linear as TELinear from transformer_engine.pytorch.fp8 import ( @@ -40,7 +36,6 @@ FP8GlobalStateManager, get_default_fp8_recipe, ) - from transformer_engine.common.recipe import MXFP8BlockScaling from thunder.tests.test_transformer_engine_executor import te_assert_close is_fp8_supported, fp8_support_reason = check_fp8_support() @@ -360,221 +355,6 @@ def _test_native_ddp_helper(input_data): return None -def _test_ddp_transformer_engine_v1(input_data): - # Test Description: We run a dummy training loop for a simple `Linear(Relu(Linear(x)))` - # model with thunder (using TE executor) and with PyTorch eager + TE - # and verify that the weights have converged to same value and - # fp8 meta state is same after `n_iter`. - init_method, world_size, rank, executor, device, dtype, _unused_kwargs = input_data - devicetype = devices.device_from_string(device).devicetype - _unused_dtype = ltorch.to_torch_dtype(dtype) - init_per_process_distributed(init_method, devicetype, world_size, rank) - - torch.cuda.set_device(rank) - - dim = 256 - # Running more iterations leads to `nan` for both eager and thunder - # with BlockScaling. - # Potentially because we are training on dummy data and task - n_iter = 5 - - class ThunderModel(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.fc1 = torch.nn.Linear(dim, dim, bias=False) - self.fc2 = torch.nn.Linear(dim, dim, bias=False) - - def forward(self, x): - return self.fc2(torch.nn.functional.relu(self.fc1(x))) - - # Weights - fc1_weight = torch.randn(dim, dim, requires_grad=True).cuda() - fc2_weight = torch.randn(dim, dim, requires_grad=True).cuda() - - # Inputs (different input on different rank). - if rank == 0: - x = torch.arange(dim * dim, dtype=torch.float).view(dim, dim).cuda() - if rank == 1: - x = torch.randn(dim, dim).cuda() * 100 - - thunder_model = ThunderModel().cuda() - thunder_model.fc1.weight.data = fc1_weight.clone() - thunder_model.fc2.weight.data = fc2_weight.clone() - - jit_model = thunder.distributed.ddp( - thunder.jit( - thunder_model, - executors=[ - transformer_engine_v1_ex, - ] - + executor.executors_list(), - ) - ) - - optim = torch.optim.SGD(jit_model.parameters()) - - for _ in range(n_iter): - o = jit_model(x).sum() - o.backward() - optim.step() - optim.zero_grad() - - # See https://github.com/NVIDIA/TransformerEngine/issues/814 - FP8GlobalStateManager.reset() - - class TEModel(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.fc1 = TELinear(dim, dim, bias=False) - self.fc2 = TELinear(dim, dim, bias=False) - - def forward(self, x): - return self.fc2(torch.nn.functional.relu(self.fc1(x))) - - te_model = TEModel().cuda() - te_model.fc1.weight.data = fc1_weight.clone() - te_model.fc2.weight.data = fc2_weight.clone() - - ddp_model = DDP(te_model) - - optim = torch.optim.SGD(te_model.parameters()) - - for _ in range(n_iter): - with fp8_autocast(): - o = ddp_model(x).sum() - - o.backward() - optim.step() - optim.zero_grad() - - thunder_to_te_layer_map = {"te_linear_0": te_model.fc1, "te_linear_1": te_model.fc2} - - fwd_traces = thunder.last_traces(jit_model) - - def is_same_across_ranks(t): - t_clone = t.clone() - torch.distributed.all_reduce(t_clone, op=torch.distributed.ReduceOp.AVG) - assert_close(t, t_clone) - - # Compare the state of the two models. - comparison_exceptions = [] - if not isinstance( - get_default_fp8_recipe(), MXFP8BlockScaling - ): # MXFP8BlockScaling recipe doesn't have state like scale, amax_history. - for bound_symbol in fwd_traces[-1].bound_symbols: - if "te_linear" in bound_symbol.sym.name: - thunder_fp8_meta = bound_symbol._call_ctx[bound_symbol.sym.name].func.fp8_meta - te_fp8_meta = thunder_to_te_layer_map[bound_symbol.sym.name].fp8_meta - try: - # fwd tensor history - assert_close(thunder_fp8_meta["scaling_fwd"].scale, te_fp8_meta["scaling_fwd"].scale) - assert_close(thunder_fp8_meta["scaling_fwd"].amax_history, te_fp8_meta["scaling_fwd"].amax_history) - # bwd tensor history - assert_close(thunder_fp8_meta["scaling_bwd"].scale, te_fp8_meta["scaling_bwd"].scale) - assert_close(thunder_fp8_meta["scaling_bwd"].amax_history, te_fp8_meta["scaling_bwd"].amax_history) - - # This has to be on all ranks so that the computation is not blocked - is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].scale) - # NOTE: TE forward tensor meta-data sync - # Syncing of FP8 meta-data happens in two step in the forward pass. - # 1. When we enter the fp8_autocast(), all the forward fp8 meta-data - # in global buffer is synced. - # See: https://github.com/NVIDIA/TransformerEngine/blob/6a9edc38bf9b941b7d369af5103fa8fe0b121d61/transformer_engine/pytorch/fp8.py#L409-L412 - # 2. Post this, in the forward pass of the module in `prepare_forward`, - # we read from the global-buffer the synced meta-data. - # See: https://github.com/NVIDIA/TransformerEngine/blob/6a9edc38bf9b941b7d369af5103fa8fe0b121d61/transformer_engine/pytorch/module/base.py#L539-L545 - # However, at the end of this forward pass, we have seen new inputs and outputs. Their amax are recorded on - # 0th row of `amax_history` (which will be synced only in the next forward pass). - # So, here we check that every row except for `0` is same. - is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].amax_history[1:]) - is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].scale) - is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].amax_history) - except Exception as e: - # Return exceptions only for rank==0 - if rank == 0: - comparison_exceptions.append(e) - - # Compare weights after `n_iters` - try: - assert_close(thunder_model.fc1.weight, te_model.fc1.weight) - assert_close(thunder_model.fc2.weight, te_model.fc2.weight) - except Exception as e: - # Return exceptions only for rank==0 - if rank == 0: - comparison_exceptions.append(e) - - return comparison_exceptions - - -def _test_ddp_transformer_engine_v1_llama_sanity(input_data): - # Test Description: We run a dummy training loop for a Transformer Model - # We run a few iterations to see that TransformerEngine doesn't throw internal assertion - # due to reordering of forward and backward operators. - # (This test will fail without `_rearrange_transformer_engine_linear` in `torch_autograd.py`) - # For more details, see docstring for `_rearrange_transformer_engine_linear` in transformer_engine_v1_ex.py. - from thunder.tests.llama2_model import Transformer, ModelArgs - - init_method, world_size, rank, executor, device, dtype, _unused_kwargs = input_data - devicetype = devices.device_from_string(device).devicetype - _unused_dtype = ltorch.to_torch_dtype(dtype) - init_per_process_distributed(init_method, devicetype, world_size, rank) - - torch.cuda.set_device(rank) - # data - batch_size = 64 - max_seq_len = 64 - vocab_size = 64 - - model_args = dict( - dim=64, - n_layers=1, - n_heads=2, - n_kv_heads=2, - vocab_size=vocab_size, - multiple_of=32, - max_seq_len=max_seq_len, - dropout=0.0, - hidden_dim=64, - ) - gptconf = ModelArgs(**model_args) - model = Transformer(gptconf) - model.to(device) - x = torch.randint(0, vocab_size, (batch_size, max_seq_len), dtype=torch.int64, device=device) - y = torch.randint(0, vocab_size, (batch_size, max_seq_len), dtype=torch.int64, device=device) - jit_model = thunder.distributed.ddp( - thunder.jit(model, executors=(transformer_engine_v1_ex,) + thunder.get_default_executors()) - ) - - sanity_exceptions = [] - try: - for _ in range(5): - out = jit_model(x, y).sum() - out.backward() - - bwd_exec_trace = thunder.last_backward_traces(jit_model)[-1] - - # Last symbol of the trace should be `return` - return_sym_idx = len(bwd_exec_trace.bound_symbols) - 1 - assert thunder.core.prims.PrimIDs.RETURN == bwd_exec_trace.bound_symbols[return_sym_idx].sym.id - - # Verify that the symbol to sync backward - # fp8 metadata is present in backward trace. - for idx, bsym in enumerate(bwd_exec_trace.bound_symbols): - if bsym.sym.id == te_sync_fp8_meta_bwd.id: - # Verify that `te_sync_fp8_meta_bwd` is before the last symbol of the trace - # which is `return` - assert idx < return_sym_idx - break - else: - raise RuntimeError("Backward sync symbol not found.") - except Exception as e: - sanity_exceptions.append(e) - - if rank == 0: - return sanity_exceptions - return None - - def _test_ddp_transformer_engine(input_data): # Test Description: We run a dummy training loop for a simple `Linear(Relu(Linear(x)))` # model with thunder (using TE executor) and with PyTorch eager + TE @@ -682,8 +462,6 @@ def _test_ddp_transformer_engine_llama_sanity(input_data): # Test Description: We run a dummy training loop for a Transformer Model # We run a few iterations to see that TransformerEngine doesn't throw internal assertion # due to reordering of forward and backward operators. - # (This test will fail without `_rearrange_transformer_engine_linear` in `torch_autograd.py`) - # For more details, see docstring for `_rearrange_transformer_engine_linear` in transformer_engine_v1_ex.py. from thunder.tests.llama2_model import Transformer, ModelArgs from thunder.core.proxies import variableify @@ -788,51 +566,6 @@ def test_native_ddp(executor, devices, dtype, bucket_size_in_mb): pass -@instantiate( - dtypes=(thunder.float32,), - num_devices=2, - devicetypes=(devices.DeviceType.CUDA,), - executors=(TorchExecutor,), - decorators=( - pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed."), - pytest.mark.skipif(not is_fp8_supported, reason=fp8_support_reason), - # NOTE: Setting `NVTE_TORCH_COMPILE` - # It is important to set this flag so that TE doesn't use - # `torch.compile` to fuse a few operations. This is because - # `torch.compile` creates a new process and that leads to - # the error : daemonic processes are not allowed to have children - # when running the tests. - # With the setting below, we use `torch.jit` for this test suite - # See: https://github.com/NVIDIA/TransformerEngine/blob/a38b291b0d1b04847e8ab1df8550df642a03a27d/transformer_engine/pytorch/jit.py#L11-L19 - # NOTE: We don't pass `clear=True` to `unittest.mock.patch.dict` as that may clear paths - # from environment leading to picking up of incorrect dependencies in the spawned process. - unittest.mock.patch.dict(os.environ, {"NVTE_TORCH_COMPILE": "0"}), - ), -) -@distributed_wrapper("test_ddp_transformer_engine_v1", _test_ddp_transformer_engine_v1) -def test_ddp_transformer_engine_v1(executor, devices, dtype): - pass - - -@instantiate( - dtypes=(thunder.float32,), - num_devices=2, - devicetypes=(devices.DeviceType.CUDA,), - executors=(TorchExecutor,), - decorators=( - pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed."), - pytest.mark.skipif(not is_fp8_supported, reason=fp8_support_reason), - # See NOTE: Setting `NVTE_TORCH_COMPILE` - # NOTE: We don't pass `clear=True` to `unittest.mock.patch.dict` as that may clear paths - # from environment leading to picking up of incorrect dependencies in the spawned process. - unittest.mock.patch.dict(os.environ, {"NVTE_TORCH_COMPILE": "0"}), - ), -) -@distributed_wrapper("test_ddp_transformer_engine_v1_llama_sanity", _test_ddp_transformer_engine_v1_llama_sanity) -def test_ddp_transformer_engine_v1_llama_sanity(executor, devices, dtype): - pass - - @instantiate( dtypes=(thunder.float32,), num_devices=2, diff --git a/thunder/tests/distributed/test_fsdp.py b/thunder/tests/distributed/test_fsdp.py index 92832d17cf..cc6389ad27 100644 --- a/thunder/tests/distributed/test_fsdp.py +++ b/thunder/tests/distributed/test_fsdp.py @@ -14,6 +14,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel from torch.distributed.fsdp.wrap import always_wrap_policy from torch.testing import assert_close, make_tensor +from lightning_utilities.core.imports import package_available import thunder import thunder.executors @@ -23,11 +24,6 @@ from thunder.distributed import fsdp from thunder.tests.framework import instantiate, TorchExecutor -from thunder.executors.transformer_engine_v1ex import ( - transformer_engine_v1_ex, - TE_AVAILABLE, -) - from thunder.executors.transformer_engineex import transformer_engine_ex, TransformerEngineTransform @@ -35,7 +31,7 @@ # This will be correctly updated below when TE Engine is installed # and if the current environment doesn't support FP8. fp8_support_reason: str = "" -if TE_AVAILABLE: +if package_available("transformer_engine"): from transformer_engine.pytorch import fp8_autocast from transformer_engine.pytorch import Linear as TELinear from transformer_engine.pytorch.fp8 import ( @@ -769,215 +765,6 @@ def finalize_pg(pg): return None -def _test_fsdp_transformer_engine_v1(input_data): - # Test Description: We run a dummy training loop for a simple `Linear(Relu(Linear(x)))` - # model with thunder (using TE executor) and with PyTorch eager + TE - # and verify that the weights have converged to same value and - # fp8 meta state is same after `n_iter`. - init_method, world_size, rank, executor, device, _unused_dtype, kwargs = input_data - thunder_fsdp_strategy, intermediate_activation_sharding = kwargs["thunder_fsdp_strategy_and_intermediate_sharding"] - devicetype = devices.device_from_string(device).devicetype - - # Setting LOCAL_RANK is necessary for thunder.distributed.fsdp - with unittest.mock.patch.dict(os.environ, {"LOCAL_RANK": str(rank)}): - init_per_process_distributed(init_method, devicetype, world_size, rank) - torch.cuda.set_device(rank) - - dim = 256 - # Running more iterations leads to `nan` for both eager and thunder - # with BlockScaling. - n_iter = 5 - - class ThunderModel(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.fc1 = torch.nn.Linear(dim, dim, bias=False) - self.fc2 = torch.nn.Linear(dim, dim, bias=False) - - def forward(self, x): - return self.fc2(torch.nn.functional.relu(self.fc1(x))) - - # Weights - fc1_weight = torch.randn(dim, dim, requires_grad=True, device="cuda") - fc2_weight = torch.randn(dim, dim, requires_grad=True, device="cuda") - - # Inputs (different input on different rank). - if rank == 0: - x = torch.arange(dim * dim, dtype=torch.float, device="cuda").view(dim, dim) - if rank == 1: - x = torch.randn(dim, dim, device="cuda") * 100 - - with torch.device("cuda"): - thunder_model = ThunderModel() - thunder_model.fc1.weight.data = fc1_weight.clone() - thunder_model.fc2.weight.data = fc2_weight.clone() - - jit_model = thunder.distributed.fsdp( - thunder.jit( - thunder_model, - executors=[ - transformer_engine_v1_ex, - ] - + executor.executors_list(), - fp8_shard_intermediate_activation=intermediate_activation_sharding, - ), - sharding_strategy=thunder_fsdp_strategy, - ) - - optim = torch.optim.SGD(jit_model.parameters()) - - for _ in range(n_iter): - o = jit_model(x).sum() - o.backward() - optim.step() - optim.zero_grad() - - # See https://github.com/NVIDIA/TransformerEngine/issues/814 - FP8GlobalStateManager.reset() - - class TEModel(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.fc1 = TELinear(dim, dim, bias=False) - self.fc2 = TELinear(dim, dim, bias=False) - - def forward(self, x): - return self.fc2(torch.nn.functional.relu(self.fc1(x))) - - with torch.device("cuda"): - te_model = TEModel() - te_model.fc1.weight.data = fc1_weight.clone() - te_model.fc2.weight.data = fc2_weight.clone() - - fsdp_model = FullyShardedDataParallel(te_model, auto_wrap_policy=always_wrap_policy) - if intermediate_activation_sharding: - transformer_engine.pytorch.distributed.prepare_te_modules_for_fsdp(fsdp_model) - optim = torch.optim.SGD(te_model.parameters()) - - for _ in range(n_iter): - with fp8_autocast(): - o = fsdp_model(x).sum() - - o.backward() - optim.step() - optim.zero_grad() - - thunder_to_te_layer_map = {"te_linear_0": te_model.fc1, "te_linear_1": te_model.fc2} - - fwd_traces = thunder.last_traces(jit_model) - - def is_same_across_ranks(t): - t_clone = t.clone() - torch.distributed.all_reduce(t_clone, op=torch.distributed.ReduceOp.AVG) - assert_close(t, t_clone) - - # Compare the state of the two models. - comparison_exceptions = [] - if not isinstance( - get_default_fp8_recipe(), MXFP8BlockScaling - ): # BlockScaling recipe doesn't have state like scale, amax_history. - for bound_symbol in fwd_traces[-1].bound_symbols: - if "te_linear" in bound_symbol.sym.name: - thunder_fp8_meta = bound_symbol._call_ctx[bound_symbol.sym.name].func.fp8_meta - te_fp8_meta = thunder_to_te_layer_map[bound_symbol.sym.name].fp8_meta - try: - # fwd tensor history - assert_close(thunder_fp8_meta["scaling_fwd"].scale, te_fp8_meta["scaling_fwd"].scale) - assert_close( - thunder_fp8_meta["scaling_fwd"].amax_history, te_fp8_meta["scaling_fwd"].amax_history - ) - # bwd tensor history - assert_close(thunder_fp8_meta["scaling_bwd"].scale, te_fp8_meta["scaling_bwd"].scale) - assert_close( - thunder_fp8_meta["scaling_bwd"].amax_history, te_fp8_meta["scaling_bwd"].amax_history - ) - - # This has to be on all ranks so that the computation is not blocked - is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].scale) - # See NOTE: TE forward tensor meta-data sync - is_same_across_ranks(thunder_fp8_meta["scaling_fwd"].amax_history[1:]) - is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].scale) - is_same_across_ranks(thunder_fp8_meta["scaling_bwd"].amax_history) - except Exception as e: - # Return exceptions only for rank==0 - if rank == 0: - comparison_exceptions.append(e) - - # Compare weights after `n_iters` - shard_size = int(dim / world_size) - fsdp_te_params = tuple(te_model.parameters()) - try: - assert_close(jit_model.get_parameter("fc1.weight"), fsdp_te_params[0].view(shard_size, dim)) - assert_close(jit_model.get_parameter("fc2.weight"), fsdp_te_params[1].view(shard_size, dim)) - except Exception as e: - # Return exceptions only for rank==0 - if rank == 0: - comparison_exceptions.append(e) - - return comparison_exceptions - - -def _test_fsdp_transformer_engine_v1_bucketing(input_data): - # Test Description: Test is to that TE works with bucketing. - from thunder.tests.llama2_model import Transformer, ModelArgs - - init_method, world_size, rank, executor, device, _unused_dtype, kwargs = input_data - thunder_fsdp_strategy, bucketing = kwargs["thunder_fsdp_strategy_and_bucketing"] - devicetype = devices.device_from_string(device).devicetype - - # Setting LOCAL_RANK is necessary for thunder.distributed.fsdp - with unittest.mock.patch.dict(os.environ, {"LOCAL_RANK": str(rank)}): - init_per_process_distributed(init_method, devicetype, world_size, rank) - torch.cuda.set_device(rank) - - # data - batch_size = 64 - max_seq_len = 64 - vocab_size = 64 - - model_args = dict( - dim=64, - n_layers=2, - n_heads=2, - n_kv_heads=2, - vocab_size=vocab_size, - multiple_of=32, - max_seq_len=max_seq_len, - dropout=0.0, - hidden_dim=64, - ) - gptconf = ModelArgs(**model_args) - model = Transformer(gptconf) - model.to(device) - x = torch.randint(0, vocab_size, (batch_size, max_seq_len), dtype=torch.int64, device=device) - y = torch.randint(0, vocab_size, (batch_size, max_seq_len), dtype=torch.int64, device=device) - jit_model = thunder.distributed.fsdp( - thunder.jit(model, executors=(transformer_engine_v1_ex,) + thunder.get_default_executors()), - sharding_strategy=thunder_fsdp_strategy, - bucketing_strategy=bucketing, - ) - - sanity_exceptions = [] - try: - for _ in range(5): - out = jit_model(x, y).sum() - out.backward() - - # Verifies te_linear was called - forward_trace = thunder.last_traces(jit_model) - backward_trace = thunder.last_backward_traces(jit_model) - assert any(bsym.sym.name.startswith("te_linear") for bsym in forward_trace[-1].bound_symbols) - assert any( - bsym.sym.name.startswith("te_functional_linear_backward") for bsym in backward_trace[-1].bound_symbols - ) - except Exception as e: - sanity_exceptions.append(e) - - if rank == 0: - return sanity_exceptions - return None - - # NOTE CPU is skipped because of # RuntimeError: no support for _allgather_base in Gloo process group @instantiate( @@ -1000,65 +787,6 @@ def test_native_fsdp(executor, devices, dtype, fsdp_bucketing_strategy): pass -@instantiate( - dtypes=(thunder.float32,), - num_devices=2, - devicetypes=(devices.DeviceType.CUDA,), - executors=(TorchExecutor,), - decorators=( - # NOTE: ddp_wrapper - pytest.mark.parametrize( - "thunder_fsdp_strategy_and_intermediate_sharding", - ( - (FSDPType.ZERO2, False), - (FSDPType.ZERO3, False), - # Intermediate sharding is only availabe TE v1.8 onwards - pytest.param( - (FSDPType.ZERO3, True), - marks=pytest.mark.skip("Intermediate sharding is errors in TE 2.0 (also with eager)."), - ), - ), - ), - pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed."), - pytest.mark.skipif(not is_fp8_supported, reason=fp8_support_reason), - # See NOTE: Setting `NVTE_TORCH_COMPILE` - # NOTE: We don't pass `clear=True` to `unittest.mock.patch.dict` as that may clear paths - # from environment leading to picking up of incorrect dependencies in the spawned process. - unittest.mock.patch.dict(os.environ, {"NVTE_TORCH_COMPILE": "0"}), - ), -) -@distributed_wrapper("test_fsdp_transformer_engine_v1", _test_fsdp_transformer_engine_v1) -def test_fsdp_transformer_engine_v1(executor, devices, dtype, thunder_fsdp_strategy_and_intermediate_sharding): - pass - - -@instantiate( - dtypes=(thunder.float32,), - num_devices=2, - devicetypes=(devices.DeviceType.CUDA,), - executors=(TorchExecutor,), - decorators=( - # NOTE: ddp_wrapper - pytest.mark.parametrize( - "thunder_fsdp_strategy_and_bucketing", - ( - (FSDPType.ZERO3, FSDPBucketingStrategy.LAYER), - (FSDPType.ZERO3, FSDPBucketingStrategy.BLOCK), - ), - ), - pytest.mark.skipif(not TE_AVAILABLE, reason="TransformerEngine is not installed."), - pytest.mark.skipif(not is_fp8_supported, reason=fp8_support_reason), - # See NOTE: Setting `NVTE_TORCH_COMPILE` - # NOTE: We don't pass `clear=True` to `unittest.mock.patch.dict` as that may clear paths - # from environment leading to picking up of incorrect dependencies in the spawned process. - unittest.mock.patch.dict(os.environ, {"NVTE_TORCH_COMPILE": "0"}), - ), -) -@distributed_wrapper("test_fsdp_transformer_engine_bucketing", _test_fsdp_transformer_engine_v1_bucketing) -def test_fsdp_transformer_engine_v1_bucketing(executor, devices, dtype, thunder_fsdp_strategy_and_bucketing): - pass - - def _test_fsdp_transformer_engine(input_data): # Test Description: We run a dummy training loop for a simple `Linear(Relu(Linear(x)))` # model with thunder (using TE executor) and with PyTorch eager + TE diff --git a/thunder/tests/test_extend.py b/thunder/tests/test_extend.py index 13b8d50ef1..66067cbd1c 100644 --- a/thunder/tests/test_extend.py +++ b/thunder/tests/test_extend.py @@ -134,7 +134,6 @@ def test_get_all_executors_includes_all_native_executors(): "torchcompile_cat", "torchcompile_xentropy", "python", - "transformer_engine_v1", } if package_available("triton"): # `triton` maybe installed on a system without GPU. diff --git a/thunder/tests/test_recipes.py b/thunder/tests/test_recipes.py index 273baa19ec..efc15ce861 100644 --- a/thunder/tests/test_recipes.py +++ b/thunder/tests/test_recipes.py @@ -224,13 +224,13 @@ def test_plugins_composition(monkeypatch): with patch("thunder.jit") as mock_jit: _ = thunder.compile(model, plugins="fp8") call_args = mock_jit.call_args - assert "transformer_engine_v1" in [el.name for el in call_args.kwargs["executors"]] + assert "transformer_engine" in [el.name for el in call_args.kwargs["executors"]] for ex in get_expected_executors(): assert ex.name in [el.name for el in call_args.kwargs["executors"]] _ = thunder.compile(model, plugins=["fp8"]) call_args = mock_jit.call_args - assert "transformer_engine_v1" in [el.name for el in call_args.kwargs["executors"]] + assert "transformer_engine" in [el.name for el in call_args.kwargs["executors"]] for ex in get_expected_executors(): assert ex.name in [el.name for el in call_args.kwargs["executors"]] @@ -238,7 +238,7 @@ def test_plugins_composition(monkeypatch): _ = thunder.compile(model, plugins=[FP8()]) call_args = mock_jit.call_args - assert "transformer_engine_v1" in [el.name for el in call_args.kwargs["executors"]] + assert "transformer_engine" in [el.name for el in call_args.kwargs["executors"]] for ex in get_expected_executors(): assert ex.name in [el.name for el in call_args.kwargs["executors"]] @@ -266,7 +266,7 @@ def test_plugins_composition(monkeypatch): transforms = call_args.kwargs["transforms"] for expected in expected_transforms: assert any(isinstance(el, expected) for el in transforms) - assert "transformer_engine_v1" in [el.name for el in call_args.kwargs["executors"]] + assert "transformer_engine" in [el.name for el in call_args.kwargs["executors"]] @pytest.mark.skipif(not nvfuser_available(), reason="nvFuser is not available") diff --git a/thunder/tests/test_transformer_engine_v1_executor.py b/thunder/tests/test_transformer_engine_v1_executor.py deleted file mode 100644 index 1ddad9c4f7..0000000000 --- a/thunder/tests/test_transformer_engine_v1_executor.py +++ /dev/null @@ -1,304 +0,0 @@ -import pytest -import torch -from torch.testing import assert_close - -import thunder -from thunder.tests.framework import requiresCUDA - -# NOTE: On SM120/121, TE defaults to using Float8BlockScaling -# which is currently unsupported in thunder, we skip the tests for these SM architectures. -from thunder.tests.utils import skip_on_sm120_and_sm121, is_sm120_orsm121 - -pytest.importorskip("transformer_engine", reason="transformer_engine was not found, skipping the tests.") -from thunder.executors.transformer_engine_v1ex import transformer_engine_v1_ex -from transformer_engine.common import recipe -import transformer_engine.pytorch as te - -# FP8 is supported on compute arch 8.9 onwards. -# MXFP8 is supported on compute arch 10.0 onwards. -# Skip the tests if current hardware is not supported. -is_fp8_supported, msg_fp8 = te.fp8.check_fp8_support() -is_mxfp8_supported, msg_mxfp8 = te.fp8.check_mxfp8_support() -if not is_fp8_supported: - pytest.skip(msg_fp8, allow_module_level=True) - -hybrid_fp8_delayed_scaling_recipe = recipe.DelayedScaling() -mxfp8_e4m3_recipe = recipe.MXFP8BlockScaling() - -# `None` is used to test the default recipe. -recipes = (None, hybrid_fp8_delayed_scaling_recipe, mxfp8_e4m3_recipe) -recipe_ids = ("default", "delayed_scaling", "mxfp8_e4m3") - - -@requiresCUDA -@pytest.mark.parametrize("fp8_recipe", recipes, ids=recipe_ids) -def test_te_linear_forward_backward(fp8_recipe: recipe.Recipe): - if fp8_recipe and not (fp8_recipe.delayed() or is_mxfp8_supported): - pytest.skip(msg_mxfp8) - - if is_sm120_orsm121 and fp8_recipe is None: - pytest.skip("On SM120/121, default recipe is Float8BlockScaling which is not supported") - - # Test Description: - # Verify that `torch.nn.functional.linear` is replaced with `te_linear_*` - # and the output as well as the gradients match for thunder compiled code. - dtype = torch.bfloat16 - device = "cuda" - - # TE inputs (3D input) - x_te = torch.randn(3, 768, 4096, device=device, dtype=dtype, requires_grad=True) - te_linear1 = te.Linear(4096, 4096, params_dtype=dtype) - te_linear2 = te.Linear(4096, 2048, params_dtype=dtype) - - # thunder inputs - x = x_te.detach().clone() - x.requires_grad_(True) - w1 = te_linear1.weight.detach().clone() - w1.requires_grad_(True) - w2 = te_linear2.weight.detach().clone() - w2.requires_grad_(True) - - def fn(x, w1, w2): - o = torch.nn.functional.linear(x, w1) - return torch.nn.functional.linear(o + x, w2) - - cfn = thunder.jit(fn, executors=[transformer_engine_v1_ex], te_fp8_recipe=fp8_recipe) - - # Enable autocasting for the forward pass - thunder_result = cfn(x, w1, w2) - - # Enable autocasting for the forward pass - with te.fp8_autocast(fp8_recipe=fp8_recipe): - inter_result = te_linear1(x_te) - te_result = te_linear2(inter_result + x_te) - - # Verifies the result is close to TE - assert_close(thunder_result, te_result) - - grad_output = torch.randn_like(te_result) - te_result.backward(grad_output) - thunder_result.backward(grad_output) - - assert_close(x.grad, x_te.grad) - assert_close(w1.grad, te_linear1.weight.grad) - assert_close(w2.grad, te_linear2.weight.grad) - - # Verifies te_linear was called - forward_trace = thunder.last_traces(cfn) - backward_trace = thunder.last_backward_traces(cfn) - assert any(bsym.sym.name.startswith("te_linear") for bsym in forward_trace[-1].bound_symbols) - assert any(bsym.sym.name.startswith("te_functional_linear_backward") for bsym in backward_trace[-1].bound_symbols) - - -@requiresCUDA -@pytest.mark.parametrize("fp8_recipe", recipes, ids=recipe_ids) -def test_te_linear_forward_backward_multiple_iteration(fp8_recipe): - if fp8_recipe and not (fp8_recipe.delayed() or is_mxfp8_supported): - pytest.skip(msg_mxfp8) - - if is_sm120_orsm121 and fp8_recipe is None: - pytest.skip("On SM120/121, default recipe is Float8BlockScaling which is not supported") - - # Test Description: - # In this test, we verify whether a model using TransformerEngine Linear - # and transformer_engine executor converge to same state. - # Since, the FP8 operations are stateful, we want to verify that - # our output matches over multiple iterations (where state handling comes into picture) - dtype = torch.bfloat16 - device = "cuda" - # Running more iterations leads to `nan` for both eager and thunder - # with BlockScaling. - # Potentially because we are training on dummy data and task - iterations = 6 - - # TE inputs - input_shape = (768, 4096) - te_linear1 = te.Linear(4096, 4096, params_dtype=dtype) - te_linear2 = te.Linear(4096, 2048, params_dtype=dtype) - - torch.nn.init.kaiming_uniform_(te_linear1.weight) - torch.nn.init.kaiming_uniform_(te_linear2.weight) - - def clone_params(*params): - return tuple(param.detach().clone() for param in params) - - # Parameters for thunder to optimize - w1, w2, b1, b2 = clone_params(te_linear1.weight, te_linear2.weight, te_linear1.bias, te_linear2.bias) - - target_value = torch.randint(42, (768,), dtype=torch.int64, device=device) - - inputs = tuple(torch.rand(*input_shape, device=device, dtype=dtype) for _ in range(iterations)) - - def train_model(model, optimizer): - # Run for `iterations`. - for iter_n in range(iterations): - x = inputs[iter_n] - result = model(x) - loss = torch.nn.functional.cross_entropy(result, target_value) - loss.backward() - optimizer.step() - optimizer.zero_grad() - - def te_model(x): - # Enable autocasting for the forward pass - with te.fp8_autocast(fp8_recipe=fp8_recipe): - return te_linear2(te_linear1(x)) - - te_sgd_optimizer = torch.optim.SGD(list(te_linear1.parameters()) + list(te_linear2.parameters())) - - train_model(te_model, te_sgd_optimizer) - - def fn(x, w1, w2, b1, b2): - o = torch.nn.functional.linear(x, w1, b1) - return torch.nn.functional.linear(o, w2, b2) - - cfn = thunder.jit(fn, executors=[transformer_engine_v1_ex], te_fp8_recipe=fp8_recipe) - - # Enable grad on thunder params. - list(map(lambda t: t.requires_grad_(True), (w1, w2, b1, b2))) - thunder_sgd_optimizer = torch.optim.SGD([w1, w2, b1, b2]) - - def thunder_model(x): - return cfn(x, w1, w2, b1, b2) - - train_model(thunder_model, thunder_sgd_optimizer) - - # Verify that the weights and biases converge to same value after few iterations. - assert_close(w1, te_linear1.weight) - assert_close(w2, te_linear2.weight) - assert_close(b1, te_linear1.bias) - assert_close(b2, te_linear2.bias) - - -@requiresCUDA -@skip_on_sm120_and_sm121 -def test_te_linear_invalid_inputs(): - def assert_not_transformed(x, w): - def fn(x, w): - return torch.nn.functional.linear(x, w) - - cfn = thunder.jit(fn, executors=[transformer_engine_v1_ex]) - cfn(x, w) - trace = thunder.last_traces(cfn)[-1] - assert not any(bsym.sym.name.startswith("te_linear") for bsym in trace.bound_symbols) - - # CPU is not supported. - device = "cpu" - x = torch.randn(16, 16, device=device) - w = torch.randn(16, 16, device=device) - assert_not_transformed(x, w) - - # Input shapes are not supported by TE. - device = "cuda" - x = torch.randn(16, 4, device=device) - w = torch.randn(16, 4, device=device) - assert_not_transformed(x, w) - - -@requiresCUDA -@skip_on_sm120_and_sm121 -def test_te_with_autocast(): - from thunder.transforms.autocast import autocast - - def foo(x, w): - return thunder.torch.linear(x, w) - - device = "cuda" - x = torch.randn(64, 64, device=device, requires_grad=True) - w = torch.randn(64, 64, device=device, requires_grad=True) - - cfunc = thunder.jit( - autocast(foo, dtype=thunder.dtypes.bfloat16), - executors=[transformer_engine_v1_ex], - disable_preprocessing=True, - ) - cfunc(x, w) - - fwd_traces = thunder.last_traces(cfunc) - # Verify that we have replaced `prims.linear` with `te_linear` - assert any(bsym.sym.name.startswith("te_linear") for bsym in fwd_traces[-1].bound_symbols) - - -# NOTE: strict=False as it passes on Blackwell. -# NOTE: Type of the error is different in different versions. -@pytest.mark.xfail( - strict=False, - raises=(ValueError, TypeError), - reason="See https://github.com/Lightning-AI/lightning-thunder/issues/2221", -) -@requiresCUDA -@skip_on_sm120_and_sm121 -def test_te_with_retain_graph(): - def foo(x, w): - return thunder.torch.linear(x, w) - - device = "cuda" - x = torch.randn(16, 16, device=device, requires_grad=True) - w = torch.randn(16, 16, device=device, requires_grad=True) - - cfunc = thunder.jit( - foo, - executors=[transformer_engine_v1_ex], - ) - out = cfunc(x, w) - - # Retain graph is not supported correctly by TE - # https://github.com/NVIDIA/TransformerEngine/issues/990 - out.backward(torch.randn_like(out), retain_graph=True) - out.backward(torch.randn_like(out)) - - -@requiresCUDA -@skip_on_sm120_and_sm121 -def test_te_trace_metadata_propagation(): - # This test is to verify that we correctly propagate metadata `_include_te_fp8_autocast` on - # trace using `from_trace`. `_include_te_fp8_autocast` is used to enable wrapping forward trace with `fp8_autocast`. - def foo(x, w): - return torch.nn.functional.linear(x, w) - - device = "cuda" - x = torch.randn(64, 64, device=device, requires_grad=True) - w = torch.randn(64, 64, device=device, requires_grad=True) - - class MyNoopTransform(thunder.core.transforms.Transform): - def transform_trace_post_optimization(self, computation_trace, **kwargs): - new_trace = thunder.core.trace.from_trace(computation_trace) - new_trace.bound_symbols = computation_trace.bound_symbols - return new_trace - - cfunc = thunder.jit( - foo, - executors=[transformer_engine_v1_ex], - transforms=[ - MyNoopTransform(), - ], - ) - cfunc(x, w) - - fwd_traces = thunder.last_traces(cfunc) - - # Verify that we have `te_linear` in the trace. - assert any(bsym.sym.name.startswith("te_linear") for bsym in fwd_traces[-1].bound_symbols) - - -@skip_on_sm120_and_sm121 -def test_te_grad_computation_with_intermediate(): - # Test for issue - https://github.com/Lightning-AI/lightning-thunder/issues/1966 - def fn(x, w): - # Due to autocast, trace becomes something like this - # t4 = prims.convert_element_type(x, dtypes.bfloat16) # t4: "cuda:0 bf16[32, 32]" - # t5 = prims.convert_element_type(w, dtypes.bfloat16) # t5: "cuda:0 bf16[32, 32]" - # t6 = prims.linear(t4, t5, None) # t6: "cuda:0 bf16[32, 32]" - with torch.autocast("cuda", torch.bfloat16): - return torch.nn.functional.linear(x, w) - - with torch.device("cuda"): - x = torch.randn(32, 32, requires_grad=True) - w = torch.randn(32, 32, requires_grad=True) - - tfn = thunder.jit(fn, executors=(transformer_engine_v1_ex,)) - - o = tfn(x, w) - o.sum().backward() - - assert w.grad is not None From 60682c5e1852dcbb6a3574019fb491f22778bbd7 Mon Sep 17 00:00:00 2001 From: riccardofelluga <11768013+riccardofelluga@users.noreply.github.com> Date: Thu, 20 Nov 2025 16:18:43 +0200 Subject: [PATCH 03/11] update recipe with new executor --- thunder/plugins/fp8.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/thunder/plugins/fp8.py b/thunder/plugins/fp8.py index 7d4f2ff1a1..916a591cf6 100644 --- a/thunder/plugins/fp8.py +++ b/thunder/plugins/fp8.py @@ -5,17 +5,29 @@ class FP8(Plugin): """ Plugin for enabling FP8 precision via NVIDIA Transformer Engine, enabling higher throughput of matrix operations in FP8. - See `lightning-thunder/thunder/executors/transformer_engine_v1ex.py` for implementation details. + See `lightning-thunder/thunder/executors/transformer_engine_ex.py` for implementation details. """ + def setup_transforms(self): + """ + Fetches the TransformerEngine transform. + + Returns: + list[Transform]: A list containing the TransformerEngine transforms. + """ + + from thunder.executors.transformer_engineex import TransformerEngineTransform + + return [TransformerEngineTransform()] + def setup_executors(self): """ - Imports the Transformer Engine executor. + Imports the TransformerEngine executor. Returns: list[Executor]: A list containing the Transformer Engine executor. """ - from thunder.executors.transformer_engine_v1ex import transformer_engine_v1_ex + from thunder.executors.transformer_engineex import transformer_engine_ex - return [transformer_engine_v1_ex] + return [transformer_engine_ex] From cdd258db3279153b4253ec2967ee20423a6280f6 Mon Sep 17 00:00:00 2001 From: riccardofelluga <11768013+riccardofelluga@users.noreply.github.com> Date: Thu, 20 Nov 2025 16:19:10 +0200 Subject: [PATCH 04/11] remove te v1 executor --- thunder/core/trace.py | 12 - thunder/executors/transformer_engine_v1ex.py | 563 ------------------- thunder/extend/__init__.py | 1 - thunder/transforms/autodiff.py | 6 - 4 files changed, 582 deletions(-) delete mode 100644 thunder/executors/transformer_engine_v1ex.py diff --git a/thunder/core/trace.py b/thunder/core/trace.py index ac205c235a..48962ec7d1 100644 --- a/thunder/core/trace.py +++ b/thunder/core/trace.py @@ -406,18 +406,6 @@ def keyfn(class_or_module: type | ModuleType) -> str: program.append("") if include_decorators: - # NOTE: For TransformerEngine executor, we want to wrap the generated - # forward function in fp8_autocast ctx manager. - # In the future, if other executor has similar requirements, we should - # add a new extension point for executors - # NOTE: For TE v1.6 onwards, `fp8_autocast` checks if `torch.is_grad_enabled` for updating - # the FP8 scales/inverses. So this decorator should be applied before `torch.no_grad` (so that - # it is in grad enabled part). - from thunder.executors.transformer_engine_v1ex import _is_te_linear_enabled, _get_te_wrapper_string - - if TraceTag.AUGMENTED_FORWARD and _is_te_linear_enabled(import_ctx, object_ctx): - program.append(_get_te_wrapper_string()) - # Disable gradients since Thunder takes care of this (for when calling torch operations) program.append("@torch.no_grad()") # Disable autocast since we already generated the trace with it in consideration (for when calling torch diff --git a/thunder/executors/transformer_engine_v1ex.py b/thunder/executors/transformer_engine_v1ex.py deleted file mode 100644 index 941e736d90..0000000000 --- a/thunder/executors/transformer_engine_v1ex.py +++ /dev/null @@ -1,563 +0,0 @@ -from functools import partial -from itertools import chain -from typing import Any -from collections.abc import Sequence -from collections.abc import Callable -from contextlib import contextmanager, nullcontext -from collections import deque -from importlib.metadata import version -from looseversion import LooseVersion -import warnings - -import torch - -from lightning_utilities.core.imports import package_available - -from thunder.core.proxies import TensorProxy -from thunder.core.trace import get_tracectx -from thunder.core.symbol import Symbol, BoundSymbol -import thunder.core.devices as devices -import thunder.core.prims as prims -from thunder.core.proxies import AnyProxy -from thunder.core.vjp_utils import disable_caching_split_forward_and_backward -from thunder.extend import OperatorExecutor, register_executor -from thunder.core.compile_data import get_compile_option, get_compile_data -from thunder.distributed import FSDPType -from thunder.executors.utils import Context, set_saved_tensors - - -__all__ = [ - "transformer_engine_v1_ex", -] - -TE_AVAILABLE: bool = package_available("transformer_engine") - -# We rely on internal details of TransformerEngine like `_Linear` autograd.Function. -# As these details are not public, they can change -# Ex. addition of a positional argument for cpu_offloading (not as the last argument) -# between version 1.2 and 1.3. -# Hence, we have these guards based on version. - -te: None | Any = None -if TE_AVAILABLE: - try: - import transformer_engine.pytorch as te - from transformer_engine.common.recipe import MXFP8BlockScaling, DelayedScaling - from transformer_engine.pytorch.constants import MXFP8_BLOCK_SCALING_SIZE - from transformer_engine.pytorch.module.linear import _Linear - from transformer_engine.pytorch.module.base import TransformerEngineBaseModule - from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, get_default_fp8_recipe - from transformer_engine.pytorch.utils import check_dim_for_fp8_exec - from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled - import transformer_engine_torch as tex - except Exception as ex: - warnings.warn(f"transformer_engine failed to import with exception {ex}") - TE_AVAILABLE = False - - TE_VERSION_2_0_PLUS = LooseVersion(version("transformer_engine")) > LooseVersion("2.0") - if not TE_VERSION_2_0_PLUS: - msg = f"Installed version of transformer_engine {version('transformer_engine')} is not supported, please upgrade to version 2.0 from https://github.com/NVIDIA/TransformerEngine/tree/release_v2.0. `transformer_engine_v1_ex` will not be used." - warnings.warn(msg) - TE_AVAILABLE = False - - -if not TE_AVAILABLE: - TransformerEngineBaseModule = object - -# [NOTE] IMPLEMENTATION DETAILS -# -# We try to re-use TransformerEngine implementation of `Linear` and `_Linear` as much as possible. -# As `thunder` expects operator to be passed all of its inputs, we have `TELinear` module which doesn't -# register any parameters and takes all `Tensor` arguments as input (It based on `Linear` from TE) -# FP8 tensors require extra meta-data per Tensor. Similar to TE, this meta-data is saved in module `TELinear`. -# NOTE: Implementation supports a limited set of input sizes where dim0 is divisible by 8 and dim1 is divisible by 16. -# -# Ref to `_Linear`: https://github.com/NVIDIA/TransformerEngine/blob/b957aa475bcbcf22405381d18bd7fefe4fb6b171/transformer_engine/pytorch/module/linear.py#L52 -# Ref to `Linear`: https://github.com/NVIDIA/TransformerEngine/blob/b957aa475bcbcf22405381d18bd7fefe4fb6b171/transformer_engine/pytorch/module/linear.py#L543 -# Stateful Operator: -# This means that every call to this `linear` requires a corresponding `TELinear` instance for -# backing the required FP8 state. This is done by creating a new `BoundSymbol` with corresponding instance -# when replacing calls to `prims.linear` (see `_create_fp8_linear_bound_symbol`). -# Eg. -# Original Program: -# -# def func(a, b, d): -# out = torch.nn.functional.linear(a, b) -# out = torch.nn.functional.linear(out, d) -# return out -# -# Traced Program: -# -# @torch.no_grad() -# @no_autocast -# @transformer_engine.fp8_autocast(fp8_recipe=te_fp8_recipe) -# def func(a, b, d): -# # a: "cuda:0 bf16[16, 32]" -# # b: "cuda:0 bf16[64, 32]" -# # d: "cuda:0 bf16[32, 64]" -# (t0, _) = te_linear_0(a, b, None, is_grad_enabled=False) # Backed by it's own instance of TELinear -# del a, b -# (t1, _) = te_linear_1(t0, d, None, is_grad_enabled=False) # Backed by it's own instance of TELinear -# del t0, d -# return t1 -# -# Managing Residuals for Backward: -# As we re-use `_Linear` which is a `torch.autograd.Function`, it requires a `ctx` Context object to -# save required objects for backward. We have our own `Context` class for the same. -# `_Linear` saves a lot of objects in `ctx` some of which is generated during the first call to `forward`. -# -# [NOTE] Enable grad within context -# To correctly compute the gradients, `_Linear` expects `requires_grad` to be -# set on the `input`, `weight` and `bias` tensor. -# But when applying `vjp`, the input tensor may not have requires_grad -# (as the rules take care relevant transformation). Thus we use `enable_grad` decorator -# when applying the forward and backward rule. -# -# Reference to points where TE looks at `requires_grad`: -# Ref: https://github.com/NVIDIA/TransformerEngine/blob/b957aa475bcbcf22405381d18bd7fefe4fb6b171/transformer_engine/pytorch/module/linear.py#L264 -# Ref: https://github.com/NVIDIA/TransformerEngine/blob/b957aa475bcbcf22405381d18bd7fefe4fb6b171/transformer_engine/pytorch/module/linear.py#L434 - - -# Eagerly apply map without -# storing the output. -def eager_map(*args): - return deque(map(*args), maxlen=0) - - -# Set requires_grad to True for passed tensors -# in this context. -@contextmanager -def enable_grad(*tensors): - original_requires_grad = tuple(map(lambda t: t.requires_grad, tensors)) - eager_map(lambda t: t.requires_grad_(True), tensors) - try: - yield - finally: - eager_map(lambda t, org_r_grad: t.requires_grad_(org_r_grad), tensors, original_requires_grad) - - -FP8_SHARD_INTERMEDIATE_ACTIVATIONS = "fp8_shard_intermediate_activation" - - -def _should_shard_intermediate() -> bool: - compile_data = get_compile_data() - - should_shard_intermediate_options: bool | None = get_compile_option( - FP8_SHARD_INTERMEDIATE_ACTIVATIONS, - "transformer_engine_v1_ex: Whether the intermediate activations should be sharded or not. Only applicable with FSDP Zero3, ignored otherwise.", - ) - - if getattr(compile_data.fn, "use_fsdp", False): - if getattr(compile_data.fn, "sharding_strategy") == FSDPType.ZERO3 and should_shard_intermediate_options: - return True - - if should_shard_intermediate_options: # user passed `True` but FSDPType was not Zero3 - warnings.warn( - f"transformer_engine_v1_ex: {FP8_SHARD_INTERMEDIATE_ACTIVATIONS} is only applicable for FSDP Zero3" - ) - - return False - - -def _get_num_saved_tensors(fp8_recipe): - MIN_DIM = MXFP8_BLOCK_SCALING_SIZE - te_linear = te.Linear(MIN_DIM, MIN_DIM) - - x = torch.randn(MIN_DIM, MIN_DIM, device="cuda") - with te.fp8_autocast(fp8_recipe=fp8_recipe): - o = te_linear(x) - return len(o.grad_fn.saved_tensors) - - -class TELinear(TransformerEngineBaseModule): - def __init__(self, in_features: int, out_features: int) -> None: - super().__init__() - - self.in_features = in_features - self.out_features = out_features - - # Used by `get_fp8_weights_scratchpad` - self.primary_weights_in_fp8 = False - - if FP8GlobalStateManager.with_fp8_parameters(): - raise RuntimeError("Primary weights in FP8 is not supported under `thunder.jit`.") - - # NOTE - This is available only v1.8 onwards - if _should_shard_intermediate(): - self.pg = get_compile_data().process_group_for_ddp - else: - self.pg = None - - def forward(self, inp, weight, bias, is_grad_enabled: bool = False): - # NOTE: Backward FP8 metadata sync - # TransformerEngine v1.6 onwards, we control the sync and update of FP8 metadata for FP8 tensors - # tied to backward pass (i.e. the gradient tensors) - # Also, note that the forward tensor metadata sync occurs at the exit of `fp8_autocast` context manager - # which is not controlled by us. - # - # We consume the `is_first_fp8_module` so that the automatic sync for FP8 metadata is disabled. - FP8GlobalStateManager.is_first_fp8_module() # Consume first module token. - - tensor_inputs = tuple(filter(lambda t: isinstance(t, torch.Tensor), (inp, weight, bias))) - # See [NOTE] Enable grad within context - # TE backward depends on `requires_grad` to compute grads. - # so under grad mode we enable grad for input tensors - # Ref: https://github.com/NVIDIA/TransformerEngine/blob/b957aa475bcbcf22405381d18bd7fefe4fb6b171/transformer_engine/pytorch/module/linear.py#L264 - grad_ctx = enable_grad(*tensor_inputs) if is_grad_enabled else nullcontext() - with grad_ctx, self.prepare_forward(inp) as inp: - assert self.fp8 or not self.primary_weights_in_fp8, ( - "Need to run inside fp8_autocast region when weights are stored in FP8." - ) - - ( - input_quantizer, - weight_quantizer, - output_quantizer, - grad_output_quantizer, - grad_input_quantizer, - ) = self._get_quantizers(is_grad_enabled) - - ctx = Context() if is_grad_enabled else None - - import inspect - - params = inspect.signature(_Linear.forward).parameters - - # Currently we do not support `tp` meaning tensor model parallel case. - # We hard-code the arguments related to distributed for now. - use_bias = bias is not None - - kwargs = { - "ctx": ctx, - "weight": weight, - "inp": inp, - "bias": torch.tensor([]) if not use_bias else bias, - "is_first_microbatch": None, - "fp8": self.fp8, - "fp8_calibration": self.fp8_calibration, - "input_quantizer": input_quantizer, - "weight_quantizer": weight_quantizer, - "output_quantizer": output_quantizer, - "grad_output_quantizer": grad_output_quantizer, - "grad_input_quantizer": grad_input_quantizer, - "fuse_wgrad_accumulation": False, - "cpu_offloading": CPUOffloadEnabled, - "tp_group": None, - "tp_size": 1, - "sequence_parallel": False, - "tensor_parallel": False, - "activation_dtype": inp.dtype, - "parallel_mode": None, - "is_grad_enabled": is_grad_enabled, - "ub_overlap_rs": False, - "ub_overlap_ag": False, - "ub_name": None, - "fp8_output": False, - "fsdp_group": self.pg, - "module": self, - "skip_fp8_weight_update": None, - } - - # Optimistic key value insertion for the sake of compatibility with main branch - for param_name in params: - if param_name not in kwargs: - param = params[param_name] - if param.default is not param.empty: - kwargs[param_name] = param.default - else: - kwargs[param_name] = None - - # Remove kwargs if they are not used in the current version. - unused_kwargs = set(kwargs.keys()) - set(params) - for unused_kwarg in unused_kwargs: - kwargs.pop(unused_kwarg) - - out = _Linear.forward(**kwargs) - ctx = ctx if is_grad_enabled else None - saved_tensors = ctx.pop_saved_tensors() if is_grad_enabled else None - return out, saved_tensors, ctx - - def _get_quantizers(self, is_grad_enabled): - # NOTE: Currently, we disallow changing these settings. - fp8_output = False - fp8_grad = False - - if not self.fp8: - return [None] * 5 - grad_input_quantizer = None - grad_output_quantizer = None - output_quantizer = None - input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] - input_quantizer.internal = True - weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] - weight_quantizer.internal = True - if fp8_output: - output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] - if is_grad_enabled: - grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] - grad_output_quantizer.internal = True - if fp8_grad: - grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] - return ( - input_quantizer, - weight_quantizer, - output_quantizer, - grad_output_quantizer, - grad_input_quantizer, - ) - - -# # # # # # # # # -# Make Executor for TE -# # # # # # # # # -transformer_engine_v1_ex = OperatorExecutor("transformer_engine_v1") -register_executor(transformer_engine_v1_ex) - - -def make_te_linear_meta(is_grad_enabled: bool = False): - def _te_functional_linear_meta( - a: TensorProxy, w: TensorProxy, bias: None | TensorProxy - ) -> tuple[TensorProxy, AnyProxy | None]: - # Input Shape : (*, Hin) - # Output Shape : (*, Hout) where * is any number of dims including None. - output_shape = list(a.shape) - output_shape[-1] = w.shape[0] - if is_grad_enabled: - global LINEAR_CALLS_COUNTER - ctx_dict = AnyProxy(object(), prefix=f"ctx_te_{LINEAR_CALLS_COUNTER}") - - # It's not critical to model the exact shape and dtype of - # saved_tensors since they are not used in Thunder's meta functions. - saved_tensors = tuple( - TensorProxy(like=a, shape=a.shape) - for _ in range(_get_num_saved_tensors(get_recipe_from_options_or_default_recipe())) - ) - - return TensorProxy(like=a, shape=output_shape), saved_tensors, ctx_dict - return TensorProxy(like=a, shape=output_shape), None, None - - return _te_functional_linear_meta - - -# -# Registers the backward function -# -def _te_functional_linear_backward_impl( - a_shape: tuple, - w_shape: tuple, - b_shape: tuple | None, - ctx: Context, - saved_tensors: Sequence[torch.Tensor], - g: torch.Tensor, -) -> [torch.Tensor, torch.Tensor, None | torch.Tensor]: - with set_saved_tensors(ctx, saved_tensors): - grads = _Linear.backward(ctx, g) - - grad_inputs = (grads[1], grads[0], grads[2]) - return grad_inputs - - -def _te_functional_linear_backward_meta( - a_shape: tuple, - w_shape: tuple, - b_shape: tuple | None, - ctx: Context, - saved_tensors: Sequence[TensorProxy], - g: TensorProxy, -) -> [TensorProxy, TensorProxy, None | TensorProxy]: - return ( - TensorProxy(like=g, shape=a_shape), - TensorProxy(like=g, shape=w_shape), - TensorProxy(like=g, shape=b_shape) if b_shape else None, - ) - - -te_functional_linear_backward = transformer_engine_v1_ex.register_operator( - "te_functional_linear_backward", meta=_te_functional_linear_backward_meta, fn=_te_functional_linear_backward_impl -) - -LINEAR_CALLS_COUNTER = 0 - -if TE_AVAILABLE: - # Recipe is chosen based on hardware platform - # For H100 or lower, it returns DelayedScaling recipe. - # For B200, it returns MXFP8BlockScaling recipe. - _DEFAULT_RECIPE = get_default_fp8_recipe() - -IMPORT_CTX_TE_KEY = "transformer_engine" -FP8_RECIPE_KEY = "te_fp8_recipe" - - -def get_recipe_from_options_or_default_recipe(): - desc = "transformer_engine_v1_ex: Optional fp8_recipe for `fp8_autocast` context manager." - if (fp8_recipe := get_compile_option(FP8_RECIPE_KEY, desc)) is None: - fp8_recipe = _DEFAULT_RECIPE - - return fp8_recipe - - -# Creates a new stateful operator for each invocation of `linear`. -def _create_fp8_linear_bound_symbol( - a: TensorProxy, w: TensorProxy, b: TensorProxy, is_grad_enabled=False -) -> tuple[torch.Tensor, AnyProxy | None]: - linear_fn = partial(TELinear(w.shape[1], w.shape[0]), is_grad_enabled=is_grad_enabled) - global LINEAR_CALLS_COUNTER - name = f"te_linear_{LINEAR_CALLS_COUNTER}" - - fp8_recipe = get_recipe_from_options_or_default_recipe() - - def bind_postprocess(bsym: BoundSymbol) -> None: - # This dict is then used by trace.python_ctx() to resolve the - # BoundSymbol to the actual function. - bsym._call_ctx: dict[str, Callable] = {name: linear_fn} - bsym._import_ctx: dict[str, Any] = {IMPORT_CTX_TE_KEY: te} - bsym._object_ctx: dict[str, Any] = {FP8_RECIPE_KEY: fp8_recipe} - - meta_fn = make_te_linear_meta(is_grad_enabled=is_grad_enabled) - sym = Symbol( - name=name, - meta=meta_fn, - is_prim=True, - executor=transformer_engine_v1_ex, - _bind_postprocess=bind_postprocess, - tags=(prims.OpTags.DONT_RECOMPUTE_IN_BACKWARD,), - ) - bsym = sym.bind(a, w, b, output=meta_fn(a, w, b)) - - # Now we need to append the BoundSymbol to the current trace. - trace = get_tracectx() - trace.scopes[-1].append(bsym) - for p in chain(bsym.flat_proxy_outs, bsym.flat_proxy_args): - trace.names.add(p.name) - - LINEAR_CALLS_COUNTER += 1 - - # Used in augmented forward rule. - # Returns are `result, saved_tensors, ctx`. - if is_grad_enabled: - return bsym.output - - return bsym.output[0] - - -# -# Registers transformer_engine_v1_ex as an executor for torch.nn.functional.linear -# - - -def _linear_checker( - a: TensorProxy, - w: TensorProxy, - bias: None | TensorProxy, -) -> bool: - # Make sure that we don't claim an operator - # if `TransformerEngine` is not available (not installed or version requirements not met) - # and it is passed as an executor to `thunder.jit()` - if not TE_AVAILABLE: - return False - - def is_cuda(t): - return t.device.devicetype == devices.DeviceType.CUDA - - inputs = (a, w) - if bias is not None: - inputs = inputs + (bias,) - - # Helper function as input shape can be (*, Hin) - def _view_input_as_2d(x): - shape = x.shape - return x.view((-1, shape[-1])) - - fp8_recipe = get_recipe_from_options_or_default_recipe() - - def check_valid_fp8_shapes(a): - # DelayedScaling and MXFP8BlockScaling have different shape requirements. - if isinstance(fp8_recipe, DelayedScaling): - return check_dim_for_fp8_exec(a) - - assert isinstance(fp8_recipe, MXFP8BlockScaling) - shape = a.shape - return shape[0] % MXFP8_BLOCK_SCALING_SIZE == 0 and shape[1] % MXFP8_BLOCK_SCALING_SIZE == 0 - - # Inputs must be on CUDA and - # input sizes must satisfy -> dim0 is divisible by 8 and dim1 is divisible by 16. - return all(map(is_cuda, inputs)) and check_valid_fp8_shapes(_view_input_as_2d(a)) and check_valid_fp8_shapes(w) - - -def linear_forward_rule(a, w, bias): - out, saved_tensors, ctx = _create_fp8_linear_bound_symbol(a, w, bias, is_grad_enabled=True) - primal = out - saved_for_backward = (a.shape, w.shape, bias.shape if bias is not None else None, ctx, saved_tensors) - return primal, saved_for_backward - - -# Translate calls from torch.nn.functional.linear to te.Linear (when the checker above returns True) -def _linear_transform(a: TensorProxy, w: TensorProxy, b: TensorProxy) -> torch.Tensor: - return _create_fp8_linear_bound_symbol(a, w, b, is_grad_enabled=False) - - -@disable_caching_split_forward_and_backward -def _linear_grad(a: TensorProxy, w: TensorProxy, b: TensorProxy) -> TensorProxy: - out, saved_for_backward = linear_forward_rule(a, w, b) - g = prims.get_grad(out) - ga, gw, gb = te_functional_linear_backward(*saved_for_backward, g) - prims.put_grad(a, ga) - prims.put_grad(w, gw) - if b is not None: - prims.put_grad(b, gb) - return out - - -# Registers the implementation for torch.nn.functional.linear -transformer_engine_v1_ex.register_implementation( - prims.linear, - checker=_linear_checker, - execution_transform=_linear_transform, - grad_transform=_linear_grad, -) - - -def _is_te_linear_enabled(import_ctx, object_ctx): - # These keys are present in `import_ctx` and `object_ctx` only if - # we actually replaced a linear call with a new TE operator. - is_te_exec_enabled = IMPORT_CTX_TE_KEY in import_ctx and FP8_RECIPE_KEY in object_ctx - return is_te_exec_enabled - - -TE_CTX_STR = f"@{IMPORT_CTX_TE_KEY}.fp8_autocast(fp8_recipe={FP8_RECIPE_KEY})" - - -def _get_te_wrapper_string(): - return TE_CTX_STR - - -def te_sync_fp8_meta_bwd_meta(): - pass - - -def te_sync_fp8_meta_bwd_impl(): - FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) - - -te_sync_fp8_meta_bwd = transformer_engine_v1_ex.register_operator( - "te_sync_fp8_meta_bwd", meta=te_sync_fp8_meta_bwd_meta, fn=te_sync_fp8_meta_bwd_impl -) - - -def _transformer_engine_bwd_fp8_meta_sync(_, bw_extrace): - # See doc of `_insert_bwd_fp8_meta_sync` for more details. - _insert_bwd_fp8_meta_sync(bw_extrace) - - -def _insert_bwd_fp8_meta_sync(bw_extrace): - # This functions insert the symbol `te_sync_fp8_meta_bwd` to the end of the backward - # trace which takes care of syncing and updating the FP8 metadata for backward tensors. - # See NOTE: Backward FP8 metadata sync - bwd_idx = len(bw_extrace.bound_symbols) - 1 - bw_extrace.bound_symbols.insert(bwd_idx, te_sync_fp8_meta_bwd.bind(output=None)) - - -def transformer_engine_v1_bwd_fp8_meta_sync(forward_trace, backward_trace): - if transformer_engine_v1_ex in get_compile_data().executors_list: - # NOTE: `_transformer_engine_bwd_fp8_meta_sync` may mutate `fw_extrace` or `bw_extrace`. - _transformer_engine_bwd_fp8_meta_sync(forward_trace, backward_trace) diff --git a/thunder/extend/__init__.py b/thunder/extend/__init__.py index 2373b8d500..da2e2dedd6 100644 --- a/thunder/extend/__init__.py +++ b/thunder/extend/__init__.py @@ -540,7 +540,6 @@ def get_all_executors() -> tuple[Executor, ...]: fa3ex, torch_compile, torchex, - transformer_engine_v1ex, transformer_engineex, triton_crossentropy, ) diff --git a/thunder/transforms/autodiff.py b/thunder/transforms/autodiff.py index 43a05f8460..2d4261cdf6 100644 --- a/thunder/transforms/autodiff.py +++ b/thunder/transforms/autodiff.py @@ -594,12 +594,6 @@ def backward_fn(saved_for_backward, cotangents): backward_trace = dce(backward_trace) - # Importing here to avoid cyclical dependencies in future. - # NOTE: This is required only for v1 executor. - from thunder.executors.transformer_engine_v1ex import transformer_engine_v1_bwd_fp8_meta_sync - - transformer_engine_v1_bwd_fp8_meta_sync(forward_trace, backward_trace) - # We only want to apply it on backward trace. from thunder.torch.experimental.dtensor_utils import check_dtensor_cotangent_metadata_in_backward From e6301228d8191f69ef881966b4ddd747f5c5abd0 Mon Sep 17 00:00:00 2001 From: riccardofelluga <11768013+riccardofelluga@users.noreply.github.com> Date: Thu, 20 Nov 2025 16:22:31 +0200 Subject: [PATCH 05/11] update comment --- thunder/executors/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/thunder/executors/utils.py b/thunder/executors/utils.py index 6c12685329..606c8c1bf9 100644 --- a/thunder/executors/utils.py +++ b/thunder/executors/utils.py @@ -97,7 +97,6 @@ def __repr__(self) -> str: # Helper to use torch.autograd.Function as an implementation for a symbol. -# See `transformer_engine_v1ex.py` for example. class Context: def __init__(self): self.saved_tensors = () From be758c33c58c020a2ae63afb6cbde9b353cb12f4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Nov 2025 14:25:19 +0000 Subject: [PATCH 06/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- thunder/tests/distributed/test_fsdp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/thunder/tests/distributed/test_fsdp.py b/thunder/tests/distributed/test_fsdp.py index cc6389ad27..231dea2c9c 100644 --- a/thunder/tests/distributed/test_fsdp.py +++ b/thunder/tests/distributed/test_fsdp.py @@ -39,7 +39,6 @@ FP8GlobalStateManager, get_default_fp8_recipe, ) - from transformer_engine.common.recipe import MXFP8BlockScaling import transformer_engine from thunder.tests.test_transformer_engine_executor import te_assert_close From ecaac00d18649d31723f7b57feae3ee5eb60be90 Mon Sep 17 00:00:00 2001 From: riccardofelluga <11768013+riccardofelluga@users.noreply.github.com> Date: Thu, 20 Nov 2025 16:47:11 +0200 Subject: [PATCH 07/11] add te_available flag --- thunder/tests/distributed/test_ddp.py | 3 ++- thunder/tests/distributed/test_fsdp.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/thunder/tests/distributed/test_ddp.py b/thunder/tests/distributed/test_ddp.py index cfb2bf8111..e7eeadea50 100644 --- a/thunder/tests/distributed/test_ddp.py +++ b/thunder/tests/distributed/test_ddp.py @@ -28,7 +28,8 @@ # This will be correctly updated below when TE Engine is installed # and if the current environment doesn't support FP8. fp8_support_reason: str = "" -if package_available("transformer_engine"): +TE_AVAILABLE = package_available("transformer_engine") +if TE_AVAILABLE: from transformer_engine.pytorch import fp8_autocast from transformer_engine.pytorch import Linear as TELinear from transformer_engine.pytorch.fp8 import ( diff --git a/thunder/tests/distributed/test_fsdp.py b/thunder/tests/distributed/test_fsdp.py index cc6389ad27..aa3bc47ac9 100644 --- a/thunder/tests/distributed/test_fsdp.py +++ b/thunder/tests/distributed/test_fsdp.py @@ -31,7 +31,8 @@ # This will be correctly updated below when TE Engine is installed # and if the current environment doesn't support FP8. fp8_support_reason: str = "" -if package_available("transformer_engine"): +TE_AVAILABLE = package_available("transformer_engine") +if TE_AVAILABLE: from transformer_engine.pytorch import fp8_autocast from transformer_engine.pytorch import Linear as TELinear from transformer_engine.pytorch.fp8 import ( From bc5e5321fcd4cd3a0fdc718891089d45b30ad014 Mon Sep 17 00:00:00 2001 From: riccardofelluga <11768013+riccardofelluga@users.noreply.github.com> Date: Thu, 20 Nov 2025 18:12:45 +0200 Subject: [PATCH 08/11] update comment and add None guard for transform --- thunder/plugins/fp8.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/thunder/plugins/fp8.py b/thunder/plugins/fp8.py index 916a591cf6..de4b489651 100644 --- a/thunder/plugins/fp8.py +++ b/thunder/plugins/fp8.py @@ -5,7 +5,7 @@ class FP8(Plugin): """ Plugin for enabling FP8 precision via NVIDIA Transformer Engine, enabling higher throughput of matrix operations in FP8. - See `lightning-thunder/thunder/executors/transformer_engine_ex.py` for implementation details. + See `lightning-thunder/thunder/executors/transformer_engineex.py` for implementation details. """ def setup_transforms(self): @@ -18,6 +18,10 @@ def setup_transforms(self): from thunder.executors.transformer_engineex import TransformerEngineTransform + # When TE executor is not available, both the transform and the executor will be None. + if TransformerEngineTransform is None: + return [] + return [TransformerEngineTransform()] def setup_executors(self): From c0655f4b5351f9b8d683a9a9539bd67e5b77bd75 Mon Sep 17 00:00:00 2001 From: riccardofelluga <11768013+riccardofelluga@users.noreply.github.com> Date: Thu, 20 Nov 2025 18:57:28 +0200 Subject: [PATCH 09/11] add more guarding for TE recipe --- thunder/plugins/fp8.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/thunder/plugins/fp8.py b/thunder/plugins/fp8.py index de4b489651..98e4d70e7d 100644 --- a/thunder/plugins/fp8.py +++ b/thunder/plugins/fp8.py @@ -34,4 +34,8 @@ def setup_executors(self): """ from thunder.executors.transformer_engineex import transformer_engine_ex + # When TE executor is not available, both the transform and the executor will be None. + if transformer_engine_ex is None: + return [] + return [transformer_engine_ex] From 6430a7d1947200cb850ad5f868b6946a7f20f572 Mon Sep 17 00:00:00 2001 From: riccardofelluga <11768013+riccardofelluga@users.noreply.github.com> Date: Thu, 20 Nov 2025 19:32:39 +0200 Subject: [PATCH 10/11] add test skip for recipe composition test --- thunder/tests/test_recipes.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/thunder/tests/test_recipes.py b/thunder/tests/test_recipes.py index efc15ce861..aa6e75f353 100644 --- a/thunder/tests/test_recipes.py +++ b/thunder/tests/test_recipes.py @@ -8,6 +8,7 @@ from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM from transformers.models.llama import LlamaConfig, LlamaForCausalLM +from lightning_utilities.core.imports import package_available from thunder.extend import deregister_executor from torch.testing import assert_close from thunder.recipes import HFTransformers @@ -215,6 +216,7 @@ def test_plugins_basics(): # test skipped if nvfuser isn't available because providing plugins calls BaseRecipe @pytest.mark.skipif(not nvfuser_available(), reason="nvFuser is not available") +@pytest.mark.skipif(not package_available("transformer_engine"), reason="TransformerEngine is not available") @pytest.mark.skipif(IS_WINDOWS, reason="libuv error with PT build on windows") def test_plugins_composition(monkeypatch): model = torch.nn.Sequential(torch.nn.Linear(2048, 4096), torch.nn.ReLU(), torch.nn.Linear(4096, 64)) From b2712d9138f0f69ffebb92e72264da57fa8f764c Mon Sep 17 00:00:00 2001 From: riccardofelluga <11768013+riccardofelluga@users.noreply.github.com> Date: Fri, 21 Nov 2025 15:11:09 +0200 Subject: [PATCH 11/11] add doctring with example --- thunder/executors/utils.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/thunder/executors/utils.py b/thunder/executors/utils.py index 606c8c1bf9..3f01214e99 100644 --- a/thunder/executors/utils.py +++ b/thunder/executors/utils.py @@ -96,8 +96,33 @@ def __repr__(self) -> str: return s -# Helper to use torch.autograd.Function as an implementation for a symbol. class Context: + """Helper to use torch.autograd.Function as an implementation for a symbol. + + This class provides a minimal interface for saving and retrieving tensors between + forward and backward passes, compatible with torch.autograd.Function implementations. + + Usage Pattern: + Forward Pass: + 1. Create Context() instance + 2. Pass to autograd function which calls ctx.save_for_backward(...) + 3. Extract tensors with ctx.pop_saved_tensors() + + Backward Pass: + 1. Restore tensors using set_saved_tensors(ctx, saved_tensors) context manager + 2. Call backward function which accesses ctx.saved_tensors + + Example: + # Forward + ctx = Context() + out = CustomFunction.forward(ctx, x, weight) + saved = ctx.pop_saved_tensors() + + # Backward + with set_saved_tensors(ctx, saved): + grads = CustomFunction.backward(ctx, grad_out) + """ + def __init__(self): self.saved_tensors = ()