Skip to content

Commit c29b6d6

Browse files
danielsuoFlax Authors
authored andcommitted
[flax:examples:wmt] Small linter fixes.
PiperOrigin-RevId: 838880356
1 parent 697f4e5 commit c29b6d6

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

examples/wmt/models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
# pytype: disable=wrong-keyword-args
2121
# pytype: disable=attribute-error
2222

23-
from typing import Any, Optional
2423
from collections.abc import Callable
24+
from typing import Any
2525

2626
from flax import linen as nn
2727
from flax import struct
@@ -549,7 +549,8 @@ def decode(
549549

550550
# Make padding attention masks.
551551
if config.decode:
552-
# for fast autoregressive decoding only a special encoder-decoder mask is used
552+
# for fast autoregressive decoding only a special encoder-decoder mask is
553+
# used
553554
decoder_mask = None
554555
encoder_decoder_mask = nn.make_attention_mask(
555556
jnp.ones_like(targets) > 0, inputs > 0, dtype=config.dtype

examples/wmt/train.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ def loss_fn(params):
250250
if state.dynamic_scale:
251251
# if is_fin == False the gradients contain Inf/NaNs and optimizer state and
252252
# params should be restored (= skip this step).
253-
select_fn = functools.partial(jnp.where, is_fin)
253+
select_fn = functools.partial(jnp.where, is_fin) # pylint: disable=undefined-variable
254254
new_state = new_state.replace(
255255
opt_state=jax.tree_util.tree_map(
256256
select_fn, new_state.opt_state, state.opt_state
@@ -259,7 +259,7 @@ def loss_fn(params):
259259
select_fn, new_state.params, state.params
260260
),
261261
)
262-
metrics["loss_scale"] = dynamic_scale.scale * metrics["denominator"]
262+
metrics["loss_scale"] = dynamic_scale.scale * metrics["denominator"] # pylint: disable=undefined-variable
263263

264264
return new_state, metrics
265265

@@ -649,8 +649,8 @@ def decode_tokens(toks):
649649
metrics_sums = jax.tree_util.tree_map(jnp.sum, train_metrics)
650650
denominator = metrics_sums.pop("denominator")
651651
summary = jax.tree_util.tree_map(
652-
lambda x: x / denominator, metrics_sums
653-
) # pylint: disable=cell-var-from-loop
652+
lambda x: x / denominator, metrics_sums # pylint: disable=cell-var-from-loop
653+
)
654654
summary["learning_rate"] = lr
655655
summary = {"train_" + k: v for k, v in summary.items()}
656656
writer.write_scalars(step, summary)

0 commit comments

Comments
 (0)