Skip to content
72 changes: 51 additions & 21 deletions lib/astunparse/unparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
# We unparse those infinities to INFSTR.
INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1)

def merge_dicts(x, *args):
z = x.copy()
for d in args:
z.update(d)
return z

def interleave(inter, f, seq):
"""Call f on each item in seq, calling inter() in between.
"""
Expand Down Expand Up @@ -56,14 +62,18 @@ def leave(self):
"Decrease the indentation level."
self._indent -= 1

def dispatch(self, tree):
def dispatch(self, tree, **kw):
"Dispatcher function, dispatching tree type T to method _T."
if isinstance(tree, list):
for t in tree:
self.dispatch(t)
return
meth = getattr(self, "_"+tree.__class__.__name__)
meth(tree)
cname = getattr(tree, '_class', tree.__class__.__name__)
meth = getattr(self, "_"+cname, None)
if meth:
meth(tree, **(kw if meth.__name__ in ["_Tuple"] else {}))
else:
self.write('"<' + cname + '>"')


############### Unparsing methods ######################
Expand Down Expand Up @@ -122,7 +132,7 @@ def _Assign(self, t):
def _AugAssign(self, t):
self.fill()
self.dispatch(t.target)
self.write(" "+self.binop[t.op.__class__.__name__]+"= ")
self.write(" "+self.getop(t.op)+"= ")
self.dispatch(t.value)

def _AnnAssign(self, t):
Expand Down Expand Up @@ -275,7 +285,7 @@ def _TryExcept(self, t):
self.leave()

def _TryFinally(self, t):
if len(t.body) == 1 and isinstance(t.body[0], ast.TryExcept):
if len(t.body) == 1 and (isinstance(t.body[0], ast.TryExcept) or getattr(t.body[0], "_class", "") == "TryExcept"):
# try-except-finally
self.dispatch(t.body)
else:
Expand Down Expand Up @@ -393,7 +403,8 @@ def _If(self, t):
self.leave()
# collapse nested ifs into equivalent elifs.
while (t.orelse and len(t.orelse) == 1 and
isinstance(t.orelse[0], ast.If)):
(isinstance(t.orelse[0], ast.If) or
getattr(t.orelse[0], '_class', '') == 'If')):
t = t.orelse[0]
self.fill("elif ")
self.dispatch(t.test)
Expand Down Expand Up @@ -482,13 +493,17 @@ def _FormattedValue(self, t):
# FormattedValue(expr value, int? conversion, expr? format_spec)
self.write("f")
string = StringIO()
self._fstring_JoinedStr(t, string.write)
self._fstring_JoinedStr1(t, string.write)
self.write(repr(string.getvalue()))

def _fstring_JoinedStr1(self, value, write):
cname = getattr(value, '_class', type(value).__name__)
meth = getattr(self, "_fstring_" + cname)
meth(value, write)

def _fstring_JoinedStr(self, t, write):
for value in t.values:
meth = getattr(self, "_fstring_" + type(value).__name__)
meth(value, write)
self._fstring_JoinedStr1(value, write)

def _fstring_Str(self, t, write):
value = t.s.replace("{", "{{").replace("}", "}}")
Expand All @@ -513,7 +528,8 @@ def _fstring_FormattedValue(self, t, write):
write("!{conversion}".format(conversion=conversion))
if t.format_spec:
write(":")
meth = getattr(self, "_fstring_" + type(t.format_spec).__name__)
cname = getattr(t.format_spec, '_class', type(t.format_spec).__name__)
meth = getattr(self, "_fstring_" + cname)
meth(t.format_spec, write)
write("}")

Expand Down Expand Up @@ -648,22 +664,23 @@ def write_item(item):
interleave(lambda: self.write(", "), write_item, zip(t.keys, t.values))
self.write("}")

def _Tuple(self, t):
self.write("(")
def _Tuple(self, t, noparen=False, **kw):
if not noparen: self.write("(")
if len(t.elts) == 1:
elt = t.elts[0]
self.dispatch(elt)
self.write(",")
else:
interleave(lambda: self.write(", "), self.dispatch, t.elts)
self.write(")")
if not noparen: self.write(")")

unop = {"Invert":"~", "Not": "not", "UAdd":"+", "USub":"-"}
def _UnaryOp(self, t):
self.write("(")
self.write(self.unop[t.op.__class__.__name__])
self.write(self.getop(t.op))
self.write(" ")
if six.PY2 and isinstance(t.op, ast.USub) and isinstance(t.operand, ast.Num):
if (six.PY2 and (isinstance(t.op, ast.USub) or t.op == '-')
and (isinstance(t.operand, ast.Num) or getattr(t.operand, '_class', '') == 'Num')):
# If we're applying unary minus to a number, parenthesize the number.
# This is necessary: -2147483648 is different from -(2147483648) on
# a 32-bit machine (the first is an int, the second a long), and
Expand All @@ -682,7 +699,7 @@ def _UnaryOp(self, t):
def _BinOp(self, t):
self.write("(")
self.dispatch(t.left)
self.write(" " + self.binop[t.op.__class__.__name__] + " ")
self.write(" " + self.getop(t.op) + " ")
self.dispatch(t.right)
self.write(")")

Expand All @@ -692,23 +709,36 @@ def _Compare(self, t):
self.write("(")
self.dispatch(t.left)
for o, e in zip(t.ops, t.comparators):
self.write(" " + self.cmpops[o.__class__.__name__] + " ")
self.write(" " + self.getop(o) + " ")
self.dispatch(e)
self.write(")")

boolops = {ast.And: 'and', ast.Or: 'or'}
boolops = {'ast.And': 'and', 'ast.Or': 'or', 'And': 'and', 'Or': 'or'}
def _BoolOp(self, t):
self.write("(")
s = " %s " % self.boolops[t.op.__class__]
s = " %s " % self.getop(t.op)
interleave(lambda: self.write(s), self.dispatch, t.values)
self.write(")")

allops = merge_dicts(binop, unop, cmpops, boolops)
@classmethod
def getop(self, op):
opcode = ''
cname = op.__class__.__name__
if cname == "str" or cname == "unicode":
opcode = op
else:
opcode = self.allops[cname]
return opcode

def _Attribute(self,t):
self.dispatch(t.value)
# Special case: 3.__abs__() is a syntax error, so if t.value
# is an integer literal then we need to either parenthesize
# it or add an extra space to get 3 .__abs__().
if isinstance(t.value, getattr(ast, 'Constant', getattr(ast, 'Num', None))) and isinstance(t.value.n, int):
if ((isinstance(t.value, getattr(ast, 'Constant', getattr(ast, 'Num', None)))
or getattr(t.value, '_class', '') in ['Constant', 'Num'])
and isinstance(getattr(t.value, 'n', t.value), int)):
self.write(" ")
self.write(".")
self.write(t.attr)
Expand Down Expand Up @@ -741,7 +771,7 @@ def _Call(self, t):
def _Subscript(self, t):
self.dispatch(t.value)
self.write("[")
self.dispatch(t.slice)
self.dispatch(t.slice, noparen=True)
self.write("]")

def _Starred(self, t):
Expand Down
12 changes: 12 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,18 @@ def test_raise_from(self):
def test_bytes(self):
self.check_roundtrip("b'123'")

def test_index(self):
self.check_roundtrip("r[0]")
self.check_roundtrip("r[0:5]")
self.check_roundtrip("r[0:5:2]")
self.check_roundtrip("r[1::2]")

def test_index2(self):
self.check_roundtrip("r[0,i]")
self.check_roundtrip("r[:,0:5]")
self.check_roundtrip("r[i:1:-1, 2]")
self.check_roundtrip("r[i:1:-1, :]")

@unittest.skipIf(sys.version_info < (3, 6), "Not supported < 3.6")
def test_formatted_value(self):
self.check_roundtrip('f"{value}"')
Expand Down