Skip to content

Commit 97cb5d1

Browse files
committed
fix ci
1 parent 4b41bb5 commit 97cb5d1

File tree

6 files changed

+9
-7
lines changed

6 files changed

+9
-7
lines changed

internlm/core/context/parallel_context.py

-1
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@ def __init__(self):
159159
self.virtual_pipeline_parallel_rank = None
160160
self._expert_parallel_group_names = []
161161
self.is_evaluating = False
162-
self.recompute_forward_no_comm = False
163162

164163
@property
165164
def config(self):

internlm/initialize/launch.py

+2
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,8 @@ def args_sanity_check():
294294
"torch.tf32",
295295
]
296296

297+
gpc.config._add_item("recompute_forward_no_comm", False)
298+
297299
if "checkpoint" in model:
298300
if "checkpoint_tp_no_comm" not in model:
299301
gpc.config.model._add_item("checkpoint_tp_no_comm", True)

internlm/model/modeling_internlm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def _dropout_and_norm_ffn(_residual, _hidden_states):
216216
hidden_states = self.mlp(hidden_states)
217217

218218
# pad residual
219-
if gpc.recompute_forward_no_comm and is_using_sequence_parallel():
219+
if gpc.config.recompute_forward_no_comm and is_using_sequence_parallel():
220220
residual = padding_residual(residual)
221221

222222
return hidden_states + residual

internlm/model/modeling_internlm2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def _dropout_and_norm_ffn(_residual, _hidden_states):
261261
hidden_states = self.feed_forward(hidden_states)
262262

263263
# pad residual
264-
if gpc.recompute_forward_no_comm and is_using_sequence_parallel():
264+
if gpc.config.recompute_forward_no_comm and is_using_sequence_parallel():
265265
residual = padding_residual(residual)
266266

267267
return hidden_states + residual

internlm/model/modules/mlp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def forward(self, x):
9999
else:
100100
fussed_out = self.fused_w1_w3(x)
101101
w1_o, w3_o = torch.split(fussed_out, fussed_out.shape[-1] // 2, dim=-1)
102-
out = self.w2(Silu(w1_o, w3_o), no_communication=gpc.recompute_forward_no_comm)
102+
out = self.w2(Silu(w1_o, w3_o), no_communication=gpc.config.recompute_forward_no_comm)
103103
return out
104104

105105

internlm/solver/activation_checkpoint.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def recompute_forward_context(args, no_communication):
4747
handle = None
4848
try:
4949
# Set True when entering the context
50-
if no_communication:
51-
gpc.recompute_forward_no_comm = True
50+
if no_communication and hasattr(gpc.config, "recompute_forward_no_comm"):
51+
gpc.config.recompute_forward_no_comm = True
5252
if is_using_sequence_parallel():
5353
# overlap all_gather
5454
grad_output = args[0]
@@ -58,7 +58,8 @@ def recompute_forward_context(args, no_communication):
5858
yield
5959
finally:
6060
# Set False when exiting the context
61-
gpc.recompute_forward_no_comm = False
61+
if hasattr(gpc.config, "recompute_forward_no_comm"):
62+
gpc.config.recompute_forward_no_comm = False
6263

6364
if handle:
6465
handle.wait()

0 commit comments

Comments
 (0)