Skip to content

Add tests for bounds_check_indices v2 #3920

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions fbgemm_gpu/codegen/utils/embedding_bounds_check_host.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,21 @@ void bounds_check_indices_cuda(
const int64_t info_B_num_bits,
const int64_t info_B_mask,
const int8_t bounds_check_version) {
#if USE_ROCM
// Force using bounds_check_indices v2 on ROCm because ROCm has a constraint
// that the gridDim * blockDim has to be smaller than 2^32. The v1 kernel can
// be launched with gridDim * blockDim > 2^32 while the v2 kernel limits the
// gridDim size to 64 * # of SMs. Thus, its gridDim * blockDim is guaranteed
// to be smaller than 2^32
const auto bounds_check_indices_fn = _bounds_check_indices_cuda_v2;
#else
TORCH_CHECK(bounds_check_version == 1 || bounds_check_version == 2);
const static bool use_v2 =
fbgemm_gpu::config::is_feature_enabled(
fbgemm_gpu::config::FeatureGateName::BOUNDS_CHECK_INDICES_V2) ||
bounds_check_version == 2;
const auto bounds_check_indices_fn =
use_v2 ? _bounds_check_indices_cuda_v2 : _bounds_check_indices_cuda_v1;
const static bool use_v2 = fbgemm_gpu::config::is_feature_enabled(
fbgemm_gpu::config::FeatureGateName::BOUNDS_CHECK_INDICES_V2);
const auto bounds_check_indices_fn = (use_v2 || bounds_check_version == 2)
? _bounds_check_indices_cuda_v2
: _bounds_check_indices_cuda_v1;
#endif
bounds_check_indices_fn(
rows_per_table,
indices,
Expand Down
1 change: 1 addition & 0 deletions fbgemm_gpu/fbgemm_gpu/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,7 @@ def bounds_check_indices_abstract(
b_t_map: Optional[torch.Tensor] = None,
info_B_num_bits: int = -1,
info_B_mask: int = -1,
bounds_check_version: int = 1,
) -> None:
"""
This meta function is used to fake the bounds checking
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,16 @@ def __init__( # noqa C901
# See:
# https://fb.workplace.com/groups/fbgemmusers/permalink/9438488366231860/
cache_precision = SparseType.FP32
self.log("Override cache_precision=SparseType.FP32 on ROCm")

# NOTE: Use bounds_check_indices v2 on ROCm because ROCm has a
# constraint that the gridDim * blockDim has to be smaller than
# 2^32. The v1 kernel can be launched with gridDim * blockDim >
# 2^32 while the v2 kernel limits the gridDim size to 64 * # of
# SMs. Thus, its gridDim * blockDim is guaranteed to be smaller
# than 2^32
self.bounds_check_version = 2
self.log("Override bounds_check_version=2 on ROCm")
else:
# NOTE: The changes from D65865527 are retained here until we can
# test that the the hack also works for non-ROCm environments.
Expand Down
9 changes: 9 additions & 0 deletions fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def test_transpose(self, B: int, T: int, E: int) -> None:
]
),
mixed_B=st.booleans(),
bounds_check_version=st.sampled_from((1, 2)),
)
@settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None)
def test_bounds_check( # noqa C901
Expand All @@ -249,6 +250,7 @@ def test_bounds_check( # noqa C901
weighted: bool,
dtype: torch.dtype,
mixed_B: bool,
bounds_check_version: int,
) -> None:
rows_per_table = torch.tensor(
np.random.randint(low=1, high=1000, size=(T,))
Expand Down Expand Up @@ -348,6 +350,7 @@ def test_bounds_check( # noqa C901
bounds_check_mode,
warning,
weights,
bounds_check_version=bounds_check_version,
**vbe_args,
)
# we don't modify when we are in-bounds.
Expand All @@ -361,6 +364,7 @@ def test_bounds_check( # noqa C901
bounds_check_mode,
warning,
weights,
bounds_check_version=bounds_check_version,
**vbe_args,
)
torch.testing.assert_close(indices, torch.zeros_like(indices))
Expand All @@ -376,6 +380,7 @@ def test_bounds_check( # noqa C901
bounds_check_mode,
warning,
weights,
bounds_check_version=bounds_check_version,
**vbe_args,
)
# It would be nice to test the CUDA implementation of BoundsCheckMode==FATAL,
Expand All @@ -397,6 +402,7 @@ def test_bounds_check( # noqa C901
bounds_check_mode,
warning,
weights,
bounds_check_version=bounds_check_version,
**vbe_args,
)
if offsets.numel() > 0:
Expand All @@ -417,6 +423,7 @@ def test_bounds_check( # noqa C901
bounds_check_mode,
warning,
weights,
bounds_check_version=bounds_check_version,
)

# test offsets.size(0) ! = B * T + 1 case. Here we test with T >= 2 case.
Expand Down Expand Up @@ -444,6 +451,7 @@ def test_bounds_check( # noqa C901
bounds_check_mode,
warning,
weights,
bounds_check_version=bounds_check_version,
)

# test weights.size(0) != indices.size(0) case
Expand All @@ -459,6 +467,7 @@ def test_bounds_check( # noqa C901
warning,
weights,
**vbe_args,
bounds_check_version=bounds_check_version,
)

@given(
Expand Down
Loading