Skip to content

Commit 910aeaf

Browse files
authored
[RELAX] Fix rotary embedding buffer size calculation (#18102)
* Change head_dim//2 to rotary_dim//2 in LongRope scaling * Fixes buffer size when rotary_dim differs from head_dim
1 parent 9eb8b30 commit 910aeaf

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

python/tvm/relax/frontend/nn/llm/position_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
493493
var_q: T.handle,
494494
var_k: T.handle,
495495
var_v: T.handle,
496-
ext_factors: T.Buffer((head_dim // 2,), "float32"), # type: ignore
496+
ext_factors: T.Buffer((rotary_dim // 2,), "float32"), # type: ignore
497497
):
498498
T.func_attr(
499499
{

0 commit comments

Comments
 (0)