Skip to content

Commit

Permalink
fix(sim/graph): improve networkx integration
Browse files Browse the repository at this point in the history
  • Loading branch information
mirkolenz committed Nov 7, 2024
1 parent b8e2289 commit ead9c71
Showing 1 changed file with 27 additions and 30 deletions.
57 changes: 27 additions & 30 deletions cbrkit/sim/graphs/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,7 @@ def from_rustworkx[N, E](g: "rustworkx.PyDiGraph[N, E]") -> Graph[int, N, E, Any
)
edges = immutables.Map(
(edge_id, Edge(edge_id, nodes[source_id], nodes[target_id], edge_data))
for edge_id, (
source_id,
target_id,
edge_data,
) in g.edge_index_map().items()
for edge_id, (source_id, target_id, edge_data) in g.edge_index_map().items()
)

return Graph(nodes, edges, g.attrs)
Expand All @@ -164,43 +160,44 @@ def from_rustworkx[N, E](g: "rustworkx.PyDiGraph[N, E]") -> Graph[int, N, E, Any
try:
import networkx as nx

def to_networkx[N, E](g: Graph[Any, N, E, Any]) -> "nx.DiGraph":
def to_networkx(g: Graph) -> nx.DiGraph:
ng = nx.DiGraph()
# Set graph attributes
ng.graph.update(g.data)
ng.graph = g.data

# Add nodes with their data
for node in g.nodes.values():
ng.add_node(node.key, data=node.data)
ng.add_nodes_from(
(
node.key,
(node.data if isinstance(node.data, Mapping) else {"data": node.data}),
)
for node in g.nodes
)

# Add edges with their data
for edge in g.edges.values():
ng.add_edge(edge.source.key, edge.target.key, key=edge.key, data=edge.data)
ng.add_edges_from(
(
edge.source.key,
edge.target.key,
(
{**edge.data, "key": edge.key}
if isinstance(edge.data, Mapping)
else {"data": edge.data, "key": edge.key}
),
)
for edge in g.edges.values()
)

return ng

def from_networkx[N, E](g: "nx.DiGraph") -> Graph[Any, N, E, Any]:
# Create nodes
def from_networkx(g: nx.DiGraph) -> Graph:
nodes = immutables.Map(
(node_id, Node(node_id, g.nodes[node_id].get("data")))
for node_id in g.nodes
(idx, Node(idx, data)) for idx, data in g.nodes(data=True)
)

# Create edges
edges = immutables.Map(
(
edge_data.get("key", idx),
Edge(
edge_data.get("key", idx),
nodes[source],
nodes[target],
edge_data.get("data"),
),
)
for idx, (source, target, edge_data) in enumerate(g.edges(data=True))
(idx, Edge(idx, nodes[source_id], nodes[target_id], edge_data))
for idx, (source_id, target_id, edge_data) in enumerate(g.edges(data=True))
)

return Graph(nodes, edges, dict(g.graph))
return Graph(nodes, edges, g.graph)

__all__ += ["to_networkx", "from_networkx"]

Expand Down

0 comments on commit ead9c71

Please sign in to comment.