3838)
3939
4040
41+ def _get_number_of_gpu_sm () -> int :
42+ if not torch .cuda .is_available ():
43+ raise RuntimeError ("CUDA is not available" )
44+ device_props = torch .cuda .get_device_properties (0 )
45+ return device_props .multi_processor_count
46+
47+
4148def _str_1d_tensor (t : torch .Tensor ) -> str :
4249 sl = [f"{ x :7.4f} " for x in t .tolist ()]
4350 if len (sl ) > 5 :
@@ -48,6 +55,7 @@ def _str_1d_tensor(t: torch.Tensor) -> str:
4855def _do_test_all_to_all (
4956 pgi : ProcessGroupInfo ,
5057 dp_size : int ,
58+ max_sm_count : int ,
5159 moe : MoEConfig ,
5260 internode : bool ,
5361 use_compile : bool ,
@@ -80,6 +88,7 @@ def _do_test_all_to_all(
8088 * torch .float32 .itemsize
8189 )
8290 ),
91+ max_sm_count = max_sm_count ,
8392 )
8493 else :
8594 ata = AllToAll .intranode (
@@ -100,6 +109,7 @@ def _do_test_all_to_all(
100109 * torch .float32 .itemsize
101110 )
102111 ),
112+ max_sm_count = max_sm_count ,
103113 )
104114
105115 # Generate the same test data on all ranks
@@ -291,6 +301,7 @@ def _worker_test_all_to_all(
291301 dp_size : int ,
292302 in_dtype : str ,
293303 out_dtype : str ,
304+ max_sm_count : int ,
294305 moe_config : MoEConfig ,
295306 internode : bool ,
296307 use_compile : bool = False ,
@@ -305,18 +316,21 @@ def _worker_test_all_to_all(
305316 out_dtype = getattr (torch , out_dtype ),
306317 )
307318
308- _do_test_all_to_all (pgi , dp_size , moe_config , internode , use_compile )
319+ _do_test_all_to_all (pgi , dp_size , max_sm_count , moe_config , internode , use_compile )
309320
310321 nvshmem_finalize ()
311322
312323
313324@pytest .mark .skipif (torch .cuda .device_count () < 4 , reason = "Requires at least 4 GPUs" )
314325@pytest .mark .parametrize ("in_dtype" , ["bfloat16" , "float8_e4m3fn" , "float16" ])
315326@pytest .mark .parametrize ("out_dtype" , ["float16" , "bfloat16" ])
327+ @pytest .mark .parametrize (
328+ "max_sm_count" , [_get_number_of_gpu_sm (), _get_number_of_gpu_sm () // 2 ]
329+ )
316330@pytest .mark .parametrize ("internode" , [True , False ])
317331@pytest .mark .parametrize ("use_compile" , [False , True ])
318332def test_all_to_all_4_gpu (
319- in_dtype : str , out_dtype : str , internode : bool , use_compile : bool
333+ in_dtype : str , out_dtype : str , max_sm_count : int , internode : bool , use_compile : bool
320334) -> None :
321335 world_size = 4
322336 dp_size = 2
@@ -326,6 +340,7 @@ def test_all_to_all_4_gpu(
326340 dp_size ,
327341 in_dtype ,
328342 out_dtype ,
343+ max_sm_count ,
329344 small_moe ,
330345 internode ,
331346 use_compile ,
@@ -336,13 +351,15 @@ def _worker_test_all_to_all_multi_node(
336351 pgi : ProcessGroupInfo ,
337352 in_dtype : str ,
338353 out_dtype : str ,
354+ max_sm_count : int ,
339355) -> None :
340356 dp_size = 4
341357 _worker_test_all_to_all (
342358 pgi ,
343359 dp_size ,
344360 in_dtype ,
345361 out_dtype ,
362+ max_sm_count ,
346363 medium_moe ,
347364 True ,
348365 )
@@ -352,4 +369,7 @@ def _worker_test_all_to_all_multi_node(
352369@pytest .mark .parametrize ("in_dtype" , ["bfloat16" , "float8_e4m3fn" , "float16" ])
353370@pytest .mark .parametrize ("out_dtype" , ["float16" , "bfloat16" ])
354371def test_all_to_all_multi_node (in_dtype : str , out_dtype : str ) -> None :
355- parallel_launch_from_env (_worker_test_all_to_all_multi_node , in_dtype , out_dtype )
372+ max_sm_count = _get_number_of_gpu_sm ()
373+ parallel_launch_from_env (
374+ _worker_test_all_to_all_multi_node , in_dtype , out_dtype , max_sm_count
375+ )
0 commit comments