TE-compatible CUDA graph callable runtime with optional TransformerEngine integration and a PyTorch-only runtime path.
from te_graph_runtime import make_graphed_callablesThe public entrypoint intentionally follows TransformerEngine v2.16's
make_graphed_callables API, with the CUDA graph parameter-gradient lifetime
fix from NVIDIA/TransformerEngine PR #2937 and capture_time_hooks from PR
#2831 carried on top. The implementation does not delegate to
transformer_engine.pytorch.graph.make_graphed_callables or to
torch.cuda.make_graphed_callables; it captures and replays graphs directly with
PyTorch CUDA graph and autograd primitives.
Editable checkout:
pip install -e .With optional TransformerEngine support:
pip install -e '.[te]'Importing te_graph_runtime is lazy: it does not import torch or
transformer_engine until make_graphed_callables is called.
- No TransformerEngine installed: generic PyTorch
nn.Modulegraphing works, includingsample_kwargs,_order,_num_layers_per_chunk, warmup hooks,capture_time_hooks, graph reset, stream/event replay kwargs, structuralbackward_dwsupport, and PR #2937's_clone_param_grads_on_returnlifetime policy. FP8/TE-specific options fail fast with an actionable error. - TransformerEngine installed: TE internals are used for FP8/FP4 recipes,
amax reduction, TE module state save/restore, TE RNG tracker state, and TE
module
backward_dwhandling. Tests compare behavior with TE's upstream implementation.
examples/custom_framework_minimal.py- wrap a custom framework block and use keyword tensor inputs.examples/custom_framework_pipeline.py- integrate an interleaved pipeline schedule with_orderand optional static-buffer reuse.examples/optional_te_fp8.py- use one code path that enables FP8 only when TransformerEngine and FP8-capable hardware are present.
See docs/custom_framework_integration.md for integration patterns, failure
modes, and test commands.
CPU/local static tests:
PYTHONPATH=src python -m pytest -qGPU tests without requiring TransformerEngine:
PYTHONPATH=src python -m pytest -q tests/test_cuda_runtime.pyGPU tests with TransformerEngine parity checks:
PYTHONPATH=src python -m pytest -q tests/test_te_parity.pySee UPSTREAM.md for the pinned upstream source and update workflow.