Skip to content

Commit bbde48b

Browse files
committed
Add query-key normalization to CausalAttn and Attention classes, including learned scaling factor
1 parent b59afa0 commit bbde48b

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

algoperf/workloads/lm/lm_jax/nanodo_model.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class ModelConfig:
2525
rmsnorm_epsilon: float = 1e-6
2626
use_residual_scaling: bool = True
2727
tie_embeddings: bool = True # Whether to tie input and output embed
28+
qknorm_epsilon: float = 1e-6
2829

2930
dtype: jnp.dtype = jnp.float32
3031
attention_init: nn.initializers.Initializer = nn.initializers.normal(stddev=0.02)
@@ -116,6 +117,7 @@ def setup(self):
116117
cfg = self.cfg
117118
assert cfg.model_dim % cfg.num_heads == 0, f'D {cfg.model_dim} not divisible by H {cfg.num_heads}'
118119
self.Dh = cfg.model_dim // cfg.num_heads
120+
self.eps = cfg.qknorm_epsilon
119121

120122
# Initialize rotary embeddings
121123
self.freqs_cis = init_rope(cfg.model_dim, cfg.seq_len, cfg.num_heads)
@@ -129,10 +131,13 @@ def setup(self):
129131
use_bias=False,
130132
dtype=cfg.dtype,
131133
)
132-
133134
self.multilinear_query = self.multilinear(name='query')
134135
self.multilinear_key = self.multilinear(name='key')
135136
self.multilinear_value = self.multilinear(name='value')
137+
# See Henry et al. (2020) "Query Key Normalization for Transformers"
138+
seq_len = cfg.seq_len
139+
attn_scale0 = jnp.log2(seq_len**2 - seq_len)
140+
self.attn_scale = self.param('attn_scale', nn.initializers.constant(attn_scale0), ())
136141
self.output_projection = nn.DenseGeneral(
137142
features=cfg.model_dim,
138143
name='attn_out_proj',
@@ -153,8 +158,9 @@ def __call__(self, x_BxLxD: jax.Array):
153158
# Apply rotary embeddings to Q and K
154159
q_BxLxHxDh, k_BxLxHxDh = apply_rope(q_BxLxHxDh, k_BxLxHxDh, self.freqs_cis)
155160

156-
# Scale queries
157-
q_BxLxHxDh /= self.Dh**0.5
161+
# Apply QK normalization
162+
q_BxLxHxDh /= jnp.linalg.norm(q_BxLxHxDh, axis=-1, keepdims=True) + self.eps
163+
k_BxLxHxDh /= jnp.linalg.norm(k_BxLxHxDh, axis=-1, keepdims=True) + self.eps
158164

159165
# Compute attention scores
160166
att_BxHxLxL = jnp.einsum('...qhd,...khd->...hqk', q_BxLxHxDh, k_BxLxHxDh)
@@ -166,6 +172,7 @@ def __call__(self, x_BxLxD: jax.Array):
166172
# Apply mask and softmax
167173
_NEG_INF = jnp.finfo(cfg.dtype).min
168174
att_BxHxLxL = jnp.where(mask_1x1xLxL, att_BxHxLxL, _NEG_INF)
175+
att_BxHxLxL = self.attn_scale * att_BxHxLxL # Learned scaling factor for QK norm
169176
att_BxHxLxL = jax.nn.softmax(att_BxHxLxL, axis=-1)
170177
att_BxHxLxL = att_BxHxLxL.astype(cfg.dtype)
171178

algoperf/workloads/lm/lm_pytorch/plainlm_model.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class ModelConfig:
2323
expanded_model_dim: int
2424
multiple_of: int = 256
2525
rmsnorm_epsilon: float = 1e-6
26+
qknorm_epsilon: float = 1e-6
2627
use_residual_scaling: bool = True
2728
tie_embeddings: bool = True
2829

@@ -95,6 +96,12 @@ def __init__(self, cfg: ModelConfig):
9596
nn.init.normal_(w, std=0.02)
9697
nn.init.normal_(self.w_out.weight, std=0.02)
9798

99+
self.eps = cfg.qknorm_epsilon # e.g., 1e-6
100+
seq_len = cfg.seq_len
101+
attn_scale0 = math.log2(seq_len**2 - seq_len)
102+
self.attn_scale = nn.Parameter(torch.tensor(attn_scale0))
103+
104+
98105
def forward(self, x, freqs_cis):
99106
bsz, seqlen, d = x.shape # (bsz, seqlen, d)
100107

@@ -117,10 +124,14 @@ def forward(self, x, freqs_cis):
117124
k = k.transpose(1, 2) # (bsz, nh, seqlen, h_dim)
118125
v = v.transpose(1, 2) # (bsz, nh, seqlen, h_dim)
119126

127+
# Apply QK normalization
128+
q = q / torch.norm(q, dim=-1, keepdim=True) + self.eps
129+
k = k / torch.norm(k, dim=-1, keepdim=True) + self.eps
130+
q *= self.attn_scale
131+
120132
out = F.scaled_dot_product_attention(
121-
q, k, v, is_causal=True
133+
q, k, v, is_causal=True, scale=1.0
122134
) # (bsz, nh, seqlen, h_dim)
123-
124135
out = (
125136
out.transpose(1, 2).contiguous().view(bsz, seqlen, d)
126137
) # (bsz, seqlen, d)

0 commit comments

Comments
 (0)