Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 48 additions & 38 deletions init2winit/dataset_lib/ogbg_molpcba.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@
add_bidirectional_edges=False,
add_virtual_node=False,
add_self_loops=False,
))
)
)

METADATA = {
'apply_one_hot_in_loss': False,
Expand Down Expand Up @@ -199,14 +200,35 @@ def _get_weights_by_nan_and_padding(labels, padding_mask):
return replaced_labels, weights


def _get_batch_iterator(dataset_iter,
batch_size,
nodes_per_graph,
edges_per_graph,
add_bidirectional_edges,
add_self_loops,
add_virtual_node,
num_shards=None):
def _augmented_sizes(
num_nodes,
num_edges,
add_bidirectional_edges,
add_virtual_node,
add_self_loops,
):
"""Fix the number of nodes and edges based on the hps."""

if add_bidirectional_edges:
num_edges *= 2
if add_self_loops:
num_edges += num_nodes
if add_virtual_node:
num_edges += num_nodes
num_nodes += 1
return num_nodes, num_edges


def _get_batch_iterator(
dataset_iter,
batch_size,
nodes_per_graph,
edges_per_graph,
add_bidirectional_edges,
add_self_loops,
add_virtual_node,
num_shards=None,
):
"""Turns a TFDS per-example iterator into a batched iterator in the init2winit format.

Constructs the batch from num_shards smaller batches, so that we can easily
Expand Down Expand Up @@ -246,6 +268,7 @@ def _get_batch_iterator(dataset_iter,
add_bidirectional_edges=add_bidirectional_edges,
add_virtual_node=add_virtual_node,
add_self_loops=add_self_loops)

jraph_iter = map(to_jraph_partial, dataset_iter)
batched_iter = jraph.dynamically_batch(jraph_iter, max_n_nodes + 1,
max_n_edges, max_n_graphs + 1)
Expand Down Expand Up @@ -322,54 +345,40 @@ def get_ogbg_molpcba(shuffle_rng, batch_size, eval_batch_size, hps=None):

max_nodes_multiplier = hps.batch_nodes_multiplier * hps.avg_nodes_per_graph
max_edges_multiplier = hps.batch_edges_multiplier * hps.avg_edges_per_graph

if hps.add_bidirectional_edges:
max_edges_multiplier *= 2

if hps.add_self_loops:
max_edges_multiplier += max_nodes_multiplier

if hps.add_virtual_node:
max_edges_multiplier += max_nodes_multiplier
max_nodes_multiplier += 1
max_nodes_multiplier, max_edges_multiplier = _augmented_sizes(
max_nodes_multiplier,
max_edges_multiplier,
hps.add_bidirectional_edges,
hps.add_virtual_node,
hps.add_self_loops,
)

iterator_from_ds = functools.partial(
_get_batch_iterator,
nodes_per_graph=int(max_nodes_multiplier),
edges_per_graph=int(max_edges_multiplier),
add_bidirectional_edges=hps.add_bidirectional_edges,
add_virtual_node=hps.add_virtual_node,
add_self_loops=hps.add_self_loops)
add_self_loops=hps.add_self_loops,
)

def train_iterator_fn():
return iterator_from_ds(
dataset_iter=iter(train_ds), batch_size=per_host_batch_size
)

def eval_train_epoch(num_batches=None):
def _eval_epoch(ds, num_batches=None):
return itertools.islice(
iterator_from_ds(
dataset_iter=iter(eval_train_ds),
dataset_iter=iter(ds),
batch_size=per_host_eval_batch_size,
),
num_batches,
)

def valid_epoch(num_batches=None):
return itertools.islice(
iterator_from_ds(
dataset_iter=iter(valid_ds), batch_size=per_host_eval_batch_size
),
num_batches,
)

def test_epoch(num_batches=None):
return itertools.islice(
iterator_from_ds(
dataset_iter=iter(test_ds), batch_size=per_host_eval_batch_size
),
num_batches,
)
eval_train_epoch = functools.partial(_eval_epoch, eval_train_ds)
valid_epoch = functools.partial(_eval_epoch, valid_ds)
test_epoch = functools.partial(_eval_epoch, test_ds)

return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch)

Expand Down Expand Up @@ -422,6 +431,7 @@ def dataset_iterator():
edges_per_graph=hps.batch_edges_multiplier * hps.avg_edges_per_graph,
add_bidirectional_edges=hps.add_bidirectional_edges,
add_virtual_node=hps.add_virtual_node,
add_self_loops=hps.add_self_loops)
add_self_loops=hps.add_self_loops,
)

return next(batch_iterator)