Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/Megatron-SWIFT/快速开始.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu2
| torch | >=2.0 | 2.6.0/2.7.1 | |
| transformer_engine | >=2.3 | | |
| apex | | 0.1 | |
| megatron_core | >=0.12 | 0.13 | |
| megatron_core | | 0.13 | |
| flash_attn | | 2.8.1/3.0.0b1 | |
| transformers | >=4.33 | 4.56.2 | |
| modelscope | >=1.23 | | |
Expand Down
2 changes: 1 addition & 1 deletion docs/source_en/Megatron-SWIFT/Quick-start.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ Recommended Operating Environment:
| torch | >=2.0 | 2.6.0/2.7.1 | |
| transformer_engine | >=2.3 | | |
| apex | | 0.1 | |
| megatron_core | >=0.12 | 0.13 | |
| megatron_core | | 0.13 | |
| flash_attn | | 2.8.1/3.0.0b1 | |
| transformers | >=4.33 | 4.56.2 | |
| modelscope | >=1.23 | | |
Expand Down
5 changes: 5 additions & 0 deletions swift/cli/_megatron/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from swift.megatron import megatron_export_main

if __name__ == '__main__':
megatron_export_main()
3 changes: 2 additions & 1 deletion swift/cli/_megatron/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
'pt': 'swift.cli._megatron.pt',
'sft': 'swift.cli._megatron.sft',
'rlhf': 'swift.cli._megatron.rlhf',
'export': 'swift.cli._megatron.export',
}


def cli_main():
return swift_cli_main(ROUTE_MAPPING)
return swift_cli_main(ROUTE_MAPPING, is_megatron=True)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions swift/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _compat_web_ui(argv):
logger.warning('Please use `swift app`.')


def cli_main(route_mapping: Optional[Dict[str, str]] = None) -> None:
def cli_main(route_mapping: Optional[Dict[str, str]] = None, is_megatron: bool = False) -> None:
route_mapping = route_mapping or ROUTE_MAPPING
argv = sys.argv[1:]
_compat_web_ui(argv)
Expand All @@ -62,7 +62,7 @@ def cli_main(route_mapping: Optional[Dict[str, str]] = None) -> None:
file_path = importlib.util.find_spec(route_mapping[method_name]).origin
torchrun_args = get_torchrun_args()
python_cmd = sys.executable
if torchrun_args is None or method_name not in {'pt', 'sft', 'rlhf', 'infer'}:
if not is_megatron and (torchrun_args is None or method_name not in {'pt', 'sft', 'rlhf', 'infer'}):
args = [python_cmd, file_path, *argv]
else:
args = [python_cmd, '-m', 'torch.distributed.run', *torchrun_args, file_path, *argv]
Expand Down
15 changes: 8 additions & 7 deletions swift/llm/argument/base_args/base_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,13 @@ def get_model_processor(self,
**kwargs):
if self.tuner_backend == 'unsloth':
return load_by_unsloth(self)
kwargs.update(self.get_model_kwargs())
res = self.get_model_kwargs()
res.update(kwargs)
# compat rlhf
kwargs['model_id_or_path'] = model or self.model
kwargs['model_type'] = model_type or self.model_type
kwargs['model_revision'] = model_revision or self.model_revision
kwargs['task_type'] = task_type or self.task_type
kwargs['num_labels'] = num_labels or self.num_labels
res['model_id_or_path'] = model or self.model
res['model_type'] = model_type or self.model_type
res['model_revision'] = model_revision or self.model_revision
res['task_type'] = task_type or self.task_type
res['num_labels'] = num_labels or self.num_labels

return get_model_tokenizer(**kwargs)
return get_model_tokenizer(**res)
2 changes: 1 addition & 1 deletion swift/llm/model/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def patch_tp_plan(load_model: bool):
transformers.__version__) < version.parse('4.50') or 'WORLD_SIZE' not in os.environ:
yield
return
logger.info('Patch tp_plan.')
logger.info_once('Patch tp_plan.')
WORLD_SIZE = os.environ.get('WORLD_SIZE')
os.environ['_PATCH_WORLD_SIZE'] = WORLD_SIZE
os.environ.pop('WORLD_SIZE')
Expand Down
1 change: 1 addition & 0 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1685,6 +1685,7 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
assert res['attention_mask'].dtype is torch.bool, f'attention_mask.dtype: {res["attention_mask"].dtype}'
for i, seq_len in enumerate(seq_lens):
res['attention_mask'][i, :, seq_len:] = 0
res['attention_mask'] = ~res['attention_mask']

for key, pad_value in zip(keys, pad_values):
if key not in res:
Expand Down
12 changes: 8 additions & 4 deletions swift/megatron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,20 @@

if TYPE_CHECKING:
from .train import megatron_sft_main, megatron_pt_main, megatron_rlhf_main
from .utils import convert_hf2mcore, convert_mcore2hf, prepare_mcore_model, adapter_state_dict_context
from .argument import MegatronTrainArguments, MegatronRLHFArguments
from .export import megatron_export_main
from .convert import convert_hf2mcore, convert_mcore2hf
from .utils import prepare_mcore_model, adapter_state_dict_context
from .argument import MegatronTrainArguments, MegatronRLHFArguments, MegatronExportArguments
from .model import MegatronModelType, MegatronModelMeta, get_megatron_model_meta, register_megatron_model
from .trainers import MegatronTrainer, MegatronDPOTrainer
from .tuners import LoraParallelLinear
else:
_import_structure = {
'train': ['megatron_sft_main', 'megatron_pt_main', 'megatron_rlhf_main'],
'utils': ['convert_hf2mcore', 'convert_mcore2hf', 'prepare_mcore_model', 'adapter_state_dict_context'],
'argument': ['MegatronTrainArguments', 'MegatronRLHFArguments'],
'export': ['megatron_export_main'],
'convert': ['convert_hf2mcore', 'convert_mcore2hf'],
'utils': ['prepare_mcore_model', 'adapter_state_dict_context'],
'argument': ['MegatronTrainArguments', 'MegatronRLHFArguments', 'MegatronExportArguments'],
'model': ['MegatronModelType', 'MegatronModelMeta', 'get_megatron_model_meta', 'register_megatron_model'],
'trainers': ['MegatronTrainer', 'MegatronDPOTrainer'],
'tuners': ['LoraParallelLinear'],
Expand Down
1 change: 1 addition & 0 deletions swift/megatron/argument/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .export_args import MegatronExportArguments
from .megatron_args import MegatronArguments
from .rlhf_args import MegatronRLHFArguments
from .train_args import MegatronTrainArguments
59 changes: 59 additions & 0 deletions swift/megatron/argument/export_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from dataclasses import dataclass

from swift.llm import HfConfigFactory
from swift.llm.argument.base_args import to_abspath
from swift.utils import get_logger
from .megatron_base_args import MegatronBaseArguments

logger = get_logger()


@dataclass
class MegatronExportArguments(MegatronBaseArguments):
to_hf: bool = False
to_mcore: bool = False
test_convert_precision: bool = False
test_convert_dtype: str = 'float32'
exist_ok: bool = False

def _init_save(self):
if self.save is None:
ckpt_dir = self.ckpt_dir or f'./{self.model_suffix}'
ckpt_dir, ckpt_name = os.path.split(ckpt_dir)
if self.to_mcore:
suffix = 'mcore'
elif self.to_hf:
suffix = 'hf'
self.save = os.path.join(ckpt_dir, f'{ckpt_name}-{suffix}')

self.save = to_abspath(self.save)
if not self.exist_ok and os.path.exists(self.save):
raise FileExistsError(f'args.save: `{self.save}` already exists.')
logger.info(f'args.save: `{self.save}`')

def __post_init__(self):
super().__post_init__()
self._init_save()
self.test_convert_dtype = HfConfigFactory.to_torch_dtype(self.test_convert_dtype)
if self.to_hf or self.to_mcore:
self._init_convert()
if self.model_info.is_moe_model is not None and self.tensor_model_parallel_size > 1:
self.sequence_parallel = True
logger.info('Settting args.sequence_parallel: True')

def _init_convert(self):
convert_kwargs = {
'no_save_optim': True,
'no_save_rng': True,
'no_load_optim': True,
'no_load_rng': True,
'finetune': True,
'attention_backend': 'unfused',
'padding_free': False,
}
for k, v in convert_kwargs.items():
setattr(self, k, v)
if self.model_info.is_moe_model:
self.moe_grouped_gemm = True
7 changes: 4 additions & 3 deletions swift/megatron/argument/megatron_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin):
torch_dtype: Optional[torch.dtype] = None
padding_free: bool = True
mlp_padding_free: bool = False
load_hf_checkpoint: bool = False
save_hf_checkpoint: bool = False
# streaming dataloader
dataloader_persistent_workers: bool = True
dataloader_prefetch_factor: int = 10
Expand Down Expand Up @@ -123,6 +125,8 @@ class ExtraMegatronArguments(RLHFMegatronArgumentsMixin, MegatronTunerMixin):
layer_types: Optional[List[str]] = None
# qwen3_vl, qwen3_omni
mrope_interleaved: Optional[bool] = None
# hf saver
max_shard_size: str = '5GB'

@staticmethod
def load_args_config(ckpt_dir: Optional[str]) -> Dict[str, Any]:
Expand Down Expand Up @@ -452,12 +456,9 @@ def __post_init__(self):
self.seq_length = self.max_position_embeddings
if self.position_embedding_type is None:
self.position_embedding_type = 'rope'
if self.tensorboard_dir is None and self.save is not None:
self.tensorboard_dir = f'{self.save}/runs'
self._init_moe()
self._init_mixed_precision()

self.tensorboard_dir = to_abspath(self.tensorboard_dir)
self.megatron_extra_kwargs = json_parse_to_dict(self.megatron_extra_kwargs)
self._init_no_rope_fusion()

Expand Down
45 changes: 45 additions & 0 deletions swift/megatron/argument/megatron_base_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os
from dataclasses import dataclass

from swift.llm import BaseArguments
from swift.utils import get_logger
from ..model import get_megatron_model_meta
from ..utils import convert_hf_config
from .megatron_args import MegatronArguments

logger = get_logger()


@dataclass
class MegatronBaseArguments(MegatronArguments, BaseArguments):

def __post_init__(self):
self.sequence_parallel_size = self.context_parallel_size
if self.packing:
self.padding_free = True
BaseArguments.__post_init__(self)
self.megatron_model_meta = get_megatron_model_meta(self.model_type)
self.seq_length = self.seq_length or self.packing_length or self.max_length
if self.streaming:
self.dataloader_type = 'external'
if self.num_workers > 1:
self.num_workers = 1
logger.info('Using streaming dataset, setting args.num_workers to 1.')

def init_model_args(self, tokenizer, config):
if self.task_type == 'seq_cls':
self.problem_type = self.problem_type or getattr(config, 'problem_type', None)
logger.info(f'args.problem_type: {self.problem_type}')
kwargs = convert_hf_config(config)
if self.new_special_tokens and kwargs['padded_vocab_size'] < len(tokenizer):
kwargs['padded_vocab_size'] = math.ceil(len(tokenizer) / 128) * 128
self.initialize_embedding = True
logger.info(f'megatron_config: {kwargs}')
for k, v in kwargs.items():
if getattr(self, k) is None:
setattr(self, k, v)
MegatronArguments.__post_init__(self)
self.extra_args = self.parse_to_megatron()
self.extra_args['model_info'] = self.model_info
self.extra_args['model_meta'] = self.model_meta
self.extra_args['megatron_model_meta'] = self.megatron_model_meta
49 changes: 14 additions & 35 deletions swift/megatron/argument/train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,17 @@

from swift.llm import BaseArguments
from swift.llm.argument.base_args import to_abspath
from swift.utils import add_version_to_work_dir, get_logger, init_process_group, is_master
from ..model import get_megatron_model_meta
from .megatron_args import MegatronArguments
from swift.utils import add_version_to_work_dir, get_logger, init_process_group, is_last_rank
from .megatron_base_args import MegatronBaseArguments

logger = get_logger()


@dataclass
class MegatronTrainArguments(MegatronArguments, BaseArguments):
class MegatronTrainArguments(MegatronBaseArguments):
add_version: bool = True
load_args: bool = False

def init_model_args(self, tokenizer, config):
if self.task_type == 'seq_cls':
self.problem_type = self.problem_type or getattr(config, 'problem_type', None)
logger.info(f'args.problem_type: {self.problem_type}')
kwargs = self.megatron_model_meta.convert_hf_config(config)
if self.new_special_tokens and kwargs['padded_vocab_size'] < len(tokenizer):
kwargs['padded_vocab_size'] = math.ceil(len(tokenizer) / 128) * 128
self.initialize_embedding = True
logger.info(f'megatron_config: {kwargs}')
for k, v in kwargs.items():
if getattr(self, k) is None:
setattr(self, k, v)
MegatronArguments.__post_init__(self)
self.extra_args = self.parse_to_megatron()
self.extra_args['model_info'] = self.model_info
self.extra_args['model_meta'] = self.model_meta
self.extra_args['megatron_model_meta'] = self.megatron_model_meta

def _init_save(self):
init_process_group(backend=self.ddp_backend, timeout=self.ddp_timeout)
if self.save is None:
Expand All @@ -45,7 +26,7 @@ def _init_save(self):
if self.add_version:
self.save = add_version_to_work_dir(self.save)
logger.info(f'args.save: {self.save}')
if is_master():
if is_last_rank():
os.makedirs(self.save, exist_ok=True)

def _init_ckpt_dir(self, adapters=None):
Expand All @@ -59,24 +40,22 @@ def _init_ckpt_dir(self, adapters=None):
self.model = old_args.get('model')

def __post_init__(self):
self.sequence_parallel_size = self.context_parallel_size
if self.packing:
self.padding_free = True
self.load = to_abspath(self.load, check_path_exist=True)
BaseArguments.__post_init__(self)
self.megatron_model_meta = get_megatron_model_meta(self.model_type)
super().__post_init__()
if len(self.dataset) == 0 and len(self.cached_dataset) == 0:
raise ValueError(f'self.dataset: {self.dataset}, self.cached_dataset: {self.cached_dataset}. '
'Please input the training dataset.')
self._init_save()
self.seq_length = self.seq_length or self.packing_length or self.max_length
if self.streaming:
self.dataloader_type = 'external'
if self.num_workers > 1:
self.num_workers = 1
logger.info('Using streaming dataset, setting args.num_workers to 1.')
if self.load is None and self.no_initialization:
if self.tensorboard_dir is None and self.save is not None:
self.tensorboard_dir = f'{self.save}/runs'
self.tensorboard_dir = to_abspath(self.tensorboard_dir)
if self.load is None and self.no_initialization and not self.load_hf_checkpoint:
raise ValueError('You did not pass `--load`, so you need to set `--no_initialization false` '
'to allow the model to initialize weights properly.')
if self.cached_dataset and self.context_parallel_size > 1:
raise ValueError('`cached_dataset` does not support context parallelism.')

def get_model_kwargs(self):
res = super().get_model_kwargs()
res['download_model'] = self.load_hf_checkpoint
return res
Loading
Loading