Skip to content

Commit b59afa0

Browse files
committed
Refactor model configuration classes to make them consistent between JAX and PyTorch, also unify initialization to be the same in both
1 parent a58fbd5 commit b59afa0

File tree

5 files changed

+180
-156
lines changed

5 files changed

+180
-156
lines changed

algoperf/workloads/lm/lm_jax/nanodo_model.py

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,32 +12,33 @@
1212

1313

1414
@dataclasses.dataclass
15-
class DoConfig:
15+
class ModelConfig:
1616
"""Hyper-parameters for Transformer decoder-only."""
1717

18-
D: int # model/embed dim = qkv dim
19-
H: int # num attention heads
20-
L: int # max context/sequence length
21-
N: int # number of transformer block layers
22-
V: int # vocab size
23-
F: int # FF inner dimension
18+
model_dim: int # model/embed dim = qkv dim
19+
num_heads: int # num attention heads
20+
seq_len: int # max context/sequence length
21+
num_layers: int # number of transformer block layers
22+
vocab_size: int # vocab size
23+
expanded_model_dim: int # FF inner dimension
24+
multiple_of: int = 256
25+
rmsnorm_epsilon: float = 1e-6
26+
use_residual_scaling: bool = True
27+
tie_embeddings: bool = True # Whether to tie input and output embed
28+
29+
dtype: jnp.dtype = jnp.float32
2430
attention_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02)
2531
linear_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02)
2632
embed_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02)
27-
use_residual_scaling: bool = True
28-
dtype: jnp.dtype = jnp.float32
29-
rmsnorm_epsilon: float = 1e-6
30-
multiple_of: int = 256
31-
tie_embeddings: bool = True # Whether to tie input and output embed
3233

3334
def __post_init__(self):
34-
self.residual_init = nn.initializers.normal(stddev=0.02/jnp.sqrt(2 * self.N))
35+
self.residual_init = nn.initializers.normal(stddev=0.02/jnp.sqrt(2 * self.num_layers))
3536

3637

3738
class Mlp(nn.Module):
3839
"""Multilayer perceptron with GLU activation."""
3940

40-
cfg: DoConfig
41+
cfg: ModelConfig
4142

4243
@nn.compact
4344
def __call__(self, x_BxLxD: jax.Array):
@@ -49,15 +50,15 @@ def __call__(self, x_BxLxD: jax.Array):
4950
# Adjust hidden dimension to keep the number of parameters invariant to
5051
# the activation function used since the GLU MLP has 3 * hidden_dim * D
5152
# parameters instead of 2 * hidden_dim * D parameters
52-
hidden_dim = cfg.F * 2 / 3
53+
hidden_dim = cfg.expanded_model_dim * 2 / 3
5354
hidden_dim = cfg.multiple_of * (
54-
(cfg.F + cfg.multiple_of - 1) // cfg.multiple_of
55+
(cfg.expanded_model_dim + cfg.multiple_of - 1) // cfg.multiple_of
5556
)
5657
# Double the hidden dimension for GLU
5758
x_BxLx2F = linear(2 * hidden_dim)(x_BxLxD)
5859
# Apply GLU activation
5960
x_BxLxF = nn.glu(x_BxLx2F, axis=-1)
60-
x_BxLxD = nn.Dense(cfg.D, use_bias=False, dtype=cfg.dtype, kernel_init=cfg.residual_init if cfg.use_residual_scaling else cfg.linear_init)(x_BxLxF)
61+
x_BxLxD = nn.Dense(cfg.model_dim, use_bias=False, dtype=cfg.dtype, kernel_init=cfg.residual_init if cfg.use_residual_scaling else cfg.linear_init)(x_BxLxF)
6162
return x_BxLxD
6263

6364

@@ -109,21 +110,21 @@ def rotate_tensor(x):
109110
class CausalAttn(nn.Module):
110111
"""Causal attention layer with rotary embeddings."""
111112

112-
cfg: DoConfig
113+
cfg: ModelConfig
113114

114115
def setup(self):
115116
cfg = self.cfg
116-
assert cfg.D % cfg.H == 0, f'D {cfg.D} not divisible by H {cfg.H}'
117-
self.Dh = cfg.D // cfg.H
117+
assert cfg.model_dim % cfg.num_heads == 0, f'D {cfg.model_dim} not divisible by H {cfg.num_heads}'
118+
self.Dh = cfg.model_dim // cfg.num_heads
118119

119120
# Initialize rotary embeddings
120-
self.freqs_cis = init_rope(cfg.D, cfg.L, cfg.H)
121+
self.freqs_cis = init_rope(cfg.model_dim, cfg.seq_len, cfg.num_heads)
121122

122123
# Maps D -> (H, Dh)
123124
self.multilinear = partial(
124125
nn.DenseGeneral,
125126
axis=-1,
126-
features=(cfg.H, self.Dh),
127+
features=(cfg.num_heads, self.Dh),
127128
kernel_init=cfg.attention_init,
128129
use_bias=False,
129130
dtype=cfg.dtype,
@@ -133,7 +134,7 @@ def setup(self):
133134
self.multilinear_key = self.multilinear(name='key')
134135
self.multilinear_value = self.multilinear(name='value')
135136
self.output_projection = nn.DenseGeneral(
136-
features=cfg.D,
137+
features=cfg.model_dim,
137138
name='attn_out_proj',
138139
# axis=(-2, -1), #
139140
kernel_init=cfg.residual_init if cfg.use_residual_scaling else cfg.linear_init,
@@ -183,7 +184,7 @@ def __call__(self, x_BxLxD: jax.Array):
183184
class TBlock(nn.Module):
184185
"""Transformer Block."""
185186

186-
docfg: DoConfig
187+
docfg: ModelConfig
187188

188189
@nn.compact
189190
def __call__(self, in_BxLxD: jax.Array):
@@ -208,25 +209,25 @@ def __call__(self, in_BxLxD: jax.Array):
208209
class TransformerDo(nn.Module):
209210
"""Transformer decoder-only."""
210211

211-
docfg: DoConfig
212+
docfg: ModelConfig
212213

213214
def setup(self):
214215
cfg = self.docfg
215216
self.embed = nn.Embed(
216-
num_embeddings=cfg.V,
217-
features=cfg.D,
217+
num_embeddings=cfg.vocab_size,
218+
features=cfg.model_dim,
218219
embedding_init=cfg.embed_init,
219220
)
220221

221-
self.blocks = [TBlock(cfg) for _ in range(cfg.N)]
222+
self.blocks = [TBlock(cfg) for _ in range(cfg.num_layers)]
222223
self.out_ln = nn.RMSNorm(param_dtype=cfg.dtype, epsilon=cfg.rmsnorm_epsilon)
223224

224225
# Output projection - tied to input embeddings if configured
225226
if cfg.tie_embeddings:
226227
self.output_proj = lambda x: self.embed.attend(x.astype(jnp.float32))
227228
else:
228229
self.output_proj = nn.Dense(
229-
cfg.V, kernel_init=cfg.embed_init, dtype=cfg.dtype, name='output_proj'
230+
cfg.vocab_size, kernel_init=cfg.embed_init, dtype=cfg.dtype, name='output_proj'
230231
)
231232

232233
def __call__(self, y_BxL: jax.Array):
@@ -255,9 +256,9 @@ def predict(self, y_BxL: jax.Array, k: int = 1):
255256
original_input = y_BxL
256257

257258
# Make sure we don't exceed the model's context length
258-
if seq_len + k > cfg.L:
259+
if seq_len + k > cfg.seq_len:
259260
raise ValueError(
260-
f"Total sequence length ({seq_len + k}) exceeds model's context length ({cfg.L})"
261+
f"Total sequence length ({seq_len + k}) exceeds model's context length ({cfg.seq_len})"
261262
)
262263

263264
# Generate k tokens autoregressively
@@ -288,25 +289,25 @@ def main():
288289
"""Create and run the DecoderOnly Transformer model."""
289290
# Initialize model configuration with smaller parameters for demo
290291
B, L = (2, 128) # Batch size, sequence length
291-
cfg = DoConfig(D=128, H=4, L=L, N=2, V=256, F=4 * 128)
292+
cfg = ModelConfig(model_dim=128, num_heads=4, seq_len=L, num_layers=2, vocab_size=256, expanded_model_dim=4 * 128)
292293
model = TransformerDo(cfg)
293294

294295
# Print model info
295296
print('\nModel Configuration:')
296-
print(f' - Model dimension (D): {cfg.D}')
297-
print(f' - Number of heads (H): {cfg.H}')
298-
print(f' - Max sequence length (L): {cfg.L}')
299-
print(f' - Number of layers (N): {cfg.N}')
300-
print(f' - Vocabulary size (V): {cfg.V}')
301-
print(f' - Feed forward dimension (F): {cfg.F}')
297+
print(f' - Model dimension (D): {cfg.model_dim}')
298+
print(f' - Number of heads (H): {cfg.num_heads}')
299+
print(f' - Max sequence length (L): {cfg.seq_len}')
300+
print(f' - Number of layers (N): {cfg.num_layers}')
301+
print(f' - Vocabulary size (V): {cfg.vocab_size}')
302+
print(f' - Feed forward dimension (F): {cfg.expanded_model_dim}')
302303

303304
# Create random input tokens (simulated token IDs)
304305
rng_key = jax.random.PRNGKey(42)
305306
input_rng, init_rng = jax.random.split(rng_key)
306307

307308
# Generate random token IDs (integers between 0 and vocab_size-1)
308309
x_BxL = jax.random.randint(
309-
input_rng, shape=(B, L), minval=0, maxval=cfg.V, dtype=jnp.int32
310+
input_rng, shape=(B, L), minval=0, maxval=cfg.vocab_size, dtype=jnp.int32
310311
)
311312

312313
# Initialize model parameters

algoperf/workloads/lm/lm_jax/workload.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from algoperf import jax_sharding_utils, param_utils, spec
99
from algoperf.workloads.lm.input_pipeline import get_data_iter
1010
from algoperf.workloads.lm.lm_jax.nanodo_model import (
11-
DoConfig,
11+
ModelConfig,
1212
TransformerDo,
1313
)
1414
from algoperf.workloads.lm.workload import BaseLmWorkload
@@ -46,13 +46,13 @@ def init_model_fn(
4646
aux_dropout_rate: Optional[float] = None,
4747
) -> spec.ModelInitState:
4848
# Initialize NanoDO transformer model
49-
cfg = DoConfig(
50-
D=self._emb_dim, # embedding dim
51-
H=self._n_heads, # num heads
52-
L=self._seq_len,
53-
N=self._n_layers, # num layers
54-
V=self._vocab_size,
55-
F=self._mlp_dim, # feedforward dim
49+
cfg = ModelConfig(
50+
model_dim=self._emb_dim, # embedding dim
51+
num_heads=self._n_heads, # num heads
52+
seq_len=self._seq_len,
53+
num_layers=self._n_layers, # num layers
54+
vocab_size=self._vocab_size,
55+
expanded_model_dim=self._mlp_dim, # feedforward dim
5656
dtype=jnp.float32,
5757
)
5858
self._model = TransformerDo(cfg)

algoperf/workloads/lm/lm_pytorch/plainlm_model.py

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,16 @@
1515

1616
@dataclass
1717
class ModelConfig:
18-
vocab_size: int
18+
model_dim: int
19+
num_heads: int
1920
seq_len: int
20-
dim: int
21-
expand: float
22-
n_layers: int
23-
n_heads: int
24-
rmsnorm_eps: float = 1e-6
25-
tie_embeddings: bool = True
21+
num_layers: int
22+
vocab_size: int
23+
expanded_model_dim: int
24+
multiple_of: int = 256
25+
rmsnorm_epsilon: float = 1e-6
2626
use_residual_scaling: bool = True
27+
tie_embeddings: bool = True
2728

2829

2930
class MLP(nn.Module):
@@ -81,13 +82,13 @@ def apply_rotary_emb_complex_like(
8182
class Attention(nn.Module):
8283
def __init__(self, cfg: ModelConfig):
8384
super().__init__()
84-
assert cfg.dim % cfg.n_heads == 0
85-
self.dim = cfg.dim
86-
self.n_heads = cfg.n_heads
87-
self.head_dim = cfg.dim // cfg.n_heads
85+
assert cfg.model_dim % cfg.num_heads == 0
86+
self.dim = cfg.model_dim
87+
self.n_heads = cfg.num_heads
88+
self.head_dim = cfg.model_dim // cfg.num_heads
8889

89-
self.w_qkv = nn.Linear(cfg.dim, 3 * cfg.dim, bias=False)
90-
self.w_out = nn.Linear(cfg.dim, cfg.dim, bias=False)
90+
self.w_qkv = nn.Linear(cfg.model_dim, 3 * cfg.model_dim, bias=False)
91+
self.w_out = nn.Linear(cfg.model_dim, cfg.model_dim, bias=False)
9192
# Split into Q, K, V sections
9293
wq, wk, wv = torch.chunk(self.w_qkv.weight, 3, dim=0)
9394
for w in [wq, wk, wv]:
@@ -131,9 +132,9 @@ class Block(nn.Module):
131132
def __init__(self, layer_id: int, cfg: ModelConfig):
132133
super().__init__()
133134
self.attn = Attention(cfg)
134-
self.attn_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps)
135-
self.mlp = MLP(dim=cfg.dim, hidden_dim=int(cfg.expand * cfg.dim))
136-
self.mlp_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps)
135+
self.attn_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon)
136+
self.mlp = MLP(dim=cfg.model_dim, hidden_dim=cfg.expanded_model_dim, multiple_of=cfg.multiple_of)
137+
self.mlp_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon)
137138
self.layer_id = layer_id
138139

139140
def forward(self, x, freqs_cis):
@@ -144,19 +145,19 @@ def forward(self, x, freqs_cis):
144145

145146

146147
class Transformer(nn.Module):
147-
def __init__(self, cfg):
148+
def __init__(self, cfg: ModelConfig):
148149
super().__init__()
149-
self.n_layers = cfg.n_layers
150+
self.n_layers = cfg.num_layers
150151
self.cfg = cfg
151-
head_dim = cfg.dim // cfg.n_heads
152-
assert cfg.dim % cfg.n_heads == 0
152+
head_dim = cfg.model_dim // cfg.num_heads
153+
assert cfg.model_dim % cfg.num_heads == 0
153154

154-
self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.dim)
155+
self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.model_dim)
155156
self.layers = nn.ModuleList(
156-
[Block(idx, cfg) for idx in range(cfg.n_layers)]
157+
[Block(idx, cfg) for idx in range(cfg.num_layers)]
157158
)
158-
self.out_norm = nn.RMSNorm(cfg.dim, eps=cfg.rmsnorm_eps)
159-
self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False)
159+
self.out_norm = nn.RMSNorm(cfg.model_dim, eps=cfg.rmsnorm_epsilon)
160+
self.lm_head = nn.Linear(cfg.model_dim, cfg.vocab_size, bias=False)
160161

161162
# Initialize freqs_cis on CPU first (more memory efficient)
162163
self.register_buffer(
@@ -184,7 +185,7 @@ def forward(self, x, targets=None):
184185
# Make sure we have enough precomputed frequencies
185186
if L > self.freqs_cis.shape[1]:
186187
# Need to recompute for longer sequence
187-
head_dim = self.cfg.dim // self.cfg.n_heads
188+
head_dim = self.cfg.model_dim // self.cfg.num_heads
188189
new_freqs = precompute_freqs_cis(
189190
head_dim, max(L, self.cfg.seq_len), 500000
190191
)
@@ -290,11 +291,11 @@ def main():
290291
config = ModelConfig(
291292
vocab_size=50257, # Common vocab size for tokenizers like BPE or SentencePiece
292293
seq_len=seq_length, # Maximum sequence length
293-
dim=1024, # Embedding dimension
294-
expand=4.0, # MLP expansion factor
295-
n_layers=12, # Number of transformer layers
296-
n_heads=8, # Number of attention heads
297-
rmsnorm_eps=1e-6, # RMSNorm epsilon
294+
model_dim=1024, # Embedding dimension
295+
expanded_model_dim=4.0, # MLP expansion factor
296+
num_layers=12, # Number of transformer layers
297+
num_heads=8, # Number of attention heads
298+
rmsnorm_epsilon=1e-6, # RMSNorm epsilon
298299
tie_embeddings=True, # Tie embedding and output weights
299300
)
300301

algoperf/workloads/lm/lm_pytorch/workload.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,11 @@ def init_model_fn(
3939
cfg = ModelConfig(
4040
vocab_size=self._vocab_size,
4141
seq_len=self._seq_len,
42-
dim=self._emb_dim, # Model dimension
43-
expand=self._mlp_dim // self._emb_dim, # MLP expansion factor
44-
# FIXME(rka97): fix expansion factor
45-
n_layers=self._n_layers, # Number of transformer layers
46-
n_heads=self._n_heads, # Number of attention heads
47-
rmsnorm_eps=1e-6,
42+
model_dim=self._emb_dim, # Model dimension
43+
expanded_model_dim=self._mlp_dim, # MLP expansion factor
44+
num_layers=self._n_layers, # Number of transformer layers
45+
num_heads=self._n_heads, # Number of attention heads
46+
rmsnorm_epsilon=1e-6,
4847
tie_embeddings=True,
4948
)
5049
self._model = Transformer(cfg)

0 commit comments

Comments
 (0)