@@ -250,7 +250,7 @@ def loss_fn(params):
250250 if state .dynamic_scale :
251251 # if is_fin == False the gradients contain Inf/NaNs and optimizer state and
252252 # params should be restored (= skip this step).
253- select_fn = functools .partial (jnp .where , is_fin )
253+ select_fn = functools .partial (jnp .where , is_fin ) # pylint: disable=undefined-variable
254254 new_state = new_state .replace (
255255 opt_state = jax .tree_util .tree_map (
256256 select_fn , new_state .opt_state , state .opt_state
@@ -259,7 +259,7 @@ def loss_fn(params):
259259 select_fn , new_state .params , state .params
260260 ),
261261 )
262- metrics ["loss_scale" ] = dynamic_scale .scale * metrics ["denominator" ]
262+ metrics ["loss_scale" ] = dynamic_scale .scale * metrics ["denominator" ] # pylint: disable=undefined-variable
263263
264264 return new_state , metrics
265265
@@ -649,8 +649,8 @@ def decode_tokens(toks):
649649 metrics_sums = jax .tree_util .tree_map (jnp .sum , train_metrics )
650650 denominator = metrics_sums .pop ("denominator" )
651651 summary = jax .tree_util .tree_map (
652- lambda x : x / denominator , metrics_sums
653- ) # pylint: disable=cell-var-from-loop
652+ lambda x : x / denominator , metrics_sums # pylint: disable=cell-var-from-loop
653+ )
654654 summary ["learning_rate" ] = lr
655655 summary = {"train_" + k : v for k , v in summary .items ()}
656656 writer .write_scalars (step , summary )
0 commit comments