@@ -585,10 +585,19 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
585585@pytest .mark .parametrize ("fp8_recipe" , fp8_recipes )
586586@pytest .mark .parametrize ("fp8_model_params" , all_boolean )
587587@pytest .mark .parametrize ("use_bias" , all_boolean )
588+ @pytest .mark .parametrize ("single_param" , all_boolean )
588589@pytest .mark .parametrize ("empty_split" , ["first" , "last" , "middle" ])
589590@pytest .mark .parametrize ("num_gemms" , [4 ])
590591def test_sanity_grouped_linear (
591- dtype , bs , model , fp8_recipe , fp8_model_params , use_bias , num_gemms , empty_split
592+ dtype ,
593+ bs ,
594+ model ,
595+ fp8_recipe ,
596+ fp8_model_params ,
597+ use_bias ,
598+ single_param ,
599+ num_gemms ,
600+ empty_split ,
592601):
593602 if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params :
594603 pytest .skip ("FP8 model parameters are not supported in debug mode." )
@@ -598,6 +607,9 @@ def test_sanity_grouped_linear(
598607 bs = bs * 16
599608 num_tokens = bs * config .max_seqlen_q * (num_gemms - 1 )
600609
610+ if single_param :
611+ os .environ ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS" ] = "1"
612+
601613 if fp8_recipe is not None :
602614 if not is_fp8_supported (config ):
603615 pytest .skip ("Model config does not support FP8" )
@@ -617,7 +629,8 @@ def test_sanity_grouped_linear(
617629 # Verify that weights are stored in contiguous GroupedTensor storage.
618630 weights = [getattr (te_grouped_linear , f"weight{ i } " ) for i in range (num_gemms )]
619631 if fp8_recipe is None or not (fp8_recipe .delayed () or fp8_recipe .float8_current_scaling ()):
620- check_grouped_tensor_pointers (weights , fp8_recipe )
632+ if single_param :
633+ check_grouped_tensor_pointers (weights , fp8_recipe )
621634
622635 inp_hidden_states = torch .randn (
623636 num_tokens , config .hidden_size , dtype = dtype , requires_grad = True
@@ -636,6 +649,9 @@ def test_sanity_grouped_linear(
636649 loss .backward ()
637650 assert out .shape == (num_tokens , ffn_hidden_size )
638651
652+ if single_param :
653+ del os .environ ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS" ]
654+
639655
640656@pytest .mark .parametrize ("dtype" , param_types )
641657@pytest .mark .parametrize ("fp8_recipe" , fp8_recipes )
0 commit comments