-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathastutils.py
242 lines (216 loc) · 8.31 KB
/
astutils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
# python AST extensions
# TODO: rely on _ast, not ast
from ast import *
from lispy.runtime import gensym
#-----------------------------------------------------------------------
# Do nodes are a way of wrapping statements in an expression.
# the last child node of a Do node must be an expression.
# the others must by statements.
class Do(expr):
_attributes = () # I don't understand yet what this is needed for
_fields = ("body",)
def __init__(self, body=None):
self.body = body
# a name that doesn't appear in the original code, and so may be
# discarded if unused. note that names that come from the original code
# may not be discarded since evaluating them can have side effects
# (i.e. NameError may be raised).
def DiscardableName(id, ctx):
name = Name(id, ctx)
name.discardable = True
return name
#-----------------------------------------------------------------------
# may be incomplete, but should cover the nodes generated by the compiler
stmt_fields = {
"ClassDef": ("body",),
"ExceptHandler": ("body",),
"Exec": ("body",),
"If": ("body", "orelse"),
"Interactive": ("body",),
"For": ("body", "orelse"),
"FunctionDef": ("body",),
"Module": ("body",),
"Suite": ("body",),
"TryExcept": ("body", "orelse"),
"TryFinally": ("body", "finalbody"),
"While": ("body", "orelse"),
"With": ("body",),
}
def is_stmt_list(node, field):
return field in stmt_fields.get(node.__class__.__name__, ())
expr_types = (boolop, expr, operator, unaryop)
def is_expr_node(node):
if hasattr(node, "ctx") and not isinstance(node.ctx, Load):
return False
return issubclass(node.__class__, expr_types)
class ExtendedNodeVisitor:
def visit(self, node):
"""Visit a node."""
# 1. visit by Node class
method = 'visit_' + node.__class__.__name__
if hasattr(self, method):
return getattr(self, method)(node)
# 2. visit by Node base class
# still not sure if I need this extension
method = 'visit_' + node.__class__.__base__.__name__
if hasattr(self, method):
return getattr(self, method)(node)
# 3. visit expressions
if is_expr_node(node) and hasattr(self, "visit_expression"):
return self.visit_expression(node)
# 4. generic visit
return self.generic_visit(node)
def generic_visit(self, node):
"""Called if no explicit visitor function exists for a node."""
for field, value in iter_fields(node):
if isinstance(value, list):
if is_stmt_list(node, field):
# this is a list of statements
if hasattr(self, "visit_stmt_list"):
self.visit_stmt_list(value)
continue
for item in value:
if isinstance(item, AST):
self.visit(item)
elif isinstance(value, AST):
self.visit(value)
class ExtendedNodeTransformer(ExtendedNodeVisitor):
def generic_visit(self, node):
for field, old_value in iter_fields(node):
old_value = getattr(node, field, None)
if isinstance(old_value, list):
if is_stmt_list(node, field):
# this is a list of statements
if hasattr(self, "visit_stmt_list"):
value = self.visit_stmt_list(old_value)
if value is None:
old_value[:] = []
else:
old_value[:] = value
continue
new_values = []
for value in old_value:
if isinstance(value, AST):
value = self.visit(value)
if value is None:
continue
elif not isinstance(value, AST):
new_values.extend(value)
continue
new_values.append(value)
old_value[:] = new_values
elif isinstance(old_value, AST):
new_node = self.visit(old_value)
if new_node is None:
delattr(node, field)
else:
setattr(node, field, new_node)
return node
#-----------------------------------------------------------------------
class DoNodeCounter(ExtendedNodeVisitor):
def __init__(self):
self.count = 0
def visit_stmt_list(self, nodes):
pass # stop recursion
def visit_Do(self, node):
self.count += 1
self.generic_visit(node)
def count_do_nodes(ast):
counter = DoNodeCounter()
counter.visit(ast)
return counter.count
class StmtNodeUnroller(ExtendedNodeTransformer):
# unrolls a single statement containing Do nodes into a list of
# statements to be injected into an enclosing statement list.
def __init__(self, count):
self.count = count
self.nodes = []
def visit_stmt_list(self, nodes):
if not nodes:
return nodes
# temporarily wrap the statement list in a module node in order
# to unroll potential Do nodes inside it.
tmp = Module(nodes)
unroll_ast(tmp)
return tmp.body
def visit_Do(self, do_node):
self.count -= 1
assert do_node.body, "Do with empty body"
stmts, tail = do_node.body[:-1], do_node.body[-1]
for node in stmts:
self.visit(node)
self.nodes.append(node)
assert is_expr_node(tail), "Do tail must be an expression"
# temporarily wrap the tail in an Expr so it could be replaced
# with a name by visit_expression if necessary.
tmp = Expr(tail)
self.visit(tmp)
return tmp.value
# these are optimizations that avoid the visit_expression treatment
def visit_Num(self, node):
return node
def visit_Str(self, node):
return node
def visit_Tuple(self, node):
self.generic_visit(node)
return node
def visit_List(self, node):
self.generic_visit(node)
return node
def visit_Dict(self, node):
self.generic_visit(node)
return node
# evaluate the expression in advance to preserve evaluation order.
# uses a gensym to keep the result of the evaluation.
def visit_expression(self, node):
if self.count == 0:
return node
self.generic_visit(node)
tmp_name = str(gensym())
self.nodes.append(
Assign([Name(tmp_name, Store())],
node))
return Name(tmp_name, Load())
def unroll_stmt_node(node):
# returns list of unrolled statments for splicing.
# used by Unroller.
count = count_do_nodes(node)
if count:
unroller = StmtNodeUnroller(count)
unroller.visit(node)
return unroller.nodes + [node]
else:
# there are no directly-unrollable Do nodes in the statement.
# however there could be nested statement lists inside the
# statement which may contain unrollable Do nodes. such cases
# are taken care of by StmtNodeUnroller.visit_stmt_list but
# in this case it must be done explicitly.
return [unroll_ast(node)]
class Unroller(ExtendedNodeTransformer):
# unrolls Do nodes into their enclosing statement lists.
def visit_Do(self, ast):
raise AssertionError, "Do node outside of statement list"
def visit_stmt_list(self, old_nodes):
nodes = []
for node in old_nodes:
nodes.extend(unroll_stmt_node(node))
return nodes
def unroll_ast(ast):
# unroll Do nodes into their enclosing statement lists.
unroller = Unroller()
unroller.visit(ast)
return ast
#-----------------------------------------------------------------------
class YieldFinder(ExtendedNodeVisitor):
def visit_FunctionDef(self, node):
# stop recursion, since we are only intersted in yields which
# appear directly in a function body, not in nested functions.
pass
def visit_Yield(self, node):
raise StopIteration
def has_yield(ast):
finder = YieldFinder()
try: finder.visit(ast)
except StopIteration:
return True
return False