Skip to content

Commit 30aa9f9

Browse files
committed
Allow overriding caption using dedicated column
1 parent 9e906c5 commit 30aa9f9

File tree

2 files changed

+78
-8
lines changed

2 files changed

+78
-8
lines changed

python-wrapper/src/neo4j_viz/snowflake.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,8 @@ def from_snowflake(
318318
319319
By default:
320320
321-
* The caption of the nodes will be set to the table name.
322-
* The caption of the relationships will be set to the table name.
321+
* The caption of the nodes will be read from a 'CAPTION' column or set to the table name if the colum does not exist.
322+
* The caption of the relationships will be read from a 'CAPTION' column or set to the table name if the column does not exist.
323323
* The color of the nodes will be set based on the caption, unless there are more than 12 node tables used.
324324
325325
Otherwise, columns will be included as properties on the nodes and relationships.
@@ -334,9 +334,9 @@ def from_snowflake(
334334
node_dfs, rel_dfs, rel_table_names = _map_tables(session, project_model)
335335

336336
for i, node_df in enumerate(node_dfs):
337-
node_df["table"] = project_model.nodeTables[i].split(".")[-1]
337+
node_df["table"] = node_df.get("CAPTION", project_model.nodeTables[i].split(".")[-1])
338338
for i, rel_df in enumerate(rel_dfs):
339-
rel_df["table"] = rel_table_names[i].split(".")[-1]
339+
rel_df["table"] = rel_df.get("CAPTION", rel_table_names[i].split(".")[-1])
340340

341341
VG = from_dfs(node_dfs, rel_dfs)
342342

python-wrapper/tests/test_snowflake.py

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import pytest
22
from snowflake.snowpark import Session
3-
from snowflake.snowpark.types import LongType, StructField, StructType
3+
from snowflake.snowpark.types import LongType, StringType, StructField, StructType
44

5+
from neo4j_viz import Relationship
56
from neo4j_viz.node import Node
67
from neo4j_viz.snowflake import from_snowflake
78

@@ -60,13 +61,82 @@ def test_from_snowflake(session_with_minimal_graph: Session) -> None:
6061
)
6162

6263
assert VG.nodes == [
63-
Node(id=0, caption="NODES", color="#ffdf81", properties={"SNOWFLAKEID": 6, "table": "NODES"}),
64-
Node(id=1, caption="NODES", color="#ffdf81", properties={"SNOWFLAKEID": 7, "table": "NODES"}),
64+
Node(id=0, caption="NODES", color="#ffdf81", properties={"SNOWFLAKEID": 6}),
65+
Node(id=1, caption="NODES", color="#ffdf81", properties={"SNOWFLAKEID": 7}),
6566
]
6667

6768
assert len(VG.relationships) == 1
6869

6970
assert VG.relationships[0].source == 0
7071
assert VG.relationships[0].target == 1
7172
assert VG.relationships[0].caption == "RELS"
72-
assert VG.relationships[0].properties == {"table": "RELS"}
73+
assert VG.relationships[0].properties == {}
74+
75+
76+
@pytest.fixture
77+
def session_with_custom_caption_graph(session: Session) -> Session:
78+
"""
79+
Create a minimal graph with two nodes and one relationship.
80+
"""
81+
node_df = session.create_dataframe(
82+
data=[
83+
[6, "Alice"],
84+
[7, "Bob"],
85+
[8, "Charlie"],
86+
],
87+
schema=StructType(
88+
[
89+
StructField("NODEID", LongType()),
90+
StructField("CAPTION", StringType()),
91+
]
92+
),
93+
)
94+
node_df.write.save_as_table("NODES")
95+
96+
rel_df = session.create_dataframe(
97+
data=[
98+
[6, 7, "KNOWS"],
99+
[6, 8, "LIKES"],
100+
],
101+
schema=StructType(
102+
[
103+
StructField("SOURCENODEID", LongType()),
104+
StructField("TARGETNODEID", LongType()),
105+
StructField("CAPTION", StringType()),
106+
]
107+
),
108+
)
109+
rel_df.write.save_as_table("RELS")
110+
111+
return session
112+
113+
114+
def test_from_snowflake_with_caption(session_with_custom_caption_graph: Session) -> None:
115+
VG = from_snowflake(
116+
session_with_custom_caption_graph,
117+
{
118+
"nodeTables": ["NODES"],
119+
"relationshipTables": {
120+
"RELS": {
121+
"sourceTable": "NODES",
122+
"targetTable": "NODES",
123+
},
124+
},
125+
},
126+
)
127+
128+
assert VG.nodes == [
129+
Node(id=0, caption="Alice", color="#ffdf81", properties={"CAPTION": "Alice", "SNOWFLAKEID": 6}),
130+
Node(id=1, caption="Bob", color="#c990c0", properties={"CAPTION": "Bob", "SNOWFLAKEID": 7}),
131+
Node(id=2, caption="Charlie", color="#f79767", properties={"CAPTION": "Charlie", "SNOWFLAKEID": 8}),
132+
]
133+
134+
assert len(VG.relationships) == 2
135+
# ignore uuid in comparison
136+
VG.relationships[0].id = 0
137+
VG.relationships[1].id = 1
138+
139+
assert VG.relationships == [
140+
Relationship(id=0, source=0, target=1, caption="KNOWS", properties={"CAPTION": "KNOWS"}),
141+
Relationship(id=1, source=0, target=2, caption="LIKES", properties={"CAPTION": "LIKES"}),
142+
]

0 commit comments

Comments
 (0)