Skip to content

Commit 191ca18

Browse files
committed
pylint fix
1 parent 991f07c commit 191ca18

File tree

7 files changed

+208
-263
lines changed

7 files changed

+208
-263
lines changed

internlm/checkpoint/load_funcs.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,7 @@
11
# Copyright (c) InternLM. All rights reserved.
22

3-
from internlm.model.model_implementations.transformers.modeling_internlm2 import (
4-
InternLM2,
5-
)
6-
from internlm.model.model_implementations.transformers.modeling_llama import Llama2
73
from internlm.utils.logger import get_logger
84

95
logger = get_logger(__file__)
106

11-
LOAD_FUNC_DICT = {
12-
"llama": Llama2.load_llama_pretrained_weights,
13-
"internlm2_test": InternLM2.load_internlm2_with_dynamic_parallel_size,
14-
}
7+
LOAD_FUNC_DICT = {}

internlm/core/parallel/comm/zero.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ def _pre_forward_hook(model: nn.Module, *args, **kwargs): # pylint: disable=W06
159159
for working_param, all_splited_param in zip(
160160
self._block_working_params[block_name], all_splited_param_list
161161
):
162-
working_param.data.copy_(_flatten_dense_tensors(all_splited_param)[: working_param.numel()].view_as(working_param))
162+
working_param.data.copy_(
163+
_flatten_dense_tensors(all_splited_param)[: working_param.numel()].view_as(working_param)
164+
)
163165

164166
self._block_allgather_handles[block_name] = None
165167
self._block_gathered_params[block_name] = []

internlm/model/model_implementations/transformers/modeling_internlm.py

-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from internlm.core.context import ParallelMode
1414
from internlm.core.context import global_context as gpc
1515
from internlm.core.naive_amp import set_output_attr_to_module
16-
from internlm.core.parallel.shard import partition_uniform
1716
from internlm.model.model_implementations.transformers.base_model import (
1817
BaseTransformerModel,
1918
)
@@ -522,7 +521,6 @@ def load_hf_weights(folder: str, model: nn.Module) -> None:
522521

523522
internlm_accelerator.empty_cache()
524523

525-
526524
@staticmethod
527525
def convert_internevo2hf_weights(src: str, tgt: str) -> None:
528526
model_config = gpc.config.model

internlm/model/model_implementations/transformers/modeling_internlm2.py

-193
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import math
33
import os
44
from contextlib import nullcontext
5-
from functools import reduce
65
from typing import Optional
76

87
import torch
@@ -13,7 +12,6 @@
1312
from internlm.core.context import ParallelMode
1413
from internlm.core.context import global_context as gpc
1514
from internlm.core.parallel.comm.cpu_offload import get_cpu_offload_context
16-
from internlm.core.parallel.shard import partition_uniform
1715
from internlm.model.model_implementations.transformers.base_model import (
1816
BaseTransformerModel,
1917
)
@@ -31,7 +29,6 @@
3129
from internlm.model.model_ops.utils import (
3230
convert_attn_args_to_kwargs,
3331
convert_attn_kwargs_to_args,
34-
get_parallel_size_from_file,
3532
)
3633
from internlm.solver.activation_checkpoint import activation_checkpoint
3734
from internlm.utils.logger import get_logger
@@ -636,196 +633,6 @@ def load_hf_weights(folder: str, model: nn.Module) -> None:
636633

637634
internlm_accelerator.empty_cache()
638635

639-
@staticmethod
640-
def load_internlm2_with_dynamic_parallel_size(folder, model):
641-
"""Load InternLM2 with dynamic parallel size."""
642-
assert folder is not None, "Please specify the folder of the pretrained model"
643-
assert gpc.config.model_type in ["INTERNLM2"], "dynamic_parallel is only for INTERNLM2"
644-
645-
fns = get_fns(folder)
646-
if gpc.is_rank_for_log():
647-
logger.info(f"Loading pretrained model from {folder}")
648-
model_fns, old_tp, old_pp = get_parallel_size_from_file(fns) # pylint: disable=W0612
649-
650-
tp = gpc.get_world_size(ParallelMode.TENSOR)
651-
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
652-
assert old_tp % tp == 0 or tp % old_tp == 0, (
653-
f"Expected TP size in loaded checkpoint to be fit with TP size in current config, but got {old_tp} in "
654-
f"checkpoint and {tp} in current config"
655-
)
656-
657-
correspond_tps = []
658-
659-
if old_tp <= tp:
660-
correspond_tps.append(tp_rank // (tp // old_tp))
661-
ratio = tp // old_tp
662-
rank = tp_rank % ratio
663-
else:
664-
for i in range(old_tp // tp):
665-
correspond_tps.append(tp_rank * (old_tp // tp) + i)
666-
rank = 0
667-
ratio = 1
668-
669-
current_states = {}
670-
671-
pp = gpc.get_world_size(ParallelMode.PIPELINE) # noqa: F841 # pylint: disable=W0612
672-
673-
assert gpc.config.model.num_chunks == 1, "May cause future collisions, ignore this if necessary"
674-
675-
old_pp_partition = partition_uniform(gpc.config.model.num_layers, old_pp, 1)
676-
677-
for idx, parts in enumerate(old_pp_partition):
678-
start, end = parts[0]
679-
if model.last_layer <= start or model.first_layer >= end:
680-
continue
681-
tmp_states = {}
682-
683-
for correspond_tp in correspond_tps:
684-
model_name = f"model_tp{correspond_tp}_pp{idx}.pt"
685-
states = llm_load(os.path.join(folder, model_name), map_location="cpu")
686-
states = {k.replace("model.", ""): v for k, v in states.items()}
687-
for i in range(start, end):
688-
if i >= model.last_layer:
689-
break
690-
if i < model.first_layer:
691-
continue
692-
693-
for name in list(states.keys()):
694-
if f".{i-start}." in name:
695-
to_name = name.replace(f".{i-start}.", f".{i-model.first_layer}.")
696-
697-
if gpc.config.model_type == "INTERNLM2":
698-
if "norm" in name:
699-
tmp_states[to_name] = [states.pop(name)]
700-
elif any(x in name for x in ("wo", "w2")):
701-
tmp_states[to_name] = tmp_states.get(to_name, [])
702-
tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=1)[rank])
703-
elif any(x in name for x in ("w1", "w3")):
704-
tmp_states[to_name] = tmp_states.get(to_name, [])
705-
tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=0)[rank])
706-
elif any(x in name for x in ("wqkv",)):
707-
tmp_states[to_name] = tmp_states.get(to_name, [])
708-
if tp > gpc.config.model.num_kv_attention_heads:
709-
assert old_tp <= gpc.config.model.num_kv_attention_heads, (
710-
f"`old_tp ({old_tp}) => tp ({tp})` is not supported. "
711-
"At least one of `tp` and `old_tp` should be less than or "
712-
"equal to `num_kv_attention_heads`"
713-
)
714-
# Suitable for cases where the num_kv_attention_head is small,
715-
# but you want to have a large TP Size
716-
q_per_kv = (
717-
gpc.config.model.num_attention_heads
718-
// gpc.config.model.num_kv_attention_heads
719-
)
720-
head_dim = gpc.config.model.hidden_size // gpc.config.model.num_attention_heads
721-
index = torch.concat(
722-
(
723-
torch.arange(q_per_kv).chunk(ratio, dim=0)[tp_rank % ratio],
724-
torch.tensor([q_per_kv, q_per_kv + 1]),
725-
)
726-
)
727-
index = index + (q_per_kv + 2) * (tp_rank // ratio)
728-
index = index % (
729-
(q_per_kv + 2) * (gpc.config.model.num_kv_attention_heads / old_tp)
730-
)
731-
index = index * head_dim
732-
index = index.repeat_interleave(head_dim) + torch.arange(head_dim).repeat(
733-
index.shape[0]
734-
)
735-
tmp_states[to_name].append(
736-
torch.index_select(states.pop(name), 0, index.to(torch.int32))
737-
)
738-
else:
739-
tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=0)[rank])
740-
else:
741-
raise KeyError(f"Unknown key {name}.")
742-
743-
else:
744-
assert False, "unsupported model type"
745-
746-
if "tok_embeddings.weight" in states and model.first_layer == 0:
747-
tmp_states["tok_embeddings.weight"] = tmp_states.get("tok_embeddings.weight", [])
748-
tmp_states["tok_embeddings.weight"].append(
749-
states["tok_embeddings.weight"].chunk(ratio, dim=1)[rank]
750-
)
751-
if "output.weight" in states and model.last_layer == gpc.config.model.num_layers:
752-
tmp_states["norm.weight"] = [states["norm.weight"]]
753-
tmp_states["output.weight"] = tmp_states.get("output.weight", [])
754-
tmp_states["output.weight"].append(states["output.weight"].chunk(ratio, dim=0)[rank])
755-
756-
states = {}
757-
758-
for name in list(tmp_states.keys()):
759-
data = tmp_states.pop(name)
760-
if len(data) == 1:
761-
current_states[name] = data[0]
762-
else:
763-
current_states[name] = torch.concat(
764-
data, dim=1 if name == "tok_embeddings.weight" or any(x in name for x in ("wo", "w2")) else 0
765-
)
766-
# Merge copied kv heads
767-
if "wqkv" in name and old_tp > gpc.config.model.num_kv_attention_heads:
768-
assert (
769-
tp <= gpc.config.model.num_kv_attention_heads
770-
), "new_tp should be less than or equal to num_kv_attention_heads"
771-
head_dim = gpc.config.model.hidden_size // gpc.config.model.num_attention_heads
772-
q_per_kv = gpc.config.model.num_attention_heads // gpc.config.model.num_kv_attention_heads
773-
copied_times = old_tp // gpc.config.model.num_kv_attention_heads
774-
cur_q_per_kv = q_per_kv // copied_times
775-
776-
# pylint: disable=all
777-
def duplicate_kv_index(i):
778-
if i % (cur_q_per_kv + 2) >= cur_q_per_kv:
779-
return i
780-
else:
781-
return -100
782-
783-
def unique_kv_index(i):
784-
if i // (cur_q_per_kv + 2) == copied_times - 1 or i % (cur_q_per_kv + 2) < cur_q_per_kv:
785-
return i
786-
else:
787-
return -100
788-
789-
# pylint: enable=all
790-
791-
# Verify
792-
duplicate_index = [duplicate_kv_index(i) for i in range((cur_q_per_kv + 2) * copied_times)]
793-
duplicate_index = [i for i in duplicate_index if i != -100]
794-
duplicate_index = _duplicate_index = torch.tensor(duplicate_index)
795-
for i in range(gpc.config.model.num_kv_attention_heads // tp - 1):
796-
duplicate_index = torch.concat(
797-
(duplicate_index, _duplicate_index + duplicate_index.max() + 1), dim=0
798-
)
799-
duplicate_kv = []
800-
for index in duplicate_index.reshape(-1, copied_times * 2).chunk(copied_times, dim=-1):
801-
index = index.reshape(-1) * head_dim
802-
index = index.repeat_interleave(head_dim) + torch.arange(head_dim).repeat(index.shape[0])
803-
duplicate_kv.append(torch.index_select(current_states[name], 0, index))
804-
assert reduce(
805-
lambda x, y: x and y,
806-
[torch.allclose(duplicate_kv[0], x, atol=1e-5) for x in duplicate_kv[1:]],
807-
), "Copied kv heads are not equal after training!"
808-
809-
# Merge
810-
unique_index = [unique_kv_index(i) for i in range((cur_q_per_kv + 2) * copied_times)]
811-
unique_index = [i for i in unique_index if i != -100]
812-
unique_index = _unique_index = torch.tensor(unique_index)
813-
for i in range(gpc.config.model.num_kv_attention_heads // tp - 1):
814-
unique_index = torch.concat((unique_index, _unique_index + unique_index.max() + 1), dim=0)
815-
unique_index = unique_index * head_dim
816-
unique_index = unique_index.repeat_interleave(head_dim) + torch.arange(head_dim).repeat(
817-
unique_index.shape[0]
818-
)
819-
current_states[name] = torch.index_select(current_states[name], 0, unique_index)
820-
missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False)
821-
822-
if gpc.get_local_rank(ParallelMode.DATA) == 0:
823-
pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
824-
logger.info(
825-
f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in "
826-
f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}"
827-
)
828-
829636
@staticmethod
830637
def convert_internevo2hf_weights(src: str, tgt: str) -> None:
831638
model_config = gpc.config.model

internlm/model/model_implementations/transformers/modeling_llama.py

-57
Original file line numberDiff line numberDiff line change
@@ -586,63 +586,6 @@ def load_hf_weights(folder: str, model: nn.Module):
586586

587587
internlm_accelerator.empty_cache()
588588

589-
@staticmethod
590-
def load_llama_pretrained_weights(folder: str, model: nn.Module) -> None:
591-
"""NOTE: when loading huggingface's llama pretrained weights, you should set `adapt_hf=True` in your config."""
592-
"""NOTE: specified for meta-llama/Llama-2-7b"""
593-
assert folder is not None, "Please specify the folder of the pretrained model"
594-
if gpc.is_rank_for_log():
595-
logger.info(f"Loading pretrained model from {folder}")
596-
597-
fns = get_fns(folder)
598-
model_fns = []
599-
for fn in fns:
600-
if fn.startswith("model_t") and not fn.endswith("md5"):
601-
model_fns.append(os.path.join(folder, fn))
602-
603-
if len(model_fns) == 0:
604-
model_fns = [os.path.join(folder, fn) for fn in fns if fn.endswith(".pth") or fn.endswith(".pt")]
605-
606-
if len(model_fns) == 0:
607-
raise FileNotFoundError(f"No checkpoint file found in {folder}")
608-
609-
model_fns.sort()
610-
611-
old_tp = len(model_fns)
612-
cur_tp = gpc.get_world_size(ParallelMode.TENSOR)
613-
# If the two tp are inconsistent, you need to consider the merge before splitting
614-
if old_tp != cur_tp:
615-
raise RuntimeError(
616-
f"Your current tp is `{cur_tp}`, but the tp in folder:`{folder}` is `{old_tp}`, use `` to convert first"
617-
)
618-
619-
states = llm_load(model_fns[gpc.get_local_rank(ParallelMode.TENSOR)], map_location="cpu")
620-
621-
current_states = {}
622-
for idx, i in enumerate(range(model.first_layer, model.last_layer)):
623-
for name in list(states.keys()):
624-
if f".{i}." in name:
625-
current_states[name.replace(f".{i}.", f".{idx}.")] = states.pop(name)
626-
627-
model_state_keys = set(list(model.state_dict().keys()))
628-
629-
if "tok_embeddings.weight" in model_state_keys:
630-
current_states["tok_embeddings.weight"] = states["tok_embeddings.weight"]
631-
assert model.first_layer == 0, f"Expect model.NaiveAMPModel to be 0, but got {model.first_layer}"
632-
if "output.weight" in model_state_keys:
633-
current_states["norm.weight"] = states["norm.weight"]
634-
current_states["output.weight"] = states["output.weight"]
635-
missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False)
636-
637-
if gpc.get_local_rank(ParallelMode.DATA) == 0:
638-
pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
639-
logger.info(
640-
f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in "
641-
f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}"
642-
)
643-
644-
internlm_accelerator.empty_cache()
645-
646589
@staticmethod
647590
def convert_internevo2hf_weights(src: str, tgt: str) -> None:
648591
model_config = gpc.config.model

internlm/solver/optimizer/hybrid_zero_optim_v2.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
import torch
77
import torch.distributed as dist
8-
from torch.optim import Optimizer
98
from torch._utils import _flatten_dense_tensors
9+
from torch.optim import Optimizer
1010

1111
from internlm.core.context import (
1212
IS_REPLICA_ZERO_PARALLEL,
@@ -670,7 +670,9 @@ def step(self, closure=None):
670670

671671
# Update working parameters
672672
for working_param, all_splited_param in zip(working_params_list[gather_idx], all_splited_param_list):
673-
working_param.data.copy_(_flatten_dense_tensors(all_splited_param)[: working_param.numel()].view_as(working_param))
673+
working_param.data.copy_(
674+
_flatten_dense_tensors(all_splited_param)[: working_param.numel()].view_as(working_param)
675+
)
674676

675677
for group_id in range(self.num_param_groups):
676678
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]

0 commit comments

Comments
 (0)