Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions examples/gemma/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,11 @@
"""Input pipeline for a LM1B dataset."""

import os
import typing
from typing import Any

import tensorflow as tf
import tensorflow_datasets as tfds
import tokenizer
from clu import deterministic_data

if typing.TYPE_CHECKING:
from train import TrainConfig

AUTOTUNE = tf.data.experimental.AUTOTUNE
Features = dict[str, tf.Tensor]
Expand Down Expand Up @@ -324,7 +320,7 @@ def filter_fn(x):


def get_datasets(
config: "TrainConfig",
config: Any,
*,
n_devices: int,
vocab_path: str | None = None,
Expand Down
25 changes: 14 additions & 11 deletions examples/gemma/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,24 @@
that can be easily tested and imported in Colab.
"""

import jax
import tensorflow as tf
import train
from absl import app, flags, logging
from absl import app
from absl import flags
from absl import logging
from clu import platform
import train
import jax
from ml_collections import config_flags
import tensorflow as tf


FLAGS = flags.FLAGS

flags.DEFINE_string('workdir', None, 'Directory to store model data.')
config_flags.DEFINE_config_file(
'config',
'configs/default.py',
'File path to the training hyperparameter configuration.',
lock_config=True,
'config',
'configs/default.py',
'File path to the training hyperparameter configuration.',
lock_config=True,
)
flags.mark_flags_as_required(['workdir'])

Expand All @@ -51,11 +54,11 @@ def main(argv):
# Add a note so that we can tell which task is which JAX host.
# (Depending on the platform task 0 is not guaranteed to be host 0)
platform.work_unit().set_task_status(
f'process_index: {jax.process_index()}, '
f'process_count: {jax.process_count()}'
f'process_index: {jax.process_index()}, '
f'process_count: {jax.process_count()}'
)
platform.work_unit().create_artifact(
platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir'
platform.ArtifactType.DIRECTORY, FLAGS.workdir, 'workdir'
)

train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
Expand Down
5 changes: 3 additions & 2 deletions examples/gemma/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
import tensorflow as tf
import tensorflow_text as tftxt
from absl import logging
from sentencepiece import SentencePieceTrainer, SentencePieceProcessor
from sentencepiece import SentencePieceProcessor
from sentencepiece import SentencePieceTrainer

Features = dict[str, tf.Tensor]

Expand Down Expand Up @@ -190,5 +191,5 @@ def __call__(self, features: Features) -> Features:

def load_sentencepiece_processor(vocab_path: str):
spp = SentencePieceProcessor()
spp.load(vocab_path)
spp.Load(vocab_path)
return spp
Loading
Loading