Skip to content

Commit 46bdc85

Browse files
negvetpre-commit-ci[bot]root
authored
[PyTorch] Preserve fprop operands for dequantized backward override (#3141)
* Preserve fprop operands for dequantized backward override Signed-off-by: Evgeny <etsykunov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add test_grouped_linear_backward_override_high_precision_forces_save_original_input test Signed-off-by: root <root@prenyx0017.a51.clusters.nvidia.com> --------- Signed-off-by: Evgeny <etsykunov@nvidia.com> Signed-off-by: root <root@prenyx0017.a51.clusters.nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: root <root@prenyx0017.a51.clusters.nvidia.com>
1 parent 90baf02 commit 46bdc85

3 files changed

Lines changed: 216 additions & 0 deletions

File tree

tests/pytorch/test_backward_override.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -858,6 +858,218 @@ def test_backward_override_recipe_matches_requested_mode(
858858
assert quant_recipe.backward_override is None
859859

860860

861+
@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list)
862+
@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias"))
863+
def test_linear_backward_override_dequantized_ignores_save_original_input(
864+
recipe_name: str,
865+
use_bias: bool,
866+
) -> None:
867+
reset_rng_states()
868+
dtype = torch.bfloat16
869+
input_shape = (32, 128)
870+
out_features = 128
871+
_maybe_skip_recipe_dtype(recipe_name, dtype, "linear")
872+
_maybe_skip_unsupported_recipe_module_combo(recipe_name, "linear")
873+
_maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, "linear")
874+
875+
mode_recipe = make_recipe(recipe_name, backward_override="dequantized")
876+
skip_unsupported_backward_override("linear", mode_recipe, "dequantized")
877+
878+
module_ref = te.Linear(
879+
input_shape[-1],
880+
out_features,
881+
bias=use_bias,
882+
params_dtype=dtype,
883+
device="cuda",
884+
save_original_input=False,
885+
)
886+
module_test = te.Linear(
887+
input_shape[-1],
888+
out_features,
889+
bias=use_bias,
890+
params_dtype=dtype,
891+
device="cuda",
892+
save_original_input=True,
893+
)
894+
_copy_named_parameters(module_ref, module_test)
895+
896+
x = torch.randn(*input_shape, dtype=dtype, device="cuda")
897+
dy = torch.randn(input_shape[0], out_features, dtype=dtype, device="cuda")
898+
899+
y_ref, dx_ref, dw_ref, db_ref = _run_single_step(module_ref, x, dy, mode_recipe)
900+
y_test, x_test, saved_operands = _run_single_step_with_saved_operands(
901+
module_test, x, mode_recipe
902+
)
903+
_assert_saved_quantized_operand_uses_rowwise_only(saved_operands[0], name="linear_input")
904+
905+
y_test_detached = y_test.detach().clone()
906+
y_test.backward(dy)
907+
assert x_test.grad is not None
908+
assert module_test.weight.grad is not None
909+
dx_test = x_test.grad.detach().clone()
910+
dw_test = module_test.weight.grad.detach().clone()
911+
test_bias = getattr(module_test, "bias", None)
912+
db_test = (
913+
None if test_bias is None or test_bias.grad is None else test_bias.grad.detach().clone()
914+
)
915+
916+
assert_close(y_test_detached, y_ref, rtol=0, atol=0, check_dtype=True)
917+
assert_close(dx_test, dx_ref, rtol=0, atol=0, check_dtype=True)
918+
assert_close(dw_test, dw_ref, rtol=0, atol=0, check_dtype=True)
919+
if use_bias:
920+
assert db_test is not None and db_ref is not None
921+
assert_close(db_test, db_ref, rtol=0, atol=0, check_dtype=True)
922+
923+
924+
@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list)
925+
@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias"))
926+
def test_grouped_linear_backward_override_dequantized_ignores_save_original_input(
927+
recipe_name: str,
928+
use_bias: bool,
929+
) -> None:
930+
reset_rng_states()
931+
dtype = torch.bfloat16
932+
in_features = 128
933+
out_features = 128
934+
m_splits = [64, 64]
935+
num_gemms = len(m_splits)
936+
num_tokens = sum(m_splits)
937+
_maybe_skip_recipe_dtype(recipe_name, dtype, "grouped_linear")
938+
_maybe_skip_unsupported_recipe_module_combo(recipe_name, "grouped_linear")
939+
_maybe_skip_unsupported_grouped_splits(recipe_name, m_splits)
940+
941+
mode_recipe = make_recipe(recipe_name, backward_override="dequantized")
942+
skip_unsupported_backward_override("grouped_linear", mode_recipe, "dequantized")
943+
944+
module_ref = te.GroupedLinear(
945+
num_gemms,
946+
in_features,
947+
out_features,
948+
bias=use_bias,
949+
params_dtype=dtype,
950+
device="cuda",
951+
save_original_input=False,
952+
)
953+
module_test = te.GroupedLinear(
954+
num_gemms,
955+
in_features,
956+
out_features,
957+
bias=use_bias,
958+
params_dtype=dtype,
959+
device="cuda",
960+
save_original_input=True,
961+
)
962+
_copy_named_parameters(module_ref, module_test)
963+
964+
x = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda")
965+
dy = torch.randn(num_tokens, out_features, dtype=dtype, device="cuda")
966+
967+
y_ref, dx_ref, dw_ref, db_ref = _run_grouped_linear_single_step(
968+
module_ref, x, m_splits, dy, mode_recipe
969+
)
970+
y_test, x_test, saved_operands = _run_grouped_linear_step_with_saved_operands(
971+
module_test, x, m_splits, mode_recipe
972+
)
973+
saved_inputs = saved_operands[:num_gemms]
974+
for i, saved_input in enumerate(saved_inputs):
975+
_assert_saved_quantized_operand_uses_rowwise_only(
976+
saved_input, name=f"grouped_linear_input{i}"
977+
)
978+
979+
y_test_detached = y_test.detach().clone()
980+
y_test.backward(dy)
981+
assert x_test.grad is not None
982+
dx_test = x_test.grad.detach().clone()
983+
dw_test = [getattr(module_test, f"weight{i}").grad.detach().clone() for i in range(num_gemms)]
984+
db_test: list[Optional[torch.Tensor]] = []
985+
for i in range(num_gemms):
986+
if use_bias:
987+
db_test.append(getattr(module_test, f"bias{i}").grad.detach().clone())
988+
else:
989+
db_test.append(None)
990+
991+
assert_close(y_test_detached, y_ref, rtol=0, atol=0, check_dtype=True)
992+
assert_close(dx_test, dx_ref, rtol=0, atol=0, check_dtype=True)
993+
for test_dw, ref_dw in zip(dw_test, dw_ref):
994+
assert_close(test_dw, ref_dw, rtol=0, atol=0, check_dtype=True)
995+
if use_bias:
996+
for test_db, ref_db in zip(db_test, db_ref):
997+
assert test_db is not None and ref_db is not None
998+
assert_close(test_db, ref_db, rtol=0, atol=0, check_dtype=True)
999+
1000+
1001+
@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list)
1002+
def test_linear_backward_override_high_precision_forces_save_original_input(
1003+
recipe_name: str,
1004+
) -> None:
1005+
reset_rng_states()
1006+
dtype = torch.bfloat16
1007+
input_shape = (32, 128)
1008+
_maybe_skip_recipe_dtype(recipe_name, dtype, "linear")
1009+
_maybe_skip_unsupported_recipe_module_combo(recipe_name, "linear")
1010+
_maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, "linear")
1011+
1012+
mode_recipe = make_recipe(recipe_name, backward_override="high_precision")
1013+
skip_unsupported_backward_override("linear", mode_recipe, "high_precision")
1014+
1015+
module = te.Linear(
1016+
input_shape[-1],
1017+
128,
1018+
bias=False,
1019+
params_dtype=dtype,
1020+
device="cuda",
1021+
save_original_input=False,
1022+
)
1023+
x = torch.randn(*input_shape, dtype=dtype, device="cuda")
1024+
1025+
_, _, saved_operands = _run_single_step_with_saved_operands(module, x, mode_recipe)
1026+
1027+
assert isinstance(saved_operands[0], torch.Tensor)
1028+
1029+
1030+
@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list)
1031+
def test_grouped_linear_backward_override_high_precision_forces_save_original_input(
1032+
recipe_name: str,
1033+
) -> None:
1034+
reset_rng_states()
1035+
dtype = torch.bfloat16
1036+
in_features = 128
1037+
out_features = 128
1038+
m_splits = [64, 64]
1039+
num_gemms = len(m_splits)
1040+
num_tokens = sum(m_splits)
1041+
_maybe_skip_recipe_dtype(recipe_name, dtype, "grouped_linear")
1042+
_maybe_skip_unsupported_recipe_module_combo(recipe_name, "grouped_linear")
1043+
_maybe_skip_unsupported_grouped_splits(recipe_name, m_splits)
1044+
1045+
mode_recipe = make_recipe(recipe_name, backward_override="high_precision")
1046+
skip_unsupported_backward_override("grouped_linear", mode_recipe, "high_precision")
1047+
1048+
module = te.GroupedLinear(
1049+
num_gemms,
1050+
in_features,
1051+
out_features,
1052+
bias=False,
1053+
params_dtype=dtype,
1054+
device="cuda",
1055+
save_original_input=False,
1056+
)
1057+
x = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda")
1058+
1059+
_, _, saved_operands = _run_grouped_linear_step_with_saved_operands(
1060+
module, x, m_splits, mode_recipe
1061+
)
1062+
1063+
saved_inputs = saved_operands[:num_gemms]
1064+
assert isinstance(saved_inputs[0], torch.Tensor)
1065+
assert saved_inputs[0].shape == x.shape
1066+
assert all(saved_input is None for saved_input in saved_inputs[1:])
1067+
1068+
saved_weights = saved_operands[2 * num_gemms : 3 * num_gemms]
1069+
for saved_weight in saved_weights:
1070+
assert isinstance(saved_weight, torch.Tensor)
1071+
1072+
8611073
@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list)
8621074
@pytest.mark.parametrize("module_type", ("linear", "layernorm_linear", "ops_linear"))
8631075
@pytest.mark.parametrize("input_shape,out_features", _shape_test_cases)

transformer_engine/pytorch/module/grouped_linear.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,8 @@ def forward(
431431
backward_override = None
432432
if backward_override == "high_precision":
433433
save_original_input = True
434+
elif backward_override == "dequantized":
435+
save_original_input = False
434436

435437
num_gemms = len(m_splits)
436438
weights = weights_and_biases[:num_gemms]

transformer_engine/pytorch/module/linear.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,8 @@ def _linear_forward_impl(
289289
is_fsdp2 = args.is_fsdp2
290290
if backward_override == "high_precision":
291291
save_original_input = True
292+
elif backward_override == "dequantized":
293+
save_original_input = False
292294

293295
# NVTX label for profiling
294296
nvtx_label = "transformer_engine._Linear.forward"

0 commit comments

Comments
 (0)