@@ -181,7 +181,7 @@ class LlamaMlp(eqx.Module):
181
181
gate_proj : hnn .Linear # projection from Embed to Mlp
182
182
up_proj : hnn .Linear # projection from Embed to Mlp
183
183
down_proj : hnn .Linear # projection from Mlp to Embed
184
- act : Callable = eqx .static_field ( )
184
+ act : Callable = eqx .field ( static = True )
185
185
186
186
@staticmethod
187
187
def init (
@@ -207,7 +207,7 @@ def __call__(self, x: NamedArray, *, key=None) -> NamedArray:
207
207
208
208
209
209
class LlamaAttention (eqx .Module ):
210
- config : LlamaConfig = eqx .static_field ( )
210
+ config : LlamaConfig = eqx .field ( static = True )
211
211
q_proj : hnn .Linear # projection from Embed to query
212
212
k_proj : hnn .Linear # projection from Embed to key
213
213
v_proj : hnn .Linear # projection from Embed to value
@@ -276,12 +276,12 @@ class LlamaRMSNorm(eqx.Module):
276
276
Similar to LayerNorm, but uses the RMS of the input along the specified axis (or axes) instead of variance.
277
277
"""
278
278
279
- axis : AxisSpec = eqx .static_field ( )
279
+ axis : AxisSpec = eqx .field ( static = True )
280
280
weight : Optional [NamedArray ]
281
281
bias : Optional [NamedArray ]
282
282
283
- eps : float = eqx .static_field ( default = 1e-5 )
284
- dtype : Optional [jnp .dtype ] = eqx .static_field ( default = jnp .float32 )
283
+ eps : float = eqx .field ( static = True , default = 1e-5 )
284
+ dtype : Optional [jnp .dtype ] = eqx .field ( static = True , default = jnp .float32 )
285
285
286
286
@staticmethod
287
287
def init (axis : AxisSpec , eps : float = 1e-6 , use_weight : bool = True , use_bias : bool = True , dtype = jnp .float32 ):
@@ -316,7 +316,7 @@ def __call__(self, x: NamedArray) -> NamedArray:
316
316
317
317
318
318
class LlamaDecoderLayer (eqx .Module ):
319
- config : LlamaConfig = eqx .static_field ( )
319
+ config : LlamaConfig = eqx .field ( static = True )
320
320
self_attn : LlamaAttention
321
321
mlp : LlamaMlp
322
322
input_layernorm : LlamaRMSNorm
@@ -357,7 +357,7 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray | AttentionMask], *,
357
357
358
358
359
359
class LlamaTransformer (eqx .Module ):
360
- config : LlamaConfig = eqx .static_field ( )
360
+ config : LlamaConfig = eqx .field ( static = True )
361
361
layers : BlockFoldable [LlamaDecoderLayer ]
362
362
norm : LlamaRMSNorm
363
363
@@ -392,7 +392,7 @@ class LlamaEmbedding(ModuleWithStateDictSerialization, eqx.Module):
392
392
- Llama doesn't use dropout.
393
393
"""
394
394
395
- Vocab : Axis = eqx .static_field ( )
395
+ Vocab : Axis = eqx .field ( static = True )
396
396
token_embeddings : hnn .Embedding
397
397
398
398
@staticmethod
0 commit comments