@@ -21,14 +21,17 @@ class DoConfig:
2121 N : int # number of transformer block layers
2222 V : int # vocab size
2323 F : int # FF inner dimension
24- kernel_init : nn .initializers .Initializer = nn .initializers .xavier_uniform ( )
25- embed_init : nn .initializers .Initializer = nn .initializers .variance_scaling (
26- 1.0 , 'fan_in' , ' normal' , out_axis = 0
27- )
24+ attention_init : nn .initializers .Initializer = nn .initializers .normal ( stddev = 0.02 )
25+ linear_init : nn .initializers .Initializer = nn .initializers .normal ( stddev = 0.02 )
26+ embed_init : nn . initializers . Initializer = nn . initializers . normal ( stddev = 0.02 )
27+ use_residual_scaling : bool = True
2828 dtype : jnp .dtype = jnp .float32
2929 rmsnorm_epsilon : float = 1e-6
3030 multiple_of : int = 256
31- tie_embeddings : bool = True # Whether to tie input and output embeddings
31+ tie_embeddings : bool = True # Whether to tie input and output embed
32+
33+ def __post_init__ (self ):
34+ self .residual_init = nn .initializers .normal (stddev = 0.02 / jnp .sqrt (2 * self .N ))
3235
3336
3437class Mlp (nn .Module ):
@@ -40,9 +43,8 @@ class Mlp(nn.Module):
4043 def __call__ (self , x_BxLxD : jax .Array ):
4144 cfg = self .cfg
4245 # Use Xavier uniform initialization explicitly
43- xavier_init = nn .initializers .xavier_uniform ()
4446 linear = partial (
45- nn .Dense , kernel_init = xavier_init , use_bias = False , dtype = cfg .dtype
47+ nn .Dense , kernel_init = cfg . linear_init , use_bias = False , dtype = cfg .dtype
4648 )
4749 # Adjust hidden dimension to keep the number of parameters invariant to
4850 # the activation function used since the GLU MLP has 3 * hidden_dim * D
@@ -55,7 +57,7 @@ def __call__(self, x_BxLxD: jax.Array):
5557 x_BxLx2F = linear (2 * hidden_dim )(x_BxLxD )
5658 # Apply GLU activation
5759 x_BxLxF = nn .glu (x_BxLx2F , axis = - 1 )
58- x_BxLxD = linear (cfg .D )(x_BxLxF )
60+ x_BxLxD = nn . Dense (cfg .D , use_bias = False , dtype = cfg . dtype , kernel_init = cfg . residual_init if cfg . use_residual_scaling else cfg . linear_init )(x_BxLxF )
5961 return x_BxLxD
6062
6163
@@ -122,7 +124,7 @@ def setup(self):
122124 nn .DenseGeneral ,
123125 axis = - 1 ,
124126 features = (cfg .H , self .Dh ),
125- kernel_init = cfg .kernel_init ,
127+ kernel_init = cfg .attention_init ,
126128 use_bias = False ,
127129 dtype = cfg .dtype ,
128130 )
@@ -134,7 +136,7 @@ def setup(self):
134136 features = cfg .D ,
135137 name = 'attn_out_proj' ,
136138 # axis=(-2, -1), #
137- kernel_init = cfg .kernel_init ,
139+ kernel_init = cfg .residual_init if cfg . use_residual_scaling else cfg . linear_init ,
138140 use_bias = False ,
139141 dtype = cfg .dtype ,
140142 )
@@ -265,6 +267,9 @@ def predict(self, y_BxL: jax.Array, k: int = 1):
265267
266268 # Get the logits for the last token in each sequence
267269 next_token_logits = logits [:, - 1 , :]
270+ last_token_id = y_BxL [:, - 1 ]
271+ # Prevent predicting the same token consecutively
272+ next_token_logits = next_token_logits .at [jnp .arange (len (last_token_id )), last_token_id ].set (float ('-inf' ))
268273
269274 # Get the most likely token
270275 next_token = jnp .argmax (next_token_logits , axis = - 1 )
0 commit comments