Skip to content

Commit d90c1da

Browse files
committed
enable fsdp training and support huggingface models with ckpt in or out
1 parent 24180aa commit d90c1da

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+1170
-369
lines changed

configs/1.8B_MoE16_sft.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,6 @@
170170
* if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
171171
* if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
172172
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
173-
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
174173
tensor parallel (dict):
175174
1. size: int, the size of tensor parallel.
176175
2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],
@@ -197,7 +196,7 @@
197196
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
198197
"""
199198
parallel = dict(
200-
zero1=dict(size=-1, fsdp=False),
199+
zero1=dict(size=-1),
201200
tensor=dict(size=1, mode="mtp"),
202201
pipeline=dict(size=1, interleaved_overlap=True),
203202
weight=dict(size=1, overlap=True),

configs/57B_qwen2_MoE.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,6 @@
175175
* if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
176176
* if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
177177
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
178-
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
179178
tensor parallel (dict):
180179
1. size: int, the size of tensor parallel.
181180
2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],
@@ -202,7 +201,7 @@
202201
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
203202
"""
204203
parallel = dict(
205-
zero1=dict(size=-1, fsdp=False),
204+
zero1=dict(size=-1),
206205
tensor=dict(size=1, mode="mtp"),
207206
pipeline=dict(size=1, interleaved_overlap=True),
208207
weight=dict(size=1, overlap=True),

configs/7B_MoE4_sft.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,6 @@
182182
* if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
183183
* if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
184184
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
185-
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
186185
tensor parallel (dict):
187186
1. size: int, the size of tensor parallel.
188187
2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],
@@ -217,7 +216,7 @@
217216
4. forward_overlap_per: str, all gather prefetch granularity, per 'module' or per 'layer', defaults to 'layer'.
218217
"""
219218
parallel = dict(
220-
zero1=dict(size=-1, fsdp=False),
219+
zero1=dict(size=-1),
221220
tensor=dict(size=1, mode="mtp"),
222221
pipeline=dict(size=1, interleaved_overlap=True),
223222
weight=dict(size=1, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"),

configs/7B_baichuan2.py

-1
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,6 @@
165165
* if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
166166
* if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
167167
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
168-
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
169168
tensor parallel (dict):
170169
1. size: int, the size of tensor parallel.
171170
2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],

configs/7B_gemma.py

-1
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@
172172
* if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
173173
* if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
174174
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
175-
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
176175
tensor parallel (dict):
177176
1. size: int, the size of tensor parallel.
178177
2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],

configs/7B_internlm2.py

-1
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,6 @@
174174
* if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
175175
* if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
176176
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
177-
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
178177
tensor parallel (dict):
179178
1. size: int, the size of tensor parallel.
180179
2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],

configs/7B_isp_sft.py

-1
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@
187187
* if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
188188
* if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
189189
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
190-
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
191190
tensor parallel (dict):
192191
1. size: int, the size of tensor parallel.
193192
2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],

configs/7B_llama2.py

-1
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,6 @@
164164
* if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
165165
* if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
166166
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
167-
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
168167
tensor parallel (dict):
169168
1. size: int, the size of tensor parallel.
170169
2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],

configs/7B_qwen2.py

-1
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@
172172
* if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
173173
* if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
174174
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
175-
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
176175
tensor parallel (dict):
177176
1. size: int, the size of tensor parallel.
178177
2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],

configs/7B_sft.py

-1
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,6 @@
174174
* if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
175175
* if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
176176
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
177-
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
178177
tensor parallel (dict):
179178
1. size: int, the size of tensor parallel.
180179
2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],

configs/8x22B_mixtral.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,6 @@
176176
* if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
177177
* if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
178178
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
179-
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
180179
tensor parallel (dict):
181180
1. size: int, the size of tensor parallel.
182181
2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],
@@ -203,7 +202,7 @@
203202
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
204203
"""
205204
parallel = dict(
206-
zero1=dict(size=-1, fsdp=False),
205+
zero1=dict(size=-1),
207206
tensor=dict(size=1, mode="mtp"),
208207
pipeline=dict(size=1, interleaved_overlap=True),
209208
weight=dict(size=1, overlap=True),

configs/8x7B_mixtral.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,6 @@
176176
* if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
177177
* if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
178178
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
179-
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
180179
tensor parallel (dict):
181180
1. size: int, the size of tensor parallel.
182181
2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],
@@ -203,7 +202,7 @@
203202
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
204203
"""
205204
parallel = dict(
206-
zero1=dict(size=-1, fsdp=False),
205+
zero1=dict(size=-1),
207206
tensor=dict(size=1, mode="mtp"),
208207
pipeline=dict(size=1, interleaved_overlap=True),
209208
weight=dict(size=1, overlap=True),

configs/_base_/models/internlm2_1B.py

-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
* if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
5252
* if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
5353
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
54-
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
5554
tensor parallel (dict):
5655
1. size: int, the size of tensor parallel.
5756
2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],

configs/_base_/models/internlm2_20B.py

-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
* if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
4949
* if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
5050
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
51-
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
5251
tensor parallel (dict):
5352
1. size: int, the size of tensor parallel.
5453
2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],

configs/_base_/models/internlm2_7B.py

-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
* if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
4949
* if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
5050
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
51-
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
5251
tensor parallel (dict):
5352
1. size: int, the size of tensor parallel.
5453
2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],

configs/_base_/models/internlm_20B.py

-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
* if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
4444
* if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
4545
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
46-
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
4746
tensor parallel (dict):
4847
1. size: int, the size of tensor parallel.
4948
2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],

configs/_base_/models/internlm_7B.py

-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
* if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
4444
* if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
4545
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
46-
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
4746
tensor parallel (dict):
4847
1. size: int, the size of tensor parallel.
4948
2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],

doc/code-docs/source/initialize.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ InternEvo 使用 `argparse <https://docs.python.org/3/library/argparse.html>`_
4343
模型初始化
4444
-------------------------
4545

46-
.. autofunction:: internlm.train.initialize_model
46+
.. autofunction:: internlm.train.initialize_model_and_parallel_communicator
4747

4848
InternEvo 在配置文件中使用字段 ``model_type`` 和 ``model`` 来控制模型初始化过程。示例模型初始化配置定义如下:
4949

doc/code-docs/source/training.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
- 初始化模型
2828
.. code-block:: python
2929
30-
model = initialize_model()
30+
model = initialize_model_and_parallel_communicator()
3131
3232
详细介绍请参考: `模型初始化 <https://internevo.readthedocs.io/zh-cn/latest/initialize.html#internlm-model-init>`_
3333

doc/en/train_performance.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ model = dict(
121121
)
122122

123123
parallel = dict(
124-
zero1=dict(size=8, fsdp=False),
124+
zero1=dict(size=8),
125125
tensor=1,
126126
pipeline=dict(size=1, interleaved_overlap=True),
127127
sequence_parallel=False,

doc/train_performance.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ model = dict(
117117
)
118118

119119
parallel = dict(
120-
zero1=dict(size=8, fsdp=False),
120+
zero1=dict(size=8),
121121
tensor=1,
122122
pipeline=dict(size=1, interleaved_overlap=True),
123123
sequence_parallel=False,

doc/usage.md

-2
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,6 @@ zero1 parallel (dict):
268268
* if size == 1, zero is not used, and all dp groups retain the full amount of model parameters.
269269
* if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size.
270270
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
271-
2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False.
272271
tensor parallel (dict):
273272
1. size: int, the size of tensor parallel.
274273
2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'],
@@ -432,7 +431,6 @@ parallel = dict(
432431
-`zero1 <= 0`,则 zero1 进程组的大小等于数据并行进程组的大小,因此优化器状态参数将在数据并行范围内分配
433432
-`zero1 == 1`,则不使用 zero1 ,所有数据并行组保留完整的优化器状态参数
434433
-`zero1 > 1``zero1 <= data_parallel_world_size`,则 zero1 进程组是数据并行进程组的子集
435-
2. fsdp: 布尔值,启用/禁用torch的完全分片数据并行,默认为False。
436434
- tensor(字典):
437435
1. size: 整数,张量并行的大小。
438436
2. mode: 字符串,张量并行模式,应该是 ['mtp', 'msp', 'fsp', 'isp'] 中的一个,

generate.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from internlm.initialize import initialize_distributed_env
2222
from internlm.monitor import initialize_monitor_manager
2323
from internlm.monitor.monitor import monitor_manager as mm
24-
from internlm.train import initialize_model, initialize_parallel_communicator
24+
from internlm.train import initialize_model_and_parallel_communicator
2525
from internlm.utils.common import (
2626
enable_pytorch_expandable_segments,
2727
launch_time,
@@ -106,8 +106,7 @@ def main():
106106
raise e
107107

108108
# initialize model
109-
model = initialize_model()
110-
_ = initialize_parallel_communicator(model)
109+
model, _ = initialize_model_and_parallel_communicator()
111110
model = model.model
112111

113112
state_dict = merge_pp_within_tp(generation_config.ckpt_folder, del_model_prefix=True)

internlm/checkpoint/checkpoint_manager.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from internlm.utils.common import get_current_device
2424
from internlm.utils.logger import get_logger
2525
from internlm.utils.megatron_timers import megatron_timer as timer
26+
from internlm.utils.parallel import is_using_fsdp, is_using_hf
2627
from internlm.utils.storage_manager import (
2728
get_storage_manager,
2829
init_storage_manager,
@@ -271,7 +272,7 @@ def __init__(
271272
self.storage_manager = get_storage_manager()
272273
self.snapshot_counter = -1
273274

274-
if hasattr(model, "model"):
275+
if hasattr(model, "model") and not is_using_fsdp():
275276
model = model.model
276277

277278
self.model = model
@@ -575,6 +576,8 @@ def try_resume_training(self, train_state: TrainState, current_time=""):
575576
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
576577
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
577578
)
579+
elif is_using_fsdp() and is_using_hf() and not self.auto_resume:
580+
pass
578581
else:
579582
load_path = self.load_ckpt_info["path"]
580583
load_content = self.load_ckpt_info["content"]

0 commit comments

Comments
 (0)