Skip to content

Commit 746eb9d

Browse files
committed
Merge branch 'develop' into feat/refactor-impl
2 parents eeca1e8 + caf30d8 commit 746eb9d

File tree

5 files changed

+8
-10
lines changed

5 files changed

+8
-10
lines changed

configs/7B_isp_sft.py

+1
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@
243243
cudnn_deterministic = False
244244
cudnn_benchmark = False
245245

246+
246247
monitor = dict(
247248
# feishu alert configs
248249
alert=dict(

internlm/model/model_implementations/transformers/modeling_internlm.py

-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from internlm.core.context import ParallelMode
1919
from internlm.core.context import global_context as gpc
2020
from internlm.core.naive_amp import set_output_attr_to_module
21-
from internlm.core.parallel.shard import partition_uniform
2221
from internlm.model.model_implementations.transformers.base_model import (
2322
BaseTransformerModel,
2423
)

internlm/model/model_ops/moe/moe.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import torch
2-
import torch.nn.functional as F
32

43
from internlm.core.context import ParallelMode
54
from internlm.core.context import global_context as gpc

internlm/model/model_ops/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55
from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward
66

77
from internlm.core.context import global_context as gpc
8+
from internlm.core.parallel.comm import get_offload_manager
89

910
from .utils import RingComm, update_out_and_lse
1011

11-
fa_output_mapping = {}
12-
1312

1413
def create_buffer(tensor):
1514
buffer_shape = list(tensor.shape)
@@ -443,9 +442,10 @@ def forward(
443442
k = k.contiguous()
444443
v = v.contiguous()
445444

446-
if gpc.is_forward is False and gpc.config.selective_checkpoint:
447-
assert layer_idx in fa_output_mapping
448-
out, softmax_lse = fa_output_mapping.pop(layer_idx)
445+
_ckpt_block_num = int(gpc.config.model.checkpoint * gpc.config.isp_num_layers)
446+
447+
if gpc.is_forward is False and gpc.config.selective_checkpoint and layer_idx < _ckpt_block_num:
448+
out, softmax_lse = get_offload_manager().get_fa_output_with_layer(layer_idx)
449449
else:
450450
out, softmax_lse = zigzag_double_ring_flash_attn_forward(
451451
context_group,
@@ -460,8 +460,8 @@ def forward(
460460
)
461461

462462
# store attn forward output to avoid re-computation of attn when activation checkpoint is enabled
463-
if gpc.is_forward and gpc.config.selective_checkpoint:
464-
fa_output_mapping[layer_idx] = (out, softmax_lse)
463+
if gpc.is_forward and gpc.config.selective_checkpoint and layer_idx < _ckpt_block_num:
464+
get_offload_manager().insert_fa_output_with_layer(layer_idx=layer_idx, output=(out, softmax_lse))
465465

466466
# this should be out_padded
467467
ctx.save_for_backward(q, k, v, out, softmax_lse)

internlm/model/model_ops/utils.py

-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from tqdm import tqdm
66

77
from internlm.core.context import global_context as gpc
8-
from internlm.model.model_ops.modules.mha import MHA
98
from internlm.utils.logger import get_logger
109
from internlm.utils.storage_manager import get_fns, llm_load
1110
from internlm.utils.utils import TensorParallelMode

0 commit comments

Comments
 (0)