|
37 | 37 |
|
38 | 38 | import functools |
39 | 39 | import itertools |
| 40 | +import resource |
40 | 41 | from absl import logging |
41 | 42 | from init2winit.dataset_lib import data_utils |
42 | 43 | import jax |
@@ -612,18 +613,35 @@ def get_ogbg_molpcba(shuffle_rng, batch_size, eval_batch_size, hps=None): |
612 | 613 |
|
613 | 614 | shuffle_buffer_size = 2**15 |
614 | 615 | shuffle_rng_train, shuffle_rng_eval_train = jax.random.split(shuffle_rng) |
| 616 | + def _log_mem(label): |
| 617 | + rss_mb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024 |
| 618 | + logging.info('[ogbg] %s — RSS: %.1f MB', label, rss_mb) |
| 619 | + |
| 620 | + _log_mem('Before loading any splits') |
| 621 | + |
615 | 622 | train_ds = _load_dataset( |
616 | 623 | 'train', |
617 | 624 | should_shuffle=True, |
618 | 625 | shuffle_seed=shuffle_rng_train, |
619 | 626 | shuffle_buffer_size=shuffle_buffer_size) |
620 | | - eval_train_ds = _load_dataset( |
621 | | - 'train', |
622 | | - should_shuffle=True, |
623 | | - shuffle_seed=shuffle_rng_eval_train, |
624 | | - shuffle_buffer_size=shuffle_buffer_size) |
| 627 | + _log_mem('After loading train split') |
| 628 | + eval_train_size = min(hps.valid_size, len(train_ds)) |
| 629 | + # Use a random subset of the training data for eval_train. |
| 630 | + # This data is already loaded into memory, so this is cheap. |
| 631 | + # We just access it and wrap it in an _InMemoryDataset. |
| 632 | + eval_train_rng = np.random.default_rng(int(shuffle_rng_eval_train[0])) |
| 633 | + subset_indices = eval_train_rng.choice( |
| 634 | + len(train_ds), size=eval_train_size, replace=False |
| 635 | + ) |
| 636 | + eval_train_data = [train_ds.data[i] for i in subset_indices] |
| 637 | + eval_train_ds = _InMemoryDataset(eval_train_data, should_shuffle=False) |
| 638 | + _log_mem('After creating eval_train subset') |
| 639 | + |
625 | 640 | valid_ds = _load_dataset('validation') |
| 641 | + _log_mem('After loading validation split') |
| 642 | + |
626 | 643 | test_ds = _load_dataset('test') |
| 644 | + _log_mem('After loading test split') |
627 | 645 |
|
628 | 646 | max_nodes_multiplier = hps.batch_nodes_multiplier * hps.avg_nodes_per_graph |
629 | 647 | max_edges_multiplier = hps.batch_edges_multiplier * hps.avg_edges_per_graph |
@@ -664,18 +682,21 @@ def train_iterator_fn(): |
664 | 682 | dataset_iter=iter(train_ds), batch_size=per_host_batch_size |
665 | 683 | ) |
666 | 684 |
|
667 | | - def _eval_epoch(ds, num_batches=None): |
668 | | - return itertools.islice( |
669 | | - eval_iterator_from_ds( |
670 | | - dataset_iter=iter(ds), batch_size=per_host_eval_batch_size |
671 | | - ), |
672 | | - num_batches, |
673 | | - ) |
| 685 | + def _make_eval_epoch(ds): |
| 686 | + |
| 687 | + def epoch(num_batches=None): |
| 688 | + return itertools.islice( |
| 689 | + eval_iterator_from_ds( |
| 690 | + dataset_iter=iter(ds), batch_size=per_host_eval_batch_size |
| 691 | + ), |
| 692 | + num_batches, |
| 693 | + ) |
674 | 694 |
|
675 | | - eval_train_epoch = functools.partial(_eval_epoch, eval_train_ds) |
676 | | - valid_epoch = functools.partial(_eval_epoch, valid_ds) |
677 | | - test_epoch = functools.partial(_eval_epoch, test_ds) |
| 695 | + return data_utils.CachedIteratorFactory(epoch(), split_name='eval') |
678 | 696 |
|
| 697 | + eval_train_epoch = _make_eval_epoch(eval_train_ds) |
| 698 | + valid_epoch = _make_eval_epoch(valid_ds) |
| 699 | + test_epoch = _make_eval_epoch(test_ds) |
679 | 700 | return data_utils.Dataset( |
680 | 701 | train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch |
681 | 702 | ) |
|
0 commit comments