Skip to content

Commit 62c20a8

Browse files
committed
feat: update to new pymbolic
1 parent 6a3ffc7 commit 62c20a8

File tree

5 files changed

+232
-164
lines changed

5 files changed

+232
-164
lines changed

pytential/symbolic/mappers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def rec_int_g_arguments(mapper, expr):
6161
name: mapper.rec(arg) for name, arg in expr.kernel_arguments.items()
6262
}
6363

64-
changed = (
64+
changed = not (
6565
all(d is orig for d, orig in zip(densities, expr.densities, strict=True))
6666
and all(
6767
arg is orig for arg, orig in zip(
@@ -280,7 +280,12 @@ def map_common_subexpression(self, expr):
280280
# {{{ FlattenMapper
281281

282282
class FlattenMapper(FlattenMapperBase, IdentityMapper):
283-
pass
283+
def map_int_g(self, expr):
284+
densities, kernel_arguments, changed = rec_int_g_arguments(self, expr)
285+
if not changed:
286+
return expr
287+
288+
return replace(expr, densities=densities, kernel_arguments=kernel_arguments)
284289

285290

286291
def flatten(expr):

pytential/symbolic/pde/system_utils.py

Lines changed: 65 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -21,27 +21,21 @@
2121
"""
2222

2323
import logging
24-
import warnings
2524
from collections.abc import Mapping, Sequence
26-
from dataclasses import dataclass
27-
from typing import Any
25+
from dataclasses import dataclass, replace
26+
from typing import Any, cast
2827

2928
import numpy as np
30-
import pymbolic
31-
import sumpy.symbolic as sym
32-
from pytools import \
33-
generate_nonnegative_integer_tuples_summing_to_at_most as gnitstam
29+
import sumpy.symbolic as sp
30+
from pytools import generate_nonnegative_integer_tuples_summing_to_at_most as gnitstam
3431
from pytools import memoize_on_first_arg
3532
from sumpy.kernel import (AxisSourceDerivative, AxisTargetDerivative,
3633
DirectionalSourceDerivative, ExpressionKernel,
3734
Kernel, KernelWrapper, TargetPointMultiplier)
38-
from pymbolic.typing import ExpressionT, ArithmeticExpressionT
35+
from pymbolic.typing import ArithmeticExpressionT
3936

40-
import pytential
41-
from pytential.symbolic.mappers import IdentityMapper
42-
from pytential.symbolic.primitives import (DEFAULT_SOURCE, IntG,
43-
NodeCoordinateComponent,
44-
hashable_kernel_args)
37+
from pytential import sym
38+
from pytential.symbolic.mappers import IdentityMapper, flatten
4539
from pytential.utils import chop, solve_from_lu
4640

4741
logger = logging.getLogger(__name__)
@@ -72,8 +66,8 @@ class RewriteFailedError(RuntimeError):
7266

7367

7468
def rewrite_using_base_kernel(
75-
exprs: Sequence[ExpressionT],
76-
base_kernel: Kernel = _NO_ARG_SENTINEL) -> list[ExpressionT]:
69+
exprs: Sequence[ArithmeticExpressionT],
70+
base_kernel: Kernel = _NO_ARG_SENTINEL) -> list[ArithmeticExpressionT]:
7771
"""
7872
Rewrites a list of expressions with :class:`~pytential.symbolic.primitives.IntG`
7973
objects using *base_kernel*.
@@ -101,7 +95,7 @@ def rewrite_using_base_kernel(
10195
raise NotImplementedError
10296

10397
mapper = RewriteUsingBaseKernelMapper(base_kernel)
104-
return [mapper(expr) for expr in exprs]
98+
return [cast(ArithmeticExpressionT, mapper(expr)) for expr in exprs]
10599

106100

107101
class RewriteUsingBaseKernelMapper(IdentityMapper):
@@ -132,7 +126,7 @@ def map_int_g(self, expr):
132126

133127

134128
def _get_sympy_kernel_expression(expr: ArithmeticExpressionT,
135-
kernel_arguments: Mapping[str, Any]) -> sym.Basic:
129+
kernel_arguments: Mapping[str, Any]) -> sp.Basic:
136130
"""Convert a :mod:`pymbolic` expression to :mod:`sympy` expression
137131
after substituting kernel arguments.
138132
@@ -149,8 +143,8 @@ def _get_sympy_kernel_expression(expr: ArithmeticExpressionT,
149143

150144

151145
def _monom_to_expr(monom: Sequence[int],
152-
variables: Sequence[sym.Basic | ArithmeticExpressionT]
153-
) -> sym.Basic | ArithmeticExpressionT:
146+
variables: Sequence[sp.Basic | ArithmeticExpressionT]
147+
) -> sp.Basic | ArithmeticExpressionT:
154148
"""Convert a monomial to an expression using given variables.
155149
156150
For example, ``[3, 2, 1]`` with variables ``[x, y, z]`` is converted to
@@ -164,7 +158,7 @@ def _monom_to_expr(monom: Sequence[int],
164158
return prod
165159

166160

167-
def convert_target_transformation_to_source(int_g: IntG) -> list[IntG]:
161+
def convert_target_transformation_to_source(int_g: sym.IntG) -> list[sym.IntG]:
168162
r"""Convert an ``IntG`` with :class:`~sumpy.kernel.AxisTargetDerivative`
169163
or :class:`~sumpy.kernel.TargetPointMultiplier` to a list
170164
of ``IntG``\ s without them and only source dependent transformations.
@@ -182,8 +176,9 @@ def convert_target_transformation_to_source(int_g: IntG) -> list[IntG]:
182176

183177
knl = int_g.target_kernel
184178
if not knl.is_translation_invariant:
185-
warnings.warn(f"Translation variant kernel ({knl}) found.",
186-
stacklevel=2)
179+
from warnings import warn
180+
181+
warn(f"Translation variant kernel ({knl}) found.", stacklevel=2)
187182
return [int_g]
188183

189184
# we use a symbol for d = (x - y)
@@ -204,7 +199,9 @@ def convert_target_transformation_to_source(int_g: IntG) -> list[IntG]:
204199
expr = expr.diff(ds[knl.axis])
205200
found = True
206201
else:
207-
warnings.warn(
202+
from warnings import warn
203+
204+
warn(
208205
f"Unknown target kernel ({knl}) found. "
209206
"Returning IntG expression unchanged.", stacklevel=2)
210207
return [int_g]
@@ -213,9 +210,9 @@ def convert_target_transformation_to_source(int_g: IntG) -> list[IntG]:
213210
if not found:
214211
return [int_g]
215212

216-
int_g = int_g.copy(target_kernel=knl)
213+
int_g = replace(int_g, target_kernel=knl)
217214

218-
sources_pymbolic = [NodeCoordinateComponent(i) for i in range(knl.dim)]
215+
sources_pymbolic = sym.nodes(knl.dim).as_vector()
219216
expr = expr.expand()
220217
# Now the expr is an Add and looks like
221218
# u_{d[0], d[1]}(d, y)*d[0]*y[1] + u(d, y) * d[1]
@@ -255,7 +252,7 @@ def convert_target_transformation_to_source(int_g: IntG) -> list[IntG]:
255252
for _ in range(nrepeats):
256253
knl = AxisSourceDerivative(axis, knl)
257254
new_source_kernels.append(knl)
258-
new_int_g = int_g.copy(source_kernels=new_source_kernels)
255+
new_int_g = replace(int_g, source_kernels=tuple(new_source_kernels))
259256

260257
(monom, coeff,) = remaining_factors.terms()[0]
261258
# Now from d[0]*y[1], we separate the two terms
@@ -266,66 +263,73 @@ def convert_target_transformation_to_source(int_g: IntG) -> list[IntG]:
266263
* conv(coeff)
267264
# since d/d(d) = - d/d(y), we multiply by -1 to get source derivatives
268265
density_multiplier *= (-1)**int(sum(nrepeats for _, nrepeats in derivatives))
269-
new_int_gs = _multiply_int_g(new_int_g, sym.sympify(expr_multiplier),
266+
new_int_gs = _multiply_int_g(new_int_g, sp.sympify(expr_multiplier),
270267
density_multiplier)
271268
result.extend(new_int_gs)
272269
return result
273270

274271

275-
def _multiply_int_g(int_g: IntG, expr_multiplier: sym.Basic,
276-
density_multiplier: ArithmeticExpressionT) -> list[IntG]:
272+
def _multiply_int_g(int_g: sym.IntG, expr_multiplier: sp.Basic,
273+
density_multiplier: ArithmeticExpressionT) -> list[sym.IntG]:
277274
"""Multiply the expression in ``IntG`` with the *expr_multiplier*
278275
which is a symbolic (:mod:`sympy` or :mod:`symengine`) expression and
279276
multiply the densities with *density_multiplier* which is a :mod:`pymbolic`
280277
expression.
281278
"""
279+
from pymbolic import substitute
280+
282281
result = []
283282

284283
base_kernel = int_g.target_kernel.get_base_kernel()
285-
sym_d = sym.make_sym_vector("d", base_kernel.dim)
284+
sym_d = sp.make_sym_vector("d", base_kernel.dim)
286285
base_kernel_expr = _get_sympy_kernel_expression(base_kernel.expression,
287286
int_g.kernel_arguments)
288-
subst = {pymbolic.var(f"d{i}"): pymbolic.var("d")[i] for i in
287+
subst = {sym.var(f"d{i}"): sym.var("d")[i] for i in
289288
range(base_kernel.dim)}
290-
conv = sym.SympyToPymbolicMapper()
289+
conv = sp.SympyToPymbolicMapper()
291290

292291
if expr_multiplier == 1:
293292
# if there's no expr_multiplier, only multiply the densities
294-
return [int_g.copy(densities=tuple(density*density_multiplier
295-
for density in int_g.densities))]
293+
return [replace(
294+
int_g,
295+
densities=tuple(density*density_multiplier for density in int_g.densities))
296+
]
296297

297298
for knl, density in zip(int_g.source_kernels, int_g.densities, strict=True):
298299
if expr_multiplier == 1:
299300
new_knl = knl.get_base_kernel()
300301
else:
301302
new_expr = conv(knl.postprocess_at_source(base_kernel_expr, sym_d)
302303
* expr_multiplier)
303-
new_expr = pymbolic.substitute(new_expr, subst)
304-
new_knl = ExpressionKernel(knl.dim, new_expr,
304+
new_expr = substitute(new_expr, subst)
305+
new_knl = ExpressionKernel(knl.dim, flatten(new_expr),
305306
knl.get_base_kernel().global_scaling_const,
306307
knl.is_complex_valued)
307-
result.append(int_g.copy(target_kernel=new_knl,
308+
result.append(replace(
309+
int_g,
310+
target_kernel=new_knl,
308311
densities=(density*density_multiplier,),
309-
source_kernels=(new_knl,)))
312+
source_kernels=(new_knl,)
313+
))
310314
return result
311315

312316

313317
def rewrite_int_g_using_base_kernel(
314-
int_g: IntG, base_kernel: ExpressionKernel) -> ArithmeticExpressionT:
318+
int_g: sym.IntG, base_kernel: ExpressionKernel) -> ArithmeticExpressionT:
315319
r"""Rewrite an ``IntG`` to an expression with ``IntG``\ s having the
316320
base kernel *base_kernel*.
317321
"""
318322
result: ArithmeticExpressionT = 0
319323
for knl, density in zip(int_g.source_kernels, int_g.densities, strict=True):
320324
result += _rewrite_int_g_using_base_kernel(
321-
int_g.copy(source_kernels=(knl,), densities=(density,)),
325+
replace(int_g, source_kernels=(knl,), densities=(density,)),
322326
base_kernel)
323327

324328
return result
325329

326330

327331
def _rewrite_int_g_using_base_kernel(
328-
int_g: IntG, base_kernel: ExpressionKernel) -> ArithmeticExpressionT:
332+
int_g: sym.IntG, base_kernel: ExpressionKernel) -> ArithmeticExpressionT:
329333
r"""Rewrites an ``IntG`` with only one source kernel to an expression with
330334
``IntG``\ s having the base kernel *base_kernel*.
331335
"""
@@ -338,17 +342,17 @@ def _rewrite_int_g_using_base_kernel(
338342
source_kernel, = int_g.source_kernels
339343
deriv_relation = get_deriv_relation_kernel(source_kernel.get_base_kernel(),
340344
base_kernel, hashable_kernel_arguments=(
341-
hashable_kernel_args(int_g.kernel_arguments)))
345+
sym.hashable_kernel_args(int_g.kernel_arguments)))
342346

343347
const = deriv_relation.const
344348
# NOTE: we set a dofdesc here to force the evaluation of this integral
345349
# on the source instead of the target when using automatic tagging
346350
# see :meth:`pytential.symbolic.mappers.LocationTagger._default_dofdesc`
347351
if int_g.source.geometry is None:
348-
dd = int_g.source.copy(geometry=DEFAULT_SOURCE)
352+
dd = int_g.source.copy(geometry=sym.DEFAULT_SOURCE)
349353
else:
350354
dd = int_g.source
351-
const *= pytential.sym.integral(dim, dim-1, density, dofdesc=dd)
355+
const *= sym.integral(dim, dim-1, density, dofdesc=dd)
352356

353357
if const != 0 and target_kernel != target_kernel.get_base_kernel():
354358
# There might be some TargetPointMultipliers hanging around.
@@ -377,8 +381,13 @@ def _rewrite_int_g_using_base_kernel(
377381
for _ in range(val):
378382
knl = AxisSourceDerivative(d, knl)
379383
c *= -1
380-
result += int_g.copy(source_kernels=(knl,), target_kernel=target_kernel,
381-
densities=(density * c,), kernel_arguments=new_kernel_args)
384+
result += replace(
385+
int_g,
386+
source_kernels=(knl,),
387+
target_kernel=target_kernel,
388+
densities=(density * c,),
389+
kernel_arguments=new_kernel_args)
390+
382391
return result
383392

384393

@@ -432,7 +441,7 @@ def get_deriv_relation(
432441
res = []
433442
for knl in kernels:
434443
res.append(get_deriv_relation_kernel(knl, base_kernel,
435-
hashable_kernel_arguments=hashable_kernel_args(kernel_arguments),
444+
hashable_kernel_arguments=sym.hashable_kernel_args(kernel_arguments),
436445
tol=tol, order=order))
437446
return res
438447

@@ -460,14 +469,14 @@ def get_deriv_relation_kernel(
460469
order=order,
461470
hashable_kernel_arguments=hashable_kernel_arguments)
462471
dim = base_kernel.dim
463-
sym_vec = sym.make_sym_vector("d", dim)
464-
sympy_conv = sym.SympyToPymbolicMapper()
472+
sym_vec = sp.make_sym_vector("d", dim)
473+
sympy_conv = sp.SympyToPymbolicMapper()
465474

466475
expr = _get_sympy_kernel_expression(kernel.expression, kernel_arguments)
467476
vec = []
468477
for i in range(len(mis)):
469478
vec.append(evalf(expr.xreplace(dict(zip(sym_vec, rand[:, i], strict=True)))))
470-
vec = sym.Matrix(vec)
479+
vec = sp.Matrix(vec)
471480
result = []
472481
const = 0
473482
logger.debug("%s = ", kernel)
@@ -493,8 +502,8 @@ def get_deriv_relation_kernel(
493502

494503
@dataclass
495504
class LUFactorization:
496-
L: sym.Matrix
497-
U: sym.Matrix
505+
L: sp.Matrix
506+
U: sp.Matrix
498507
perm: Sequence[tuple[int, int]]
499508

500509

@@ -539,8 +548,8 @@ def _get_base_kernel_matrix_lu_factorization(
539548
rand: np.ndarray = rng.integers(1, 10**15, size=(dim, len(mis))).astype(object)
540549
for i in range(rand.shape[0]):
541550
for j in range(rand.shape[1]):
542-
rand[i, j] = sym.sympify(rand[i, j])/10**15
543-
sym_vec = sym.make_sym_vector("d", dim)
551+
rand[i, j] = sp.sympify(rand[i, j])/10**15
552+
sym_vec = sp.make_sym_vector("d", dim)
544553

545554
base_expr = _get_sympy_kernel_expression(base_kernel.expression,
546555
dict(hashable_kernel_arguments))
@@ -560,7 +569,7 @@ def _get_base_kernel_matrix_lu_factorization(
560569
row.append(1)
561570
mat.append(row)
562571

563-
sym_mat = sym.Matrix(mat)
572+
sym_mat = sp.Matrix(mat)
564573
failed = False
565574
try:
566575
L, U, perm = sym_mat.LUdecomposition()
@@ -569,7 +578,7 @@ def _get_base_kernel_matrix_lu_factorization(
569578
# and sympy returns U with last row zero
570579
failed = True
571580

572-
if not sym.USE_SYMENGINE and all(expr == 0 for expr in U[-1, :]):
581+
if not sp.USE_SYMENGINE and all(expr == 0 for expr in U[-1, :]):
573582
failed = True
574583

575584
if failed:

pytential/symbolic/primitives.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from pymbolic.geometric_algebra.primitives import (
3838
NablaComponent, Derivative as DerivativeBase)
3939
from pymbolic.primitives import make_sym_vector
40+
from pymbolic.typing import ArithmeticExpressionT
4041

4142
from pytools.obj_array import make_obj_array, flat_obj_array
4243
from sumpy.kernel import Kernel, SpatialConstant
@@ -1486,7 +1487,7 @@ class IntG(Expression):
14861487
derivatives attached. k-th elements represents the k-th source derivative
14871488
operator above.
14881489
"""
1489-
densities: tuple[Expression, ...]
1490+
densities: tuple[ArithmeticExpressionT, ...]
14901491
"""A tuple of density expressions. Length of this tuple must match the length
14911492
of the *source_kernels* arguments.
14921493
"""

0 commit comments

Comments
 (0)