Skip to content

Commit f844905

Browse files
authored
[PyTorch] Make grouped weights opt-in (NVIDIA#2678)
* Make grouped weights opt-in Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Change varname Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
1 parent 496620a commit f844905

2 files changed

Lines changed: 22 additions & 3 deletions

File tree

tests/pytorch/test_sanity.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -585,10 +585,19 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
585585
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
586586
@pytest.mark.parametrize("fp8_model_params", all_boolean)
587587
@pytest.mark.parametrize("use_bias", all_boolean)
588+
@pytest.mark.parametrize("single_param", all_boolean)
588589
@pytest.mark.parametrize("empty_split", ["first", "last", "middle"])
589590
@pytest.mark.parametrize("num_gemms", [4])
590591
def test_sanity_grouped_linear(
591-
dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split
592+
dtype,
593+
bs,
594+
model,
595+
fp8_recipe,
596+
fp8_model_params,
597+
use_bias,
598+
single_param,
599+
num_gemms,
600+
empty_split,
592601
):
593602
if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params:
594603
pytest.skip("FP8 model parameters are not supported in debug mode.")
@@ -598,6 +607,9 @@ def test_sanity_grouped_linear(
598607
bs = bs * 16
599608
num_tokens = bs * config.max_seqlen_q * (num_gemms - 1)
600609

610+
if single_param:
611+
os.environ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"] = "1"
612+
601613
if fp8_recipe is not None:
602614
if not is_fp8_supported(config):
603615
pytest.skip("Model config does not support FP8")
@@ -617,7 +629,8 @@ def test_sanity_grouped_linear(
617629
# Verify that weights are stored in contiguous GroupedTensor storage.
618630
weights = [getattr(te_grouped_linear, f"weight{i}") for i in range(num_gemms)]
619631
if fp8_recipe is None or not (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()):
620-
check_grouped_tensor_pointers(weights, fp8_recipe)
632+
if single_param:
633+
check_grouped_tensor_pointers(weights, fp8_recipe)
621634

622635
inp_hidden_states = torch.randn(
623636
num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
@@ -636,6 +649,9 @@ def test_sanity_grouped_linear(
636649
loss.backward()
637650
assert out.shape == (num_tokens, ffn_hidden_size)
638651

652+
if single_param:
653+
del os.environ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"]
654+
639655

640656
@pytest.mark.parametrize("dtype", param_types)
641657
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Union, Optional, Callable, Tuple, List
77
from itertools import chain
88
import warnings
9+
import os
910

1011
import functools
1112
import torch
@@ -793,7 +794,9 @@ def make_grouped_weights(self, defer_init=False) -> None:
793794

794795
def reset_parameters(self, defer_init=False):
795796
super().reset_parameters(defer_init=defer_init)
796-
self.make_grouped_weights(defer_init=defer_init)
797+
# Grouped tensor weights is an opt-in feature.
798+
if bool(int(os.getenv("NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS", "0"))):
799+
self.make_grouped_weights(defer_init=defer_init)
797800

798801
def set_tensor_parallel_attributes(self, defer_init=False) -> None:
799802
"""Set attributes needed for TP"""

0 commit comments

Comments
 (0)