1414
1515
1616def _fetch_node_dfs (
17- gds : GraphDataScience , G : Graph , node_properties : list [str ], node_labels : list [str ]
17+ gds : GraphDataScience , G : Graph , node_properties_by_label : dict [ str , list [str ] ], node_labels : list [str ]
1818) -> dict [str , pd .DataFrame ]:
1919 return {
2020 lbl : gds .graph .nodeProperties .stream (
21- G , node_properties = node_properties , node_labels = [lbl ], separate_property_columns = True
21+ G , node_properties = node_properties_by_label [ lbl ] , node_labels = [lbl ], separate_property_columns = True
2222 )
2323 for lbl in node_labels
2424 }
@@ -79,24 +79,31 @@ def from_gds(
7979 """
8080 node_properties_from_gds = G .node_properties ()
8181 assert isinstance (node_properties_from_gds , pd .Series )
82- actual_node_properties = list (chain .from_iterable (node_properties_from_gds .to_dict ().values ()))
82+ actual_node_properties = node_properties_from_gds .to_dict ()
83+ all_actual_node_properties = list (chain .from_iterable (actual_node_properties .values ()))
8384
84- if size_property is not None and size_property not in actual_node_properties :
85- raise ValueError (f"There is no node property '{ size_property } ' in graph '{ G .name ()} '" )
85+ if size_property is not None :
86+ if size_property not in all_actual_node_properties :
87+ raise ValueError (f"There is no node property '{ size_property } ' in graph '{ G .name ()} '" )
8688
8789 if additional_node_properties is None :
88- additional_node_properties = actual_node_properties
90+ node_properties_by_label = { k : set ( v ) for k , v in actual_node_properties . items ()}
8991 else :
9092 for prop in additional_node_properties :
91- if prop not in actual_node_properties :
93+ if prop not in all_actual_node_properties :
9294 raise ValueError (f"There is no node property '{ prop } ' in graph '{ G .name ()} '" )
9395
94- node_properties = set ()
95- if additional_node_properties is not None :
96- node_properties .update (additional_node_properties )
96+ node_properties_by_label = {}
97+ for label , props in actual_node_properties .items ():
98+ node_properties_by_label [label ] = {
99+ prop for prop in actual_node_properties [label ] if prop in additional_node_properties
100+ }
101+
97102 if size_property is not None :
98- node_properties .add (size_property )
99- node_properties = list (node_properties )
103+ for label , props in node_properties_by_label .items ():
104+ props .add (size_property )
105+
106+ node_properties_by_label = {k : list (v ) for k , v in node_properties_by_label .items ()}
100107
101108 node_count = G .node_count ()
102109 if node_count > max_node_count :
@@ -112,13 +119,14 @@ def from_gds(
112119 property_name = None
113120 try :
114121 # Since GDS does not allow us to only fetch node IDs, we add the degree property
115- # as a temporary property to ensure that we have at least one property to fetch
116- if len (actual_node_properties ) == 0 :
122+ # as a temporary property to ensure that we have at least one property for each label to fetch
123+ if sum ([ len (props ) == 0 for props in node_properties_by_label . values ()]) > 0 :
117124 property_name = f"neo4j-viz_property_{ uuid4 ()} "
118125 gds .degree .mutate (G_fetched , mutateProperty = property_name )
119- node_properties = [property_name ]
126+ for props in node_properties_by_label .values ():
127+ props .append (property_name )
120128
121- node_dfs = _fetch_node_dfs (gds , G_fetched , node_properties , G_fetched .node_labels ())
129+ node_dfs = _fetch_node_dfs (gds , G_fetched , node_properties_by_label , G_fetched .node_labels ())
122130 if property_name is not None :
123131 for df in node_dfs .values ():
124132 df .drop (columns = [property_name ], inplace = True )
@@ -131,35 +139,35 @@ def from_gds(
131139 gds .graph .nodeProperties .drop (G_fetched , node_properties = [property_name ])
132140
133141 for df in node_dfs .values ():
134- df .rename (columns = {"nodeId" : "id" }, inplace = True )
135142 if property_name is not None and property_name in df .columns :
136143 df .drop (columns = [property_name ], inplace = True )
137- rel_df .rename (columns = {"sourceNodeId" : "source" , "targetNodeId" : "target" }, inplace = True )
138144
139145 node_props_df = pd .concat (node_dfs .values (), ignore_index = True , axis = 0 ).drop_duplicates ()
140146 if size_property is not None :
141- if "size" in actual_node_properties and size_property != "size" :
147+ if "size" in all_actual_node_properties and size_property != "size" :
142148 node_props_df .rename (columns = {"size" : "__size" }, inplace = True )
143149 node_props_df .rename (columns = {size_property : "size" }, inplace = True )
144150
145151 for lbl , df in node_dfs .items ():
146- if "labels" in actual_node_properties :
152+ if "labels" in all_actual_node_properties :
147153 df .rename (columns = {"labels" : "__labels" }, inplace = True )
148154 df ["labels" ] = lbl
149155
150- node_labels_df = pd .concat ([df [["id " , "labels" ]] for df in node_dfs .values ()], ignore_index = True , axis = 0 )
151- node_labels_df = node_labels_df .groupby ("id " ).agg ({"labels" : list })
156+ node_labels_df = pd .concat ([df [["nodeId " , "labels" ]] for df in node_dfs .values ()], ignore_index = True , axis = 0 )
157+ node_labels_df = node_labels_df .groupby ("nodeId " ).agg ({"labels" : list })
152158
153- node_df = node_props_df .merge (node_labels_df , on = "id " )
159+ node_df = node_props_df .merge (node_labels_df , on = "nodeId " )
154160
155- if "caption" not in actual_node_properties :
161+ if "caption" not in all_actual_node_properties :
156162 node_df ["caption" ] = node_df ["labels" ].astype (str )
157163
158164 if "caption" not in rel_df .columns :
159165 rel_df ["caption" ] = rel_df ["relationshipType" ]
160166
161167 try :
162- return _from_dfs (node_df , rel_df , node_radius_min_max = node_radius_min_max , rename_properties = {"__size" : "size" })
168+ return _from_dfs (
169+ node_df , rel_df , node_radius_min_max = node_radius_min_max , rename_properties = {"__size" : "size" }, dropna = True
170+ )
163171 except ValueError as e :
164172 err_msg = str (e )
165173 if "column" in err_msg :
0 commit comments