Skip to content

Commit 413e0a2

Browse files
Ahmed Khaledcopybara-github
authored andcommitted
internal
PiperOrigin-RevId: 876604747
1 parent 9102894 commit 413e0a2

1 file changed

Lines changed: 36 additions & 15 deletions

File tree

init2winit/dataset_lib/ogbg_molpcba.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
import functools
3939
import itertools
40+
import resource
4041
from absl import logging
4142
from init2winit.dataset_lib import data_utils
4243
import jax
@@ -612,18 +613,35 @@ def get_ogbg_molpcba(shuffle_rng, batch_size, eval_batch_size, hps=None):
612613

613614
shuffle_buffer_size = 2**15
614615
shuffle_rng_train, shuffle_rng_eval_train = jax.random.split(shuffle_rng)
616+
def _log_mem(label):
617+
rss_mb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024
618+
logging.info('[ogbg] %s — RSS: %.1f MB', label, rss_mb)
619+
620+
_log_mem('Before loading any splits')
621+
615622
train_ds = _load_dataset(
616623
'train',
617624
should_shuffle=True,
618625
shuffle_seed=shuffle_rng_train,
619626
shuffle_buffer_size=shuffle_buffer_size)
620-
eval_train_ds = _load_dataset(
621-
'train',
622-
should_shuffle=True,
623-
shuffle_seed=shuffle_rng_eval_train,
624-
shuffle_buffer_size=shuffle_buffer_size)
627+
_log_mem('After loading train split')
628+
eval_train_size = min(hps.valid_size, len(train_ds))
629+
# Use a random subset of the training data for eval_train.
630+
# This data is already loaded into memory, so this is cheap.
631+
# We just access it and wrap it in an _InMemoryDataset.
632+
eval_train_rng = np.random.default_rng(int(shuffle_rng_eval_train[0]))
633+
subset_indices = eval_train_rng.choice(
634+
len(train_ds), size=eval_train_size, replace=False
635+
)
636+
eval_train_data = [train_ds.data[i] for i in subset_indices]
637+
eval_train_ds = _InMemoryDataset(eval_train_data, should_shuffle=False)
638+
_log_mem('After creating eval_train subset')
639+
625640
valid_ds = _load_dataset('validation')
641+
_log_mem('After loading validation split')
642+
626643
test_ds = _load_dataset('test')
644+
_log_mem('After loading test split')
627645

628646
max_nodes_multiplier = hps.batch_nodes_multiplier * hps.avg_nodes_per_graph
629647
max_edges_multiplier = hps.batch_edges_multiplier * hps.avg_edges_per_graph
@@ -664,18 +682,21 @@ def train_iterator_fn():
664682
dataset_iter=iter(train_ds), batch_size=per_host_batch_size
665683
)
666684

667-
def _eval_epoch(ds, num_batches=None):
668-
return itertools.islice(
669-
eval_iterator_from_ds(
670-
dataset_iter=iter(ds), batch_size=per_host_eval_batch_size
671-
),
672-
num_batches,
673-
)
685+
def _make_eval_epoch(ds):
686+
687+
def epoch(num_batches=None):
688+
return itertools.islice(
689+
eval_iterator_from_ds(
690+
dataset_iter=iter(ds), batch_size=per_host_eval_batch_size
691+
),
692+
num_batches,
693+
)
674694

675-
eval_train_epoch = functools.partial(_eval_epoch, eval_train_ds)
676-
valid_epoch = functools.partial(_eval_epoch, valid_ds)
677-
test_epoch = functools.partial(_eval_epoch, test_ds)
695+
return data_utils.CachedIteratorFactory(epoch(), split_name='eval')
678696

697+
eval_train_epoch = _make_eval_epoch(eval_train_ds)
698+
valid_epoch = _make_eval_epoch(valid_ds)
699+
test_epoch = _make_eval_epoch(test_ds)
679700
return data_utils.Dataset(
680701
train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch
681702
)

0 commit comments

Comments
 (0)