Skip to content

Commit 16677b3

Browse files
committed
Fixed lm1b_nnx example training script
1 parent 12a29ec commit 16677b3

File tree

5 files changed

+41
-30
lines changed

5 files changed

+41
-30
lines changed

examples/lm1b_nnx/README.md

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Then install Flax + the example dependencies:
5252
git clone --depth=1 --branch=main https://github.com/google/flax
5353
cd flax
5454
pip install -e .
55-
cd examples/lm1b
55+
cd examples/lm1b_nnx
5656
pip install -r requirements.txt
5757
```
5858

@@ -75,9 +75,9 @@ tensorboard --logdir=$HOME/logs
7575
You should expect to get numbers similar to these:
7676

7777

78-
Hardware | config | Training time | Loss | TensorBoard.dev | Workdir
79-
-------- | ------- | ------------- | -------------- | ------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------
80-
TPU v3-8 | default | 13h18m | 3.127 | [2021-08-08](https://tensorboard.dev/experiment/n30WkNOZTJq3RHWD7wNslg/) | [gs://flax_public/examples/lm1b/default](https://console.cloud.google.com/storage/browser/flax_public/examples/lm1b/default)
78+
Hardware | config | Training time | Loss | Workdir
79+
-------- | ------- | ------------- | -------------- | --------------------------------------------------------------------------------------------------------------------------
80+
TPU v3-8 | default | 13h18m | 3.127 | [gs://flax_public/examples/lm1b/default](https://console.cloud.google.com/storage/browser/flax_public/examples/lm1b/default)
8181

8282
### Downloading the LM1B Datasets
8383

@@ -87,6 +87,5 @@ data on a storage bucket, from where it can be loaded directly. Set the
8787
`TFDS_DATA_DIR` to your storage bucket path (`gs://<bucket name>`).
8888

8989
You can download and prepare LM1B datasets using TFDS directly:
90-
`python -m tensorflow_datasets.scripts.download_and_prepare
91-
--datasets=lm1b`
90+
`python -m tensorflow_datasets.scripts.download_and_prepare --datasets=lm1b`
9291

examples/lm1b_nnx/input_pipeline_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,11 @@ def _get_datasets(self):
4848
vocab_path = os.path.join(tempfile.mkdtemp(), 'sentencepiece_model')
4949

5050
# Go two directories up to the root of the flax directory.
51-
flax_root_dir = pathlib.Path(__file__).parents[4]
51+
try:
52+
flax_root_dir = pathlib.Path(__file__).parents[4]
53+
except IndexError:
54+
flax_root_dir = "/"
5255
data_dir = str(flax_root_dir) + '/.tfds/metadata' # pylint: disable=unused-variable
53-
5456
with tfds.testing.mock_data(num_examples=128, data_dir=data_dir):
5557
train_ds, eval_ds, predict_ds, _ = input_pipeline.get_datasets(
5658
n_devices=2, config=config, vocab_path=vocab_path

examples/lm1b_nnx/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
'File path to the training hyperparameter configuration.',
3535
lock_config=True,
3636
)
37-
flags.mark_flags_as_required(['config', 'workdir'])
37+
flags.mark_flags_as_required(['workdir'])
3838

3939

4040
def main(argv):

examples/lm1b_nnx/train.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
from jax.sharding import PartitionSpec as P
4242
from utils import HasCache, TrainState
4343

44-
from flax import linen as nn
4544
from flax import nnx
4645
from flax.training import checkpoints, common_utils
4746

@@ -115,7 +114,7 @@ def compute_weighted_cross_entropy(
115114
targets, vocab_size, on_value=confidence, off_value=low_confidence
116115
)
117116

118-
loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1)
117+
loss = -jnp.sum(soft_targets * nnx.log_softmax(logits), axis=-1)
119118
loss = loss - normalizing_constant
120119

121120
normalizing_factor = np.prod(targets.shape)
@@ -191,9 +190,9 @@ def train_step(
191190

192191
dropout_rng = jax.random.fold_in(dropout_rng, state.step)
193192

194-
def loss_fn(params):
193+
def loss_fn(params, other_variables):
195194
"""loss function used for training."""
196-
module = nnx.merge(state.graphdef, params)
195+
module = nnx.merge(state.graphdef, params, other_variables)
197196
module.set_attributes(deterministic=False, decode=False)
198197
logits = module(
199198
inputs,
@@ -211,7 +210,7 @@ def loss_fn(params):
211210
step = state.step
212211
lr = learning_rate_fn(step)
213212
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
214-
(_, logits), grads = grad_fn(state.params)
213+
(_, logits), grads = grad_fn(state.params, state.other_variables)
215214
new_state = state.apply_gradients(grads=grads)
216215
metrics = compute_metrics(logits, inputs, weights)
217216
metrics['learning_rate'] = lr
@@ -221,14 +220,15 @@ def loss_fn(params):
221220

222221
def eval_step(
223222
params: nnx.State,
223+
other_variables: nnx.State,
224224
batch,
225225
graphdef: nnx.GraphDef[models.TransformerLM],
226226
label_smoothing=0.0,
227227
):
228228
"""Calculate evaluation metrics on a batch."""
229229
inputs = batch['inputs']
230230
weights = jnp.where(inputs > 0, 1.0, 0.0)
231-
module = nnx.merge(graphdef, params)
231+
module = nnx.merge(graphdef, params, other_variables)
232232
module.set_attributes(deterministic=True, decode=False)
233233
logits = module(inputs)
234234

@@ -238,6 +238,7 @@ def eval_step(
238238
def predict_step(
239239
inputs,
240240
params: nnx.State,
241+
other_variables: nnx.State,
241242
rngkey: jax.Array,
242243
graphdef: nnx.GraphDef[models.TransformerLM],
243244
eos_id: int,
@@ -247,20 +248,20 @@ def predict_step(
247248
top_k: int,
248249
):
249250
"""Predict language model on a batch."""
250-
module = nnx.merge(graphdef, params)
251+
module = nnx.merge(graphdef, params, other_variables)
251252

252253
# TODO(cgarciae): check how pytorch does this.
253254
for _path, m in module.iter_modules():
254255
if isinstance(m, HasCache):
255256
input_shape = (inputs.shape[0], max_decode_len, config.emb_dim)
256257
m.init_cache(input_shape, dtype=config.dtype)
257258

258-
graphdef, params, cache = nnx.split(module, nnx.Param, nnx.Cache)
259+
graphdef, params, cache, other_variables = nnx.split(module, nnx.Param, nnx.Cache, ...)
259260

260261
def tokens_ids_to_logits(flat_ids, cache: nnx.State):
261262
"""Token slice to logits from decoder model."""
262263
# --> [batch * beam, 1, vocab]
263-
module = nnx.merge(graphdef, params, cache)
264+
module = nnx.merge(graphdef, params, cache, other_variables)
264265
module.set_attributes(deterministic=True, decode=True)
265266
logits = module(flat_ids)
266267
cache = nnx.state(module, nnx.Cache)
@@ -313,7 +314,7 @@ def evaluate(
313314
eval_iter = iter(eval_ds) # pytype: disable=wrong-arg-types
314315
for _, eval_batch in zip(range(num_eval_steps), eval_iter):
315316
eval_batch = jax.tree.map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access
316-
metrics = jit_eval_step(state.params, eval_batch, state.graphdef)
317+
metrics = jit_eval_step(state.params, state.other_variables, eval_batch, state.graphdef)
317318
eval_metrics.append(metrics)
318319
eval_metrics = common_utils.stack_forest(eval_metrics)
319320
eval_metrics_sums = jax.tree.map(jnp.sum, eval_metrics)
@@ -330,6 +331,7 @@ def generate_prediction(
330331
jit_pred_step,
331332
graphdef: nnx.GraphDef[models.TransformerLM],
332333
params: nnx.State,
334+
other_variables: nnx.State,
333335
tokenized_prompts,
334336
eos_id,
335337
inference_rng,
@@ -359,6 +361,7 @@ def generate_prediction(
359361
predicted = jit_pred_step(
360362
pred_batch,
361363
params,
364+
other_variables,
362365
inference_rngs,
363366
graphdef,
364367
eos_id,
@@ -389,6 +392,7 @@ def train_and_evaluate(config: default.Config, workdir: str):
389392
workdir: Working directory for checkpoints and TF summaries. If this
390393
contains checkpoint training will be resumed from the latest checkpoint.
391394
"""
395+
workdir = os.path.abspath(workdir)
392396
tf.io.gfile.makedirs(workdir)
393397

394398
vocab_path = config.vocab_path
@@ -440,18 +444,15 @@ def encode_strings(strs, max_len):
440444
max_len=max(config.max_target_length, config.max_eval_target_length),
441445
dropout_rate=config.dropout_rate,
442446
attention_dropout_rate=config.attention_dropout_rate,
443-
kernel_init=nn.initializers.xavier_uniform(),
444-
bias_init=nn.initializers.normal(stddev=1e-6),
447+
kernel_init=nnx.initializers.xavier_uniform(),
448+
bias_init=nnx.initializers.normal(stddev=1e-6),
445449
axis_rules=config.axis_rules,
446450
)
447451

448452
# Mesh definition
449453
devices_array = utils.create_device_mesh(config)
450454
mesh = Mesh(devices_array, config.mesh_axes)
451455

452-
# print(mesh.shape)
453-
# exit()
454-
455456
start_step = 0
456457
rng = jax.random.PRNGKey(config.seed)
457458
rng, init_rng = jax.random.split(rng)
@@ -498,18 +499,19 @@ def constructor(config: models.TransformerConfig, key: jax.Array):
498499
None,
499500
), # type: ignore
500501
out_shardings=(state_sharding, None), # type: ignore
501-
static_argnums=(2, 3),
502+
static_argnames=("learning_rate_fn", "label_smoothing"),
502503
donate_argnums=0,
503504
)
504505

505506
jit_eval_step = jax.jit(
506507
eval_step,
507508
in_shardings=(
508509
state_sharding.params,
510+
state_sharding.other_variables,
509511
data_sharding,
510512
), # type: ignore
511513
out_shardings=None, # type: ignore
512-
static_argnums=(2, 3),
514+
static_argnames=("graphdef", "label_smoothing"),
513515
)
514516

515517
# Since the inputs and rngkey args for predict_step will be batched,
@@ -520,6 +522,7 @@ def constructor(config: models.TransformerConfig, key: jax.Array):
520522
in_axes=(
521523
0,
522524
jax.tree.map(lambda x: None, state.params),
525+
jax.tree.map(lambda x: None, state.other_variables),
523526
0,
524527
None,
525528
None,
@@ -532,10 +535,11 @@ def constructor(config: models.TransformerConfig, key: jax.Array):
532535
in_shardings=(
533536
data_sharding,
534537
state_sharding.params,
538+
state_sharding.other_variables,
535539
data_sharding,
536540
), # type: ignore
537541
out_shardings=data_sharding, # type: ignore
538-
static_argnums=tuple(range(3, 9)),
542+
static_argnums=tuple(range(4, 10)),
539543
)
540544

541545
# Main Train Loop
@@ -575,7 +579,7 @@ def constructor(config: models.TransformerConfig, key: jax.Array):
575579
h(step)
576580

577581
# Periodic metric handling.
578-
if step % config.eval_every_steps == 0 or is_last_step:
582+
if (step > 0 and step % config.eval_every_steps == 0) or is_last_step:
579583
with report_progress.timed('training_metrics'):
580584
logging.info('Gathering training metrics.')
581585
train_metrics = common_utils.stack_forest(train_metrics)
@@ -609,6 +613,7 @@ def constructor(config: models.TransformerConfig, key: jax.Array):
609613
jit_pred_step=jit_pred_step,
610614
graphdef=state.graphdef,
611615
params=state.params,
616+
other_variables=state.other_variables,
612617
tokenized_prompts=tokenized_prompts,
613618
eos_id=eos_id,
614619
inference_rng=inference_rng,

examples/lm1b_nnx/utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
class TrainState(train_state.TrainState):
3737
graphdef: nnx.GraphDef[TransformerLM]
38+
other_variables: nnx.State
3839

3940

4041
@runtime_checkable
@@ -157,9 +158,13 @@ def setup_initial_state(
157158

158159
with mesh:
159160
model = constructor(config, rng)
160-
graphdef, params = nnx.split(model, nnx.Param)
161+
graphdef, params, other_variables = nnx.split(model, nnx.Param, ...)
161162
state = TrainState.create(
162-
apply_fn=graphdef.apply, params=params, tx=tx, graphdef=graphdef
163+
apply_fn=graphdef.apply,
164+
params=params,
165+
other_variables=other_variables,
166+
tx=tx,
167+
graphdef=graphdef,
163168
)
164169
state = jax.tree.map(_to_array, state)
165170
state_spec = nnx.get_partition_spec(state)

0 commit comments

Comments
 (0)