Skip to content

Commit 97d96b0

Browse files
authored
Add support for the matmul (@) operator. (#270)
1 parent 821ea97 commit 97d96b0

File tree

3 files changed

+13
-8
lines changed

3 files changed

+13
-8
lines changed

CHANGES.rst

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ Changes
77
- Allow to use the package with Python 3.13 -- Caution: No security
88
audit has been done so far.
99

10+
- Add support for the matmul (``@``) operator.
11+
1012

1113
7.0 (2023-11-17)
1214
----------------

src/RestrictedPython/transformer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -768,8 +768,8 @@ def visit_BitAnd(self, node):
768768
return self.node_contents_visit(node)
769769

770770
def visit_MatMult(self, node):
771-
"""Matrix multiplication (`@`) is currently not allowed."""
772-
self.not_allowed(node)
771+
"""Allow multiplication (`@`)."""
772+
return self.node_contents_visit(node)
773773

774774
def visit_BoolOp(self, node):
775775
"""Allow bool operator without restrictions."""

tests/transformer/operators/test_arithmetic_operators.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from RestrictedPython import compile_restricted_eval
21
from tests.helper import restricted_eval
32

43

@@ -33,8 +32,12 @@ def test_FloorDiv():
3332

3433

3534
def test_MatMult():
36-
result = compile_restricted_eval('(8, 3, 5) @ (2, 7, 1)')
37-
assert result.errors == (
38-
'Line None: MatMult statements are not allowed.',
39-
)
40-
assert result.code is None
35+
class Vector:
36+
def __init__(self, values):
37+
self.values = values
38+
39+
def __matmul__(self, other):
40+
return sum(x * y for x, y in zip(self.values, other.values))
41+
42+
assert restricted_eval(
43+
'Vector((8, 3, 5)) @ Vector((2, 7, 1))', {'Vector': Vector}) == 42

0 commit comments

Comments
 (0)