Skip to content

Commit

Permalink
added test for collate_function
Browse files Browse the repository at this point in the history
  • Loading branch information
Coerulatus committed May 14, 2024
1 parent 8d41451 commit 79cf1b6
Showing 1 changed file with 115 additions and 0 deletions.
115 changes: 115 additions & 0 deletions test/data/test_Dataloaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Test the collate function."""
import hydra
from hydra import compose, initialize
from omegaconf import OmegaConf

import torch

from topobenchmarkx.data.dataloaders import to_data_list, DefaultDataModule

from topobenchmarkx.utils.config_resolvers import (
get_default_transform,
get_monitor_metric,
get_monitor_mode,
infer_in_channels,
)

import rootutils

rootutils.setup_root("./", indicator=".project-root", pythonpath=True)

class TestCollateFunction:
"""Test collate_fn."""

def setup_method(self):
"""Setup the test.
For this test we load the MUTAG dataset.
Parameters
----------
None
"""
OmegaConf.register_new_resolver("get_default_transform", get_default_transform)
OmegaConf.register_new_resolver("get_monitor_metric", get_monitor_metric)
OmegaConf.register_new_resolver("get_monitor_mode", get_monitor_mode)
OmegaConf.register_new_resolver("infer_in_channels", infer_in_channels)
OmegaConf.register_new_resolver(
"parameter_multiplication", lambda x, y: int(int(x) * int(y))
)

initialize(version_base="1.3", config_path="../../configs", job_name="job")
cfg = compose(config_name="train.yaml")

graph_loader = hydra.utils.instantiate(cfg.dataset, _recursive_=False)
datasets = graph_loader.load()
self.batch_size = 2
datamodule = DefaultDataModule(
dataset_train=datasets[0],
dataset_val=datasets[1],
dataset_test=datasets[2],
batch_size=self.batch_size
)
self.val_dataloader = datamodule.val_dataloader()
self.val_dataset = datasets[1]

def test_lift_features(self):
"""Test the collate funciton.
To test the collate function we use the DefaultDataModule class to create a dataloader that uses the collate function. We then first check that the batched data has the expected shape. We then convert the batched data back to a list and check that the data in the list is the same as the original data.
Parameters
----------
None
"""
def check_separation(matrix, n_elems_0_row, n_elems_0_col):
"""Check that the matrix is separated into two parts diagonally concatenated."""
assert torch.all(matrix[:n_elems_0_row, n_elems_0_col:] == 0)
assert torch.all(matrix[n_elems_0_row:, :n_elems_0_col] == 0)


batch = next(iter(self.val_dataloader))
elems = []
for i in range(self.batch_size):
elems.append(self.val_dataset.data_lst[i])

# Check that the batched data has the expected shape
for key in batch.keys():
if key in elems[0].keys():
if 'x_' in key or 'x'==key:
assert batch[key].shape[0] == elems[0][key].shape[0]+elems[1][key].shape[0]
assert batch[key].shape[1] == elems[0][key].shape[1]
elif 'edge_index' in key:
assert batch[key].shape[0] == 2
assert batch[key].shape[1] == elems[0][key].shape[1]+elems[1][key].shape[1]
else:
for i in range(len(batch[key].shape)):
assert batch[key].shape[i] == elems[0][key].shape[i]+elems[1][key].shape[i]
else:
if 'batch_' in key:
i = int(key.split('_')[1])
assert batch[key].shape[0] == elems[0][f'x_{i}'].shape[0]+elems[1][f'x_{i}'].shape[0]

# Check that the batched data is separated correctly
for key in batch.keys():
if 'incidence_' in key:
i = int(key.split('_')[1])
if i==0:
n0_row = 1
else:
n0_row = torch.sum(batch[f'batch_{i-1}']==0)
n0_col = torch.sum(batch[f'batch_{i}']==0)
check_separation(batch[key].to_dense(), n0_row, n0_col)

# Check that going back to a list of data gives the same data
batch_list = to_data_list(batch)
assert len(batch_list) == len(elems)
for i in range(len(batch_list)):
for key in elems[i].keys():
if key in batch_list[i].keys():
if batch_list[i][key].is_sparse:
assert torch.all(batch_list[i][key].coalesce().indices() == elems[i][key].coalesce().indices())
assert torch.allclose(batch_list[i][key].coalesce().values(), elems[i][key].coalesce().values())
assert batch_list[i][key].shape, elems[i][key].shape
else:
assert torch.allclose(batch_list[i][key], elems[i][key])

0 comments on commit 79cf1b6

Please sign in to comment.