From 79cf1b656f35d03ea8acace42fb6508ad5940e06 Mon Sep 17 00:00:00 2001 From: Coerulatus Date: Tue, 14 May 2024 21:48:37 +0000 Subject: [PATCH] added test for collate_function --- test/data/test_Dataloaders.py | 115 ++++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 test/data/test_Dataloaders.py diff --git a/test/data/test_Dataloaders.py b/test/data/test_Dataloaders.py new file mode 100644 index 00000000..c595ed13 --- /dev/null +++ b/test/data/test_Dataloaders.py @@ -0,0 +1,115 @@ +"""Test the collate function.""" +import hydra +from hydra import compose, initialize +from omegaconf import OmegaConf + +import torch + +from topobenchmarkx.data.dataloaders import to_data_list, DefaultDataModule + +from topobenchmarkx.utils.config_resolvers import ( + get_default_transform, + get_monitor_metric, + get_monitor_mode, + infer_in_channels, +) + +import rootutils + +rootutils.setup_root("./", indicator=".project-root", pythonpath=True) + +class TestCollateFunction: + """Test collate_fn.""" + + def setup_method(self): + """Setup the test. + + For this test we load the MUTAG dataset. + + Parameters + ---------- + None + """ + OmegaConf.register_new_resolver("get_default_transform", get_default_transform) + OmegaConf.register_new_resolver("get_monitor_metric", get_monitor_metric) + OmegaConf.register_new_resolver("get_monitor_mode", get_monitor_mode) + OmegaConf.register_new_resolver("infer_in_channels", infer_in_channels) + OmegaConf.register_new_resolver( + "parameter_multiplication", lambda x, y: int(int(x) * int(y)) + ) + + initialize(version_base="1.3", config_path="../../configs", job_name="job") + cfg = compose(config_name="train.yaml") + + graph_loader = hydra.utils.instantiate(cfg.dataset, _recursive_=False) + datasets = graph_loader.load() + self.batch_size = 2 + datamodule = DefaultDataModule( + dataset_train=datasets[0], + dataset_val=datasets[1], + dataset_test=datasets[2], + batch_size=self.batch_size + ) + self.val_dataloader = datamodule.val_dataloader() + self.val_dataset = datasets[1] + + def test_lift_features(self): + """Test the collate funciton. + + To test the collate function we use the DefaultDataModule class to create a dataloader that uses the collate function. We then first check that the batched data has the expected shape. We then convert the batched data back to a list and check that the data in the list is the same as the original data. + + Parameters + ---------- + None + """ + def check_separation(matrix, n_elems_0_row, n_elems_0_col): + """Check that the matrix is separated into two parts diagonally concatenated.""" + assert torch.all(matrix[:n_elems_0_row, n_elems_0_col:] == 0) + assert torch.all(matrix[n_elems_0_row:, :n_elems_0_col] == 0) + + + batch = next(iter(self.val_dataloader)) + elems = [] + for i in range(self.batch_size): + elems.append(self.val_dataset.data_lst[i]) + + # Check that the batched data has the expected shape + for key in batch.keys(): + if key in elems[0].keys(): + if 'x_' in key or 'x'==key: + assert batch[key].shape[0] == elems[0][key].shape[0]+elems[1][key].shape[0] + assert batch[key].shape[1] == elems[0][key].shape[1] + elif 'edge_index' in key: + assert batch[key].shape[0] == 2 + assert batch[key].shape[1] == elems[0][key].shape[1]+elems[1][key].shape[1] + else: + for i in range(len(batch[key].shape)): + assert batch[key].shape[i] == elems[0][key].shape[i]+elems[1][key].shape[i] + else: + if 'batch_' in key: + i = int(key.split('_')[1]) + assert batch[key].shape[0] == elems[0][f'x_{i}'].shape[0]+elems[1][f'x_{i}'].shape[0] + + # Check that the batched data is separated correctly + for key in batch.keys(): + if 'incidence_' in key: + i = int(key.split('_')[1]) + if i==0: + n0_row = 1 + else: + n0_row = torch.sum(batch[f'batch_{i-1}']==0) + n0_col = torch.sum(batch[f'batch_{i}']==0) + check_separation(batch[key].to_dense(), n0_row, n0_col) + + # Check that going back to a list of data gives the same data + batch_list = to_data_list(batch) + assert len(batch_list) == len(elems) + for i in range(len(batch_list)): + for key in elems[i].keys(): + if key in batch_list[i].keys(): + if batch_list[i][key].is_sparse: + assert torch.all(batch_list[i][key].coalesce().indices() == elems[i][key].coalesce().indices()) + assert torch.allclose(batch_list[i][key].coalesce().values(), elems[i][key].coalesce().values()) + assert batch_list[i][key].shape, elems[i][key].shape + else: + assert torch.allclose(batch_list[i][key], elems[i][key]) \ No newline at end of file