diff --git a/src/mccode_antlr/common/expression.py b/src/mccode_antlr/common/expression.py index 01b804b..bede08f 100644 --- a/src/mccode_antlr/common/expression.py +++ b/src/mccode_antlr/common/expression.py @@ -500,7 +500,7 @@ def __post_init__(self): def as_type(self, pdt): value = [x.as_type(pdt) for x in self.value] - return UnaryOp(self.op, value) + return UnaryOp(data_type=self.data_type, style=self.style, op=self.op, value=value) def _str_repr_(self, vstr): c_style = self.style == OpStyle.C diff --git a/tests/test_expression.py b/tests/test_expression.py index d2aadea..9d81648 100644 --- a/tests/test_expression.py +++ b/tests/test_expression.py @@ -267,6 +267,15 @@ def test_UnaryOp(self): not_val.style = OpStyle.PYTHON self.assertEqual(str(not_val), 'not val') + def test_UnaryOp_as_type(self): + uop = UnaryOp(DataType.undefined, OpStyle.C, '__not__', Value.id('val')) + self.assertEqual(str(uop), '!val') + iop = uop.as_type(DataType.int) + self.assertEqual(str(iop), '!val') + fop = uop.as_type(DataType.float) + self.assertEqual(str(fop), '!val') + + def test_numeric_BinaryOp(self): f = [Value.float(x) for x in range(3)] i = [Value.float(x) for x in range(3)]