@@ -239,7 +239,6 @@ def check_fp8(state, var_collect, inputs, masks, labels):
239239def train_and_evaluate (args ):
240240 """Execute model training and evaluation loop."""
241241 print (args )
242- jax .config .update ("jax_use_shardy_partitioner" , args .enable_shardy )
243242
244243 train_ds , test_ds , num_embed = get_datasets (args .max_seq_len )
245244
@@ -474,9 +473,6 @@ def encoder_parser(args):
474473 parser .add_argument (
475474 "--enable-sp" , action = "store_true" , default = False , help = "Enable sequence parallelism."
476475 )
477- parser .add_argument (
478- "--enable-shardy" , action = "store_true" , default = False , help = "Enable Shardy (experimental)."
479- )
480476
481477 return parser .parse_args (args )
482478
@@ -559,70 +555,6 @@ def test_te_nvfp4_with_sp(self):
559555 actual = train_and_evaluate (self .args )
560556 assert actual [0 ] < 0.40 and actual [1 ] > 0.82
561557
562- @unittest .skipIf (not is_bf16_supported (), "Device compute capability 8.0+ is required for BF16" )
563- def test_te_bf16_shardy (self ):
564- """Test Transformer Engine with BF16"""
565- self .args .enable_shardy = True
566- actual = train_and_evaluate (self .args )
567- assert actual [0 ] < 0.36 and actual [1 ] > 0.84
568-
569- @unittest .skipIf (not is_fp8_supported , fp8_reason )
570- def test_te_delayed_scaling_fp8_shardy (self ):
571- """Test Transformer Engine with DelayedScaling FP8"""
572- self .args .enable_shardy = True
573- self .args .use_fp8 = True
574- self .args .fp8_recipe = "DelayedScaling"
575- actual = train_and_evaluate (self .args )
576- assert actual [0 ] < 0.362 and actual [1 ] > 0.84
577-
578- @unittest .skipIf (not is_fp8_supported , fp8_reason )
579- def test_te_delayed_scaling_fp8_with_sp_shardy (self ):
580- """Test Transformer Engine with DelayedScaling FP8 + SP"""
581- self .args .enable_shardy = True
582- self .args .enable_sp = True
583- self .args .use_fp8 = True
584- self .args .fp8_recipe = "DelayedScaling"
585- actual = train_and_evaluate (self .args )
586- assert actual [0 ] < 0.362 and actual [1 ] > 0.84
587-
588- @unittest .skipIf (not is_mxfp8_supported , mxfp8_reason )
589- def test_te_mxfp8_shardy (self ):
590- """Test Transformer Engine with MXFP8"""
591- self .args .enable_shardy = True
592- self .args .use_fp8 = True
593- self .args .fp8_recipe = "MXFP8BlockScaling"
594- actual = train_and_evaluate (self .args )
595- assert actual [0 ] < 0.36 and actual [1 ] > 0.84
596-
597- @unittest .skipIf (not is_nvfp4_supported , nvfp4_reason )
598- def test_te_nvfp4_shardy (self ):
599- """Test Transformer Engine with NVFP4"""
600- self .args .enable_shardy = True
601- self .args .use_fp8 = True
602- self .args .fp8_recipe = "NVFP4BlockScaling"
603- actual = train_and_evaluate (self .args )
604- assert actual [0 ] < 0.40 and actual [1 ] > 0.82
605-
606- @unittest .skipIf (not is_mxfp8_supported , mxfp8_reason )
607- def test_te_mxfp8_with_sp_shardy (self ):
608- """Test Transformer Engine with MXFP8 + SP"""
609- self .args .enable_shardy = True
610- self .args .enable_sp = True
611- self .args .use_fp8 = True
612- self .args .fp8_recipe = "MXFP8BlockScaling"
613- actual = train_and_evaluate (self .args )
614- assert actual [0 ] < 0.36 and actual [1 ] > 0.84
615-
616- @unittest .skipIf (not is_nvfp4_supported , nvfp4_reason )
617- def test_te_nvfp4_with_sp_shardy (self ):
618- """Test Transformer Engine with NVFP4"""
619- self .args .enable_shardy = True
620- self .args .enable_sp = True
621- self .args .use_fp8 = True
622- self .args .fp8_recipe = "NVFP4BlockScaling"
623- actual = train_and_evaluate (self .args )
624- assert actual [0 ] < 0.40 and actual [1 ] > 0.82
625-
626558
627559if __name__ == "__main__" :
628560 train_and_evaluate (encoder_parser (None ))
0 commit comments