-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8d41451
commit 79cf1b6
Showing
1 changed file
with
115 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
"""Test the collate function.""" | ||
import hydra | ||
from hydra import compose, initialize | ||
from omegaconf import OmegaConf | ||
|
||
import torch | ||
|
||
from topobenchmarkx.data.dataloaders import to_data_list, DefaultDataModule | ||
|
||
from topobenchmarkx.utils.config_resolvers import ( | ||
get_default_transform, | ||
get_monitor_metric, | ||
get_monitor_mode, | ||
infer_in_channels, | ||
) | ||
|
||
import rootutils | ||
|
||
rootutils.setup_root("./", indicator=".project-root", pythonpath=True) | ||
|
||
class TestCollateFunction: | ||
"""Test collate_fn.""" | ||
|
||
def setup_method(self): | ||
"""Setup the test. | ||
For this test we load the MUTAG dataset. | ||
Parameters | ||
---------- | ||
None | ||
""" | ||
OmegaConf.register_new_resolver("get_default_transform", get_default_transform) | ||
OmegaConf.register_new_resolver("get_monitor_metric", get_monitor_metric) | ||
OmegaConf.register_new_resolver("get_monitor_mode", get_monitor_mode) | ||
OmegaConf.register_new_resolver("infer_in_channels", infer_in_channels) | ||
OmegaConf.register_new_resolver( | ||
"parameter_multiplication", lambda x, y: int(int(x) * int(y)) | ||
) | ||
|
||
initialize(version_base="1.3", config_path="../../configs", job_name="job") | ||
cfg = compose(config_name="train.yaml") | ||
|
||
graph_loader = hydra.utils.instantiate(cfg.dataset, _recursive_=False) | ||
datasets = graph_loader.load() | ||
self.batch_size = 2 | ||
datamodule = DefaultDataModule( | ||
dataset_train=datasets[0], | ||
dataset_val=datasets[1], | ||
dataset_test=datasets[2], | ||
batch_size=self.batch_size | ||
) | ||
self.val_dataloader = datamodule.val_dataloader() | ||
self.val_dataset = datasets[1] | ||
|
||
def test_lift_features(self): | ||
"""Test the collate funciton. | ||
To test the collate function we use the DefaultDataModule class to create a dataloader that uses the collate function. We then first check that the batched data has the expected shape. We then convert the batched data back to a list and check that the data in the list is the same as the original data. | ||
Parameters | ||
---------- | ||
None | ||
""" | ||
def check_separation(matrix, n_elems_0_row, n_elems_0_col): | ||
"""Check that the matrix is separated into two parts diagonally concatenated.""" | ||
assert torch.all(matrix[:n_elems_0_row, n_elems_0_col:] == 0) | ||
assert torch.all(matrix[n_elems_0_row:, :n_elems_0_col] == 0) | ||
|
||
|
||
batch = next(iter(self.val_dataloader)) | ||
elems = [] | ||
for i in range(self.batch_size): | ||
elems.append(self.val_dataset.data_lst[i]) | ||
|
||
# Check that the batched data has the expected shape | ||
for key in batch.keys(): | ||
if key in elems[0].keys(): | ||
if 'x_' in key or 'x'==key: | ||
assert batch[key].shape[0] == elems[0][key].shape[0]+elems[1][key].shape[0] | ||
assert batch[key].shape[1] == elems[0][key].shape[1] | ||
elif 'edge_index' in key: | ||
assert batch[key].shape[0] == 2 | ||
assert batch[key].shape[1] == elems[0][key].shape[1]+elems[1][key].shape[1] | ||
else: | ||
for i in range(len(batch[key].shape)): | ||
assert batch[key].shape[i] == elems[0][key].shape[i]+elems[1][key].shape[i] | ||
else: | ||
if 'batch_' in key: | ||
i = int(key.split('_')[1]) | ||
assert batch[key].shape[0] == elems[0][f'x_{i}'].shape[0]+elems[1][f'x_{i}'].shape[0] | ||
|
||
# Check that the batched data is separated correctly | ||
for key in batch.keys(): | ||
if 'incidence_' in key: | ||
i = int(key.split('_')[1]) | ||
if i==0: | ||
n0_row = 1 | ||
else: | ||
n0_row = torch.sum(batch[f'batch_{i-1}']==0) | ||
n0_col = torch.sum(batch[f'batch_{i}']==0) | ||
check_separation(batch[key].to_dense(), n0_row, n0_col) | ||
|
||
# Check that going back to a list of data gives the same data | ||
batch_list = to_data_list(batch) | ||
assert len(batch_list) == len(elems) | ||
for i in range(len(batch_list)): | ||
for key in elems[i].keys(): | ||
if key in batch_list[i].keys(): | ||
if batch_list[i][key].is_sparse: | ||
assert torch.all(batch_list[i][key].coalesce().indices() == elems[i][key].coalesce().indices()) | ||
assert torch.allclose(batch_list[i][key].coalesce().values(), elems[i][key].coalesce().values()) | ||
assert batch_list[i][key].shape, elems[i][key].shape | ||
else: | ||
assert torch.allclose(batch_list[i][key], elems[i][key]) |