diff --git a/topobenchmarkx/data/dataloaders.py b/topobenchmarkx/data/dataloaders.py index bddc1014..2578c4af 100755 --- a/topobenchmarkx/data/dataloaders.py +++ b/topobenchmarkx/data/dataloaders.py @@ -35,11 +35,10 @@ def to_data_list(batch): sparse_data = batch[key].coalesce() batch[key] = SparseTensor.from_torch_sparse_coo_tensor(sparse_data) data_list = batch.to_data_list() - for i, data in enumerate(data_list): - for key in data: + for key, d in data: if isinstance(data[key], SparseTensor): - data_list[i][key] = data[key].to_torch_sparse_coo_tensor() + data_list[i][key] = d.to_torch_sparse_coo_tensor() return data_list