|
2 | 2 | import math
|
3 | 3 | import os
|
4 | 4 | from contextlib import nullcontext
|
5 |
| -from functools import reduce |
6 | 5 | from typing import Optional
|
7 | 6 |
|
8 | 7 | import torch
|
|
13 | 12 | from internlm.core.context import ParallelMode
|
14 | 13 | from internlm.core.context import global_context as gpc
|
15 | 14 | from internlm.core.parallel.comm.cpu_offload import get_cpu_offload_context
|
16 |
| -from internlm.core.parallel.shard import partition_uniform |
17 | 15 | from internlm.model.model_implementations.transformers.base_model import (
|
18 | 16 | BaseTransformerModel,
|
19 | 17 | )
|
|
31 | 29 | from internlm.model.model_ops.utils import (
|
32 | 30 | convert_attn_args_to_kwargs,
|
33 | 31 | convert_attn_kwargs_to_args,
|
34 |
| - get_parallel_size_from_file, |
35 | 32 | )
|
36 | 33 | from internlm.solver.activation_checkpoint import activation_checkpoint
|
37 | 34 | from internlm.utils.logger import get_logger
|
@@ -636,196 +633,6 @@ def load_hf_weights(folder: str, model: nn.Module) -> None:
|
636 | 633 |
|
637 | 634 | internlm_accelerator.empty_cache()
|
638 | 635 |
|
639 |
| - @staticmethod |
640 |
| - def load_internlm2_with_dynamic_parallel_size(folder, model): |
641 |
| - """Load InternLM2 with dynamic parallel size.""" |
642 |
| - assert folder is not None, "Please specify the folder of the pretrained model" |
643 |
| - assert gpc.config.model_type in ["INTERNLM2"], "dynamic_parallel is only for INTERNLM2" |
644 |
| - |
645 |
| - fns = get_fns(folder) |
646 |
| - if gpc.is_rank_for_log(): |
647 |
| - logger.info(f"Loading pretrained model from {folder}") |
648 |
| - model_fns, old_tp, old_pp = get_parallel_size_from_file(fns) # pylint: disable=W0612 |
649 |
| - |
650 |
| - tp = gpc.get_world_size(ParallelMode.TENSOR) |
651 |
| - tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) |
652 |
| - assert old_tp % tp == 0 or tp % old_tp == 0, ( |
653 |
| - f"Expected TP size in loaded checkpoint to be fit with TP size in current config, but got {old_tp} in " |
654 |
| - f"checkpoint and {tp} in current config" |
655 |
| - ) |
656 |
| - |
657 |
| - correspond_tps = [] |
658 |
| - |
659 |
| - if old_tp <= tp: |
660 |
| - correspond_tps.append(tp_rank // (tp // old_tp)) |
661 |
| - ratio = tp // old_tp |
662 |
| - rank = tp_rank % ratio |
663 |
| - else: |
664 |
| - for i in range(old_tp // tp): |
665 |
| - correspond_tps.append(tp_rank * (old_tp // tp) + i) |
666 |
| - rank = 0 |
667 |
| - ratio = 1 |
668 |
| - |
669 |
| - current_states = {} |
670 |
| - |
671 |
| - pp = gpc.get_world_size(ParallelMode.PIPELINE) # noqa: F841 # pylint: disable=W0612 |
672 |
| - |
673 |
| - assert gpc.config.model.num_chunks == 1, "May cause future collisions, ignore this if necessary" |
674 |
| - |
675 |
| - old_pp_partition = partition_uniform(gpc.config.model.num_layers, old_pp, 1) |
676 |
| - |
677 |
| - for idx, parts in enumerate(old_pp_partition): |
678 |
| - start, end = parts[0] |
679 |
| - if model.last_layer <= start or model.first_layer >= end: |
680 |
| - continue |
681 |
| - tmp_states = {} |
682 |
| - |
683 |
| - for correspond_tp in correspond_tps: |
684 |
| - model_name = f"model_tp{correspond_tp}_pp{idx}.pt" |
685 |
| - states = llm_load(os.path.join(folder, model_name), map_location="cpu") |
686 |
| - states = {k.replace("model.", ""): v for k, v in states.items()} |
687 |
| - for i in range(start, end): |
688 |
| - if i >= model.last_layer: |
689 |
| - break |
690 |
| - if i < model.first_layer: |
691 |
| - continue |
692 |
| - |
693 |
| - for name in list(states.keys()): |
694 |
| - if f".{i-start}." in name: |
695 |
| - to_name = name.replace(f".{i-start}.", f".{i-model.first_layer}.") |
696 |
| - |
697 |
| - if gpc.config.model_type == "INTERNLM2": |
698 |
| - if "norm" in name: |
699 |
| - tmp_states[to_name] = [states.pop(name)] |
700 |
| - elif any(x in name for x in ("wo", "w2")): |
701 |
| - tmp_states[to_name] = tmp_states.get(to_name, []) |
702 |
| - tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=1)[rank]) |
703 |
| - elif any(x in name for x in ("w1", "w3")): |
704 |
| - tmp_states[to_name] = tmp_states.get(to_name, []) |
705 |
| - tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=0)[rank]) |
706 |
| - elif any(x in name for x in ("wqkv",)): |
707 |
| - tmp_states[to_name] = tmp_states.get(to_name, []) |
708 |
| - if tp > gpc.config.model.num_kv_attention_heads: |
709 |
| - assert old_tp <= gpc.config.model.num_kv_attention_heads, ( |
710 |
| - f"`old_tp ({old_tp}) => tp ({tp})` is not supported. " |
711 |
| - "At least one of `tp` and `old_tp` should be less than or " |
712 |
| - "equal to `num_kv_attention_heads`" |
713 |
| - ) |
714 |
| - # Suitable for cases where the num_kv_attention_head is small, |
715 |
| - # but you want to have a large TP Size |
716 |
| - q_per_kv = ( |
717 |
| - gpc.config.model.num_attention_heads |
718 |
| - // gpc.config.model.num_kv_attention_heads |
719 |
| - ) |
720 |
| - head_dim = gpc.config.model.hidden_size // gpc.config.model.num_attention_heads |
721 |
| - index = torch.concat( |
722 |
| - ( |
723 |
| - torch.arange(q_per_kv).chunk(ratio, dim=0)[tp_rank % ratio], |
724 |
| - torch.tensor([q_per_kv, q_per_kv + 1]), |
725 |
| - ) |
726 |
| - ) |
727 |
| - index = index + (q_per_kv + 2) * (tp_rank // ratio) |
728 |
| - index = index % ( |
729 |
| - (q_per_kv + 2) * (gpc.config.model.num_kv_attention_heads / old_tp) |
730 |
| - ) |
731 |
| - index = index * head_dim |
732 |
| - index = index.repeat_interleave(head_dim) + torch.arange(head_dim).repeat( |
733 |
| - index.shape[0] |
734 |
| - ) |
735 |
| - tmp_states[to_name].append( |
736 |
| - torch.index_select(states.pop(name), 0, index.to(torch.int32)) |
737 |
| - ) |
738 |
| - else: |
739 |
| - tmp_states[to_name].append(states.pop(name).chunk(ratio, dim=0)[rank]) |
740 |
| - else: |
741 |
| - raise KeyError(f"Unknown key {name}.") |
742 |
| - |
743 |
| - else: |
744 |
| - assert False, "unsupported model type" |
745 |
| - |
746 |
| - if "tok_embeddings.weight" in states and model.first_layer == 0: |
747 |
| - tmp_states["tok_embeddings.weight"] = tmp_states.get("tok_embeddings.weight", []) |
748 |
| - tmp_states["tok_embeddings.weight"].append( |
749 |
| - states["tok_embeddings.weight"].chunk(ratio, dim=1)[rank] |
750 |
| - ) |
751 |
| - if "output.weight" in states and model.last_layer == gpc.config.model.num_layers: |
752 |
| - tmp_states["norm.weight"] = [states["norm.weight"]] |
753 |
| - tmp_states["output.weight"] = tmp_states.get("output.weight", []) |
754 |
| - tmp_states["output.weight"].append(states["output.weight"].chunk(ratio, dim=0)[rank]) |
755 |
| - |
756 |
| - states = {} |
757 |
| - |
758 |
| - for name in list(tmp_states.keys()): |
759 |
| - data = tmp_states.pop(name) |
760 |
| - if len(data) == 1: |
761 |
| - current_states[name] = data[0] |
762 |
| - else: |
763 |
| - current_states[name] = torch.concat( |
764 |
| - data, dim=1 if name == "tok_embeddings.weight" or any(x in name for x in ("wo", "w2")) else 0 |
765 |
| - ) |
766 |
| - # Merge copied kv heads |
767 |
| - if "wqkv" in name and old_tp > gpc.config.model.num_kv_attention_heads: |
768 |
| - assert ( |
769 |
| - tp <= gpc.config.model.num_kv_attention_heads |
770 |
| - ), "new_tp should be less than or equal to num_kv_attention_heads" |
771 |
| - head_dim = gpc.config.model.hidden_size // gpc.config.model.num_attention_heads |
772 |
| - q_per_kv = gpc.config.model.num_attention_heads // gpc.config.model.num_kv_attention_heads |
773 |
| - copied_times = old_tp // gpc.config.model.num_kv_attention_heads |
774 |
| - cur_q_per_kv = q_per_kv // copied_times |
775 |
| - |
776 |
| - # pylint: disable=all |
777 |
| - def duplicate_kv_index(i): |
778 |
| - if i % (cur_q_per_kv + 2) >= cur_q_per_kv: |
779 |
| - return i |
780 |
| - else: |
781 |
| - return -100 |
782 |
| - |
783 |
| - def unique_kv_index(i): |
784 |
| - if i // (cur_q_per_kv + 2) == copied_times - 1 or i % (cur_q_per_kv + 2) < cur_q_per_kv: |
785 |
| - return i |
786 |
| - else: |
787 |
| - return -100 |
788 |
| - |
789 |
| - # pylint: enable=all |
790 |
| - |
791 |
| - # Verify |
792 |
| - duplicate_index = [duplicate_kv_index(i) for i in range((cur_q_per_kv + 2) * copied_times)] |
793 |
| - duplicate_index = [i for i in duplicate_index if i != -100] |
794 |
| - duplicate_index = _duplicate_index = torch.tensor(duplicate_index) |
795 |
| - for i in range(gpc.config.model.num_kv_attention_heads // tp - 1): |
796 |
| - duplicate_index = torch.concat( |
797 |
| - (duplicate_index, _duplicate_index + duplicate_index.max() + 1), dim=0 |
798 |
| - ) |
799 |
| - duplicate_kv = [] |
800 |
| - for index in duplicate_index.reshape(-1, copied_times * 2).chunk(copied_times, dim=-1): |
801 |
| - index = index.reshape(-1) * head_dim |
802 |
| - index = index.repeat_interleave(head_dim) + torch.arange(head_dim).repeat(index.shape[0]) |
803 |
| - duplicate_kv.append(torch.index_select(current_states[name], 0, index)) |
804 |
| - assert reduce( |
805 |
| - lambda x, y: x and y, |
806 |
| - [torch.allclose(duplicate_kv[0], x, atol=1e-5) for x in duplicate_kv[1:]], |
807 |
| - ), "Copied kv heads are not equal after training!" |
808 |
| - |
809 |
| - # Merge |
810 |
| - unique_index = [unique_kv_index(i) for i in range((cur_q_per_kv + 2) * copied_times)] |
811 |
| - unique_index = [i for i in unique_index if i != -100] |
812 |
| - unique_index = _unique_index = torch.tensor(unique_index) |
813 |
| - for i in range(gpc.config.model.num_kv_attention_heads // tp - 1): |
814 |
| - unique_index = torch.concat((unique_index, _unique_index + unique_index.max() + 1), dim=0) |
815 |
| - unique_index = unique_index * head_dim |
816 |
| - unique_index = unique_index.repeat_interleave(head_dim) + torch.arange(head_dim).repeat( |
817 |
| - unique_index.shape[0] |
818 |
| - ) |
819 |
| - current_states[name] = torch.index_select(current_states[name], 0, unique_index) |
820 |
| - missing_keys, unexpected_keys = model.load_state_dict(current_states, strict=False) |
821 |
| - |
822 |
| - if gpc.get_local_rank(ParallelMode.DATA) == 0: |
823 |
| - pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) |
824 |
| - logger.info( |
825 |
| - f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in " |
826 |
| - f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}" |
827 |
| - ) |
828 |
| - |
829 | 636 | @staticmethod
|
830 | 637 | def convert_internevo2hf_weights(src: str, tgt: str) -> None:
|
831 | 638 | model_config = gpc.config.model
|
|
0 commit comments