From fc7b929d2cab11f2a2f88daff5c6d46b8f2a2fd3 Mon Sep 17 00:00:00 2001 From: sarunya Date: Wed, 2 Apr 2025 16:16:38 -0700 Subject: [PATCH 1/2] Use bounds_check_indices v2 on ROCm Differential Revision: D72334377 --- .../utils/embedding_bounds_check_host.cpp | 20 +++++++++++++------ ...t_table_batched_embeddings_ops_training.py | 10 ++++++++++ 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check_host.cpp b/fbgemm_gpu/codegen/utils/embedding_bounds_check_host.cpp index f29645c21..7fd2377bc 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check_host.cpp +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check_host.cpp @@ -60,13 +60,21 @@ void bounds_check_indices_cuda( const int64_t info_B_num_bits, const int64_t info_B_mask, const int8_t bounds_check_version) { +#if USE_ROCM + // Force using bounds_check_indices v2 on ROCm because ROCm has a constraint + // that the gridDim * blockDim has to be smaller than 2^32. The v1 kernel can + // be launched with gridDim * blockDim > 2^32 while the v2 kernel limits the + // gridDim size to 64 * # of SMs. Thus, its gridDim * blockDim is guaranteed + // to be smaller than 2^32 + const auto bounds_check_indices_fn = _bounds_check_indices_cuda_v2; +#else TORCH_CHECK(bounds_check_version == 1 || bounds_check_version == 2); - const static bool use_v2 = - fbgemm_gpu::config::is_feature_enabled( - fbgemm_gpu::config::FeatureGateName::BOUNDS_CHECK_INDICES_V2) || - bounds_check_version == 2; - const auto bounds_check_indices_fn = - use_v2 ? _bounds_check_indices_cuda_v2 : _bounds_check_indices_cuda_v1; + const static bool use_v2 = fbgemm_gpu::config::is_feature_enabled( + fbgemm_gpu::config::FeatureGateName::BOUNDS_CHECK_INDICES_V2); + const auto bounds_check_indices_fn = (use_v2 || bounds_check_version == 2) + ? _bounds_check_indices_cuda_v2 + : _bounds_check_indices_cuda_v1; +#endif bounds_check_indices_fn( rows_per_table, indices, diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index c5a13cd76..cf99e15d6 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -716,6 +716,16 @@ def __init__( # noqa C901 # See: # https://fb.workplace.com/groups/fbgemmusers/permalink/9438488366231860/ cache_precision = SparseType.FP32 + self.log("Override cache_precision=SparseType.FP32 on ROCm") + + # NOTE: Use bounds_check_indices v2 on ROCm because ROCm has a + # constraint that the gridDim * blockDim has to be smaller than + # 2^32. The v1 kernel can be launched with gridDim * blockDim > + # 2^32 while the v2 kernel limits the gridDim size to 64 * # of + # SMs. Thus, its gridDim * blockDim is guaranteed to be smaller + # than 2^32 + self.bounds_check_version = 2 + self.log("Override bounds_check_version=2 on ROCm") else: # NOTE: The changes from D65865527 are retained here until we can # test that the the hack also works for non-ROCm environments. From 65f154d13e989732f8f371e28b00ac66cbf9bf0c Mon Sep 17 00:00:00 2001 From: Sarunya Pumma Date: Wed, 2 Apr 2025 23:21:20 -0700 Subject: [PATCH 2/2] Add tests for bounds_check_indices v2 (#3920) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/1008 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3920 As title Reviewed By: q10 Differential Revision: D72345566 --- fbgemm_gpu/fbgemm_gpu/sparse_ops.py | 1 + fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index 8b642361d..a2fb1ea26 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -829,6 +829,7 @@ def bounds_check_indices_abstract( b_t_map: Optional[torch.Tensor] = None, info_B_num_bits: int = -1, info_B_mask: int = -1, + bounds_check_version: int = 1, ) -> None: """ This meta function is used to fake the bounds checking diff --git a/fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py b/fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py index e88934693..fd9ccdcad 100644 --- a/fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py +++ b/fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py @@ -237,6 +237,7 @@ def test_transpose(self, B: int, T: int, E: int) -> None: ] ), mixed_B=st.booleans(), + bounds_check_version=st.sampled_from((1, 2)), ) @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) def test_bounds_check( # noqa C901 @@ -249,6 +250,7 @@ def test_bounds_check( # noqa C901 weighted: bool, dtype: torch.dtype, mixed_B: bool, + bounds_check_version: int, ) -> None: rows_per_table = torch.tensor( np.random.randint(low=1, high=1000, size=(T,)) @@ -348,6 +350,7 @@ def test_bounds_check( # noqa C901 bounds_check_mode, warning, weights, + bounds_check_version=bounds_check_version, **vbe_args, ) # we don't modify when we are in-bounds. @@ -361,6 +364,7 @@ def test_bounds_check( # noqa C901 bounds_check_mode, warning, weights, + bounds_check_version=bounds_check_version, **vbe_args, ) torch.testing.assert_close(indices, torch.zeros_like(indices)) @@ -376,6 +380,7 @@ def test_bounds_check( # noqa C901 bounds_check_mode, warning, weights, + bounds_check_version=bounds_check_version, **vbe_args, ) # It would be nice to test the CUDA implementation of BoundsCheckMode==FATAL, @@ -397,6 +402,7 @@ def test_bounds_check( # noqa C901 bounds_check_mode, warning, weights, + bounds_check_version=bounds_check_version, **vbe_args, ) if offsets.numel() > 0: @@ -417,6 +423,7 @@ def test_bounds_check( # noqa C901 bounds_check_mode, warning, weights, + bounds_check_version=bounds_check_version, ) # test offsets.size(0) ! = B * T + 1 case. Here we test with T >= 2 case. @@ -444,6 +451,7 @@ def test_bounds_check( # noqa C901 bounds_check_mode, warning, weights, + bounds_check_version=bounds_check_version, ) # test weights.size(0) != indices.size(0) case @@ -459,6 +467,7 @@ def test_bounds_check( # noqa C901 warning, weights, **vbe_args, + bounds_check_version=bounds_check_version, ) @given(