Skip to content

Commit 12431d7

Browse files
danielsuoFlax Authors
authored andcommitted
[jax:benchmarks] Add tracing/lowering benchmarks for a few flax examples.
PiperOrigin-RevId: 799217137
1 parent 771eadb commit 12431d7

File tree

8 files changed

+392
-281
lines changed

8 files changed

+392
-281
lines changed

examples/gemma/input_pipeline.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,11 @@
1515
"""Input pipeline for a LM1B dataset."""
1616

1717
import os
18-
import typing
18+
from typing import Any
1919

2020
import tensorflow as tf
2121
import tensorflow_datasets as tfds
2222
import tokenizer
23-
from clu import deterministic_data
24-
25-
if typing.TYPE_CHECKING:
26-
from train import TrainConfig
2723

2824
AUTOTUNE = tf.data.experimental.AUTOTUNE
2925
Features = dict[str, tf.Tensor]
@@ -324,7 +320,7 @@ def filter_fn(x):
324320

325321

326322
def get_datasets(
327-
config: "TrainConfig",
323+
config: Any,
328324
*,
329325
n_devices: int,
330326
vocab_path: str | None = None,

examples/gemma/main.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,24 @@
1818
that can be easily tested and imported in Colab.
1919
"""
2020

21-
import jax
22-
import tensorflow as tf
23-
import train
24-
from absl import app, flags, logging
21+
from absl import app
22+
from absl import flags
23+
from absl import logging
2524
from clu import platform
25+
import train
26+
import jax
2627
from ml_collections import config_flags
28+
import tensorflow as tf
29+
2730

2831
FLAGS = flags.FLAGS
2932

3033
flags.DEFINE_string('workdir', None, 'Directory to store model data.')
3134
config_flags.DEFINE_config_file(
32-
'config',
33-
'configs/default.py',
34-
'File path to the training hyperparameter configuration.',
35-
lock_config=True,
35+
'config',
36+
'configs/default.py',
37+
'File path to the training hyperparameter configuration.',
38+
lock_config=True,
3639
)
3740
flags.mark_flags_as_required(['workdir'])
3841

@@ -51,11 +54,11 @@ def main(argv):
5154
# Add a note so that we can tell which task is which JAX host.
5255
# (Depending on the platform task 0 is not guaranteed to be host 0)
5356
platform.work_unit().set_task_status(
54-
f'process_index: {jax.process_index()}, '
55-
f'process_count: {jax.process_count()}'
57+
f'process_index: {jax.process_index()}, '
58+
f'process_count: {jax.process_count()}'
5659
)
5760
platform.work_unit().create_artifact(
58-
platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir'
61+
platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir'
5962
)
6063

6164
train.train_and_evaluate(FLAGS.config, FLAGS.workdir)

examples/gemma/tokenizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
import tensorflow as tf
2626
import tensorflow_text as tftxt
2727
from absl import logging
28-
from sentencepiece import SentencePieceTrainer, SentencePieceProcessor
28+
from sentencepiece import SentencePieceProcessor
29+
from sentencepiece import SentencePieceTrainer
2930

3031
Features = dict[str, tf.Tensor]
3132

@@ -190,5 +191,5 @@ def __call__(self, features: Features) -> Features:
190191

191192
def load_sentencepiece_processor(vocab_path: str):
192193
spp = SentencePieceProcessor()
193-
spp.load(vocab_path)
194+
spp.Load(vocab_path)
194195
return spp

0 commit comments

Comments
 (0)