Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 36 additions & 15 deletions init2winit/dataset_lib/ogbg_molpcba.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

import functools
import itertools
import resource
from absl import logging
from init2winit.dataset_lib import data_utils
import jax
Expand Down Expand Up @@ -612,18 +613,35 @@ def get_ogbg_molpcba(shuffle_rng, batch_size, eval_batch_size, hps=None):

shuffle_buffer_size = 2**15
shuffle_rng_train, shuffle_rng_eval_train = jax.random.split(shuffle_rng)
def _log_mem(label):
rss_mb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024
logging.info('[ogbg] %s — RSS: %.1f MB', label, rss_mb)

_log_mem('Before loading any splits')

train_ds = _load_dataset(
'train',
should_shuffle=True,
shuffle_seed=shuffle_rng_train,
shuffle_buffer_size=shuffle_buffer_size)
eval_train_ds = _load_dataset(
'train',
should_shuffle=True,
shuffle_seed=shuffle_rng_eval_train,
shuffle_buffer_size=shuffle_buffer_size)
_log_mem('After loading train split')
eval_train_size = min(hps.valid_size, len(train_ds))
# Use a random subset of the training data for eval_train.
# This data is already loaded into memory, so this is cheap.
# We just access it and wrap it in an _InMemoryDataset.
eval_train_rng = np.random.default_rng(int(shuffle_rng_eval_train[0]))
subset_indices = eval_train_rng.choice(
len(train_ds), size=eval_train_size, replace=False
)
eval_train_data = [train_ds.data[i] for i in subset_indices]
eval_train_ds = _InMemoryDataset(eval_train_data, should_shuffle=False)
_log_mem('After creating eval_train subset')

valid_ds = _load_dataset('validation')
_log_mem('After loading validation split')

test_ds = _load_dataset('test')
_log_mem('After loading test split')

max_nodes_multiplier = hps.batch_nodes_multiplier * hps.avg_nodes_per_graph
max_edges_multiplier = hps.batch_edges_multiplier * hps.avg_edges_per_graph
Expand Down Expand Up @@ -664,18 +682,21 @@ def train_iterator_fn():
dataset_iter=iter(train_ds), batch_size=per_host_batch_size
)

def _eval_epoch(ds, num_batches=None):
return itertools.islice(
eval_iterator_from_ds(
dataset_iter=iter(ds), batch_size=per_host_eval_batch_size
),
num_batches,
)
def _make_eval_epoch(ds):

def epoch(num_batches=None):
return itertools.islice(
eval_iterator_from_ds(
dataset_iter=iter(ds), batch_size=per_host_eval_batch_size
),
num_batches,
)

eval_train_epoch = functools.partial(_eval_epoch, eval_train_ds)
valid_epoch = functools.partial(_eval_epoch, valid_ds)
test_epoch = functools.partial(_eval_epoch, test_ds)
return data_utils.CachedIteratorFactory(epoch(), split_name='eval')

eval_train_epoch = _make_eval_epoch(eval_train_ds)
valid_epoch = _make_eval_epoch(valid_ds)
test_epoch = _make_eval_epoch(test_ds)
return data_utils.Dataset(
train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch
)
Expand Down