Skip to content

Commit abf6db9

Browse files
committed
Use a symbolic function
1 parent 495dfc5 commit abf6db9

File tree

3 files changed

+31
-43
lines changed

3 files changed

+31
-43
lines changed

sumpy/codegen.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -592,39 +592,6 @@ def map_sum(self, expr, *args):
592592
# }}}
593593

594594

595-
# {{{ helmholtz rewrite
596-
class HelmholtzRewriter(CSECachingIdentityMapper, CallExternalRecMapper):
597-
def __init__(self, k, ik):
598-
self.k = k
599-
self.ik = ik
600-
601-
def map_variable(self, expr, *args):
602-
if expr.name == self.ik.name:
603-
return 1j*self.k
604-
else:
605-
return expr
606-
607-
def map_call(self, expr, *args):
608-
if isinstance(expr.function, prim.Variable) \
609-
and expr.function.name == "exp":
610-
params = expr.parameters
611-
assert len(params) == 1
612-
param = self.rec(params[0])
613-
if isinstance(param, prim.Product) and 1j in param.children:
614-
children = list(param.children)
615-
del children[children.index(1j)]
616-
params = (prim.Product(tuple(children)),)
617-
return prim.Call(prim.Variable("cos"), params) + \
618-
1j * prim.Call(prim.Variable("sin"), params)
619-
620-
return super().map_call(expr, *args)
621-
622-
map_common_subexpression_uncached = IdentityMapper.map_common_subexpression
623-
624-
625-
# }}}
626-
627-
628595
class MathConstantRewriter(CSECachingIdentityMapper, CallExternalRecMapper):
629596
def map_variable(self, expr, *args):
630597
if expr.name == "pi":

sumpy/kernel.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,12 @@ def __init__(self, dim, helmholtz_k_name="k",
522522
if allow_evanescent:
523523
expr = var("exp")(var("I")*k*r)/r
524524
else:
525-
expr = var("exp")(var("Ik")*r)/r
525+
# expi is a function that takes in a real and returns a
526+
# complex number such that
527+
# expi(x) = exp(I * x)
528+
# Retaining the information that the input is real leads
529+
# to better code generation
530+
expr = var("expi")(k*r)/r
526531
scaling = 1/(4*var("pi"))
527532
else:
528533
raise RuntimeError("unsupported dimensionality")
@@ -579,15 +584,6 @@ def get_pde_as_diff_op(self):
579584
k = sym.Symbol(self.helmholtz_k_name)
580585
return (laplacian(w) + k**2 * w)
581586

582-
def get_code_transformer(self):
583-
k = SpatialConstant(self.helmholtz_k_name)
584-
585-
if self.allow_evanescent:
586-
return lambda expr: expr
587-
else:
588-
from sumpy.codegen import HelmholtzRewriter
589-
return HelmholtzRewriter(k, var("Ik"))
590-
591587

592588
class YukawaKernel(ExpressionKernel):
593589
init_arg_names = ("dim", "yukawa_lambda_name")

sumpy/symbolic.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,13 @@ def map_Mul(self, expr): # noqa: N802
313313

314314
return math.prod(num_args) / math.prod(den_args)
315315

316+
def map_FunctionSymbol(self, expr):
317+
if expr.get_name() == "ExpI":
318+
arg = self.rec(expr.args[0])
319+
return prim.Variable("cos")(arg) + 1j * prim.Variable("sin")(arg)
320+
else:
321+
return SympyToPymbolicMapperBase.map_FunctionSymbol(self, expr)
322+
316323

317324
class PymbolicToSympyMapperWithSymbols(PymbolicToSympyMapper):
318325
def map_variable(self, expr):
@@ -338,6 +345,9 @@ def map_call(self, expr):
338345
args = [self.rec(param) for param in expr.parameters]
339346
args.append(0)
340347
return BesselJ(*args)
348+
elif expr.function.name == "expi":
349+
args = [self.rec(param) for param in expr.parameters]
350+
return ExpI(*args)
341351
else:
342352
return PymbolicToSympyMapper.map_call(self, expr)
343353

@@ -369,8 +379,20 @@ class Hankel1(_BesselOrHankel):
369379
pass
370380

371381

382+
class ExpI(sympy.Function):
383+
"""A symbolic function that takes a real value as an
384+
input and returns a complex number such that
385+
expi(x) = exp(i*x).
386+
"""
387+
nargs = (1,)
388+
389+
def fdiff(self, argindex=1):
390+
return self.func(self.args[0]) * sympy.I
391+
392+
372393
_SympyBesselJ = BesselJ
373394
_SympyHankel1 = Hankel1
395+
_SympyExpI = ExpI
374396

375397
if USE_SYMENGINE:
376398
def BesselJ(*args): # noqa: N802 # pylint: disable=function-redefined
@@ -379,4 +401,7 @@ def BesselJ(*args): # noqa: N802 # pylint: disable=function-redefined
379401
def Hankel1(*args): # noqa: N802 # pylint: disable=function-redefined
380402
return sym.sympify(_SympyHankel1(*args))
381403

404+
def ExpI(*args): # noqa: N802 # pylint: disable=function-redefined
405+
return sym.sympify(_SympyExpI(*args))
406+
382407
# vim: fdm=marker

0 commit comments

Comments
 (0)