@@ -104,9 +104,10 @@ def _get_model_spec_lm(opt, variables, src_vocabs, tgt_vocabs, num_source_embedd
104
104
activation_fn = getattr (opt , "pos_ffn_activation_fn" , "relu" )
105
105
num_heads = getattr (opt , "heads" , 8 )
106
106
num_kv = getattr (opt , "num_kv" , 0 )
107
- if num_kv == num_heads :
107
+ if num_kv == num_heads or num_kv == 0 :
108
108
num_kv = None
109
109
rotary_dim = 0 if with_rotary else None
110
+ rotary_interleave = getattr (opt , "rotary_interleave" , True )
110
111
ffn_glu = activation_fn == "silu"
111
112
sliding_window = getattr (opt , "sliding_window" , 0 )
112
113
@@ -119,7 +120,7 @@ def _get_model_spec_lm(opt, variables, src_vocabs, tgt_vocabs, num_source_embedd
119
120
alibi = with_alibi ,
120
121
rms_norm = opt .layer_norm == "rms" ,
121
122
rotary_dim = rotary_dim ,
122
- rotary_interleave = True ,
123
+ rotary_interleave = rotary_interleave ,
123
124
multi_query_attention = getattr (opt , "multiquery" , False ),
124
125
num_heads_kv = num_kv ,
125
126
sliding_window = sliding_window ,
@@ -329,7 +330,7 @@ def set_linear(spec, variables, scope):
329
330
spec .weight = _get_variable (variables , "%s.weight" % scope )
330
331
bias = variables .get ("%s.bias" % scope )
331
332
if bias is not None :
332
- spec .bias = bias . numpy ()
333
+ spec .bias = bias
333
334
334
335
335
336
def set_embeddings (spec , variables , scope ):
@@ -341,7 +342,7 @@ def set_position_encodings(spec, variables, scope):
341
342
342
343
343
344
def _get_variable (variables , name ):
344
- return variables [name ]. numpy ()
345
+ return variables [name ]
345
346
346
347
347
348
def main ():
0 commit comments