From f76957f9901e5f33760c4e2b835a4893df07ccd7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADs=20F=2E=20Pereira?= Date: Thu, 23 Jan 2025 14:54:24 -0800 Subject: [PATCH] Update SimplicialLineLifting to work with new design --- .../simplicial/test_SimplicialLineLifting.py | 29 ++++-- .../liftings/graph2simplicial/line_lifting.py | 98 +++++++++++++------ 2 files changed, 88 insertions(+), 39 deletions(-) diff --git a/test/transforms/liftings/simplicial/test_SimplicialLineLifting.py b/test/transforms/liftings/simplicial/test_SimplicialLineLifting.py index 1a7732f1..e81a1f1c 100644 --- a/test/transforms/liftings/simplicial/test_SimplicialLineLifting.py +++ b/test/transforms/liftings/simplicial/test_SimplicialLineLifting.py @@ -3,7 +3,8 @@ import torch import torch_geometric -from modules.transforms.liftings.graph2simplicial.line_lifting import ( +from topobenchmark.transforms.liftings import ( + Graph2SimplicialLiftingTransform, SimplicialLineLifting, ) @@ -31,9 +32,17 @@ def setup_method(self): # Load the graph self.data = create_test_graph() # load_manual_graph() + lifting_map = SimplicialLineLifting() + # Initialise the SimplicialCliqueLifting class - self.lifting_signed = SimplicialLineLifting(signed=True) - self.lifting_unsigned = SimplicialLineLifting(signed=False) + self.lifting_signed = Graph2SimplicialLiftingTransform( + lifting=lifting_map, + signed=True, + ) + self.lifting_unsigned = Graph2SimplicialLiftingTransform( + lifting=lifting_map, + signed=False, + ) def test_lift_topology(self): """Test the lift_topology method.""" @@ -56,8 +65,11 @@ def test_lift_topology(self): print(lifted_data_signed.incidence_1.to_dense()) assert ( - abs(expected_incidence_1) == lifted_data_unsigned.incidence_1.to_dense() - ).all(), "Something is wrong with unsigned incidence_1 (nodes to edges)." + abs(expected_incidence_1) + == lifted_data_unsigned.incidence_1.to_dense() + ).all(), ( + "Something is wrong with unsigned incidence_1 (nodes to edges)." + ) assert ( expected_incidence_1 == lifted_data_signed.incidence_1.to_dense() ).all(), "Something is wrong with signed incidence_1 (nodes to edges)." @@ -77,8 +89,11 @@ def test_lift_topology(self): ) assert ( - abs(expected_incidence_2) == lifted_data_unsigned.incidence_2.to_dense() + abs(expected_incidence_2) + == lifted_data_unsigned.incidence_2.to_dense() ).all(), "Something is wrong with unsigned incidence_2 (edges to triangles)." assert ( expected_incidence_2 == lifted_data_signed.incidence_2.to_dense() - ).all(), "Something is wrong with signed incidence_2 (edges to triangles)." + ).all(), ( + "Something is wrong with signed incidence_2 (edges to triangles)." + ) diff --git a/topobenchmark/transforms/liftings/graph2simplicial/line_lifting.py b/topobenchmark/transforms/liftings/graph2simplicial/line_lifting.py index 7c178653..b350781a 100644 --- a/topobenchmark/transforms/liftings/graph2simplicial/line_lifting.py +++ b/topobenchmark/transforms/liftings/graph2simplicial/line_lifting.py @@ -1,73 +1,107 @@ +r"""This module implements the line lifting. + +This lifting constructs a simplicial complex called the *Line simplicial complex*. +This is a generalization of the so-called +`Line graph dict: - r"""Lifts the topology of a graph to simplicial domain via line simplicial complex construction. + def lift(self, domain): + r"""Lift the topology of a graph to a simplicial complex. Parameters ---------- - data : torch_geometric.data.Data - The input data to be lifted. + domain : nx.Graph + Graph to be lifted. Returns - ---------- - dict - The lifted topology. + ------- + toponetx.SimplicialComplex + Lifted simplicial complex. """ - - graph = self._generate_graph_from_data(data) + graph = domain line_graph = nx.line_graph(graph) node_features = { - node: ((data.x[node[0], :] + data.x[node[1], :]) / 2) + node: ( + ( + graph.nodes[node[0]]["features"] + + graph.nodes[node[1]]["features"] + ) + / 2 + ) for node in list(line_graph.nodes) } cliques = nx.find_cliques(line_graph) - simplices = list(cliques) # list(map(lambda x: set(x), cliques)) + simplices = list(cliques) # we need to rename simplices here since now vertices are named as pairs - self.rename_vertices_dict = {node: i for i, node in enumerate(line_graph.nodes)} - self.rename_vertices_dict_inverse = { - i: node for node, i in self.rename_vertices_dict.items() + rename_vertices_dict = { + node: i for i, node in enumerate(line_graph.nodes) } - renamed_line_graph = nx.relabel_nodes(line_graph, self.rename_vertices_dict) - renamed_simplices = [ - {self.rename_vertices_dict[vertex] for vertex in simplex} + {rename_vertices_dict[vertex] for vertex in simplex} for simplex in simplices ] renamed_node_features = { - self.rename_vertices_dict[node]: value + rename_vertices_dict[node]: value for node, value in node_features.items() } simplicial_complex = SimplicialComplex(simplices=renamed_simplices) - self.complex_dim = simplicial_complex.dim - simplicial_complex.set_simplex_attributes( renamed_node_features, name="features" ) - return self._get_lifted_topology(simplicial_complex, renamed_line_graph) + return simplicial_complex