diff --git a/fbgemm_gpu/codegen/utils/embedding_bounds_check_host.cpp b/fbgemm_gpu/codegen/utils/embedding_bounds_check_host.cpp index f29645c217..7fd2377bcd 100644 --- a/fbgemm_gpu/codegen/utils/embedding_bounds_check_host.cpp +++ b/fbgemm_gpu/codegen/utils/embedding_bounds_check_host.cpp @@ -60,13 +60,21 @@ void bounds_check_indices_cuda( const int64_t info_B_num_bits, const int64_t info_B_mask, const int8_t bounds_check_version) { +#if USE_ROCM + // Force using bounds_check_indices v2 on ROCm because ROCm has a constraint + // that the gridDim * blockDim has to be smaller than 2^32. The v1 kernel can + // be launched with gridDim * blockDim > 2^32 while the v2 kernel limits the + // gridDim size to 64 * # of SMs. Thus, its gridDim * blockDim is guaranteed + // to be smaller than 2^32 + const auto bounds_check_indices_fn = _bounds_check_indices_cuda_v2; +#else TORCH_CHECK(bounds_check_version == 1 || bounds_check_version == 2); - const static bool use_v2 = - fbgemm_gpu::config::is_feature_enabled( - fbgemm_gpu::config::FeatureGateName::BOUNDS_CHECK_INDICES_V2) || - bounds_check_version == 2; - const auto bounds_check_indices_fn = - use_v2 ? _bounds_check_indices_cuda_v2 : _bounds_check_indices_cuda_v1; + const static bool use_v2 = fbgemm_gpu::config::is_feature_enabled( + fbgemm_gpu::config::FeatureGateName::BOUNDS_CHECK_INDICES_V2); + const auto bounds_check_indices_fn = (use_v2 || bounds_check_version == 2) + ? _bounds_check_indices_cuda_v2 + : _bounds_check_indices_cuda_v1; +#endif bounds_check_indices_fn( rows_per_table, indices, diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index 8b642361d8..a2fb1ea26b 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -829,6 +829,7 @@ def bounds_check_indices_abstract( b_t_map: Optional[torch.Tensor] = None, info_B_num_bits: int = -1, info_B_mask: int = -1, + bounds_check_version: int = 1, ) -> None: """ This meta function is used to fake the bounds checking diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index c5a13cd768..cf99e15d68 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -716,6 +716,16 @@ def __init__( # noqa C901 # See: # https://fb.workplace.com/groups/fbgemmusers/permalink/9438488366231860/ cache_precision = SparseType.FP32 + self.log("Override cache_precision=SparseType.FP32 on ROCm") + + # NOTE: Use bounds_check_indices v2 on ROCm because ROCm has a + # constraint that the gridDim * blockDim has to be smaller than + # 2^32. The v1 kernel can be launched with gridDim * blockDim > + # 2^32 while the v2 kernel limits the gridDim size to 64 * # of + # SMs. Thus, its gridDim * blockDim is guaranteed to be smaller + # than 2^32 + self.bounds_check_version = 2 + self.log("Override bounds_check_version=2 on ROCm") else: # NOTE: The changes from D65865527 are retained here until we can # test that the the hack also works for non-ROCm environments. diff --git a/fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py b/fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py index e88934693c..fd9ccdcad1 100644 --- a/fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py +++ b/fbgemm_gpu/test/tbe/utils/split_embeddings_utils_test.py @@ -237,6 +237,7 @@ def test_transpose(self, B: int, T: int, E: int) -> None: ] ), mixed_B=st.booleans(), + bounds_check_version=st.sampled_from((1, 2)), ) @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) def test_bounds_check( # noqa C901 @@ -249,6 +250,7 @@ def test_bounds_check( # noqa C901 weighted: bool, dtype: torch.dtype, mixed_B: bool, + bounds_check_version: int, ) -> None: rows_per_table = torch.tensor( np.random.randint(low=1, high=1000, size=(T,)) @@ -348,6 +350,7 @@ def test_bounds_check( # noqa C901 bounds_check_mode, warning, weights, + bounds_check_version=bounds_check_version, **vbe_args, ) # we don't modify when we are in-bounds. @@ -361,6 +364,7 @@ def test_bounds_check( # noqa C901 bounds_check_mode, warning, weights, + bounds_check_version=bounds_check_version, **vbe_args, ) torch.testing.assert_close(indices, torch.zeros_like(indices)) @@ -376,6 +380,7 @@ def test_bounds_check( # noqa C901 bounds_check_mode, warning, weights, + bounds_check_version=bounds_check_version, **vbe_args, ) # It would be nice to test the CUDA implementation of BoundsCheckMode==FATAL, @@ -397,6 +402,7 @@ def test_bounds_check( # noqa C901 bounds_check_mode, warning, weights, + bounds_check_version=bounds_check_version, **vbe_args, ) if offsets.numel() > 0: @@ -417,6 +423,7 @@ def test_bounds_check( # noqa C901 bounds_check_mode, warning, weights, + bounds_check_version=bounds_check_version, ) # test offsets.size(0) ! = B * T + 1 case. Here we test with T >= 2 case. @@ -444,6 +451,7 @@ def test_bounds_check( # noqa C901 bounds_check_mode, warning, weights, + bounds_check_version=bounds_check_version, ) # test weights.size(0) != indices.size(0) case @@ -459,6 +467,7 @@ def test_bounds_check( # noqa C901 warning, weights, **vbe_args, + bounds_check_version=bounds_check_version, ) @given(