5
5
from flash_attn .flash_attn_interface import _flash_attn_backward , _flash_attn_forward
6
6
7
7
from internlm .core .context import global_context as gpc
8
+ from internlm .core .parallel .comm import get_offload_manager
8
9
9
10
from .utils import RingComm , update_out_and_lse
10
11
11
- fa_output_mapping = {}
12
-
13
12
14
13
def create_buffer (tensor ):
15
14
buffer_shape = list (tensor .shape )
@@ -443,9 +442,10 @@ def forward(
443
442
k = k .contiguous ()
444
443
v = v .contiguous ()
445
444
446
- if gpc .is_forward is False and gpc .config .selective_checkpoint :
447
- assert layer_idx in fa_output_mapping
448
- out , softmax_lse = fa_output_mapping .pop (layer_idx )
445
+ _ckpt_block_num = int (gpc .config .model .checkpoint * gpc .config .isp_num_layers )
446
+
447
+ if gpc .is_forward is False and gpc .config .selective_checkpoint and layer_idx < _ckpt_block_num :
448
+ out , softmax_lse = get_offload_manager ().get_fa_output_with_layer (layer_idx )
449
449
else :
450
450
out , softmax_lse = zigzag_double_ring_flash_attn_forward (
451
451
context_group ,
@@ -460,8 +460,8 @@ def forward(
460
460
)
461
461
462
462
# store attn forward output to avoid re-computation of attn when activation checkpoint is enabled
463
- if gpc .is_forward and gpc .config .selective_checkpoint :
464
- fa_output_mapping [ layer_idx ] = (out , softmax_lse )
463
+ if gpc .is_forward and gpc .config .selective_checkpoint and layer_idx < _ckpt_block_num :
464
+ get_offload_manager (). insert_fa_output_with_layer ( layer_idx = layer_idx , output = (out , softmax_lse ) )
465
465
466
466
# this should be out_padded
467
467
ctx .save_for_backward (q , k , v , out , softmax_lse )
0 commit comments