Skip to content

Commit c1d7eae

Browse files
committed
refector show graph, allow for explicit calling of formats, add interactive widget for jupyter
Signed-off-by: Tim Paine <[email protected]>
1 parent 964c77e commit c1d7eae

File tree

8 files changed

+167
-55
lines changed

8 files changed

+167
-55
lines changed

conda/dev-environment-unix.yml

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ dependencies:
1818
- gtest
1919
- httpx>=0.20,<1
2020
- isort>=5,<6
21+
- ipydagred3
2122
- libarrow=15
2223
- librdkafka
2324
- libboost-headers

csp/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
)
2828
from csp.impl.wiring.context import clear_global_context, new_global_context
2929
from csp.math import *
30-
from csp.showgraph import show_graph
30+
from csp.showgraph import *
3131

3232
from . import stats
3333

csp/adapters/symphony.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _client_cert_post(host: str, request_url: str, cert_file: str, key_file: str
4747
response = connection.getresponse()
4848

4949
if response.status != 200:
50-
raise Exception(
50+
raise RuntimeError(
5151
f"Cannot connect for symphony handshake to https://{host}{request_url}: {response.status}:{response.reason}"
5252
)
5353
data = response.read().decode("utf-8")
@@ -275,7 +275,7 @@ def start(self, starttime, endtime):
275275
# get symphony session
276276
resp, datafeed_id = _sync_create_data_feed(self._datafeed_create_url, self._header)
277277
if resp.status_code not in (200, 201, 204):
278-
raise Exception(
278+
raise RuntimeError(
279279
f"ERROR: bad status ({resp.status_code}) from _sync_create_data_feed, cannot start Symphony reader"
280280
)
281281
else:
@@ -285,7 +285,7 @@ def start(self, starttime, endtime):
285285
for room in self._rooms:
286286
room_id = self._room_mapper.get_room_id(room)
287287
if not room_id:
288-
raise Exception(f"ERROR: unable to find Symphony room named {room}")
288+
raise RuntimeError(f"ERROR: unable to find Symphony room named {room}")
289289
self._room_ids.add(room_id)
290290

291291
# start reader thread

csp/dataframe.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import csp.baselib
55
from csp.impl.wiring.edge import Edge
6+
from csp.showgraph import show_graph
67

78
# Lazy declaration below to avoid perspective import
89
RealtimePerspectiveWidget = None
@@ -143,12 +144,7 @@ def _eval(self, starttime: datetime, endtime: datetime = None, realtime: bool =
143144
return csp.run(self._eval_graph, starttime=starttime, endtime=endtime, realtime=realtime)
144145

145146
def show_graph(self):
146-
from PIL import Image
147-
148-
import csp.showgraph
149-
150-
buffer = csp.showgraph.generate_graph(self._eval_graph)
151-
return Image.open(buffer)
147+
show_graph(self._eval_graph, graph_filename=None)
152148

153149
def to_pandas(self, starttime: datetime, endtime: datetime):
154150
import pandas
@@ -222,7 +218,9 @@ def join(self):
222218
self._runner.join()
223219

224220
except ImportError:
225-
raise ImportError("eval_perspective requires perspective-python installed")
221+
raise ModuleNotFoundError(
222+
"eval_perspective requires perspective-python installed. See https://perspective.finos.org for installation instructions."
223+
)
226224

227225
if not realtime:
228226
df = self.to_pandas(starttime, endtime)

csp/impl/pandas_accessor.py

+3-12
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from csp.impl.pandas_ext_type import TsDtype, is_csp_type
1111
from csp.impl.struct import defineNestedStruct
1212
from csp.impl.wiring.edge import Edge
13+
from csp.showgraph import show_graph
1314

1415
T = TypeVar("T")
1516

@@ -375,12 +376,7 @@ def show_graph(self):
375376
"""Show the graph corresponding to the evaluation of all the edges.
376377
For large series, this may be very large, so it may be helpful to call .head() first.
377378
"""
378-
from PIL import Image
379-
380-
import csp.showgraph
381-
382-
buffer = csp.showgraph.generate_graph(self._eval_graph, "png")
383-
return Image.open(buffer)
379+
return show_graph(self._eval_graph, graph_filename=None)
384380

385381

386382
@register_series_accessor("to_csp")
@@ -626,12 +622,7 @@ def show_graph(self):
626622
"""Show the graph corresponding to the evaluation of all the edges.
627623
For large series, this may be very large, so it may be helpful to call .head() first.
628624
"""
629-
from PIL import Image
630-
631-
import csp.showgraph
632-
633-
buffer = csp.showgraph.generate_graph(self._eval_graph, "png")
634-
return Image.open(buffer)
625+
show_graph(self._eval_graph, graph_filename=None)
635626

636627

637628
@register_dataframe_accessor("to_csp")

csp/profiler.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def initialize(self, adapter: GenericPushAdapter, display_graphs: bool):
382382
try:
383383
import matplotlib # noqa: F401
384384
except ImportError:
385-
raise Exception("You must have matplotlib installed to display profiling data graphs.")
385+
raise ModuleNotFoundError("You must have matplotlib installed to display profiling data graphs.")
386386

387387
def get(self):
388388
try:
@@ -478,7 +478,7 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
478478
def init_profiler(self):
479479
if self.http_port is not None:
480480
if not HAS_TORNADO:
481-
raise Exception("You must have tornado installed to use the HTTP profiling extension.")
481+
raise ModuleNotFoundError("You must have tornado installed to use the HTTP profiling extension.")
482482

483483
adapter = GenericPushAdapter(Future)
484484
application = tornado.web.Application(

csp/showgraph.py

+151-29
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,54 @@
11
from collections import deque, namedtuple
22
from io import BytesIO
3+
from typing import Dict, Literal
34

45
from csp.impl.wiring.runtime import build_graph
56

6-
NODE = namedtuple("NODE", ["name", "label", "color", "shape"])
7-
EDGE = namedtuple("EDGE", ["start", "end"])
7+
_KIND = Literal["output", "input", ""]
8+
_NODE = namedtuple("NODE", ["name", "label", "kind"])
9+
_EDGE = namedtuple("EDGE", ["start", "end"])
810

11+
_GRAPHVIZ_COLORMAP: Dict[_KIND, str] = {"output": "red", "input": "cadetblue1", "": "white"}
912

10-
def _build_graphviz_graph(graph_func, *args, **kwargs):
11-
from graphviz import Digraph
13+
_GRAPHVIZ_SHAPEMAP: Dict[_KIND, str] = {"output": "rarrow", "input": "rarrow", "": "box"}
14+
15+
_DAGRED3_COLORMAP: Dict[_KIND, str] = {
16+
"output": "red",
17+
"input": "#98f5ff",
18+
"": "lightgrey",
19+
}
20+
_DAGRED3_SHAPEMAP: Dict[_KIND, str] = {"output": "diamond", "input": "diamond", "": "rect"}
21+
22+
_NOTEBOOK_KIND = Literal["", "terminal", "notebook"]
23+
24+
__all__ = (
25+
"generate_graph",
26+
"show_graph_pil",
27+
"show_graph_graphviz",
28+
"show_graph_widget",
29+
"show_graph",
30+
)
31+
32+
33+
def _notebook_kind() -> _NOTEBOOK_KIND:
34+
try:
35+
from IPython import get_ipython
1236

37+
shell = get_ipython().__class__.__name__
38+
if shell == "ZMQInteractiveShell":
39+
return "notebook"
40+
elif shell == "TerminalInteractiveShell":
41+
return "terminal"
42+
else:
43+
return ""
44+
except ImportError:
45+
return ""
46+
except NameError:
47+
return ""
48+
49+
50+
def _build_graph_for_viz(graph_func, *args, **kwargs):
1351
graph = build_graph(graph_func, *args, **kwargs)
14-
digraph = Digraph(strict=True)
15-
digraph.attr(rankdir="LR", size="150,150")
1652

1753
rootnames = set()
1854
q = deque()
@@ -29,56 +65,142 @@ def _build_graphviz_graph(graph_func, *args, **kwargs):
2965
name = str(id(nodedef))
3066
visited.add(nodedef)
3167
if name in rootnames: # output node
32-
color = "red"
33-
shape = "rarrow"
68+
kind = "output"
3469
elif not sum(1 for _ in nodedef.ts_inputs()): # input node
35-
color = "cadetblue1"
36-
shape = "rarrow"
70+
kind = "input"
3771
else:
38-
color = "white"
39-
shape = "box"
72+
kind = ""
4073

4174
label = nodedef.__name__ if hasattr(nodedef, "__name__") else type(nodedef).__name__
42-
nodes.append(NODE(name=name, label=label, color=color, shape=shape))
75+
nodes.append(_NODE(name=name, label=label, kind=kind))
4376

4477
for input in nodedef.ts_inputs():
4578
if input[1].nodedef not in visited:
4679
q.append(input[1].nodedef)
47-
edges.append(EDGE(start=str(id(input[1].nodedef)), end=name))
80+
edges.append(_EDGE(start=str(id(input[1].nodedef)), end=name))
81+
return nodes, edges
82+
83+
84+
def _build_graphviz_graph(graph_func, *args, **kwargs):
85+
from graphviz import Digraph
86+
87+
nodes, edges = _build_graph_for_viz(graph_func, *args, **kwargs)
88+
89+
digraph = Digraph(strict=True)
90+
digraph.attr(rankdir="LR", size="150,150")
4891

4992
for node in nodes:
5093
digraph.node(
5194
node.name,
5295
node.label,
5396
style="filled",
54-
fillcolor=node.color,
55-
shape=node.shape,
97+
fillcolor=_GRAPHVIZ_COLORMAP[node.kind],
98+
shape=_GRAPHVIZ_SHAPEMAP[node.kind],
5699
)
57100
for edge in edges:
58101
digraph.edge(edge.start, edge.end)
59102

60103
return digraph
61104

62105

106+
def _graphviz_to_buffer(digraph, image_format="png") -> BytesIO:
107+
from graphviz import ExecutableNotFound
108+
109+
digraph.format = image_format
110+
buffer = BytesIO()
111+
112+
try:
113+
buffer.write(digraph.pipe())
114+
buffer.seek(0)
115+
return buffer
116+
except ExecutableNotFound as exc:
117+
raise ModuleNotFoundError(
118+
"Must install graphviz and have `dot` available on your PATH. See https://graphviz.org for installation instructions"
119+
) from exc
120+
121+
63122
def generate_graph(graph_func, *args, image_format="png", **kwargs):
64123
"""Generate a BytesIO image representation of the given graph"""
65124
digraph = _build_graphviz_graph(graph_func, *args, **kwargs)
66-
digraph.format = image_format
67-
buffer = BytesIO()
68-
buffer.write(digraph.pipe())
69-
buffer.seek(0)
70-
return buffer
125+
return _graphviz_to_buffer(digraph=digraph, image_format=image_format)
71126

72127

73-
def show_graph(graph_func, *args, graph_filename=None, **kwargs):
128+
def show_graph_pil(graph_func, *args, **kwargs):
129+
buffer = generate_graph(graph_func, *args, image_format="png", **kwargs)
130+
try:
131+
from PIL import Image
132+
except ImportError:
133+
raise ModuleNotFoundError(
134+
"csp requires `pillow` to display images. Install `pillow` with your python package manager, or pass `graph_filename` to generate a file output."
135+
)
136+
image = Image.open(buffer)
137+
image.show()
138+
139+
140+
def show_graph_graphviz(graph_func, *args, graph_filename=None, interactive=False, **kwargs):
141+
# extract the format of the image
74142
image_format = graph_filename.split(".")[-1] if graph_filename else "png"
75-
buffer = generate_graph(graph_func, *args, image_format=image_format, **kwargs)
76143

77-
if graph_filename:
78-
with open(graph_filename, "wb") as f:
79-
f.write(buffer.read())
144+
# Generate graph with graphviz
145+
digraph = _build_graphviz_graph(graph_func, *args, **kwargs)
146+
147+
# if we're in a notebook, return it directly for rendering
148+
if interactive:
149+
return digraph
150+
151+
# otherwise output to file
152+
buffer = _graphviz_to_buffer(digraph=digraph, image_format=image_format)
153+
with open(graph_filename, "wb") as f:
154+
f.write(buffer.read())
155+
return digraph
156+
157+
158+
def show_graph_widget(graph_func, *args, **kwargs):
159+
try:
160+
import ipydagred3
161+
except ImportError:
162+
raise ModuleNotFoundError(
163+
"csp requires `ipydagred3` to display graph widget. Install `ipydagred3` with your python package manager, or pass `graph_filename` to generate a file output."
164+
)
165+
166+
nodes, edges = _build_graph_for_viz(graph_func=graph_func, *args, **kwargs)
167+
168+
graph = ipydagred3.Graph(directed=True, attrs=dict(rankdir="LR"))
169+
170+
for node in nodes:
171+
graph.addNode(
172+
ipydagred3.Node(
173+
name=node.name,
174+
label=node.label,
175+
shape=_DAGRED3_SHAPEMAP[node.kind],
176+
style=f"fill: {_DAGRED3_COLORMAP[node.kind]}",
177+
)
178+
)
179+
for edge in edges:
180+
graph.addEdge(edge.start, edge.end)
181+
return ipydagred3.DagreD3Widget(graph=graph)
182+
183+
184+
def show_graph(graph_func, *args, graph_filename=None, **kwargs):
185+
# check if we're in jupyter
186+
if _notebook_kind() == "notebook":
187+
_HAVE_INTERACTIVE = True
80188
else:
81-
from PIL import Image
189+
_HAVE_INTERACTIVE = False
190+
191+
# display graph via pillow or ipydagred3
192+
if graph_filename in (None, "widget"):
193+
if graph_filename == "widget" and not _HAVE_INTERACTIVE:
194+
raise RuntimeError("Interactive graph viewer only works in Jupyter.")
195+
196+
# render with ipydagred3
197+
if graph_filename == "widget":
198+
return show_graph_widget(graph_func, *args, **kwargs)
199+
200+
# render with pillow
201+
return show_graph_pil(graph_func, *args, **kwargs)
82202

83-
image = Image.open(buffer)
84-
image.show()
203+
# TODO we can show graphviz in jupyter without a filename, but preserving existing behavior for now
204+
return show_graph_graphviz(
205+
graph_func, *args, graph_filename=graph_filename, interactive=_HAVE_INTERACTIVE, **kwargs
206+
)

vcpkg

Submodule vcpkg updated 1242 files

0 commit comments

Comments
 (0)