From 05bec4b97129225dee80d19fc583d166eceecfc0 Mon Sep 17 00:00:00 2001 From: Ahmed Khaled Date: Fri, 27 Feb 2026 16:18:27 -0800 Subject: [PATCH] internal PiperOrigin-RevId: 876457573 --- init2winit/dataset_lib/ogbg_molpcba.py | 86 ++++++++++++++------------ 1 file changed, 48 insertions(+), 38 deletions(-) diff --git a/init2winit/dataset_lib/ogbg_molpcba.py b/init2winit/dataset_lib/ogbg_molpcba.py index 144670a6..fd300b4c 100644 --- a/init2winit/dataset_lib/ogbg_molpcba.py +++ b/init2winit/dataset_lib/ogbg_molpcba.py @@ -68,7 +68,8 @@ add_bidirectional_edges=False, add_virtual_node=False, add_self_loops=False, - )) + ) +) METADATA = { 'apply_one_hot_in_loss': False, @@ -199,14 +200,35 @@ def _get_weights_by_nan_and_padding(labels, padding_mask): return replaced_labels, weights -def _get_batch_iterator(dataset_iter, - batch_size, - nodes_per_graph, - edges_per_graph, - add_bidirectional_edges, - add_self_loops, - add_virtual_node, - num_shards=None): +def _augmented_sizes( + num_nodes, + num_edges, + add_bidirectional_edges, + add_virtual_node, + add_self_loops, +): + """Fix the number of nodes and edges based on the hps.""" + + if add_bidirectional_edges: + num_edges *= 2 + if add_self_loops: + num_edges += num_nodes + if add_virtual_node: + num_edges += num_nodes + num_nodes += 1 + return num_nodes, num_edges + + +def _get_batch_iterator( + dataset_iter, + batch_size, + nodes_per_graph, + edges_per_graph, + add_bidirectional_edges, + add_self_loops, + add_virtual_node, + num_shards=None, +): """Turns a TFDS per-example iterator into a batched iterator in the init2winit format. Constructs the batch from num_shards smaller batches, so that we can easily @@ -246,6 +268,7 @@ def _get_batch_iterator(dataset_iter, add_bidirectional_edges=add_bidirectional_edges, add_virtual_node=add_virtual_node, add_self_loops=add_self_loops) + jraph_iter = map(to_jraph_partial, dataset_iter) batched_iter = jraph.dynamically_batch(jraph_iter, max_n_nodes + 1, max_n_edges, max_n_graphs + 1) @@ -322,16 +345,13 @@ def get_ogbg_molpcba(shuffle_rng, batch_size, eval_batch_size, hps=None): max_nodes_multiplier = hps.batch_nodes_multiplier * hps.avg_nodes_per_graph max_edges_multiplier = hps.batch_edges_multiplier * hps.avg_edges_per_graph - - if hps.add_bidirectional_edges: - max_edges_multiplier *= 2 - - if hps.add_self_loops: - max_edges_multiplier += max_nodes_multiplier - - if hps.add_virtual_node: - max_edges_multiplier += max_nodes_multiplier - max_nodes_multiplier += 1 + max_nodes_multiplier, max_edges_multiplier = _augmented_sizes( + max_nodes_multiplier, + max_edges_multiplier, + hps.add_bidirectional_edges, + hps.add_virtual_node, + hps.add_self_loops, + ) iterator_from_ds = functools.partial( _get_batch_iterator, @@ -339,37 +359,26 @@ def get_ogbg_molpcba(shuffle_rng, batch_size, eval_batch_size, hps=None): edges_per_graph=int(max_edges_multiplier), add_bidirectional_edges=hps.add_bidirectional_edges, add_virtual_node=hps.add_virtual_node, - add_self_loops=hps.add_self_loops) + add_self_loops=hps.add_self_loops, + ) def train_iterator_fn(): return iterator_from_ds( dataset_iter=iter(train_ds), batch_size=per_host_batch_size ) - def eval_train_epoch(num_batches=None): + def _eval_epoch(ds, num_batches=None): return itertools.islice( iterator_from_ds( - dataset_iter=iter(eval_train_ds), + dataset_iter=iter(ds), batch_size=per_host_eval_batch_size, ), num_batches, ) - def valid_epoch(num_batches=None): - return itertools.islice( - iterator_from_ds( - dataset_iter=iter(valid_ds), batch_size=per_host_eval_batch_size - ), - num_batches, - ) - - def test_epoch(num_batches=None): - return itertools.islice( - iterator_from_ds( - dataset_iter=iter(test_ds), batch_size=per_host_eval_batch_size - ), - num_batches, - ) + eval_train_epoch = functools.partial(_eval_epoch, eval_train_ds) + valid_epoch = functools.partial(_eval_epoch, valid_ds) + test_epoch = functools.partial(_eval_epoch, test_ds) return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch) @@ -422,6 +431,7 @@ def dataset_iterator(): edges_per_graph=hps.batch_edges_multiplier * hps.avg_edges_per_graph, add_bidirectional_edges=hps.add_bidirectional_edges, add_virtual_node=hps.add_virtual_node, - add_self_loops=hps.add_self_loops) + add_self_loops=hps.add_self_loops, + ) return next(batch_iterator)