diff --git a/lib/astunparse/unparser.py b/lib/astunparse/unparser.py index 0ef6fd8..1bd5eef 100644 --- a/lib/astunparse/unparser.py +++ b/lib/astunparse/unparser.py @@ -11,6 +11,12 @@ # We unparse those infinities to INFSTR. INFSTR = "1e" + repr(sys.float_info.max_10_exp + 1) +def merge_dicts(x, *args): + z = x.copy() + for d in args: + z.update(d) + return z + def interleave(inter, f, seq): """Call f on each item in seq, calling inter() in between. """ @@ -56,14 +62,18 @@ def leave(self): "Decrease the indentation level." self._indent -= 1 - def dispatch(self, tree): + def dispatch(self, tree, **kw): "Dispatcher function, dispatching tree type T to method _T." if isinstance(tree, list): for t in tree: self.dispatch(t) return - meth = getattr(self, "_"+tree.__class__.__name__) - meth(tree) + cname = getattr(tree, '_class', tree.__class__.__name__) + meth = getattr(self, "_"+cname, None) + if meth: + meth(tree, **(kw if meth.__name__ in ["_Tuple"] else {})) + else: + self.write('"<' + cname + '>"') ############### Unparsing methods ###################### @@ -122,7 +132,7 @@ def _Assign(self, t): def _AugAssign(self, t): self.fill() self.dispatch(t.target) - self.write(" "+self.binop[t.op.__class__.__name__]+"= ") + self.write(" "+self.getop(t.op)+"= ") self.dispatch(t.value) def _AnnAssign(self, t): @@ -275,7 +285,7 @@ def _TryExcept(self, t): self.leave() def _TryFinally(self, t): - if len(t.body) == 1 and isinstance(t.body[0], ast.TryExcept): + if len(t.body) == 1 and (isinstance(t.body[0], ast.TryExcept) or getattr(t.body[0], "_class", "") == "TryExcept"): # try-except-finally self.dispatch(t.body) else: @@ -393,7 +403,8 @@ def _If(self, t): self.leave() # collapse nested ifs into equivalent elifs. while (t.orelse and len(t.orelse) == 1 and - isinstance(t.orelse[0], ast.If)): + (isinstance(t.orelse[0], ast.If) or + getattr(t.orelse[0], '_class', '') == 'If')): t = t.orelse[0] self.fill("elif ") self.dispatch(t.test) @@ -482,13 +493,17 @@ def _FormattedValue(self, t): # FormattedValue(expr value, int? conversion, expr? format_spec) self.write("f") string = StringIO() - self._fstring_JoinedStr(t, string.write) + self._fstring_JoinedStr1(t, string.write) self.write(repr(string.getvalue())) + def _fstring_JoinedStr1(self, value, write): + cname = getattr(value, '_class', type(value).__name__) + meth = getattr(self, "_fstring_" + cname) + meth(value, write) + def _fstring_JoinedStr(self, t, write): for value in t.values: - meth = getattr(self, "_fstring_" + type(value).__name__) - meth(value, write) + self._fstring_JoinedStr1(value, write) def _fstring_Str(self, t, write): value = t.s.replace("{", "{{").replace("}", "}}") @@ -513,7 +528,8 @@ def _fstring_FormattedValue(self, t, write): write("!{conversion}".format(conversion=conversion)) if t.format_spec: write(":") - meth = getattr(self, "_fstring_" + type(t.format_spec).__name__) + cname = getattr(t.format_spec, '_class', type(t.format_spec).__name__) + meth = getattr(self, "_fstring_" + cname) meth(t.format_spec, write) write("}") @@ -648,22 +664,23 @@ def write_item(item): interleave(lambda: self.write(", "), write_item, zip(t.keys, t.values)) self.write("}") - def _Tuple(self, t): - self.write("(") + def _Tuple(self, t, noparen=False, **kw): + if not noparen: self.write("(") if len(t.elts) == 1: elt = t.elts[0] self.dispatch(elt) self.write(",") else: interleave(lambda: self.write(", "), self.dispatch, t.elts) - self.write(")") + if not noparen: self.write(")") unop = {"Invert":"~", "Not": "not", "UAdd":"+", "USub":"-"} def _UnaryOp(self, t): self.write("(") - self.write(self.unop[t.op.__class__.__name__]) + self.write(self.getop(t.op)) self.write(" ") - if six.PY2 and isinstance(t.op, ast.USub) and isinstance(t.operand, ast.Num): + if (six.PY2 and (isinstance(t.op, ast.USub) or t.op == '-') + and (isinstance(t.operand, ast.Num) or getattr(t.operand, '_class', '') == 'Num')): # If we're applying unary minus to a number, parenthesize the number. # This is necessary: -2147483648 is different from -(2147483648) on # a 32-bit machine (the first is an int, the second a long), and @@ -682,7 +699,7 @@ def _UnaryOp(self, t): def _BinOp(self, t): self.write("(") self.dispatch(t.left) - self.write(" " + self.binop[t.op.__class__.__name__] + " ") + self.write(" " + self.getop(t.op) + " ") self.dispatch(t.right) self.write(")") @@ -692,23 +709,36 @@ def _Compare(self, t): self.write("(") self.dispatch(t.left) for o, e in zip(t.ops, t.comparators): - self.write(" " + self.cmpops[o.__class__.__name__] + " ") + self.write(" " + self.getop(o) + " ") self.dispatch(e) self.write(")") - boolops = {ast.And: 'and', ast.Or: 'or'} + boolops = {'ast.And': 'and', 'ast.Or': 'or', 'And': 'and', 'Or': 'or'} def _BoolOp(self, t): self.write("(") - s = " %s " % self.boolops[t.op.__class__] + s = " %s " % self.getop(t.op) interleave(lambda: self.write(s), self.dispatch, t.values) self.write(")") + allops = merge_dicts(binop, unop, cmpops, boolops) + @classmethod + def getop(self, op): + opcode = '' + cname = op.__class__.__name__ + if cname == "str" or cname == "unicode": + opcode = op + else: + opcode = self.allops[cname] + return opcode + def _Attribute(self,t): self.dispatch(t.value) # Special case: 3.__abs__() is a syntax error, so if t.value # is an integer literal then we need to either parenthesize # it or add an extra space to get 3 .__abs__(). - if isinstance(t.value, getattr(ast, 'Constant', getattr(ast, 'Num', None))) and isinstance(t.value.n, int): + if ((isinstance(t.value, getattr(ast, 'Constant', getattr(ast, 'Num', None))) + or getattr(t.value, '_class', '') in ['Constant', 'Num']) + and isinstance(getattr(t.value, 'n', t.value), int)): self.write(" ") self.write(".") self.write(t.attr) @@ -741,7 +771,7 @@ def _Call(self, t): def _Subscript(self, t): self.dispatch(t.value) self.write("[") - self.dispatch(t.slice) + self.dispatch(t.slice, noparen=True) self.write("]") def _Starred(self, t): diff --git a/tests/common.py b/tests/common.py index 95b9755..0c14b13 100644 --- a/tests/common.py +++ b/tests/common.py @@ -293,6 +293,18 @@ def test_raise_from(self): def test_bytes(self): self.check_roundtrip("b'123'") + def test_index(self): + self.check_roundtrip("r[0]") + self.check_roundtrip("r[0:5]") + self.check_roundtrip("r[0:5:2]") + self.check_roundtrip("r[1::2]") + + def test_index2(self): + self.check_roundtrip("r[0,i]") + self.check_roundtrip("r[:,0:5]") + self.check_roundtrip("r[i:1:-1, 2]") + self.check_roundtrip("r[i:1:-1, :]") + @unittest.skipIf(sys.version_info < (3, 6), "Not supported < 3.6") def test_formatted_value(self): self.check_roundtrip('f"{value}"')