diff --git a/init2winit/dataset_lib/ogbg_molpcba.py b/init2winit/dataset_lib/ogbg_molpcba.py index 144670a6..66458104 100644 --- a/init2winit/dataset_lib/ogbg_molpcba.py +++ b/init2winit/dataset_lib/ogbg_molpcba.py @@ -68,7 +68,10 @@ add_bidirectional_edges=False, add_virtual_node=False, add_self_loops=False, - )) + train_batching='static', + eval_batching='dynamic', + ) +) METADATA = { 'apply_one_hot_in_loss': False, @@ -199,14 +202,319 @@ def _get_weights_by_nan_and_padding(labels, padding_mask): return replaced_labels, weights -def _get_batch_iterator(dataset_iter, - batch_size, - nodes_per_graph, - edges_per_graph, - add_bidirectional_edges, - add_self_loops, - add_virtual_node, - num_shards=None): +def _ensure_numpy(example): + if hasattr(example.get('edge_feat'), '_numpy'): + return data_utils.tf_to_numpy(example) + return example + + +def _augmented_sizes( + num_nodes, + num_edges, + add_bidirectional_edges, + add_virtual_node, + add_self_loops, +): + """Fix the number of nodes and edges based on the hps.""" + + if add_bidirectional_edges: + num_edges *= 2 + if add_self_loops: + num_edges += num_nodes + if add_virtual_node: + num_edges += num_nodes + num_nodes += 1 + return num_nodes, num_edges + + +def _process_example( + example: dict[str, np.ndarray], + add_bidirectional_edges: bool, + add_virtual_node: bool, + add_self_loops: bool, + nodes_buf: np.ndarray, + edges_buf: np.ndarray, + senders_buf: np.ndarray, + receivers_buf: np.ndarray, + node_offset: int, + edge_offset: int, +) -> tuple[int, int]: + """Processes a raw example and fills it into the pre-allocated buffers. + + Extracts node features, edge features, and connectivity from a single raw + OGBG graph and writes them into the corresponding pre-allocated + numpy buffers at the given offsets. Optionally adds bidirectional edges, + self-loops, and a virtual node depending on the flags. The main goal is to + batch multiple graphs into one large graph for the model. + + Args: + example: A graph dict from ogbg_molpcba with the keys ``edge_feat``, + ``node_feat``, ``edge_index``, and ``num_nodes``. + add_bidirectional_edges: If True, duplicate every edge in the reverse + direction so each original edge becomes bidirectional. + add_virtual_node: If True, append an extra node connected to all original + nodes (edges from every original node to the virtual node). + add_self_loops: If True, add a self-loop edge for every node. + nodes_buf: Pre-allocated numpy array for node features, written in-place. + edges_buf: Pre-allocated numpy array for edge features, written in-place. + senders_buf: Pre-allocated numpy array for sender indices, written in-place. + receivers_buf: Pre-allocated numpy array for receiver indices, written + in-place. + node_offset: Starting index in the node buffers for this example. + edge_offset: Starting index in the edge buffers for this example. + + Returns: + A tuple (num_nodes, num_edges) giving the total number of nodes and edges + written for this example (including any added by the augmentation flags). + """ + + edge_feat = example['edge_feat'] + node_feat = example['node_feat'] + edge_index = example['edge_index'] + num_nodes = int(np.squeeze(example['num_nodes'])) + num_edges = len(edge_index) + + nodes_buf[node_offset : node_offset + num_nodes] = node_feat + + eo = edge_offset + senders_buf[eo : eo + num_edges] = edge_index[:, 0] + node_offset + receivers_buf[eo : eo + num_edges] = edge_index[:, 1] + node_offset + edges_buf[eo : eo + num_edges] = edge_feat + eo += num_edges + + if add_bidirectional_edges: + senders_buf[eo : eo + num_edges] = edge_index[:, 1] + node_offset + receivers_buf[eo : eo + num_edges] = edge_index[:, 0] + node_offset + edges_buf[eo : eo + num_edges] = edge_feat + eo += num_edges + + if add_self_loops: + self_range = np.arange(num_nodes, dtype=np.int32) + node_offset + senders_buf[eo : eo + num_nodes] = self_range + receivers_buf[eo : eo + num_nodes] = self_range + eo += num_nodes + + if add_virtual_node: + vn_senders = np.arange(num_nodes, dtype=np.int32) + node_offset + senders_buf[eo : eo + num_nodes] = vn_senders + receivers_buf[eo : eo + num_nodes] = node_offset + num_nodes + eo += num_nodes + num_nodes += 1 + + return num_nodes, eo - edge_offset + + +def _build_full_batch( + examples: list[dict[str, np.ndarray]], + graphs_per_shard: int, + num_shards: int, + max_nodes_per_shard: int, + max_edges_per_shard: int, + add_bidirectional_edges: bool, + add_virtual_node: bool, + add_self_loops: bool, + node_feat_dim: int, + edge_feat_dim: int, + num_labels: int, +) -> dict[str, jraph.GraphsTuple | np.ndarray]: + """Builds a fully padded batch from raw examples in a single allocation. + + Pre-allocates node, edge, sender, and receiver buffers for the entire batch + and fills them by calling _process_example on each graph sequentially. + Graphs are packed into per-shard regions with a budget check; any graph that + would exceed a shard's node or edge budget is skipped. Each shard graph is + padded to ensure the batch size per shard is the same + + Args: + examples: List of graph dicts from ogbg_molpcba. + graphs_per_shard: Maximum number of real graphs packed into each shard. + num_shards: Number of shards (typically number of devices). + max_nodes_per_shard: Maximum number of nodes allocated per shard. + max_edges_per_shard: Maximum number of edges allocated per shard. + add_bidirectional_edges: If True, duplicate every edge in reverse. + add_virtual_node: If True, add a virtual node connected to all real nodes. + add_self_loops: If True, add a self-loop edge for every node. + node_feat_dim: Dimensionality of node features. + edge_feat_dim: Dimensionality of edge features. + num_labels: Number of label columns per graph. + + Returns: + A dict containing: + inputs: a jraph.GraphsTuple for the entire batch + targets: float32 label array with NaNs replaced by 0 + weights: a per-label mask that is 0 for NaN labels and padding graphs + """ + total_graphs = num_shards * (graphs_per_shard + 1) + total_nodes = num_shards * max_nodes_per_shard + total_edges = num_shards * max_edges_per_shard + + nodes_buf = np.zeros((total_nodes, node_feat_dim), dtype=np.float32) + edges_buf = np.zeros((total_edges, edge_feat_dim), dtype=np.float32) + senders_buf = np.zeros(total_edges, dtype=np.int32) + receivers_buf = np.zeros(total_edges, dtype=np.int32) + n_node = np.zeros(total_graphs, dtype=np.int32) + n_edge = np.zeros(total_graphs, dtype=np.int32) + labels = np.zeros((total_graphs, num_labels), dtype=np.float32) + + global_node_offset = 0 + global_edge_offset = 0 + example_idx = 0 + + for shard_idx in range(num_shards): + graph_base = shard_idx * (graphs_per_shard + 1) + shard_node_start = shard_idx * max_nodes_per_shard + shard_edge_start = shard_idx * max_edges_per_shard + node_budget = max_nodes_per_shard - 1 + edge_budget = max_edges_per_shard + + for local_idx in range(graphs_per_shard): + ex = examples[example_idx] + example_idx += 1 + + nn, ne = _augmented_sizes( + int(np.squeeze(ex['num_nodes'])), + len(ex['edge_index']), + add_bidirectional_edges, + add_virtual_node, + add_self_loops, + ) + nodes_used = global_node_offset - shard_node_start + edges_used = global_edge_offset - shard_edge_start + if nodes_used + nn > node_budget or edges_used + ne > edge_budget: + continue + + nn, ne = _process_example( + ex, + add_bidirectional_edges, + add_virtual_node, + add_self_loops, + nodes_buf, + edges_buf, + senders_buf, + receivers_buf, + global_node_offset, + global_edge_offset, + ) + + n_node[graph_base + local_idx] = nn + n_edge[graph_base + local_idx] = ne + labels[graph_base + local_idx] = ex['labels'] + + global_node_offset += nn + global_edge_offset += ne + + shard_node_end = shard_node_start + max_nodes_per_shard + shard_edge_end = shard_edge_start + max_edges_per_shard + pad_nodes = shard_node_end - global_node_offset + pad_edges = shard_edge_end - global_edge_offset + + pad_graph_idx = graph_base + graphs_per_shard + n_node[pad_graph_idx] = pad_nodes + n_edge[pad_graph_idx] = pad_edges + + global_node_offset = shard_node_end + global_edge_offset = shard_edge_end + + graph = jraph.GraphsTuple( + n_node=n_node, + n_edge=n_edge, + nodes=nodes_buf, + edges=edges_buf, + senders=senders_buf, + receivers=receivers_buf, + globals={}, + ) + + padding_mask = jraph.get_graph_padding_mask(graph) + nan_mask = np.isnan(labels) + replaced_labels = np.where(nan_mask, 0.0, labels) + weights = (1.0 - nan_mask) * padding_mask[:, None] + + return { + 'inputs': graph, + 'targets': replaced_labels, + 'weights': weights, + } + + +def _get_static_batch_iterator( + dataset_iter, + batch_size, + nodes_per_graph, + edges_per_graph, + add_bidirectional_edges, + add_self_loops, + add_virtual_node, + node_feat_dim, + edge_feat_dim, + num_labels, + num_shards=None, +): + """Construct a static batch iterator. + + Static batching just groups together N graphs based on the average number of + nodes and edges per graph. This is in contrast to dynamic batching, which + groups together graphs until the batch is full. Static batching is more + efficient, but it may lead to some graphs being dropped if they don't fit + into the batch. + + Args: + dataset_iter: An iterator over the dataset. + batch_size: The batch size. + nodes_per_graph: The number of nodes per graph. + edges_per_graph: The number of edges per graph. + add_bidirectional_edges: Whether to add bidirectional edges. + add_self_loops: Whether to add self loops. + add_virtual_node: Whether to add a virtual node. + node_feat_dim: The dimension of the node features. + edge_feat_dim: The dimension of the edge features. + num_labels: The number of labels. + num_shards: The number of shards. + + Yields: + A batch of graphs. + """ + if not num_shards: + num_shards = jax.local_device_count() + + graphs_per_shard = int(batch_size / num_shards) + max_nodes_per_shard = int(nodes_per_graph * graphs_per_shard) + 1 + max_edges_per_shard = int(edges_per_graph * graphs_per_shard) + + total_graphs_needed = graphs_per_shard * num_shards + + while True: + examples = [ + _ensure_numpy(ex) + for ex in itertools.islice(dataset_iter, total_graphs_needed) + ] + if len(examples) < total_graphs_needed: + break + yield _build_full_batch( + examples, + graphs_per_shard, + num_shards, + max_nodes_per_shard, + max_edges_per_shard, + add_bidirectional_edges, + add_virtual_node, + add_self_loops, + node_feat_dim, + edge_feat_dim, + num_labels, + ) + + +def _get_dynamic_batch_iterator( + dataset_iter, + batch_size, + nodes_per_graph, + edges_per_graph, + add_bidirectional_edges, + add_self_loops, + add_virtual_node, + num_shards=None, +): """Turns a TFDS per-example iterator into a batched iterator in the init2winit format. Constructs the batch from num_shards smaller batches, so that we can easily @@ -246,6 +554,7 @@ def _get_batch_iterator(dataset_iter, add_bidirectional_edges=add_bidirectional_edges, add_virtual_node=add_virtual_node, add_self_loops=add_self_loops) + jraph_iter = map(to_jraph_partial, dataset_iter) batched_iter = jraph.dynamically_batch(jraph_iter, max_n_nodes + 1, max_n_edges, max_n_graphs + 1) @@ -322,54 +631,54 @@ def get_ogbg_molpcba(shuffle_rng, batch_size, eval_batch_size, hps=None): max_nodes_multiplier = hps.batch_nodes_multiplier * hps.avg_nodes_per_graph max_edges_multiplier = hps.batch_edges_multiplier * hps.avg_edges_per_graph - - if hps.add_bidirectional_edges: - max_edges_multiplier *= 2 - - if hps.add_self_loops: - max_edges_multiplier += max_nodes_multiplier - - if hps.add_virtual_node: - max_edges_multiplier += max_nodes_multiplier - max_nodes_multiplier += 1 - - iterator_from_ds = functools.partial( - _get_batch_iterator, + max_nodes_multiplier, max_edges_multiplier = _augmented_sizes( + max_nodes_multiplier, + max_edges_multiplier, + hps.add_bidirectional_edges, + hps.add_virtual_node, + hps.add_self_loops, + ) + + common_kwargs = dict( nodes_per_graph=int(max_nodes_multiplier), edges_per_graph=int(max_edges_multiplier), add_bidirectional_edges=hps.add_bidirectional_edges, add_virtual_node=hps.add_virtual_node, - add_self_loops=hps.add_self_loops) + add_self_loops=hps.add_self_loops, + ) + static_kwargs = dict( + node_feat_dim=hps.input_node_shape[0], + edge_feat_dim=hps.input_edge_shape[0], + num_labels=hps.output_shape[0], + ) + + def _make_iterator(strategy): + if strategy == 'static': + return functools.partial( + _get_static_batch_iterator, **common_kwargs, **static_kwargs + ) + else: + return functools.partial(_get_dynamic_batch_iterator, **common_kwargs) + + train_iterator_from_ds = _make_iterator(hps.train_batching) + eval_iterator_from_ds = _make_iterator(hps.eval_batching) def train_iterator_fn(): - return iterator_from_ds( + return train_iterator_from_ds( dataset_iter=iter(train_ds), batch_size=per_host_batch_size ) - def eval_train_epoch(num_batches=None): + def _eval_epoch(ds, num_batches=None): return itertools.islice( - iterator_from_ds( - dataset_iter=iter(eval_train_ds), - batch_size=per_host_eval_batch_size, + eval_iterator_from_ds( + dataset_iter=iter(ds), batch_size=per_host_eval_batch_size ), num_batches, ) - def valid_epoch(num_batches=None): - return itertools.islice( - iterator_from_ds( - dataset_iter=iter(valid_ds), batch_size=per_host_eval_batch_size - ), - num_batches, - ) - - def test_epoch(num_batches=None): - return itertools.islice( - iterator_from_ds( - dataset_iter=iter(test_ds), batch_size=per_host_eval_batch_size - ), - num_batches, - ) + eval_train_epoch = functools.partial(_eval_epoch, eval_train_ds) + valid_epoch = functools.partial(_eval_epoch, valid_ds) + test_epoch = functools.partial(_eval_epoch, test_ds) return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch) @@ -415,13 +724,29 @@ def dataset_iterator(): 'num_nodes': tf.constant([num_nodes], dtype=tf.int32), } - batch_iterator = _get_batch_iterator( - dataset_iterator(), - batch_size=hps.batch_size, - nodes_per_graph=hps.batch_nodes_multiplier * hps.avg_nodes_per_graph, - edges_per_graph=hps.batch_edges_multiplier * hps.avg_edges_per_graph, - add_bidirectional_edges=hps.add_bidirectional_edges, - add_virtual_node=hps.add_virtual_node, - add_self_loops=hps.add_self_loops) + batching = getattr(hps, 'train_batching', 'static') + if batching == 'static': + batch_iterator = _get_static_batch_iterator( + dataset_iterator(), + batch_size=hps.batch_size, + nodes_per_graph=hps.batch_nodes_multiplier * hps.avg_nodes_per_graph, + edges_per_graph=hps.batch_edges_multiplier * hps.avg_edges_per_graph, + add_bidirectional_edges=hps.add_bidirectional_edges, + add_virtual_node=hps.add_virtual_node, + add_self_loops=hps.add_self_loops, + node_feat_dim=9, + edge_feat_dim=3, + num_labels=128, + ) + else: + batch_iterator = _get_dynamic_batch_iterator( + dataset_iterator(), + batch_size=hps.batch_size, + nodes_per_graph=hps.batch_nodes_multiplier * hps.avg_nodes_per_graph, + edges_per_graph=hps.batch_edges_multiplier * hps.avg_edges_per_graph, + add_bidirectional_edges=hps.add_bidirectional_edges, + add_virtual_node=hps.add_virtual_node, + add_self_loops=hps.add_self_loops, + ) return next(batch_iterator) diff --git a/init2winit/dataset_lib/test_datasets_multihost.py b/init2winit/dataset_lib/test_datasets_multihost.py index 9de322e0..3e3b209c 100644 --- a/init2winit/dataset_lib/test_datasets_multihost.py +++ b/init2winit/dataset_lib/test_datasets_multihost.py @@ -37,8 +37,8 @@ def ogbg_molpcba_mock(*args, **kwargs): num_examples = 10 num_nodes = 8 num_edges = 8 - node_dim = 3 - edge_dim = 2 + node_dim = 9 + edge_dim = 3 return tf.data.Dataset.from_generator( lambda: ( { diff --git a/init2winit/dataset_lib/test_ogbg_molpcba.py b/init2winit/dataset_lib/test_ogbg_molpcba.py index 659e3765..4e0ca9ce 100644 --- a/init2winit/dataset_lib/test_ogbg_molpcba.py +++ b/init2winit/dataset_lib/test_ogbg_molpcba.py @@ -93,6 +93,7 @@ def _get_dataset(shuffle_seed, additional_hps=None): hps.train_size = 4 hps.valid_size = 4 hps.test_size = 4 + hps.output_shape = (NUM_LABELS,) hps.avg_nodes_per_graph = NODES_SIZE_MULTIPLIER hps.avg_edges_per_graph = EDGES_SIZE_MULTIPLIER hps.batch_nodes_multiplier = 1.0