Skip to content

Commit cd99740

Browse files
authored
[Fix] Lib with small max_seq_len incompatible with prebuilt weight (#840)
This PR fixes an issue introduced by #780, which broke our intended behavior to make the cos/sin shape independent of the max sequence length, so that no matter what max sequence length people use, they can always use a same set of prebuilt weight and do not need to clone different weight repositories. This intended behavior is broken by #780. However, it is true that the needs for larger max sequence length are growing. Prior to #780, when the max sequence length is larger than 2048, the cached cos/sin do not work anymore and break. To be compatible as much as possible, this PR changes the behavior to "taking the maximum value of 2048 and the specified max sequence length when building the model lib". With this fix, when the maximum sequence length is smaller than 2048, we are still able to use the prebuilt weights. And when it is larger than 2048, we will only be able to use the weight converted along the build.
1 parent 8bd6918 commit cd99740

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

mlc_llm/relax_model/llama.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -577,13 +577,13 @@ def __init__(self, config: LlamaConfig, sep_embed: bool = False):
577577
assert config.hidden_size % config.num_attention_heads == 0
578578
head_dim = config.hidden_size // config.num_attention_heads
579579

580-
# Set the cached sin/cos to the max seq len.
580+
# Set the cached sin/cos to the maximum of 2048 and max seq len.
581581
# This will be eliminated further with online rotary embedding calculation.
582582
self.cos_cached = nn.Parameter(
583-
(config.max_sequence_length, head_dim), dtype=config.dtype, name="cos_cached"
583+
(max(config.max_sequence_length, 2048), head_dim), dtype=config.dtype, name="cos_cached"
584584
)
585585
self.sin_cached = nn.Parameter(
586-
(config.max_sequence_length, head_dim), dtype=config.dtype, name="sin_cached"
586+
(max(config.max_sequence_length, 2048), head_dim), dtype=config.dtype, name="sin_cached"
587587
)
588588
############ End ############
589589

@@ -892,9 +892,9 @@ def f_compute_relax_param(relax_pname: str, torch_params: List[Any]):
892892
inv_freq = 1.0 / (
893893
config.position_embedding_base ** (np.arange(0, head_dim, 2).astype("float32") / head_dim)
894894
)
895-
# Set the cached sin/cos to the max sequence length.
895+
# Set the cached sin/cos to the maximum of 2048 and max sequence length.
896896
# This will be eliminated further with online rotary embedding calculation.
897-
t = np.arange(config.max_sequence_length, dtype=inv_freq.dtype)
897+
t = np.arange(max(config.max_sequence_length, 2048), dtype=inv_freq.dtype)
898898
freqs = np.einsum("i,j->ij", t, inv_freq)
899899
emb = np.concatenate((freqs, freqs), axis=-1)
900900
param_list[-2] = tvm.nd.array(np.cos(emb).astype(config.dtype), device)

mlc_llm/transform/fuse_split_rotary_embedding.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
def get_split_rotary(num_attention_heads, head_dim, max_sequence_length=2048):
1717
hidden_size = num_attention_heads * head_dim
18+
max_sequence_length = max(max_sequence_length, 2048)
1819

1920
@T.prim_func
2021
def split_rotary(
@@ -77,6 +78,7 @@ def split_rotary(
7778

7879
def fuse_split_rotary_embedding(mod, num_attention_heads, hidden_size, max_sequence_length=2048):
7980
head_dim = hidden_size // num_attention_heads
81+
max_sequence_length = max(max_sequence_length, 2048)
8082

8183
mod["split_rotary"] = get_split_rotary(num_attention_heads, head_dim, max_sequence_length)
8284

0 commit comments

Comments
 (0)