Skip to content

Commit 245d7a8

Browse files
sryapfacebook-github-bot
authored andcommitted
Add tests for bounds_check_indices v2 (#3920)
Summary: X-link: facebookresearch/FBGEMM#1008 As title Differential Revision: D72345566
1 parent 9bdb32a commit 245d7a8

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

Diff for: fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py

+9
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def test_transpose(self, B: int, T: int, E: int) -> None:
237237
]
238238
),
239239
mixed_B=st.booleans(),
240+
bounds_check_version=st.sampled_from((1, 2)),
240241
)
241242
@settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None)
242243
def test_bounds_check( # noqa C901
@@ -249,6 +250,7 @@ def test_bounds_check( # noqa C901
249250
weighted: bool,
250251
dtype: torch.dtype,
251252
mixed_B: bool,
253+
bounds_check_version: int,
252254
) -> None:
253255
rows_per_table = torch.tensor(
254256
np.random.randint(low=1, high=1000, size=(T,))
@@ -348,6 +350,7 @@ def test_bounds_check( # noqa C901
348350
bounds_check_mode,
349351
warning,
350352
weights,
353+
bounds_check_version=bounds_check_version,
351354
**vbe_args,
352355
)
353356
# we don't modify when we are in-bounds.
@@ -361,6 +364,7 @@ def test_bounds_check( # noqa C901
361364
bounds_check_mode,
362365
warning,
363366
weights,
367+
bounds_check_version=bounds_check_version,
364368
**vbe_args,
365369
)
366370
torch.testing.assert_close(indices, torch.zeros_like(indices))
@@ -376,6 +380,7 @@ def test_bounds_check( # noqa C901
376380
bounds_check_mode,
377381
warning,
378382
weights,
383+
bounds_check_version=bounds_check_version,
379384
**vbe_args,
380385
)
381386
# It would be nice to test the CUDA implementation of BoundsCheckMode==FATAL,
@@ -397,6 +402,7 @@ def test_bounds_check( # noqa C901
397402
bounds_check_mode,
398403
warning,
399404
weights,
405+
bounds_check_version=bounds_check_version,
400406
**vbe_args,
401407
)
402408
if offsets.numel() > 0:
@@ -417,6 +423,7 @@ def test_bounds_check( # noqa C901
417423
bounds_check_mode,
418424
warning,
419425
weights,
426+
bounds_check_version=bounds_check_version,
420427
)
421428

422429
# test offsets.size(0) ! = B * T + 1 case. Here we test with T >= 2 case.
@@ -444,6 +451,7 @@ def test_bounds_check( # noqa C901
444451
bounds_check_mode,
445452
warning,
446453
weights,
454+
bounds_check_version=bounds_check_version,
447455
)
448456

449457
# test weights.size(0) != indices.size(0) case
@@ -459,6 +467,7 @@ def test_bounds_check( # noqa C901
459467
warning,
460468
weights,
461469
**vbe_args,
470+
bounds_check_version=bounds_check_version,
462471
)
463472

464473
@given(

0 commit comments

Comments
 (0)