1212
1313
1414@dataclasses .dataclass
15- class DoConfig :
15+ class ModelConfig :
1616 """Hyper-parameters for Transformer decoder-only."""
1717
18- D : int # model/embed dim = qkv dim
19- H : int # num attention heads
20- L : int # max context/sequence length
21- N : int # number of transformer block layers
22- V : int # vocab size
23- F : int # FF inner dimension
18+ model_dim : int # model/embed dim = qkv dim
19+ num_heads : int # num attention heads
20+ seq_len : int # max context/sequence length
21+ num_layers : int # number of transformer block layers
22+ vocab_size : int # vocab size
23+ expanded_model_dim : int # FF inner dimension
24+ multiple_of : int = 256
25+ rmsnorm_epsilon : float = 1e-6
26+ use_residual_scaling : bool = True
27+ tie_embeddings : bool = True # Whether to tie input and output embed
28+
29+ dtype : jnp .dtype = jnp .float32
2430 attention_init : nn .initializers .Initializer = nn .initializers .normal (stddev = 0.02 )
2531 linear_init : nn .initializers .Initializer = nn .initializers .normal (stddev = 0.02 )
2632 embed_init : nn .initializers .Initializer = nn .initializers .normal (stddev = 0.02 )
27- use_residual_scaling : bool = True
28- dtype : jnp .dtype = jnp .float32
29- rmsnorm_epsilon : float = 1e-6
30- multiple_of : int = 256
31- tie_embeddings : bool = True # Whether to tie input and output embed
3233
3334 def __post_init__ (self ):
34- self .residual_init = nn .initializers .normal (stddev = 0.02 / jnp .sqrt (2 * self .N ))
35+ self .residual_init = nn .initializers .normal (stddev = 0.02 / jnp .sqrt (2 * self .num_layers ))
3536
3637
3738class Mlp (nn .Module ):
3839 """Multilayer perceptron with GLU activation."""
3940
40- cfg : DoConfig
41+ cfg : ModelConfig
4142
4243 @nn .compact
4344 def __call__ (self , x_BxLxD : jax .Array ):
@@ -49,15 +50,15 @@ def __call__(self, x_BxLxD: jax.Array):
4950 # Adjust hidden dimension to keep the number of parameters invariant to
5051 # the activation function used since the GLU MLP has 3 * hidden_dim * D
5152 # parameters instead of 2 * hidden_dim * D parameters
52- hidden_dim = cfg .F * 2 / 3
53+ hidden_dim = cfg .expanded_model_dim * 2 / 3
5354 hidden_dim = cfg .multiple_of * (
54- (cfg .F + cfg .multiple_of - 1 ) // cfg .multiple_of
55+ (cfg .expanded_model_dim + cfg .multiple_of - 1 ) // cfg .multiple_of
5556 )
5657 # Double the hidden dimension for GLU
5758 x_BxLx2F = linear (2 * hidden_dim )(x_BxLxD )
5859 # Apply GLU activation
5960 x_BxLxF = nn .glu (x_BxLx2F , axis = - 1 )
60- x_BxLxD = nn .Dense (cfg .D , use_bias = False , dtype = cfg .dtype , kernel_init = cfg .residual_init if cfg .use_residual_scaling else cfg .linear_init )(x_BxLxF )
61+ x_BxLxD = nn .Dense (cfg .model_dim , use_bias = False , dtype = cfg .dtype , kernel_init = cfg .residual_init if cfg .use_residual_scaling else cfg .linear_init )(x_BxLxF )
6162 return x_BxLxD
6263
6364
@@ -109,21 +110,21 @@ def rotate_tensor(x):
109110class CausalAttn (nn .Module ):
110111 """Causal attention layer with rotary embeddings."""
111112
112- cfg : DoConfig
113+ cfg : ModelConfig
113114
114115 def setup (self ):
115116 cfg = self .cfg
116- assert cfg .D % cfg .H == 0 , f'D { cfg .D } not divisible by H { cfg .H } '
117- self .Dh = cfg .D // cfg .H
117+ assert cfg .model_dim % cfg .num_heads == 0 , f'D { cfg .model_dim } not divisible by H { cfg .num_heads } '
118+ self .Dh = cfg .model_dim // cfg .num_heads
118119
119120 # Initialize rotary embeddings
120- self .freqs_cis = init_rope (cfg .D , cfg .L , cfg .H )
121+ self .freqs_cis = init_rope (cfg .model_dim , cfg .seq_len , cfg .num_heads )
121122
122123 # Maps D -> (H, Dh)
123124 self .multilinear = partial (
124125 nn .DenseGeneral ,
125126 axis = - 1 ,
126- features = (cfg .H , self .Dh ),
127+ features = (cfg .num_heads , self .Dh ),
127128 kernel_init = cfg .attention_init ,
128129 use_bias = False ,
129130 dtype = cfg .dtype ,
@@ -133,7 +134,7 @@ def setup(self):
133134 self .multilinear_key = self .multilinear (name = 'key' )
134135 self .multilinear_value = self .multilinear (name = 'value' )
135136 self .output_projection = nn .DenseGeneral (
136- features = cfg .D ,
137+ features = cfg .model_dim ,
137138 name = 'attn_out_proj' ,
138139 # axis=(-2, -1), #
139140 kernel_init = cfg .residual_init if cfg .use_residual_scaling else cfg .linear_init ,
@@ -183,7 +184,7 @@ def __call__(self, x_BxLxD: jax.Array):
183184class TBlock (nn .Module ):
184185 """Transformer Block."""
185186
186- docfg : DoConfig
187+ docfg : ModelConfig
187188
188189 @nn .compact
189190 def __call__ (self , in_BxLxD : jax .Array ):
@@ -208,25 +209,25 @@ def __call__(self, in_BxLxD: jax.Array):
208209class TransformerDo (nn .Module ):
209210 """Transformer decoder-only."""
210211
211- docfg : DoConfig
212+ docfg : ModelConfig
212213
213214 def setup (self ):
214215 cfg = self .docfg
215216 self .embed = nn .Embed (
216- num_embeddings = cfg .V ,
217- features = cfg .D ,
217+ num_embeddings = cfg .vocab_size ,
218+ features = cfg .model_dim ,
218219 embedding_init = cfg .embed_init ,
219220 )
220221
221- self .blocks = [TBlock (cfg ) for _ in range (cfg .N )]
222+ self .blocks = [TBlock (cfg ) for _ in range (cfg .num_layers )]
222223 self .out_ln = nn .RMSNorm (param_dtype = cfg .dtype , epsilon = cfg .rmsnorm_epsilon )
223224
224225 # Output projection - tied to input embeddings if configured
225226 if cfg .tie_embeddings :
226227 self .output_proj = lambda x : self .embed .attend (x .astype (jnp .float32 ))
227228 else :
228229 self .output_proj = nn .Dense (
229- cfg .V , kernel_init = cfg .embed_init , dtype = cfg .dtype , name = 'output_proj'
230+ cfg .vocab_size , kernel_init = cfg .embed_init , dtype = cfg .dtype , name = 'output_proj'
230231 )
231232
232233 def __call__ (self , y_BxL : jax .Array ):
@@ -255,9 +256,9 @@ def predict(self, y_BxL: jax.Array, k: int = 1):
255256 original_input = y_BxL
256257
257258 # Make sure we don't exceed the model's context length
258- if seq_len + k > cfg .L :
259+ if seq_len + k > cfg .seq_len :
259260 raise ValueError (
260- f"Total sequence length ({ seq_len + k } ) exceeds model's context length ({ cfg .L } )"
261+ f"Total sequence length ({ seq_len + k } ) exceeds model's context length ({ cfg .seq_len } )"
261262 )
262263
263264 # Generate k tokens autoregressively
@@ -288,25 +289,25 @@ def main():
288289 """Create and run the DecoderOnly Transformer model."""
289290 # Initialize model configuration with smaller parameters for demo
290291 B , L = (2 , 128 ) # Batch size, sequence length
291- cfg = DoConfig ( D = 128 , H = 4 , L = L , N = 2 , V = 256 , F = 4 * 128 )
292+ cfg = ModelConfig ( model_dim = 128 , num_heads = 4 , seq_len = L , num_layers = 2 , vocab_size = 256 , expanded_model_dim = 4 * 128 )
292293 model = TransformerDo (cfg )
293294
294295 # Print model info
295296 print ('\n Model Configuration:' )
296- print (f' - Model dimension (D): { cfg .D } ' )
297- print (f' - Number of heads (H): { cfg .H } ' )
298- print (f' - Max sequence length (L): { cfg .L } ' )
299- print (f' - Number of layers (N): { cfg .N } ' )
300- print (f' - Vocabulary size (V): { cfg .V } ' )
301- print (f' - Feed forward dimension (F): { cfg .F } ' )
297+ print (f' - Model dimension (D): { cfg .model_dim } ' )
298+ print (f' - Number of heads (H): { cfg .num_heads } ' )
299+ print (f' - Max sequence length (L): { cfg .seq_len } ' )
300+ print (f' - Number of layers (N): { cfg .num_layers } ' )
301+ print (f' - Vocabulary size (V): { cfg .vocab_size } ' )
302+ print (f' - Feed forward dimension (F): { cfg .expanded_model_dim } ' )
302303
303304 # Create random input tokens (simulated token IDs)
304305 rng_key = jax .random .PRNGKey (42 )
305306 input_rng , init_rng = jax .random .split (rng_key )
306307
307308 # Generate random token IDs (integers between 0 and vocab_size-1)
308309 x_BxL = jax .random .randint (
309- input_rng , shape = (B , L ), minval = 0 , maxval = cfg .V , dtype = jnp .int32
310+ input_rng , shape = (B , L ), minval = 0 , maxval = cfg .vocab_size , dtype = jnp .int32
310311 )
311312
312313 # Initialize model parameters
0 commit comments