@@ -25,6 +25,7 @@ class ModelConfig:
2525  rmsnorm_epsilon : float  =  1e-6 
2626  use_residual_scaling : bool  =  True 
2727  tie_embeddings : bool  =  True   # Whether to tie input and output embed 
28+   qknorm_epsilon : float  =  1e-6 
2829
2930  dtype : jnp .dtype  =  jnp .float32 
3031  attention_init : nn .initializers .Initializer  =  nn .initializers .normal (stddev = 0.02 )
@@ -116,6 +117,7 @@ def setup(self):
116117    cfg  =  self .cfg 
117118    assert  cfg .model_dim  %  cfg .num_heads  ==  0 , f'D { cfg .model_dim } { cfg .num_heads }  
118119    self .Dh  =  cfg .model_dim  //  cfg .num_heads 
120+     self .eps  =  cfg .qknorm_epsilon 
119121
120122    # Initialize rotary embeddings 
121123    self .freqs_cis  =  init_rope (cfg .model_dim , cfg .seq_len , cfg .num_heads )
@@ -129,10 +131,13 @@ def setup(self):
129131      use_bias = False ,
130132      dtype = cfg .dtype ,
131133    )
132- 
133134    self .multilinear_query  =  self .multilinear (name = 'query' )
134135    self .multilinear_key  =  self .multilinear (name = 'key' )
135136    self .multilinear_value  =  self .multilinear (name = 'value' )
137+     # See Henry et al. (2020) "Query Key Normalization for Transformers" 
138+     seq_len  =  cfg .seq_len 
139+     attn_scale0  =  jnp .log2 (seq_len ** 2  -  seq_len ) 
140+     self .attn_scale  =  self .param ('attn_scale' , nn .initializers .constant (attn_scale0 ), ())
136141    self .output_projection  =  nn .DenseGeneral (
137142      features = cfg .model_dim ,
138143      name = 'attn_out_proj' ,
@@ -153,8 +158,9 @@ def __call__(self, x_BxLxD: jax.Array):
153158    # Apply rotary embeddings to Q and K 
154159    q_BxLxHxDh , k_BxLxHxDh  =  apply_rope (q_BxLxHxDh , k_BxLxHxDh , self .freqs_cis )
155160
156-     # Scale queries 
157-     q_BxLxHxDh  /=  self .Dh ** 0.5 
161+     # Apply QK normalization 
162+     q_BxLxHxDh  /=  jnp .linalg .norm (q_BxLxHxDh , axis = - 1 , keepdims = True ) +  self .eps 
163+     k_BxLxHxDh  /=  jnp .linalg .norm (k_BxLxHxDh , axis = - 1 , keepdims = True ) +  self .eps 
158164
159165    # Compute attention scores 
160166    att_BxHxLxL  =  jnp .einsum ('...qhd,...khd->...hqk' , q_BxLxHxDh , k_BxLxHxDh )
@@ -166,6 +172,7 @@ def __call__(self, x_BxLxD: jax.Array):
166172    # Apply mask and softmax 
167173    _NEG_INF  =  jnp .finfo (cfg .dtype ).min 
168174    att_BxHxLxL  =  jnp .where (mask_1x1xLxL , att_BxHxLxL , _NEG_INF )
175+     att_BxHxLxL  =  self .attn_scale  *  att_BxHxLxL  # Learned scaling factor for QK norm 
169176    att_BxHxLxL  =  jax .nn .softmax (att_BxHxLxL , axis = - 1 )
170177    att_BxHxLxL  =  att_BxHxLxL .astype (cfg .dtype )
171178
0 commit comments