Skip to content

Commit 150153f

Browse files
committed
use HF standard naming for qkv and mlp layers
1 parent 9a4d45f commit 150153f

File tree

5 files changed

+108
-45
lines changed

5 files changed

+108
-45
lines changed

internlm/model/model_implementations/transformers/modeling_internlm2.py

+26-11
Original file line numberDiff line numberDiff line change
@@ -809,19 +809,34 @@ def unique_kv_index(i):
809809
)
810810
current_states[name] = torch.index_select(current_states[name], 0, unique_index)
811811

812+
mlp_layer_fusion = any(".w13." in key for key in current_states.keys())
812813
fixed_current_states = {}
813-
for name in current_states.keys():
814-
if "wqkv" in name:
815-
new_name = name.replace("wqkv", "qkv_proj")
816-
elif ".w1." in name:
817-
new_name = name.replace("feed_forward.w1", "feed_forward.gate_proj")
818-
elif ".w2." in name:
819-
new_name = name.replace("feed_forward.w2", "feed_forward.down_proj")
820-
elif ".w3." in name:
821-
new_name = name.replace("feed_forward.w3", "feed_forward.up_proj")
814+
for key in current_states.keys():
815+
if "wqkv" in key:
816+
new_key = key.replace("wqkv", "qkv_proj")
817+
# elif "wq" in key:
818+
# new_key = key.replace("wq", "q_proj")
819+
# elif "wk" in key:
820+
# new_key = key.replace("wk", "k_proj")
821+
# elif "wv" in key:
822+
# new_key = key.replace("wv", "v_proj")
823+
# elif "wo" in key:
824+
# new_key = key.replace("wo", "o_proj")
825+
elif ".w1." in key:
826+
new_key = key.replace("feed_forward.w1", "feed_forward.gate_proj")
827+
elif ".w2." in key:
828+
new_key = (
829+
key.replace("feed_forward.w2", "feed_forward.dense_4h_to_h")
830+
if mlp_layer_fusion
831+
else key.replace("feed_forward.w2", "feed_forward.down_proj")
832+
)
833+
elif ".w3." in key:
834+
new_key = key.replace("feed_forward.w3", "feed_forward.up_proj")
835+
elif ".w13." in key:
836+
new_key = key.replace("feed_forward.w13", "feed_forward.dense_h_to_4h")
822837
else:
823-
new_name = name
824-
fixed_current_states[new_name] = current_states[name]
838+
new_key = key
839+
fixed_current_states[new_key] = current_states[key]
825840

826841
missing_keys, unexpected_keys = model.load_state_dict(fixed_current_states, strict=False)
827842

internlm/model/model_ops/modules/mha.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -46,27 +46,27 @@ def split_fused_wqkv_weight(wqkv, *args, **kwargs): # pylint: disable=W0613
4646

4747

4848
def _qkv_pre_load_convert(module: "GQA", state_dict, prefix: str, *args, **kwargs) -> None: # pylint: disable=W0613
49-
wq_name, wk_name, wv_name, fused_name = (
49+
wq_name, wk_name, wv_name, wqkv_name = (
5050
f"{prefix}q_proj.weight",
5151
f"{prefix}k_proj.weight",
5252
f"{prefix}v_proj.weight",
5353
f"{prefix}qkv_proj.weight",
5454
)
5555

56-
if module.enable_qkv_fusion and fused_name not in state_dict:
56+
if module.enable_qkv_fusion and wqkv_name not in state_dict:
5757
wq, wk, wv = state_dict.pop(wq_name), state_dict.pop(wk_name), state_dict.pop(wv_name)
58-
state_dict[fused_name] = torch.cat([wq, wk, wv], dim=0)
58+
state_dict[wqkv_name] = torch.cat([wq, wk, wv], dim=0)
5959

6060
if not module.enable_qkv_fusion and (
6161
wq_name not in state_dict or wk_name not in state_dict or wv_name not in state_dict
6262
):
6363
state_dict[wq_name], state_dict[wk_name], state_dict[wv_name] = split_fused_wqkv_weight(
64-
state_dict.pop(fused_name), *args, **kwargs
64+
state_dict.pop(wqkv_name), *args, **kwargs
6565
)
6666

6767

6868
def _qkv_save_convert(module: "GQA", state_dict, prefix: str, *args, **kwargs) -> Dict: # pylint: disable=W0613
69-
wq_name, wk_name, wv_name, fused_name = (
69+
wq_name, wk_name, wv_name, wqkv_name = (
7070
f"{prefix}q_proj.weight",
7171
f"{prefix}k_proj.weight",
7272
f"{prefix}v_proj.weight",
@@ -75,7 +75,7 @@ def _qkv_save_convert(module: "GQA", state_dict, prefix: str, *args, **kwargs) -
7575

7676
if module.enable_qkv_fusion:
7777
state_dict[wq_name], state_dict[wk_name], state_dict[wv_name] = split_fused_wqkv_weight(
78-
state_dict.pop(fused_name), *args, **kwargs
78+
state_dict.pop(wqkv_name), *args, **kwargs
7979
)
8080

8181
return state_dict
@@ -162,6 +162,10 @@ def __init__(
162162
self.q_proj = new_linear("wq", embed_dim, embed_dim, bias, **factory_kwargs)
163163
self.k_proj = new_linear("wk", embed_dim, self.kv_dim, bias, **factory_kwargs)
164164
self.v_proj = new_linear("wv", embed_dim, self.kv_dim, bias, **factory_kwargs)
165+
self.register_checkpoint_compatibility_hooks(
166+
partial(_qkv_pre_load_convert, q_dim=self.embed_dim, kv_dim=self.kv_dim),
167+
partial(_qkv_save_convert, q_dim=self.embed_dim, kv_dim=self.kv_dim),
168+
)
165169

166170
self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
167171
self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
@@ -462,14 +466,14 @@ def __init__(
462466
if enable_qkv_fusion:
463467
assert bias is False, "Fuesd wqkv only support bias is False."
464468
self.qkv_proj = new_linear("wqkv", embed_dim, q_dim + 2 * self.kv_dim, bias, **factory_kwargs)
465-
self._register_load_state_dict_pre_hook(
466-
partial(_qkv_pre_load_convert, q_dim=q_dim, kv_dim=self.kv_dim), with_module=True
467-
)
468-
self._register_state_dict_hook(partial(_qkv_save_convert, q_dim=q_dim, kv_dim=self.kv_dim))
469469
else:
470470
self.q_proj = new_linear("wq", embed_dim, q_dim, bias, **factory_kwargs)
471471
self.k_proj = new_linear("wk", embed_dim, self.kv_dim, bias, **factory_kwargs)
472472
self.v_proj = new_linear("wv", embed_dim, self.kv_dim, bias, **factory_kwargs)
473+
self.register_checkpoint_compatibility_hooks(
474+
partial(_qkv_pre_load_convert, q_dim=q_dim, kv_dim=self.kv_dim),
475+
partial(_qkv_save_convert, q_dim=self.embed_dim, kv_dim=self.kv_dim),
476+
)
473477

474478
self.inner_attn = SelfAttention(
475479
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout, layer_idx=layer_idx

internlm/model/model_ops/modules/mlp.py

+36-12
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# -*- encoding: utf-8 -*-
33

4-
from typing import Dict, Optional
4+
from typing import Callable, Dict, Optional
55

66
import torch
77
from torch import nn
@@ -22,21 +22,38 @@ def split_fused_mlp_weight(w1_w3):
2222
def _mlp_pre_load_convert(
2323
module: "FeedForward", state_dict, prefix: str, *args, **kwargs # pylint: disable=W0613
2424
) -> None:
25-
w1_name, w3_name, fused_name = f"{prefix}w1.weight", f"{prefix}w3.weight", f"{prefix}fused_w1_w3.weight"
25+
gate_proj_name, up_proj_name, dense_h_to_4h_name = (
26+
f"{prefix}gate_proj.weight",
27+
f"{prefix}up_proj.weight",
28+
f"{prefix}dense_h_to_4h.weight",
29+
)
30+
down_proj_name, dense_4h_to_h_name = f"{prefix}down_proj.weight", f"{prefix}dense_4h_to_h.weight"
2631

27-
if module.mlp_layer_fusion and fused_name not in state_dict:
28-
w1, w3 = state_dict.pop(w1_name), state_dict.pop(w3_name)
29-
state_dict[fused_name] = torch.cat([w1, w3], dim=0)
32+
if module.mlp_layer_fusion and dense_h_to_4h_name not in state_dict:
33+
gate_proj, up_proj = state_dict.pop(gate_proj_name), state_dict.pop(up_proj_name)
34+
state_dict[dense_h_to_4h_name] = torch.cat([gate_proj, up_proj], dim=0)
35+
state_dict[dense_4h_to_h_name] = state_dict.pop(down_proj_name)
3036

31-
if not module.mlp_layer_fusion and (w1_name not in state_dict or w3_name not in state_dict):
32-
state_dict[w1_name], state_dict[w3_name] = split_fused_mlp_weight(state_dict.pop(fused_name))
37+
if not module.mlp_layer_fusion and (gate_proj_name not in state_dict or up_proj_name not in state_dict):
38+
state_dict[gate_proj_name], state_dict[up_proj_name] = split_fused_mlp_weight(
39+
state_dict.pop(dense_h_to_4h_name)
40+
)
41+
state_dict[down_proj_name] = state_dict.pop(dense_4h_to_h_name)
3342

3443

3544
def _mlp_save_convert(module: "FeedForward", state_dict, prefix: str, *args, **kwargs) -> Dict: # pylint: disable=W0613
36-
w1_name, w3_name, fused_name = f"{prefix}w1.weight", f"{prefix}w3.weight", f"{prefix}fused_w1_w3.weight"
45+
gate_proj_name, up_proj_name, dense_h_to_4h_name = (
46+
f"{prefix}gate_proj.weight",
47+
f"{prefix}up_proj.weight",
48+
f"{prefix}dense_h_to_4h.weight",
49+
)
50+
down_proj_name, dense_4h_to_h_name = f"{prefix}down_proj.weight", f"{prefix}dense_4h_to_h.weight"
3751

3852
if module.mlp_layer_fusion:
39-
state_dict[w1_name], state_dict[w3_name] = split_fused_mlp_weight(state_dict.pop(fused_name))
53+
state_dict[gate_proj_name], state_dict[up_proj_name] = split_fused_mlp_weight(
54+
state_dict.pop(dense_h_to_4h_name)
55+
)
56+
state_dict[down_proj_name] = state_dict.pop(dense_4h_to_h_name)
4057

4158
return state_dict
4259

@@ -92,9 +109,6 @@ def __init__(
92109
self.dense_4h_to_h = new_linear(
93110
"w2", hidden_features, out_features, bias, device=device, dtype=dtype, is_expert=is_expert
94111
)
95-
96-
self._register_load_state_dict_pre_hook(_mlp_pre_load_convert, with_module=True)
97-
self._register_state_dict_hook(_mlp_save_convert)
98112
else:
99113
self.gate_proj = new_linear(
100114
"w1", in_features, hidden_features, bias, device=device, dtype=dtype, is_expert=is_expert
@@ -105,6 +119,16 @@ def __init__(
105119
self.down_proj = new_linear(
106120
"w2", hidden_features, out_features, bias, device=device, dtype=dtype, is_expert=is_expert
107121
)
122+
self.register_checkpoint_compatibility_hooks(_mlp_pre_load_convert, _mlp_save_convert)
123+
124+
def register_checkpoint_compatibility_hooks(
125+
self, pre_load_hook: Optional[Callable] = None, pre_save_hook: Optional[Callable] = None
126+
):
127+
# Here we explicitly expose the checkpoint compatibility interface of the module,
128+
# hoping that model developers will make good use of it when adapting.
129+
# Is this interface already meeting all reasonable requirements?
130+
self._register_load_state_dict_pre_hook(pre_load_hook, with_module=True)
131+
self._register_state_dict_hook(pre_save_hook)
108132

109133
def forward(self, x):
110134
if not self.mlp_layer_fusion:

tests/test_training/train_CI.py

+32-12
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,26 @@
2222
from internlm.checkpoint import CheckpointManager # noqa: E402
2323
from internlm.core.context import ParallelMode # noqa: E402
2424
from internlm.core.context import global_context as gpc # noqa: E402
25+
from internlm.core.trainer import record_current_batch_training_metrics # noqa: E402
2526
from internlm.core.trainer import Trainer, TrainState # noqa: E402
2627
from internlm.data import ( # noqa: E402
2728
build_train_loader_with_data_type,
2829
build_valid_loader_with_data_type,
2930
)
3031
from internlm.eval import evaluate_on_val_dls # noqa: E402
3132
from internlm.initialize import initialize_launcher # noqa: E402
33+
from internlm.initialize import initialize_trainer # noqa: E402
3234
from internlm.initialize.initialize_model import ( # noqa: E402
3335
initialize_model_and_parallel_communicator,
3436
)
35-
from internlm.initialize import initialize_trainer # noqa: E402
3637
from internlm.model.model_ops.losses import InternLoss # noqa: E402
3738
from internlm.model.model_ops.metrics import ( # noqa: E402
3839
AccPerplex,
3940
SchedulerMetricHook,
4041
)
4142
from internlm.monitor import initialize_monitor_manager # noqa: E402
42-
from internlm.monitor import monitor_manager as mm # noqa: E402
4343
from internlm.monitor import send_alert_message # noqa: E402
44-
from internlm.core.trainer import record_current_batch_training_metrics # noqa: E402
44+
from internlm.monitor import monitor_manager as mm # noqa: E402
4545
from internlm.utils.common import ( # noqa: E402
4646
BatchSkipper,
4747
get_current_device,
@@ -65,18 +65,38 @@ def check_model_weights(model, ckpt_path, total_equal=False):
6565
model1_dict = torch.load(ckpt_path, map_location="cuda")
6666
model2_dict = model.state_dict()
6767

68-
copy_of_ordered_dict = model2_dict.copy()
69-
70-
for key in copy_of_ordered_dict.keys():
68+
mlp_layer_fusion = any(".w13." in key for key in model1_dict.keys())
69+
fixed_model1_dict = {}
70+
for key in model1_dict.keys():
7171
if "wqkv" in key:
72-
model2_dict[key.replace("wqkv", "Wqkv")] = model2_dict.pop(key)
73-
key = key.replace("wqkv", "Wqkv")
74-
if key not in model1_dict:
75-
assert False, f"Error: The key {key} for current model dose not exist in standard ckpt!"
72+
new_key = key.replace("wqkv", "qkv_proj")
73+
elif "wq" in key:
74+
new_key = key.replace("wq", "q_proj")
75+
elif "wk" in key:
76+
new_key = key.replace("wk", "k_proj")
77+
elif "wv" in key:
78+
new_key = key.replace("wv", "v_proj")
79+
# elif "wo" in key:
80+
# new_key = key.replace("wo", "o_proj")
81+
elif ".w1." in key:
82+
new_key = key.replace("feed_forward.w1", "feed_forward.gate_proj")
83+
elif ".w2." in key:
84+
new_key = (
85+
key.replace("feed_forward.w2", "feed_forward.dense_4h_to_h")
86+
if mlp_layer_fusion
87+
else key.replace("feed_forward.w2", "feed_forward.down_proj")
88+
)
89+
elif ".w3." in key:
90+
new_key = key.replace("feed_forward.w3", "feed_forward.up_proj")
91+
elif ".w13." in key:
92+
new_key = key.replace("feed_forward.w13", "feed_forward.dense_h_to_4h")
93+
else:
94+
new_key = key
95+
fixed_model1_dict[new_key] = model1_dict[key]
7696

77-
for key in model1_dict.keys():
97+
for key in fixed_model1_dict.keys():
7898
if key in model2_dict:
79-
tensor1 = model1_dict[key]
99+
tensor1 = fixed_model1_dict[key]
80100
tensor2 = model2_dict[key]
81101
if total_equal:
82102
assert torch.equal(tensor1, tensor2), "model weights are not equal"

tools/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)