Skip to content

Commit 9102894

Browse files
Ahmed Khaledcopybara-github
authored andcommitted
internal
PiperOrigin-RevId: 876603098
1 parent 57b228a commit 9102894

2 files changed

Lines changed: 489 additions & 4 deletions

File tree

init2winit/dataset_lib/ogbg_molpcba.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
import itertools
4040
from absl import logging
4141
from init2winit.dataset_lib import data_utils
42-
from init2winit.dataset_lib.data_utils import Dataset
4342
import jax
4443
import jraph
4544
from ml_collections.config_dict import config_dict
@@ -513,7 +512,7 @@ def _get_dynamic_batch_iterator(
513512
add_virtual_node,
514513
num_shards=None,
515514
):
516-
"""Turns a TFDS per-example iterator into a batched iterator in the init2winit format.
515+
"""Turns a TFDS per-example iterator into a batched iterator.
517516
518517
Constructs the batch from num_shards smaller batches, so that we can easily
519518
shard the batch to multiple devices during training. We use
@@ -535,7 +534,6 @@ def _get_dynamic_batch_iterator(
535534
536535
Yields:
537536
Batch in the init2winit format.
538-
539537
"""
540538
if not num_shards:
541539
num_shards = jax.local_device_count()
@@ -678,7 +676,9 @@ def _eval_epoch(ds, num_batches=None):
678676
valid_epoch = functools.partial(_eval_epoch, valid_ds)
679677
test_epoch = functools.partial(_eval_epoch, test_ds)
680678

681-
return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch)
679+
return data_utils.Dataset(
680+
train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch
681+
)
682682

683683

684684
def get_fake_batch(hps):

0 commit comments

Comments
 (0)