Skip to content

Commit 2b25adf

Browse files
Avoid JAX dtype error on gpt3-6b.
PiperOrigin-RevId: 809189775
1 parent 0a364fe commit 2b25adf

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/MaxText/optimizers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ def _update_momentum(update, mu, nu):
123123
# `bias_corrected_dacay` is calculated as calculating `jnp.power(decay, t)` in low
124124
# precision can result in it being rounded to 1 and subsequently a
125125
# "division by zero" error.
126-
beta1_decay = bias_corrected_decay(count, beta1).astype(update)
127-
beta2_decay = bias_corrected_decay(count, beta2).astype(update)
126+
beta1_decay = bias_corrected_decay(count, beta1).astype(update.dtype)
127+
beta2_decay = bias_corrected_decay(count, beta2).astype(update.dtype)
128128
mu = (1.0 - beta1_decay) * update + beta1_decay * mu
129129
nu = (1.0 - beta2_decay) * (update**2) + beta2_decay * nu
130130
return _slot_opt_state(mu=mu, nu=nu)

0 commit comments

Comments
 (0)