Skip to content

Commit caf30d8

Browse files
feat(2D): support selective-checkpoint-offload for 2D attention (#396)
1 parent 20cbcfb commit caf30d8

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

configs/7B_isp_sft.py

+1
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@
245245
cudnn_deterministic = False
246246
cudnn_benchmark = False
247247

248+
248249
monitor = dict(
249250
# feishu alert configs
250251
alert=dict(

internlm/model/ops/ring_flash_attn/zigzag_ring_flash_attn_with_sliding_window.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55
from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward
66

77
from internlm.core.context.parallel_context import global_context as gpc
8+
from internlm.core.parallel.comm import get_offload_manager
89

910
from .utils import RingComm, update_out_and_lse
1011

11-
fa_output_mapping = {}
12-
1312

1413
def create_buffer(tensor):
1514
buffer_shape = list(tensor.shape)
@@ -443,9 +442,10 @@ def forward(
443442
k = k.contiguous()
444443
v = v.contiguous()
445444

446-
if gpc.is_forward is False and gpc.config.selective_checkpoint:
447-
assert layer_idx in fa_output_mapping
448-
out, softmax_lse = fa_output_mapping.pop(layer_idx)
445+
_ckpt_block_num = int(gpc.config.model.checkpoint * gpc.config.isp_num_layers)
446+
447+
if gpc.is_forward is False and gpc.config.selective_checkpoint and layer_idx < _ckpt_block_num:
448+
out, softmax_lse = get_offload_manager().get_fa_output_with_layer(layer_idx)
449449
else:
450450
out, softmax_lse = zigzag_double_ring_flash_attn_forward(
451451
context_group,
@@ -460,8 +460,8 @@ def forward(
460460
)
461461

462462
# store attn forward output to avoid re-computation of attn when activation checkpoint is enabled
463-
if gpc.is_forward and gpc.config.selective_checkpoint:
464-
fa_output_mapping[layer_idx] = (out, softmax_lse)
463+
if gpc.is_forward and gpc.config.selective_checkpoint and layer_idx < _ckpt_block_num:
464+
get_offload_manager().insert_fa_output_with_layer(layer_idx=layer_idx, output=(out, softmax_lse))
465465

466466
# this should be out_padded
467467
ctx.save_for_backward(q, k, v, out, softmax_lse)

0 commit comments

Comments
 (0)