Skip to content

Commit dd974f3

Browse files
committed
Enable sympy-symengine round-trip for Dummy
1 parent aceeb64 commit dd974f3

File tree

4 files changed

+49
-10
lines changed

4 files changed

+49
-10
lines changed

symengine/lib/symengine.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@ cdef extern from "<symengine/basic.h>" namespace "SymEngine":
323323
rcp_const_basic make_rcp_Symbol "SymEngine::make_rcp<const SymEngine::Symbol>"(string name) nogil
324324
rcp_const_basic make_rcp_Dummy "SymEngine::make_rcp<const SymEngine::Dummy>"() nogil
325325
rcp_const_basic make_rcp_Dummy "SymEngine::make_rcp<const SymEngine::Dummy>"(string name) nogil
326+
rcp_const_basic make_rcp_Dummy "SymEngine::make_rcp<const SymEngine::Dummy>"(string name, size_t index) nogil
326327
rcp_const_basic make_rcp_PySymbol "SymEngine::make_rcp<const SymEngine::PySymbol>"(string name, PyObject * pyobj, bool use_pickle) except +
327328
rcp_const_basic make_rcp_Constant "SymEngine::make_rcp<const SymEngine::Constant>"(string name) nogil
328329
rcp_const_basic make_rcp_Infty "SymEngine::make_rcp<const SymEngine::Infty>"(RCP[const Number] i) nogil

symengine/lib/symengine_wrapper.in.pyx

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -278,10 +278,10 @@ def sympy2symengine(a, raise_error=False):
278278
"""
279279
import sympy
280280
from sympy.core.function import AppliedUndef as sympy_AppliedUndef
281-
if isinstance(a, sympy.Symbol):
281+
if isinstance(a, sympy.Dummy):
282+
return Dummy(a.name, a.dummy_index)
283+
elif isinstance(a, sympy.Symbol):
282284
return Symbol(a.name)
283-
elif isinstance(a, sympy.Dummy):
284-
return Dummy(a.name)
285285
elif isinstance(a, sympy.Mul):
286286
return mul(*[sympy2symengine(x, raise_error) for x in a.args])
287287
elif isinstance(a, sympy.Add):
@@ -1337,15 +1337,18 @@ cdef class Symbol(Expr):
13371337

13381338
cdef class Dummy(Symbol):
13391339

1340-
def __init__(Basic self, name=None, *args, **kwargs):
1341-
if name is None:
1342-
self.thisptr = symengine.make_rcp_Dummy()
1340+
def __init__(Basic self, name=None, dummy_index=None, *args, **kwargs):
1341+
if dummy_index is None:
1342+
if name is None:
1343+
self.thisptr = symengine.make_rcp_Dummy()
1344+
else:
1345+
self.thisptr = symengine.make_rcp_Dummy(name.encode("utf-8"))
13431346
else:
1344-
self.thisptr = symengine.make_rcp_Dummy(name.encode("utf-8"))
1347+
self.thisptr = symengine.make_rcp_Dummy(name.encode("utf-8"), dummy_index)
13451348

13461349
@property
13471350
def name(self):
1348-
return self.__str__()[1:]
1351+
return self.__str__()
13491352

13501353
def _sympy_(self):
13511354
import sympy

symengine/tests/test_pickling.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from symengine import symbols, sin, sinh, have_numpy, have_llvm, cos, Symbol
1+
from symengine import symbols, sin, sinh, have_numpy, have_llvm, cos, Symbol, Dummy
22
from symengine.test_utilities import raises
33
import pickle
44
import unittest
@@ -57,3 +57,17 @@ def test_llvm_double():
5757
ll = pickle.loads(ss)
5858
inp = [1, 2, 3]
5959
assert np.allclose(l(inp), ll(inp))
60+
61+
62+
def _check_pickling_roundtrip(arg):
63+
s2 = pickle.dumps(arg)
64+
arg2 = pickle.loads(s2)
65+
assert arg == arg2
66+
s3 = pickle.dump(arg2)
67+
arg3 = pickle.loads(s3)
68+
assert arg == arg3
69+
70+
def test_pickling_roundtrip():
71+
x, y, z = symbols('x y z')
72+
_check_pickling_roundtrip(x+y)
73+
_check_pickling_roundtrip(Dummy('d'))

symengine/tests/test_sympy_conv.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from symengine import (Symbol, Integer, sympify, SympifyError, log,
22
function_symbol, I, E, pi, oo, zoo, nan, true, false,
3-
exp, gamma, have_mpfr, have_mpc, DenseMatrix, sin, cos, tan, cot,
3+
exp, gamma, have_mpfr, have_mpc, DenseMatrix, Dummy, sin, cos, tan, cot,
44
csc, sec, asin, acos, atan, acot, acsc, asec, sinh, cosh, tanh, coth,
55
asinh, acosh, atanh, acoth, atan2, Add, Mul, Pow, diff, GoldenRatio,
66
Catalan, EulerGamma, UnevaluatedExpr, RealDouble)
@@ -833,3 +833,24 @@ def test_conv_large_integers():
833833
if have_sympy:
834834
c = a._sympy_()
835835
d = sympify(c)
836+
837+
def _check_sympy_roundtrip(arg):
838+
arg_sy1 = sympy.sympify(arg)
839+
arg_se2 = sympify(arg_sy1)
840+
print(f"{type(arg)=} {type(arg_se2)=}")
841+
assert arg == arg_se2
842+
arg_sy2 = sympy.sympify(arg_se2)
843+
assert arg_sy2 == arg_sy1
844+
arg_se3 = sympify(arg_sy2)
845+
assert arg_se3 == arg
846+
847+
848+
@unittest.skipIf(not have_sympy, "SymPy not installed")
849+
def test_sympy_roundtrip():
850+
x = Symbol("x")
851+
y = Symbol("y")
852+
d = Dummy("d")
853+
_check_sympy_roundtrip(x)
854+
_check_sympy_roundtrip(x+y)
855+
_check_sympy_roundtrip(x**y)
856+
_check_sympy_roundtrip(d)

0 commit comments

Comments
 (0)