diff --git a/flagscale/train/megatron/training/arguments_fs.py b/flagscale/train/megatron/training/arguments_fs.py index 1bf27fc328..1b9e6e0be4 100644 --- a/flagscale/train/megatron/training/arguments_fs.py +++ b/flagscale/train/megatron/training/arguments_fs.py @@ -765,7 +765,9 @@ def _add_regularization_args(parser): def _add_flagos_args(parser): - group = parser.add_argument_group(title="flagscale transformer engine fl") + group = parser.add_argument_group(title="flagscale fl") + group.add_argument('--mg-fl-prefer', type=str, choices=['cuda', 'musa', 'txda'], default='', + help='Backend selection for megatron fl.') group.add_argument('--te-fl-prefer', type=str, choices=['flagos', 'vendor', 'reference'], default='vendor', help='Backend selection for transformer engine fl.') group.add_argument('--te-fl-per-op', type=str, default=None, diff --git a/flagscale/train/megatron/training/training.py b/flagscale/train/megatron/training/training.py index 674403e3dc..cd727f9ac2 100644 --- a/flagscale/train/megatron/training/training.py +++ b/flagscale/train/megatron/training/training.py @@ -805,6 +805,8 @@ def pretrain( ###### FlagScale Begin ###### args = get_args() + if args.mg_fl_prefer: + os.environ['MG_FL_PREFER'] = args.mg_fl_prefer # enable flagos:triton / vendor:cuda / reference:torch backend for transformer engine fl if args.te_fl_prefer: os.environ['TE_FL_PREFER'] = args.te_fl_prefer