Skip to content

Commit 7b54e06

Browse files
adamnschFlorentinD
andcommitted
Sample large graphs by default with from_gds
Co-Authored-By: Florentin Dörre <[email protected]>
1 parent f11d1be commit 7b54e06

File tree

2 files changed

+60
-17
lines changed

2 files changed

+60
-17
lines changed

python-wrapper/src/neo4j_viz/gds.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,17 @@
22

33
from itertools import chain
44
from typing import Optional
5+
from uuid import uuid4
56

67
import pandas as pd
78
from graphdatascience import Graph, GraphDataScience
9+
from pandas import Series
810

911
from .pandas import _from_dfs
1012
from .visualization_graph import VisualizationGraph
1113

1214

13-
def _node_dfs(
15+
def _fetch_node_dfs(
1416
gds: GraphDataScience, G: Graph, node_properties: list[str], node_labels: list[str]
1517
) -> dict[str, pd.DataFrame]:
1618
return {
@@ -21,17 +23,17 @@ def _node_dfs(
2123
}
2224

2325

24-
def _rel_df(gds: GraphDataScience, G: Graph) -> pd.DataFrame:
26+
def _fetch_rel_df(gds: GraphDataScience, G: Graph) -> pd.DataFrame:
2527
relationship_properties = G.relationship_properties()
28+
assert isinstance(relationship_properties, Series)
2629

27-
if len(relationship_properties) > 0:
28-
if isinstance(relationship_properties, pd.Series):
29-
relationship_properties_per_type = relationship_properties.tolist()
30-
property_set: set[str] = set()
31-
for props in relationship_properties_per_type:
32-
if props:
33-
property_set.update(props)
30+
relationship_properties_per_type = relationship_properties.tolist()
31+
property_set: set[str] = set()
32+
for props in relationship_properties_per_type:
33+
if props:
34+
property_set.update(props)
3435

36+
if len(property_set) > 0:
3537
return gds.graph.relationshipProperties.stream(
3638
G, relationship_properties=list(property_set), separate_property_columns=True
3739
)
@@ -45,6 +47,7 @@ def from_gds(
4547
size_property: Optional[str] = None,
4648
additional_node_properties: Optional[list[str]] = None,
4749
node_radius_min_max: Optional[tuple[float, float]] = (3, 60),
50+
max_node_count: int = 10_000,
4851
) -> VisualizationGraph:
4952
"""
5053
Create a VisualizationGraph from a GraphDataScience object and a Graph object.
@@ -68,6 +71,9 @@ def from_gds(
6871
node_radius_min_max : tuple[float, float], optional
6972
Minimum and maximum node radius, by default (3, 60).
7073
To avoid tiny or huge nodes in the visualization, the node sizes are scaled to fit in the given range.
74+
max_node_count : int, optional
75+
The maximum number of nodes to fetch from the graph. The graph will be sampled using random walk with restarts
76+
if its node count exceeds this number.
7177
"""
7278
node_properties_from_gds = G.node_properties()
7379
assert isinstance(node_properties_from_gds, pd.Series)
@@ -86,14 +92,40 @@ def from_gds(
8692
node_properties = set()
8793
if additional_node_properties is not None:
8894
node_properties.update(additional_node_properties)
89-
9095
if size_property is not None:
9196
node_properties.add(size_property)
92-
9397
node_properties = list(node_properties)
94-
node_dfs = _node_dfs(gds, G, node_properties, G.node_labels())
98+
99+
node_count = G.node_count()
100+
if node_count > max_node_count:
101+
sampling_ratio = float(max_node_count) / node_count
102+
sample_name = f"neo4j-viz_sample_{uuid4()}"
103+
G_fetched, _ = gds.graph.sample.rwr(sample_name, G, samplingRatio=sampling_ratio, nodeLabelStratification=True)
104+
else:
105+
G_fetched = G
106+
107+
property_name = None
108+
try:
109+
# Since GDS does not allow us to only fetch node IDs, we add the degree property
110+
# as a temporary property to ensure that we have at least one property to fetch
111+
if len(actual_node_properties) == 0:
112+
property_name = f"neo4j-viz_property_{uuid4()}"
113+
gds.degree.mutate(G_fetched, mutateProperty=property_name)
114+
node_properties = [property_name]
115+
116+
node_dfs = _fetch_node_dfs(gds, G_fetched, node_properties, G_fetched.node_labels())
117+
rel_df = _fetch_rel_df(gds, G_fetched)
118+
finally:
119+
if G_fetched.name() != G.name():
120+
G_fetched.drop()
121+
elif property_name is not None:
122+
gds.graph.nodeProperties.drop(G_fetched, node_properties=[property_name])
123+
95124
for df in node_dfs.values():
96125
df.rename(columns={"nodeId": "id"}, inplace=True)
126+
if property_name is not None and property_name in df.columns:
127+
df.drop(columns=[property_name], inplace=True)
128+
rel_df.rename(columns={"sourceNodeId": "source", "targetNodeId": "target"}, inplace=True)
97129

98130
node_props_df = pd.concat(node_dfs.values(), ignore_index=True, axis=0).drop_duplicates()
99131
if size_property is not None:
@@ -114,9 +146,6 @@ def from_gds(
114146
if "caption" not in actual_node_properties:
115147
node_df["caption"] = node_df["labels"].astype(str)
116148

117-
rel_df = _rel_df(gds, G)
118-
rel_df.rename(columns={"sourceNodeId": "source", "targetNodeId": "target"}, inplace=True)
119-
120149
try:
121150
return _from_dfs(node_df, rel_df, node_radius_min_max=node_radius_min_max, rename_properties={"__size": "size"})
122151
except ValueError as e:

python-wrapper/tests/test_gds.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,10 @@ def test_from_gds_mocked(mocker: MockerFixture) -> None:
170170
lambda x: pd.Series({lbl: node_properties for lbl in nodes.keys()}),
171171
)
172172
mocker.patch("graphdatascience.Graph.node_labels", lambda x: list(nodes.keys()))
173+
mocker.patch("graphdatascience.Graph.node_count", lambda x: sum(len(df) for df in nodes.values()))
173174
mocker.patch("graphdatascience.GraphDataScience.__init__", lambda x: None)
174-
mocker.patch("neo4j_viz.gds._node_dfs", return_value=nodes)
175-
mocker.patch("neo4j_viz.gds._rel_df", return_value=rels)
175+
mocker.patch("neo4j_viz.gds._fetch_node_dfs", return_value=nodes)
176+
mocker.patch("neo4j_viz.gds._fetch_rel_df", return_value=rels)
176177

177178
gds = GraphDataScience() # type: ignore[call-arg]
178179
G = Graph() # type: ignore[call-arg]
@@ -244,3 +245,16 @@ def test_from_gds_node_errors(gds: Any) -> None:
244245
additional_node_properties=["component", "size"],
245246
node_radius_min_max=None,
246247
)
248+
249+
250+
@pytest.mark.requires_neo4j_and_gds
251+
def test_from_gds_sample(gds: Any) -> None:
252+
from neo4j_viz.gds import from_gds
253+
254+
with gds.graph.generate("hello", node_count=11_000, average_degree=1) as G:
255+
VG = from_gds(gds, G)
256+
257+
assert len(VG.nodes) >= 9_500
258+
assert len(VG.nodes) <= 10_500
259+
assert len(VG.relationships) >= 9_500
260+
assert len(VG.relationships) <= 10_500

0 commit comments

Comments
 (0)