diff --git a/init2winit/dataset_lib/ogbg_molpcba.py b/init2winit/dataset_lib/ogbg_molpcba.py index 79578273..176a60b4 100644 --- a/init2winit/dataset_lib/ogbg_molpcba.py +++ b/init2winit/dataset_lib/ogbg_molpcba.py @@ -37,7 +37,7 @@ import functools import itertools - +from absl import logging from init2winit.dataset_lib import data_utils from init2winit.dataset_lib.data_utils import Dataset import jax @@ -75,6 +75,28 @@ } +class _InMemoryDataset: + """In-memory dataset that supports shuffling and repeating.""" + + def __init__(self, data, should_shuffle=False, shuffle_seed=None): + self._data = data + self._should_shuffle = should_shuffle + self._shuffle_seed = shuffle_seed + + def __iter__(self): + if self._should_shuffle: + rng = np.random.default_rng(int(self._shuffle_seed)) + while True: + perm = rng.permutation(len(self._data)) + for i in perm: + yield self._data[i] + else: + yield from self._data + + def __len__(self): + return len(self._data) + + def _load_dataset(split, should_shuffle=False, shuffle_seed=None, @@ -96,19 +118,17 @@ def _load_dataset(split, split=split, shuffle_files=should_shuffle, read_config=read_config) + logging.info('Loading in memory dataset...') + dataset = list(tfds.as_numpy(dataset)) - if should_shuffle: - dataset = dataset.shuffle( - seed=dataset_shuffle_seed, buffer_size=shuffle_buffer_size) - dataset = dataset.repeat() - - return dataset + return _InMemoryDataset(dataset, should_shuffle, dataset_shuffle_seed) def _to_jraph(example, add_bidirectional_edges, add_virtual_node, add_self_loops): """Converts an example graph to jraph.GraphsTuple.""" - example = data_utils.tf_to_numpy(example) + if hasattr(example['edge_feat'], '_numpy'): + example = data_utils.tf_to_numpy(example) edge_feat = example['edge_feat'] node_feat = example['node_feat'] edge_index = example['edge_index']