Skip to content

Commit

Permalink
Update NeighborhoodLifting to work with new design
Browse files Browse the repository at this point in the history
  • Loading branch information
luisfpereira committed Jan 23, 2025
1 parent 8e8b90b commit 24da9d8
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 380 deletions.
9 changes: 5 additions & 4 deletions test/transforms/liftings/cell/test_NeighborhoodLifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,22 @@

import torch

from modules.data.utils.utils import load_manual_graph
from modules.transforms.liftings.graph2cell.neighborhood_lifting import (
from topobenchmark.data.utils.utils import load_manual_graph
from topobenchmark.transforms.liftings import (
Graph2CellLiftingTransform,
NeighborhoodLifting,
)


class TestCellCyclesLifting:
class TestNeighborhoodLifting:
"""Test the NeighborhoodLifting class."""

def setup_method(self):
# Load the graph
self.data = load_manual_graph()

# Initialise the NeighborhoodLifting class
self.lifting = NeighborhoodLifting()
self.lifting = Graph2CellLiftingTransform(NeighborhoodLifting())

def test_lift_topology(self):
# Test the lift_topology method
Expand Down
Original file line number Diff line number Diff line change
@@ -1,53 +1,62 @@
import torch_geometric
"""This module implements the neighborhood lifting for graphs to cell complexes.
Definition:
* 0-cells: Vertices of the graph.
* 1-cells: Edges of the graph.
* Higher-dimensional cells: Defined based on the neighborhoods of vertices.
A 2-cell is added for each vertex and its immediate neighbors.
Characteristics:
Star-like Structure: Star-like structures centered around a vertex and include all its adjacent vertices.
Flexibility: This approach can generate higher-dimensional cells even in graphs that do not have cycles.
Local Connectivity: The focus is on local connectivity rather than global cycles.
"""

from toponetx.classes import CellComplex

from modules.transforms.liftings.graph2cell.base import Graph2CellLifting
from topobenchmark.transforms.liftings.base import LiftingMap


class NeighborhoodLifting(Graph2CellLifting):
r"""Lifts graphs to cell complexes by identifying the cycles as 2-cells.
class NeighborhoodLifting(LiftingMap):
"""Lifts graphs to cell complexes by identifying the cycles as 2-cells.
Parameters
----------
max_cell_length : int, optional
The maximum length of the cycles to be lifted. Default is None.
**kwargs : optional
Additional arguments for the class.
"""

def __init__(self, max_cell_length=None, **kwargs):
super().__init__(**kwargs)
self.complex_dim = 2
def __init__(self, max_cell_length=None):
super().__init__()
self.max_cell_length = max_cell_length

def lift_topology(self, data: torch_geometric.data.Data) -> dict:
r"""Finds the cycles of a graph and lifts them to 2-cells.
def lift(self, domain):
"""Finds the cycles of a graph and lifts them to 2-cells.
Parameters
----------
data : torch_geometric.data.Data
The input data to be lifted.
domain : nx.Graph
Graph to be lifted.
Returns
-------
dict
The lifted topology.
CellComplex
Lifted cell complex.
"""
graph = domain

G = self._generate_graph_from_data(data)

cell_complex = CellComplex(G)
cell_complex = CellComplex(graph)

vertices = list(G.nodes())
vertices = list(graph.nodes())
for v in vertices:
cell_complex.add_node(v, rank=0)

edges = list(G.edges())
edges = list(graph.edges())
for edge in edges:
cell_complex.add_cell(edge, rank=1)

for v in vertices:
neighbors = list(G.neighbors(v))
neighbors = list(graph.neighbors(v))
if len(neighbors) > 1:
two_cell = [v, *neighbors]
if (
Expand All @@ -58,4 +67,4 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict:
else:
cell_complex.add_cell(two_cell, rank=2)

return self._get_lifted_topology(cell_complex, G)
return cell_complex
354 changes: 0 additions & 354 deletions tutorials/tutorial_neighborhood_lifting.ipynb

This file was deleted.

0 comments on commit 24da9d8

Please sign in to comment.