Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7553e6a
Stage 1&2: Python containers + quantize/gemm dispatch/unwrap
negvet Mar 31, 2026
19acc5e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 31, 2026
3f33f40
Merge branch 'main' into hybrid_quantization
negvet Apr 6, 2026
f80f5d0
Enable quantized_model_init
negvet Apr 16, 2026
2185b30
FSDP support
negvet Apr 17, 2026
f22a395
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2026
c9754e9
Merge branch 'main' into hybrid_quantization
negvet Apr 17, 2026
103fffe
Enable CPU offloading
negvet Apr 22, 2026
16fb371
Activation recomputation
negvet Apr 24, 2026
a50fd63
TP/SP
negvet Apr 24, 2026
2214843
Resolve comments: hybrid uniform list, make_empty try, __repr__, etc.
negvet Apr 24, 2026
88fe467
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 24, 2026
f649cc9
Merge branch 'main' into hybrid_quantization
negvet Apr 29, 2026
4858491
Respect usage
negvet Apr 29, 2026
ef31a9a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 29, 2026
e7136fd
Merge main
negvet May 13, 2026
c7da5b2
Misc minor fixes: comments, tests, etc.
negvet May 20, 2026
a164cd3
Towards MLM integration
negvet May 21, 2026
62e7668
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 21, 2026
4b0d18c
Merge main
negvet Jun 1, 2026
7316516
Resolve comments: improve fsdp/tp/sp tests + amax reduction fix
negvet Jun 3, 2026
5892a74
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2026
ec7b84c
Enable FSDP2 hybrid protocol for Float8Block tensor
negvet Jun 5, 2026
b99277a
Enable Identity (no-op) quantization
negvet Jun 9, 2026
8cc3332
Bug fixing
negvet Jun 9, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hybrid_quantization.xml $TE_PATH/tests/pytorch/test_hybrid_quantization.py || test_fail "test_hybrid_quantization.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
Expand Down
1 change: 1 addition & 0 deletions qa/L1_pytorch_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PAT
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_exact.xml $TE_PATH/tests/pytorch/distributed/test_numerics_exact.py || test_fail "test_numerics_exact.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hybrid_tp_sp.xml $TE_PATH/tests/pytorch/distributed/test_hybrid_tp_sp.py || test_fail "test_hybrid_tp_sp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py"
Expand Down
22 changes: 22 additions & 0 deletions tests/pytorch/distributed/fsdp2_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ def _check_nvfp4_support():
("NVFP4BlockScaling", _check_nvfp4_support),
]

_HYBRID_RECIPE_CONFIGS = [
("HybridFP8CurrentScaling", fp8.check_fp8_support),
("HybridMXFP8", fp8.check_mxfp8_support),
("HybridFloat8BlockScaling", fp8.check_fp8_block_scaling_support),
("HybridMixed_MXFP8_FP8", fp8.check_mxfp8_support),
]


def _parametrize_recipes():
params = []
Expand All @@ -56,6 +63,16 @@ def _parametrize_recipes():
return params


def _parametrize_hybrid_recipes():
params = []
for name, check_fn in _HYBRID_RECIPE_CONFIGS:
supported, reason = check_fn()
params.append(
pytest.param(name, id=name, marks=pytest.mark.skipif(not supported, reason=reason))
)
return params


# ── Session / per-test fixtures ──────────────────────────────────────
@pytest.fixture(scope="session", autouse=True)
def dist_init():
Expand Down Expand Up @@ -83,3 +100,8 @@ def _cleanup():
@pytest.fixture(params=_parametrize_recipes())
def recipe_name(request):
return request.param


@pytest.fixture(params=_parametrize_hybrid_recipes())
def hybrid_recipe_name(request):
return request.param
106 changes: 105 additions & 1 deletion tests/pytorch/distributed/fsdp2_tests/fsdp2_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,117 @@
"""Shared utility functions for FSDP2 distributed tests."""

import transformer_engine.common.recipe
from transformer_engine.pytorch import QuantizedTensor
from transformer_engine.pytorch import HybridQuantizer, IdentityQuantizer, QuantizedTensor
from transformer_engine.pytorch.custom_recipes.quantization_recipes_base import (
current_scaling_quantizer_factory,
float8_block_scaling_quantizer_factory,
mxfp8_quantizer_factory,
)


def get_recipe_from_string(recipe):
return getattr(transformer_engine.common.recipe, recipe)()


def _hybrid_fp8_current_qfactory(role):
"""FP8 current-scaling rowwise + FP8 current-scaling columnwise."""
is_linear = role is not None and role.module_type in ("linear", "grouped_linear")
if is_linear and role.tensor_type in ("input", "weight", "output"):
return HybridQuantizer(
rowwise_quantizer=current_scaling_quantizer_factory(role),
columnwise_quantizer=current_scaling_quantizer_factory(role),
)
return current_scaling_quantizer_factory(role)


def _hybrid_mxfp8_qfactory(role):
"""MXFP8 rowwise + MXFP8 columnwise."""
is_linear = role is not None and role.module_type in ("linear", "grouped_linear")
if is_linear and role.tensor_type in ("input", "weight", "output"):
return HybridQuantizer(
rowwise_quantizer=mxfp8_quantizer_factory(role),
columnwise_quantizer=mxfp8_quantizer_factory(role),
)
return mxfp8_quantizer_factory(role)


def _hybrid_float8_block_qfactory(role):
"""Float8 block-scaling rowwise + Float8 block-scaling columnwise."""
is_linear = role is not None and role.module_type in ("linear", "grouped_linear")
if is_linear and role.tensor_type in ("input", "weight", "output"):
return HybridQuantizer(
rowwise_quantizer=float8_block_scaling_quantizer_factory(role),
columnwise_quantizer=float8_block_scaling_quantizer_factory(role),
)
return float8_block_scaling_quantizer_factory(role)


def _hybrid_mixed_mxfp8_fp8_qfactory(role):
"""MXFP8 rowwise + FP8 current columnwise (cross-format hybrid)."""
is_linear = role is not None and role.module_type in ("linear", "grouped_linear")
if is_linear and role.tensor_type in ("input", "weight", "output"):
return HybridQuantizer(
rowwise_quantizer=mxfp8_quantizer_factory(role),
columnwise_quantizer=current_scaling_quantizer_factory(role),
)
return current_scaling_quantizer_factory(role)


def _hybrid_fp8_current_identity_qfactory(role):
"""FP8 current forward + high-precision backward via Identity."""
is_linear = role is not None and role.module_type in ("linear", "grouped_linear")
if is_linear and role.tensor_type in ("input", "weight", "output"):
return HybridQuantizer(
rowwise_quantizer=current_scaling_quantizer_factory(role),
columnwise_quantizer=IdentityQuantizer(),
)
if is_linear and role.tensor_type in ("grad_output", "grad_input"):
return IdentityQuantizer()
return current_scaling_quantizer_factory(role)


def _identity_qfactory(role): # pylint: disable=unused-argument
"""High-precision passthrough for every quantizer slot."""
return IdentityQuantizer()


# The qfactories above are registered here as module-level functions (not
# lambdas or closures) on purpose: DCP serializes ``CustomRecipe`` via
# ``pickle``, and closure-based qfactories (or inner functions capturing state)
# are not picklable. Keeping them at module scope lets them pickle by reference.
# See ``run_fsdp2_fused_adam.py::test_hybrid_dcp_output_parity``.
_HYBRID_QFACTORIES = {
"HybridFP8CurrentScaling": _hybrid_fp8_current_qfactory,
"HybridMXFP8": _hybrid_mxfp8_qfactory,
"HybridFloat8BlockScaling": _hybrid_float8_block_qfactory,
"HybridMixed_MXFP8_FP8": _hybrid_mixed_mxfp8_fp8_qfactory,
"HybridFP8CurrentScalingIdentity": _hybrid_fp8_current_identity_qfactory,
"Identity": _identity_qfactory,
}


def get_hybrid_recipe_from_string(recipe):
"""Build a CustomRecipe wrapping a module-level (picklable) hybrid qfactory.

Each hybrid qfactory composes one or two role-aware base factories from
``quantization_recipes_base`` per direction; per-role behavior is delegated
to the base factory and the hybrid layer only decides the direction pairing.

Supported values:
"HybridFP8CurrentScaling" — FP8 current for both directions
"HybridMXFP8" — MXFP8 for both directions
"HybridFloat8BlockScaling" — Float8 block scaling for both directions
"HybridMixed_MXFP8_FP8" — MXFP8 rowwise + FP8 current columnwise
"HybridFP8CurrentScalingIdentity" — FP8 current forward + Identity backward
"Identity" — high-precision passthrough for every slot
"""
if recipe not in _HYBRID_QFACTORIES:
raise ValueError(
f"Unknown hybrid recipe '{recipe}'. Supported: {sorted(_HYBRID_QFACTORIES.keys())}"
)
return transformer_engine.common.recipe.CustomRecipe(qfactory=_HYBRID_QFACTORIES[recipe])


def save_custom_attrs(module):
custom_attrs = {}
for name, param in module.named_parameters():
Expand Down
Loading