1
1
from collections import deque , namedtuple
2
2
from io import BytesIO
3
+ from typing import Dict , Literal
3
4
4
5
from csp .impl .wiring .runtime import build_graph
5
6
6
- NODE = namedtuple ("NODE" , ["name" , "label" , "color" , "shape" ])
7
- EDGE = namedtuple ("EDGE" , ["start" , "end" ])
7
+ _KIND = Literal ["output" , "input" , "" ]
8
+ _NODE = namedtuple ("NODE" , ["name" , "label" , "kind" ])
9
+ _EDGE = namedtuple ("EDGE" , ["start" , "end" ])
8
10
11
+ _GRAPHVIZ_COLORMAP : Dict [_KIND , str ] = {"output" : "red" , "input" : "cadetblue1" , "" : "white" }
9
12
10
- def _build_graphviz_graph (graph_func , * args , ** kwargs ):
11
- from graphviz import Digraph
13
+ _GRAPHVIZ_SHAPEMAP : Dict [_KIND , str ] = {"output" : "rarrow" , "input" : "rarrow" , "" : "box" }
14
+
15
+ _DAGRED3_COLORMAP : Dict [_KIND , str ] = {
16
+ "output" : "red" ,
17
+ "input" : "#98f5ff" ,
18
+ "" : "lightgrey" ,
19
+ }
20
+ _DAGRED3_SHAPEMAP : Dict [_KIND , str ] = {"output" : "diamond" , "input" : "diamond" , "" : "rect" }
21
+
22
+ _NOTEBOOK_KIND = Literal ["" , "terminal" , "notebook" ]
23
+
24
+ __all__ = (
25
+ "generate_graph" ,
26
+ "show_graph_pil" ,
27
+ "show_graph_graphviz" ,
28
+ "show_graph_widget" ,
29
+ "show_graph" ,
30
+ )
31
+
32
+
33
+ def _notebook_kind () -> _NOTEBOOK_KIND :
34
+ try :
35
+ from IPython import get_ipython
12
36
37
+ shell = get_ipython ().__class__ .__name__
38
+ if shell == "ZMQInteractiveShell" :
39
+ return "notebook"
40
+ elif shell == "TerminalInteractiveShell" :
41
+ return "terminal"
42
+ else :
43
+ return ""
44
+ except ImportError :
45
+ return ""
46
+ except NameError :
47
+ return ""
48
+
49
+
50
+ def _build_graph_for_viz (graph_func , * args , ** kwargs ):
13
51
graph = build_graph (graph_func , * args , ** kwargs )
14
- digraph = Digraph (strict = True )
15
- digraph .attr (rankdir = "LR" , size = "150,150" )
16
52
17
53
rootnames = set ()
18
54
q = deque ()
@@ -29,56 +65,142 @@ def _build_graphviz_graph(graph_func, *args, **kwargs):
29
65
name = str (id (nodedef ))
30
66
visited .add (nodedef )
31
67
if name in rootnames : # output node
32
- color = "red"
33
- shape = "rarrow"
68
+ kind = "output"
34
69
elif not sum (1 for _ in nodedef .ts_inputs ()): # input node
35
- color = "cadetblue1"
36
- shape = "rarrow"
70
+ kind = "input"
37
71
else :
38
- color = "white"
39
- shape = "box"
72
+ kind = ""
40
73
41
74
label = nodedef .__name__ if hasattr (nodedef , "__name__" ) else type (nodedef ).__name__
42
- nodes .append (NODE (name = name , label = label , color = color , shape = shape ))
75
+ nodes .append (_NODE (name = name , label = label , kind = kind ))
43
76
44
77
for input in nodedef .ts_inputs ():
45
78
if input [1 ].nodedef not in visited :
46
79
q .append (input [1 ].nodedef )
47
- edges .append (EDGE (start = str (id (input [1 ].nodedef )), end = name ))
80
+ edges .append (_EDGE (start = str (id (input [1 ].nodedef )), end = name ))
81
+ return nodes , edges
82
+
83
+
84
+ def _build_graphviz_graph (graph_func , * args , ** kwargs ):
85
+ from graphviz import Digraph
86
+
87
+ nodes , edges = _build_graph_for_viz (graph_func , * args , ** kwargs )
88
+
89
+ digraph = Digraph (strict = True )
90
+ digraph .attr (rankdir = "LR" , size = "150,150" )
48
91
49
92
for node in nodes :
50
93
digraph .node (
51
94
node .name ,
52
95
node .label ,
53
96
style = "filled" ,
54
- fillcolor = node .color ,
55
- shape = node .shape ,
97
+ fillcolor = _GRAPHVIZ_COLORMAP [ node .kind ] ,
98
+ shape = _GRAPHVIZ_SHAPEMAP [ node .kind ] ,
56
99
)
57
100
for edge in edges :
58
101
digraph .edge (edge .start , edge .end )
59
102
60
103
return digraph
61
104
62
105
106
+ def _graphviz_to_buffer (digraph , image_format = "png" ) -> BytesIO :
107
+ from graphviz import ExecutableNotFound
108
+
109
+ digraph .format = image_format
110
+ buffer = BytesIO ()
111
+
112
+ try :
113
+ buffer .write (digraph .pipe ())
114
+ buffer .seek (0 )
115
+ return buffer
116
+ except ExecutableNotFound as exc :
117
+ raise ModuleNotFoundError (
118
+ "Must install graphviz and have `dot` available on your PATH. See https://graphviz.org for installation instructions"
119
+ ) from exc
120
+
121
+
63
122
def generate_graph (graph_func , * args , image_format = "png" , ** kwargs ):
64
123
"""Generate a BytesIO image representation of the given graph"""
65
124
digraph = _build_graphviz_graph (graph_func , * args , ** kwargs )
66
- digraph .format = image_format
67
- buffer = BytesIO ()
68
- buffer .write (digraph .pipe ())
69
- buffer .seek (0 )
70
- return buffer
125
+ return _graphviz_to_buffer (digraph = digraph , image_format = image_format )
71
126
72
127
73
- def show_graph (graph_func , * args , graph_filename = None , ** kwargs ):
128
+ def show_graph_pil (graph_func , * args , ** kwargs ):
129
+ buffer = generate_graph (graph_func , * args , image_format = "png" , ** kwargs )
130
+ try :
131
+ from PIL import Image
132
+ except ImportError :
133
+ raise ModuleNotFoundError (
134
+ "csp requires `pillow` to display images. Install `pillow` with your python package manager, or pass `graph_filename` to generate a file output."
135
+ )
136
+ image = Image .open (buffer )
137
+ image .show ()
138
+
139
+
140
+ def show_graph_graphviz (graph_func , * args , graph_filename = None , interactive = False , ** kwargs ):
141
+ # extract the format of the image
74
142
image_format = graph_filename .split ("." )[- 1 ] if graph_filename else "png"
75
- buffer = generate_graph (graph_func , * args , image_format = image_format , ** kwargs )
76
143
77
- if graph_filename :
78
- with open (graph_filename , "wb" ) as f :
79
- f .write (buffer .read ())
144
+ # Generate graph with graphviz
145
+ digraph = _build_graphviz_graph (graph_func , * args , ** kwargs )
146
+
147
+ # if we're in a notebook, return it directly for rendering
148
+ if interactive :
149
+ return digraph
150
+
151
+ # otherwise output to file
152
+ buffer = _graphviz_to_buffer (digraph = digraph , image_format = image_format )
153
+ with open (graph_filename , "wb" ) as f :
154
+ f .write (buffer .read ())
155
+ return digraph
156
+
157
+
158
+ def show_graph_widget (graph_func , * args , ** kwargs ):
159
+ try :
160
+ import ipydagred3
161
+ except ImportError :
162
+ raise ModuleNotFoundError (
163
+ "csp requires `ipydagred3` to display graph widget. Install `ipydagred3` with your python package manager, or pass `graph_filename` to generate a file output."
164
+ )
165
+
166
+ nodes , edges = _build_graph_for_viz (graph_func = graph_func , * args , ** kwargs )
167
+
168
+ graph = ipydagred3 .Graph (directed = True , attrs = dict (rankdir = "LR" ))
169
+
170
+ for node in nodes :
171
+ graph .addNode (
172
+ ipydagred3 .Node (
173
+ name = node .name ,
174
+ label = node .label ,
175
+ shape = _DAGRED3_SHAPEMAP [node .kind ],
176
+ style = f"fill: { _DAGRED3_COLORMAP [node .kind ]} " ,
177
+ )
178
+ )
179
+ for edge in edges :
180
+ graph .addEdge (edge .start , edge .end )
181
+ return ipydagred3 .DagreD3Widget (graph = graph )
182
+
183
+
184
+ def show_graph (graph_func , * args , graph_filename = None , ** kwargs ):
185
+ # check if we're in jupyter
186
+ if _notebook_kind () == "notebook" :
187
+ _HAVE_INTERACTIVE = True
80
188
else :
81
- from PIL import Image
189
+ _HAVE_INTERACTIVE = False
190
+
191
+ # display graph via pillow or ipydagred3
192
+ if graph_filename in (None , "widget" ):
193
+ if graph_filename == "widget" and not _HAVE_INTERACTIVE :
194
+ raise RuntimeError ("Interactive graph viewer only works in Jupyter." )
195
+
196
+ # render with ipydagred3
197
+ if graph_filename == "widget" :
198
+ return show_graph_widget (graph_func , * args , ** kwargs )
199
+
200
+ # render with pillow
201
+ return show_graph_pil (graph_func , * args , ** kwargs )
82
202
83
- image = Image .open (buffer )
84
- image .show ()
203
+ # TODO we can show graphviz in jupyter without a filename, but preserving existing behavior for now
204
+ return show_graph_graphviz (
205
+ graph_func , * args , graph_filename = graph_filename , interactive = _HAVE_INTERACTIVE , ** kwargs
206
+ )
0 commit comments