Skip to content

Commit 08d9f20

Browse files
Merge pull request #2368 from AI-Hypercomputer:chengnuojin-fix-vmem
PiperOrigin-RevId: 809331612
2 parents 974323d + 44ef76c commit 08d9f20

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

.github/workflows/run_tests_internal.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,4 +67,4 @@ jobs:
6767
FINAL_PYTEST_MARKER="${{ inputs.pytest_marker }} and not scheduled_only"
6868
fi
6969
python3 -m pip install -e . --no-dependencies &&
70-
python3 -m pytest -v -m "${FINAL_PYTEST_MARKER}" --durations=0
70+
LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536' python3 -m pytest -v -m "${FINAL_PYTEST_MARKER}" --durations=0

tests/train_compile_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ def test_moe_deepseek_scanned_bf16(self):
493493
"megablox=False",
494494
"per_device_batch_size=2",
495495
"max_target_length=1024",
496-
"attention=dot_product", # Change to flash attention once it works for MLA
496+
"attention=flash",
497497
"dtype=bfloat16",
498498
"weight_dtype=bfloat16",
499499
"scan_layers=True",
@@ -518,7 +518,7 @@ def test_moe_deepseek_unscanned_bf16(self):
518518
"megablox=False",
519519
"per_device_batch_size=1",
520520
"max_target_length=1024",
521-
"attention=dot_product", # Change to flash attention once it works for MLA
521+
"attention=flash",
522522
"dtype=bfloat16",
523523
"weight_dtype=bfloat16",
524524
"scan_layers=False",
@@ -541,7 +541,7 @@ def test_moe_deepseek_with_device_limit(self):
541541
"megablox=False",
542542
"per_device_batch_size=1",
543543
"max_target_length=1024",
544-
"attention=dot_product", # Change to flash attention once it works for MLA
544+
"attention=flash",
545545
"dtype=bfloat16",
546546
"weight_dtype=bfloat16",
547547
"n_routing_groups=8",
@@ -565,7 +565,7 @@ def test_moe_deepseek_without_device_limit(self):
565565
"megablox=False",
566566
"per_device_batch_size=1",
567567
"max_target_length=1024",
568-
"attention=dot_product", # Change to flash attention once it works for MLA
568+
"attention=flash",
569569
"dtype=bfloat16",
570570
"weight_dtype=bfloat16",
571571
"n_routing_groups=-1",
@@ -585,7 +585,7 @@ def test_moe_deepseek_pipeline_subset(self):
585585
"compile_topology_num_slices=8",
586586
"use_iota_embed=true",
587587
"model_name=deepseek3-671b",
588-
"megablox=False", # dropless not yet supported (b/418313093)
588+
"megablox=True",
589589
"sparse_matmul=False",
590590
"capacity_factor=1",
591591
"per_device_batch_size=1",

0 commit comments

Comments
 (0)