@@ -493,7 +493,7 @@ def test_moe_deepseek_scanned_bf16(self):
493493 "megablox=False" ,
494494 "per_device_batch_size=2" ,
495495 "max_target_length=1024" ,
496- "attention=dot_product" , # Change to flash attention once it works for MLA
496+ "attention=flash" ,
497497 "dtype=bfloat16" ,
498498 "weight_dtype=bfloat16" ,
499499 "scan_layers=True" ,
@@ -518,7 +518,7 @@ def test_moe_deepseek_unscanned_bf16(self):
518518 "megablox=False" ,
519519 "per_device_batch_size=1" ,
520520 "max_target_length=1024" ,
521- "attention=dot_product" , # Change to flash attention once it works for MLA
521+ "attention=flash" ,
522522 "dtype=bfloat16" ,
523523 "weight_dtype=bfloat16" ,
524524 "scan_layers=False" ,
@@ -541,7 +541,7 @@ def test_moe_deepseek_with_device_limit(self):
541541 "megablox=False" ,
542542 "per_device_batch_size=1" ,
543543 "max_target_length=1024" ,
544- "attention=dot_product" , # Change to flash attention once it works for MLA
544+ "attention=flash" ,
545545 "dtype=bfloat16" ,
546546 "weight_dtype=bfloat16" ,
547547 "n_routing_groups=8" ,
@@ -565,7 +565,7 @@ def test_moe_deepseek_without_device_limit(self):
565565 "megablox=False" ,
566566 "per_device_batch_size=1" ,
567567 "max_target_length=1024" ,
568- "attention=dot_product" , # Change to flash attention once it works for MLA
568+ "attention=flash" ,
569569 "dtype=bfloat16" ,
570570 "weight_dtype=bfloat16" ,
571571 "n_routing_groups=-1" ,
@@ -585,7 +585,7 @@ def test_moe_deepseek_pipeline_subset(self):
585585 "compile_topology_num_slices=8" ,
586586 "use_iota_embed=true" ,
587587 "model_name=deepseek3-671b" ,
588- "megablox=False" , # dropless not yet supported (b/418313093)
588+ "megablox=True" ,
589589 "sparse_matmul=False" ,
590590 "capacity_factor=1" ,
591591 "per_device_batch_size=1" ,
0 commit comments