Skip to content

Commit 87a9dba

Browse files
authored
torch.testing.assert_equal didn't make it (#273)
looks like pt-1.11 dropped `torch.testing.assert_equal`, so using `torch.testing.assert_equal` instead
1 parent affff3d commit 87a9dba

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

Diff for: megatron/testing_utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,9 @@ def get_gpu_count():
232232
return 0
233233

234234
def torch_assert_equal(actual, expected, **kwargs):
235-
# assert_equal was added around pt-1.9, it does better checks - e.g will check dimensions match
236-
if hasattr(torch.testing, "assert_equal"):
237-
return torch.testing.assert_equal(actual, expected, **kwargs)
235+
# assert_close was added around pt-1.9, it does better checks - e.g will check dimensions match
236+
if hasattr(torch.testing, "assert_close"):
237+
return torch.testing.assert_close(actual, expected, rtol=0.0, atol=0.0, **kwargs)
238238
else:
239239
return torch.allclose(actual, expected, rtol=0.0, atol=0.0)
240240

@@ -886,4 +886,4 @@ def flatten_arguments(args):
886886
887887
Example: {"arg1": "value1", "arg2": "value2"} -> ["IGNORED", "arg1", "value1", "arg2", "value2"]
888888
"""
889-
return ["IGNORED"] + [item for key_value in args.items() for item in key_value if item != ""]
889+
return ["IGNORED"] + [item for key_value in args.items() for item in key_value if item != ""]

0 commit comments

Comments
 (0)