Skip to content

Commit 142d3ea

Browse files
[Feature] Add --mg-fl-prefer argument for Megatron-LM-FL vendor selection (#1183)
## Summary - Add `--mg-fl-prefer` CLI argument to select the preferred vendor backend (`cuda`/`musa`/`txda`) for Megatron-LM-FL plugin override dispatch - Sync the argument value to the `MG_FL_PREFER` environment variable at training startup, following the same pattern as `--te-fl-prefer` / `TE_FL_PREFER` - Rename the argument group from "flagscale transformer engine fl" to "flagscale fl" to reflect the broader scope ## Changes ### `flagscale/train/megatron/training/arguments_fs.py` - Add `--mg-fl-prefer` argument with choices `['cuda', 'musa', 'txda']`, default empty string - Rename argument group title to "flagscale fl" ### `flagscale/train/megatron/training/training.py` - Sync `args.mg_fl_prefer` to `os.environ['MG_FL_PREFER']` in `pretrain()`, alongside the existing `TE_FL_PREFER` sync logic ## Usage CLI: ```bash python train.py --mg-fl-prefer musa --te-fl-prefer reference ``` YAML config: ```yaml model: mg_fl_prefer: musa te_fl_prefer: reference ``` ## Test plan - [ ] Verify `--mg-fl-prefer musa` sets `MG_FL_PREFER=musa` in the environment - [ ] Verify omitting `--mg-fl-prefer` does not set `MG_FL_PREFER` - [ ] Verify invalid values are rejected by argparse choices validation - [ ] Verify compatibility with existing `--te-fl-prefer` argument
1 parent 613b0d8 commit 142d3ea

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

flagscale/train/megatron/training/arguments_fs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,9 @@ def _add_regularization_args(parser):
765765

766766

767767
def _add_flagos_args(parser):
768-
group = parser.add_argument_group(title="flagscale transformer engine fl")
768+
group = parser.add_argument_group(title="flagscale fl")
769+
group.add_argument('--mg-fl-prefer', type=str, choices=['cuda', 'musa', 'txda'], default='',
770+
help='Backend selection for megatron fl.')
769771
group.add_argument('--te-fl-prefer', type=str, choices=['flagos', 'vendor', 'reference'], default='vendor',
770772
help='Backend selection for transformer engine fl.')
771773
group.add_argument('--te-fl-per-op', type=str, default=None,

flagscale/train/megatron/training/training.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,8 @@ def pretrain(
805805

806806
###### FlagScale Begin ######
807807
args = get_args()
808+
if args.mg_fl_prefer:
809+
os.environ['MG_FL_PREFER'] = args.mg_fl_prefer
808810
# enable flagos:triton / vendor:cuda / reference:torch backend for transformer engine fl
809811
if args.te_fl_prefer:
810812
os.environ['TE_FL_PREFER'] = args.te_fl_prefer

0 commit comments

Comments
 (0)