Skip to content

Commit 8bc2b7a

Browse files
authored
Merge branch 'main' into push_causallm
2 parents 9f46d77 + 22356c0 commit 8bc2b7a

10 files changed

+36
-36
lines changed

src/levanter/models/attention.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ class AttentionMask(eqx.Module):
581581
582582
"""
583583

584-
is_causal: bool = eqx.static_field()
584+
is_causal: bool = eqx.field(static=True)
585585
explicit_mask: Optional[NamedArray] = None
586586
segment_ids: Optional[NamedArray] = None
587587
# CF https://github.com/jax-ml/jax/blob/47858c4ac2fd4757a3b6fc5bb2981b71a71f00c2/jax/experimental/pallas/ops/tpu/flash_attention.py#L34

src/levanter/models/backpack.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def from_hf_config(cls, hf_config: PretrainedConfig):
9797
class BackpackMlp(eqx.Module):
9898
c_fc: hnn.Linear # projection from Embed to Intermediate (typically 4x Embed)
9999
c_proj: hnn.Linear # projection from Intermediate to Embed
100-
act: Callable = eqx.static_field()
100+
act: Callable = eqx.field(static=True)
101101

102102
@staticmethod
103103
def init(
@@ -134,7 +134,7 @@ class WeightsOnlyAttention(ModuleWithStateDictSerialization):
134134
"""
135135

136136
# No projection
137-
config: Gpt2Config = eqx.static_field()
137+
config: Gpt2Config = eqx.field(static=True)
138138

139139
c_attn: hnn.Linear # input projection from [embed] -> [(q, k, v), heads, head_dim]
140140
dropout: hnn.Dropout
@@ -225,7 +225,7 @@ class BackpackSenses(eqx.Module):
225225
ln: hnn.LayerNorm
226226
final_mlp: BackpackMlp
227227

228-
Pos: Axis = eqx.static_field()
228+
Pos: Axis = eqx.field(static=True)
229229

230230
@staticmethod
231231
def init(
@@ -266,8 +266,8 @@ def sense_embed(self, input_embeds, *, key):
266266

267267

268268
class BackpackGpt2Embeddings(eqx.Module):
269-
Vocab: Axis = eqx.static_field()
270-
config: Gpt2Config = eqx.static_field()
269+
Vocab: Axis = eqx.field(static=True)
270+
config: Gpt2Config = eqx.field(static=True)
271271

272272
token_embeddings: NamedArray
273273
position_embeddings: NamedArray

src/levanter/models/gemma.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def __call__(self, x: NamedArray) -> NamedArray:
226226

227227

228228
class GemmaDecoderLayer(ModuleWithStateDictSerialization):
229-
config: GemmaConfig = eqx.static_field()
229+
config: GemmaConfig = eqx.field(static=True)
230230
self_attn: LlamaAttention
231231
mlp: LlamaMlp
232232
input_layernorm: GemmaRMSNorm
@@ -267,7 +267,7 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray | AttentionMask], *,
267267

268268

269269
class GemmaTransformer(ModuleWithStateDictSerialization):
270-
config: GemmaConfig = eqx.static_field()
270+
config: GemmaConfig = eqx.field(static=True)
271271
layers: BlockFoldable[GemmaDecoderLayer]
272272
norm: GemmaRMSNorm
273273

src/levanter/models/gpt2.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def flops_per_token(self, vocab_size: int) -> Optional[float]:
130130
class Gpt2Mlp(eqx.Module):
131131
c_fc: hnn.Linear # projection from Embed to Intermediate (typically 4x Embed)
132132
c_proj: hnn.Linear # projection from Intermediate to Embed
133-
act: Callable = eqx.static_field()
133+
act: Callable = eqx.field(static=True)
134134

135135
@staticmethod
136136
def init(Embed: Axis, Mlp: Axis, activation_fn, *, key, use_bias: bool = True) -> "Gpt2Mlp":
@@ -153,7 +153,7 @@ def __call__(self, x: NamedArray, *, key=None):
153153

154154

155155
class Gpt2Attention(eqx.Module):
156-
config: Gpt2Config = eqx.static_field()
156+
config: Gpt2Config = eqx.field(static=True)
157157

158158
c_attn: hnn.Linear # input projection from [embed] -> [(q, k, v), heads, head_dim]
159159
c_proj: hnn.Linear # output projection from [heads, head_dim] -> [embed]
@@ -246,7 +246,7 @@ def __call__(self, x: NamedArray, mask: Optional[AttentionMask | NamedArray], la
246246

247247

248248
class Gpt2Transformer(ModuleWithStateDictSerialization):
249-
config: Gpt2Config = eqx.static_field()
249+
config: Gpt2Config = eqx.field(static=True)
250250
blocks: Stacked[Gpt2Block]
251251
ln_f: hnn.LayerNorm
252252

@@ -274,8 +274,8 @@ def _state_dict_key_map(self) -> Dict[str, Optional[str]]:
274274

275275

276276
class Gpt2Embeddings(ModuleWithStateDictSerialization, eqx.Module):
277-
Vocab: Axis = eqx.static_field()
278-
config: Gpt2Config = eqx.static_field()
277+
Vocab: Axis = eqx.field(static=True)
278+
config: Gpt2Config = eqx.field(static=True)
279279

280280
token_embeddings: hnn.Embedding
281281
position_embeddings: hnn.Embedding

src/levanter/models/llama.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ class LlamaMlp(eqx.Module):
181181
gate_proj: hnn.Linear # projection from Embed to Mlp
182182
up_proj: hnn.Linear # projection from Embed to Mlp
183183
down_proj: hnn.Linear # projection from Mlp to Embed
184-
act: Callable = eqx.static_field()
184+
act: Callable = eqx.field(static=True)
185185

186186
@staticmethod
187187
def init(
@@ -207,7 +207,7 @@ def __call__(self, x: NamedArray, *, key=None) -> NamedArray:
207207

208208

209209
class LlamaAttention(eqx.Module):
210-
config: LlamaConfig = eqx.static_field()
210+
config: LlamaConfig = eqx.field(static=True)
211211
q_proj: hnn.Linear # projection from Embed to query
212212
k_proj: hnn.Linear # projection from Embed to key
213213
v_proj: hnn.Linear # projection from Embed to value
@@ -276,12 +276,12 @@ class LlamaRMSNorm(eqx.Module):
276276
Similar to LayerNorm, but uses the RMS of the input along the specified axis (or axes) instead of variance.
277277
"""
278278

279-
axis: AxisSpec = eqx.static_field()
279+
axis: AxisSpec = eqx.field(static=True)
280280
weight: Optional[NamedArray]
281281
bias: Optional[NamedArray]
282282

283-
eps: float = eqx.static_field(default=1e-5)
284-
dtype: Optional[jnp.dtype] = eqx.static_field(default=jnp.float32)
283+
eps: float = eqx.field(static=True, default=1e-5)
284+
dtype: Optional[jnp.dtype] = eqx.field(static=True, default=jnp.float32)
285285

286286
@staticmethod
287287
def init(axis: AxisSpec, eps: float = 1e-6, use_weight: bool = True, use_bias: bool = True, dtype=jnp.float32):
@@ -316,7 +316,7 @@ def __call__(self, x: NamedArray) -> NamedArray:
316316

317317

318318
class LlamaDecoderLayer(eqx.Module):
319-
config: LlamaConfig = eqx.static_field()
319+
config: LlamaConfig = eqx.field(static=True)
320320
self_attn: LlamaAttention
321321
mlp: LlamaMlp
322322
input_layernorm: LlamaRMSNorm
@@ -357,7 +357,7 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray | AttentionMask], *,
357357

358358

359359
class LlamaTransformer(eqx.Module):
360-
config: LlamaConfig = eqx.static_field()
360+
config: LlamaConfig = eqx.field(static=True)
361361
layers: BlockFoldable[LlamaDecoderLayer]
362362
norm: LlamaRMSNorm
363363

@@ -392,7 +392,7 @@ class LlamaEmbedding(ModuleWithStateDictSerialization, eqx.Module):
392392
- Llama doesn't use dropout.
393393
"""
394394

395-
Vocab: Axis = eqx.static_field()
395+
Vocab: Axis = eqx.field(static=True)
396396
token_embeddings: hnn.Embedding
397397

398398
@staticmethod

src/levanter/models/qwen.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def flops_per_token(self, vocab_size: int):
117117

118118
# Modified attention class for Qwen
119119
class QwenAttention(eqx.Module):
120-
config: QwenConfig = eqx.static_field()
120+
config: QwenConfig = eqx.field(static=True)
121121
q_proj: hnn.Linear
122122
k_proj: hnn.Linear
123123
v_proj: hnn.Linear
@@ -201,7 +201,7 @@ def __call__(
201201

202202
# Modified decoder layer for Qwen
203203
class QwenDecoderLayer(eqx.Module):
204-
config: QwenConfig = eqx.static_field()
204+
config: QwenConfig = eqx.field(static=True)
205205
self_attn: QwenAttention
206206
mlp: LlamaMlp # Can reuse Llama MLP as structure is similar
207207
input_layernorm: LlamaRMSNorm
@@ -242,7 +242,7 @@ def __call__(self, x: NamedArray, mask: Optional[NamedArray | AttentionMask], *,
242242

243243
# Modified transformer for Qwen
244244
class QwenTransformer(LlamaTransformer):
245-
config: QwenConfig = eqx.static_field()
245+
config: QwenConfig = eqx.field(static=True)
246246
layers: BlockFoldable[QwenDecoderLayer]
247247
norm: LlamaRMSNorm
248248

src/levanter/models/whisper.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def from_hf_config(cls, hf_config: HfConfig):
123123
class WhisperMlp(eqx.Module):
124124
fc1: hnn.Linear # projection from Embed to Intermediate (typically 4x Embed)
125125
fc2: hnn.Linear # projection from Intermediate to Embed
126-
act: Callable = eqx.static_field()
126+
act: Callable = eqx.field(static=True)
127127

128128
@staticmethod
129129
def init(Embed: Axis, Mlp: Axis, activation_fn, *, key, use_bias: bool = True) -> "WhisperMlp":
@@ -146,7 +146,7 @@ def __call__(self, x: NamedArray, *, key=None):
146146

147147

148148
class WhisperAttention(eqx.Module):
149-
config: WhisperConfig = eqx.static_field()
149+
config: WhisperConfig = eqx.field(static=True)
150150

151151
q_proj: hnn.Linear # input projection from [embed] -> [q, heads, head_dim]
152152
k_proj: hnn.Linear # input projection from [embed] -> [k, heads, head_dim]
@@ -296,10 +296,10 @@ def __call__(
296296

297297

298298
class WhisperEncoder(ModuleWithStateDictSerialization):
299-
config: WhisperConfig = eqx.static_field()
299+
config: WhisperConfig = eqx.field(static=True)
300300
conv1: hnn.Conv
301301
conv2: hnn.Conv
302-
act: Callable = eqx.static_field()
302+
act: Callable = eqx.field(static=True)
303303

304304
transformer: WhisperTransformer
305305

@@ -350,8 +350,8 @@ def _state_dict_key_map(self) -> Dict[str, Optional[str]]:
350350

351351

352352
class WhisperDecoderEmbeddings(eqx.Module):
353-
Vocab: Axis = eqx.static_field()
354-
config: WhisperConfig = eqx.static_field()
353+
Vocab: Axis = eqx.field(static=True)
354+
config: WhisperConfig = eqx.field(static=True)
355355

356356
token_embeddings: hnn.Embedding
357357
position_embeddings: hnn.Embedding

src/levanter/optim/model_averaging.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class EmaModelAveraging(ModelAveraging[M]):
3333
"""
3434

3535
model: M
36-
beta: float = eqx.static_field()
36+
beta: float = eqx.field(static=True)
3737

3838
def update(self: S, new_model: M, step: int) -> S:
3939
del step

tests/test_grad_accum.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ class Mlp(eqx.Module):
1818

1919
w_in: hax.NamedArray
2020
w_out: hax.NamedArray
21-
In: hax.Axis = eqx.static_field()
22-
Out: hax.Axis = eqx.static_field()
23-
Mid: hax.Axis = eqx.static_field()
21+
In: hax.Axis = eqx.field(static=True)
22+
Out: hax.Axis = eqx.field(static=True)
23+
Mid: hax.Axis = eqx.field(static=True)
2424

2525
@staticmethod
2626
def init(In: hax.Axis, Out: hax.Axis, Mid: hax.Axis, *, key):

tests/test_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ class MLP(eqx.Module):
3333
"""slightly less annoying MLP, used for testing purposes"""
3434

3535
layers: List[nn.Linear]
36-
activation: Callable = eqx.static_field()
37-
final_activation: Callable = eqx.static_field()
36+
activation: Callable = eqx.field(static=True)
37+
final_activation: Callable = eqx.field(static=True)
3838
in_size: int = static_field()
3939
out_size: int = static_field()
4040
width_size: int = static_field()

0 commit comments

Comments
 (0)