Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions init2winit/dataset_lib/ogbg_molpcba.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import itertools
from absl import logging
from init2winit.dataset_lib import data_utils
from init2winit.dataset_lib.data_utils import Dataset
import jax
import jraph
from ml_collections.config_dict import config_dict
Expand Down Expand Up @@ -513,7 +512,7 @@ def _get_dynamic_batch_iterator(
add_virtual_node,
num_shards=None,
):
"""Turns a TFDS per-example iterator into a batched iterator in the init2winit format.
"""Turns a TFDS per-example iterator into a batched iterator.

Constructs the batch from num_shards smaller batches, so that we can easily
shard the batch to multiple devices during training. We use
Expand All @@ -535,7 +534,6 @@ def _get_dynamic_batch_iterator(

Yields:
Batch in the init2winit format.

"""
if not num_shards:
num_shards = jax.local_device_count()
Expand Down Expand Up @@ -678,7 +676,9 @@ def _eval_epoch(ds, num_batches=None):
valid_epoch = functools.partial(_eval_epoch, valid_ds)
test_epoch = functools.partial(_eval_epoch, test_ds)

return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch)
return data_utils.Dataset(
train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch
)


def get_fake_batch(hps):
Expand Down
Loading