Skip to content

Commit a58fbd5

Browse files
committed
Fix init in both models to be the same, add lm model diff test
1 parent bb4a380 commit a58fbd5

File tree

3 files changed

+893
-20
lines changed

3 files changed

+893
-20
lines changed

algoperf/workloads/lm/lm_jax/nanodo_model.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,17 @@ class DoConfig:
2121
N: int # number of transformer block layers
2222
V: int # vocab size
2323
F: int # FF inner dimension
24-
kernel_init: nn.initializers.Initializer = nn.initializers.xavier_uniform()
25-
embed_init: nn.initializers.Initializer = nn.initializers.variance_scaling(
26-
1.0, 'fan_in', 'normal', out_axis=0
27-
)
24+
attention_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02)
25+
linear_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02)
26+
embed_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02)
27+
use_residual_scaling: bool = True
2828
dtype: jnp.dtype = jnp.float32
2929
rmsnorm_epsilon: float = 1e-6
3030
multiple_of: int = 256
31-
tie_embeddings: bool = True # Whether to tie input and output embeddings
31+
tie_embeddings: bool = True # Whether to tie input and output embed
32+
33+
def __post_init__(self):
34+
self.residual_init = nn.initializers.normal(stddev=0.02/jnp.sqrt(2 * self.N))
3235

3336

3437
class Mlp(nn.Module):
@@ -40,9 +43,8 @@ class Mlp(nn.Module):
4043
def __call__(self, x_BxLxD: jax.Array):
4144
cfg = self.cfg
4245
# Use Xavier uniform initialization explicitly
43-
xavier_init = nn.initializers.xavier_uniform()
4446
linear = partial(
45-
nn.Dense, kernel_init=xavier_init, use_bias=False, dtype=cfg.dtype
47+
nn.Dense, kernel_init=cfg.linear_init, use_bias=False, dtype=cfg.dtype
4648
)
4749
# Adjust hidden dimension to keep the number of parameters invariant to
4850
# the activation function used since the GLU MLP has 3 * hidden_dim * D
@@ -55,7 +57,7 @@ def __call__(self, x_BxLxD: jax.Array):
5557
x_BxLx2F = linear(2 * hidden_dim)(x_BxLxD)
5658
# Apply GLU activation
5759
x_BxLxF = nn.glu(x_BxLx2F, axis=-1)
58-
x_BxLxD = linear(cfg.D)(x_BxLxF)
60+
x_BxLxD = nn.Dense(cfg.D, use_bias=False, dtype=cfg.dtype, kernel_init=cfg.residual_init if cfg.use_residual_scaling else cfg.linear_init)(x_BxLxF)
5961
return x_BxLxD
6062

6163

@@ -122,7 +124,7 @@ def setup(self):
122124
nn.DenseGeneral,
123125
axis=-1,
124126
features=(cfg.H, self.Dh),
125-
kernel_init=cfg.kernel_init,
127+
kernel_init=cfg.attention_init,
126128
use_bias=False,
127129
dtype=cfg.dtype,
128130
)
@@ -134,7 +136,7 @@ def setup(self):
134136
features=cfg.D,
135137
name='attn_out_proj',
136138
# axis=(-2, -1), #
137-
kernel_init=cfg.kernel_init,
139+
kernel_init=cfg.residual_init if cfg.use_residual_scaling else cfg.linear_init,
138140
use_bias=False,
139141
dtype=cfg.dtype,
140142
)
@@ -265,6 +267,9 @@ def predict(self, y_BxL: jax.Array, k: int = 1):
265267

266268
# Get the logits for the last token in each sequence
267269
next_token_logits = logits[:, -1, :]
270+
last_token_id = y_BxL[:, -1]
271+
# Prevent predicting the same token consecutively
272+
next_token_logits = next_token_logits.at[jnp.arange(len(last_token_id)), last_token_id].set(float('-inf'))
268273

269274
# Get the most likely token
270275
next_token = jnp.argmax(next_token_logits, axis=-1)

algoperf/workloads/lm/lm_pytorch/plainlm_model.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class ModelConfig:
2323
n_heads: int
2424
rmsnorm_eps: float = 1e-6
2525
tie_embeddings: bool = True
26+
use_residual_scaling: bool = True
2627

2728

2829
class MLP(nn.Module):
@@ -32,10 +33,8 @@ def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256):
3233
self.fc1 = nn.Linear(dim, 2 * hidden_dim, bias=False)
3334
self.fc2 = nn.Linear(hidden_dim, dim, bias=False)
3435
self.glu = nn.GLU(dim=2)
35-
36-
# Initialize with Xavier uniform
37-
nn.init.xavier_uniform_(self.fc1.weight)
38-
nn.init.xavier_uniform_(self.fc2.weight)
36+
nn.init.normal_(self.fc1.weight, std=0.02)
37+
nn.init.normal_(self.fc2.weight, std=0.02)
3938

4039
def forward(self, x):
4140
# x: (bsz, T, dim)
@@ -89,6 +88,11 @@ def __init__(self, cfg: ModelConfig):
8988

9089
self.w_qkv = nn.Linear(cfg.dim, 3 * cfg.dim, bias=False)
9190
self.w_out = nn.Linear(cfg.dim, cfg.dim, bias=False)
91+
# Split into Q, K, V sections
92+
wq, wk, wv = torch.chunk(self.w_qkv.weight, 3, dim=0)
93+
for w in [wq, wk, wv]:
94+
nn.init.normal_(w, std=0.02)
95+
nn.init.normal_(self.w_out.weight, std=0.02)
9296

9397
def forward(self, x, freqs_cis):
9498
bsz, seqlen, d = x.shape # (bsz, seqlen, d)
@@ -254,15 +258,11 @@ def _init_weights(self, module):
254258
if module.bias is not None:
255259
torch.nn.init.zeros_(module.bias)
256260
elif isinstance(module, nn.Embedding):
257-
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
261+
torch.nn.init.normal_(module.weight, std=0.02)
258262

259263
def _scale_residual_branches(self):
260264
for n, p in self.named_parameters():
261-
if n.endswith('fc2.weight'): # mlp/glu output layer
262-
torch.nn.init.normal_(
263-
p, mean=0.0, std=0.02 / math.sqrt(2 * self.n_layers)
264-
)
265-
if n.endswith('w_out.weight'): # attn output layer
265+
if n.endswith('fc2.weight') or n.endswith('w_out.weight'): # mlp/glu output layer
266266
torch.nn.init.normal_(
267267
p, mean=0.0, std=0.02 / math.sqrt(2 * self.n_layers)
268268
)

0 commit comments

Comments
 (0)