2121"""
2222
2323import logging
24- import warnings
2524from collections .abc import Mapping , Sequence
26- from dataclasses import dataclass
27- from typing import Any
25+ from dataclasses import dataclass , replace
26+ from typing import Any , cast
2827
2928import numpy as np
30- import pymbolic
31- import sumpy .symbolic as sym
32- from pytools import \
33- generate_nonnegative_integer_tuples_summing_to_at_most as gnitstam
29+ import sumpy .symbolic as sp
30+ from pytools import generate_nonnegative_integer_tuples_summing_to_at_most as gnitstam
3431from pytools import memoize_on_first_arg
3532from sumpy .kernel import (AxisSourceDerivative , AxisTargetDerivative ,
3633 DirectionalSourceDerivative , ExpressionKernel ,
3734 Kernel , KernelWrapper , TargetPointMultiplier )
38- from pymbolic .typing import ExpressionT , ArithmeticExpressionT
35+ from pymbolic .typing import ArithmeticExpressionT
3936
40- import pytential
41- from pytential .symbolic .mappers import IdentityMapper
42- from pytential .symbolic .primitives import (DEFAULT_SOURCE , IntG ,
43- NodeCoordinateComponent ,
44- hashable_kernel_args )
37+ from pytential import sym
38+ from pytential .symbolic .mappers import IdentityMapper , flatten
4539from pytential .utils import chop , solve_from_lu
4640
4741logger = logging .getLogger (__name__ )
@@ -72,8 +66,8 @@ class RewriteFailedError(RuntimeError):
7266
7367
7468def rewrite_using_base_kernel (
75- exprs : Sequence [ExpressionT ],
76- base_kernel : Kernel = _NO_ARG_SENTINEL ) -> list [ExpressionT ]:
69+ exprs : Sequence [ArithmeticExpressionT ],
70+ base_kernel : Kernel = _NO_ARG_SENTINEL ) -> list [ArithmeticExpressionT ]:
7771 """
7872 Rewrites a list of expressions with :class:`~pytential.symbolic.primitives.IntG`
7973 objects using *base_kernel*.
@@ -101,7 +95,7 @@ def rewrite_using_base_kernel(
10195 raise NotImplementedError
10296
10397 mapper = RewriteUsingBaseKernelMapper (base_kernel )
104- return [mapper (expr ) for expr in exprs ]
98+ return [cast ( ArithmeticExpressionT , mapper (expr ) ) for expr in exprs ]
10599
106100
107101class RewriteUsingBaseKernelMapper (IdentityMapper ):
@@ -132,7 +126,7 @@ def map_int_g(self, expr):
132126
133127
134128def _get_sympy_kernel_expression (expr : ArithmeticExpressionT ,
135- kernel_arguments : Mapping [str , Any ]) -> sym .Basic :
129+ kernel_arguments : Mapping [str , Any ]) -> sp .Basic :
136130 """Convert a :mod:`pymbolic` expression to :mod:`sympy` expression
137131 after substituting kernel arguments.
138132
@@ -149,8 +143,8 @@ def _get_sympy_kernel_expression(expr: ArithmeticExpressionT,
149143
150144
151145def _monom_to_expr (monom : Sequence [int ],
152- variables : Sequence [sym .Basic | ArithmeticExpressionT ]
153- ) -> sym .Basic | ArithmeticExpressionT :
146+ variables : Sequence [sp .Basic | ArithmeticExpressionT ]
147+ ) -> sp .Basic | ArithmeticExpressionT :
154148 """Convert a monomial to an expression using given variables.
155149
156150 For example, ``[3, 2, 1]`` with variables ``[x, y, z]`` is converted to
@@ -164,7 +158,7 @@ def _monom_to_expr(monom: Sequence[int],
164158 return prod
165159
166160
167- def convert_target_transformation_to_source (int_g : IntG ) -> list [IntG ]:
161+ def convert_target_transformation_to_source (int_g : sym . IntG ) -> list [sym . IntG ]:
168162 r"""Convert an ``IntG`` with :class:`~sumpy.kernel.AxisTargetDerivative`
169163 or :class:`~sumpy.kernel.TargetPointMultiplier` to a list
170164 of ``IntG``\ s without them and only source dependent transformations.
@@ -182,8 +176,9 @@ def convert_target_transformation_to_source(int_g: IntG) -> list[IntG]:
182176
183177 knl = int_g .target_kernel
184178 if not knl .is_translation_invariant :
185- warnings .warn (f"Translation variant kernel ({ knl } ) found." ,
186- stacklevel = 2 )
179+ from warnings import warn
180+
181+ warn (f"Translation variant kernel ({ knl } ) found." , stacklevel = 2 )
187182 return [int_g ]
188183
189184 # we use a symbol for d = (x - y)
@@ -204,7 +199,9 @@ def convert_target_transformation_to_source(int_g: IntG) -> list[IntG]:
204199 expr = expr .diff (ds [knl .axis ])
205200 found = True
206201 else :
207- warnings .warn (
202+ from warnings import warn
203+
204+ warn (
208205 f"Unknown target kernel ({ knl } ) found. "
209206 "Returning IntG expression unchanged." , stacklevel = 2 )
210207 return [int_g ]
@@ -213,9 +210,9 @@ def convert_target_transformation_to_source(int_g: IntG) -> list[IntG]:
213210 if not found :
214211 return [int_g ]
215212
216- int_g = int_g . copy ( target_kernel = knl )
213+ int_g = replace ( int_g , target_kernel = knl )
217214
218- sources_pymbolic = [ NodeCoordinateComponent ( i ) for i in range ( knl .dim )]
215+ sources_pymbolic = sym . nodes ( knl .dim ). as_vector ()
219216 expr = expr .expand ()
220217 # Now the expr is an Add and looks like
221218 # u_{d[0], d[1]}(d, y)*d[0]*y[1] + u(d, y) * d[1]
@@ -255,7 +252,7 @@ def convert_target_transformation_to_source(int_g: IntG) -> list[IntG]:
255252 for _ in range (nrepeats ):
256253 knl = AxisSourceDerivative (axis , knl )
257254 new_source_kernels .append (knl )
258- new_int_g = int_g . copy ( source_kernels = new_source_kernels )
255+ new_int_g = replace ( int_g , source_kernels = tuple ( new_source_kernels ) )
259256
260257 (monom , coeff ,) = remaining_factors .terms ()[0 ]
261258 # Now from d[0]*y[1], we separate the two terms
@@ -266,66 +263,73 @@ def convert_target_transformation_to_source(int_g: IntG) -> list[IntG]:
266263 * conv (coeff )
267264 # since d/d(d) = - d/d(y), we multiply by -1 to get source derivatives
268265 density_multiplier *= (- 1 )** int (sum (nrepeats for _ , nrepeats in derivatives ))
269- new_int_gs = _multiply_int_g (new_int_g , sym .sympify (expr_multiplier ),
266+ new_int_gs = _multiply_int_g (new_int_g , sp .sympify (expr_multiplier ),
270267 density_multiplier )
271268 result .extend (new_int_gs )
272269 return result
273270
274271
275- def _multiply_int_g (int_g : IntG , expr_multiplier : sym .Basic ,
276- density_multiplier : ArithmeticExpressionT ) -> list [IntG ]:
272+ def _multiply_int_g (int_g : sym . IntG , expr_multiplier : sp .Basic ,
273+ density_multiplier : ArithmeticExpressionT ) -> list [sym . IntG ]:
277274 """Multiply the expression in ``IntG`` with the *expr_multiplier*
278275 which is a symbolic (:mod:`sympy` or :mod:`symengine`) expression and
279276 multiply the densities with *density_multiplier* which is a :mod:`pymbolic`
280277 expression.
281278 """
279+ from pymbolic import substitute
280+
282281 result = []
283282
284283 base_kernel = int_g .target_kernel .get_base_kernel ()
285- sym_d = sym .make_sym_vector ("d" , base_kernel .dim )
284+ sym_d = sp .make_sym_vector ("d" , base_kernel .dim )
286285 base_kernel_expr = _get_sympy_kernel_expression (base_kernel .expression ,
287286 int_g .kernel_arguments )
288- subst = {pymbolic .var (f"d{ i } " ): pymbolic .var ("d" )[i ] for i in
287+ subst = {sym .var (f"d{ i } " ): sym .var ("d" )[i ] for i in
289288 range (base_kernel .dim )}
290- conv = sym .SympyToPymbolicMapper ()
289+ conv = sp .SympyToPymbolicMapper ()
291290
292291 if expr_multiplier == 1 :
293292 # if there's no expr_multiplier, only multiply the densities
294- return [int_g .copy (densities = tuple (density * density_multiplier
295- for density in int_g .densities ))]
293+ return [replace (
294+ int_g ,
295+ densities = tuple (density * density_multiplier for density in int_g .densities ))
296+ ]
296297
297298 for knl , density in zip (int_g .source_kernels , int_g .densities , strict = True ):
298299 if expr_multiplier == 1 :
299300 new_knl = knl .get_base_kernel ()
300301 else :
301302 new_expr = conv (knl .postprocess_at_source (base_kernel_expr , sym_d )
302303 * expr_multiplier )
303- new_expr = pymbolic . substitute (new_expr , subst )
304- new_knl = ExpressionKernel (knl .dim , new_expr ,
304+ new_expr = substitute (new_expr , subst )
305+ new_knl = ExpressionKernel (knl .dim , flatten ( new_expr ) ,
305306 knl .get_base_kernel ().global_scaling_const ,
306307 knl .is_complex_valued )
307- result .append (int_g .copy (target_kernel = new_knl ,
308+ result .append (replace (
309+ int_g ,
310+ target_kernel = new_knl ,
308311 densities = (density * density_multiplier ,),
309- source_kernels = (new_knl ,)))
312+ source_kernels = (new_knl ,)
313+ ))
310314 return result
311315
312316
313317def rewrite_int_g_using_base_kernel (
314- int_g : IntG , base_kernel : ExpressionKernel ) -> ArithmeticExpressionT :
318+ int_g : sym . IntG , base_kernel : ExpressionKernel ) -> ArithmeticExpressionT :
315319 r"""Rewrite an ``IntG`` to an expression with ``IntG``\ s having the
316320 base kernel *base_kernel*.
317321 """
318322 result : ArithmeticExpressionT = 0
319323 for knl , density in zip (int_g .source_kernels , int_g .densities , strict = True ):
320324 result += _rewrite_int_g_using_base_kernel (
321- int_g . copy ( source_kernels = (knl ,), densities = (density ,)),
325+ replace ( int_g , source_kernels = (knl ,), densities = (density ,)),
322326 base_kernel )
323327
324328 return result
325329
326330
327331def _rewrite_int_g_using_base_kernel (
328- int_g : IntG , base_kernel : ExpressionKernel ) -> ArithmeticExpressionT :
332+ int_g : sym . IntG , base_kernel : ExpressionKernel ) -> ArithmeticExpressionT :
329333 r"""Rewrites an ``IntG`` with only one source kernel to an expression with
330334 ``IntG``\ s having the base kernel *base_kernel*.
331335 """
@@ -338,17 +342,17 @@ def _rewrite_int_g_using_base_kernel(
338342 source_kernel , = int_g .source_kernels
339343 deriv_relation = get_deriv_relation_kernel (source_kernel .get_base_kernel (),
340344 base_kernel , hashable_kernel_arguments = (
341- hashable_kernel_args (int_g .kernel_arguments )))
345+ sym . hashable_kernel_args (int_g .kernel_arguments )))
342346
343347 const = deriv_relation .const
344348 # NOTE: we set a dofdesc here to force the evaluation of this integral
345349 # on the source instead of the target when using automatic tagging
346350 # see :meth:`pytential.symbolic.mappers.LocationTagger._default_dofdesc`
347351 if int_g .source .geometry is None :
348- dd = int_g .source .copy (geometry = DEFAULT_SOURCE )
352+ dd = int_g .source .copy (geometry = sym . DEFAULT_SOURCE )
349353 else :
350354 dd = int_g .source
351- const *= pytential . sym .integral (dim , dim - 1 , density , dofdesc = dd )
355+ const *= sym .integral (dim , dim - 1 , density , dofdesc = dd )
352356
353357 if const != 0 and target_kernel != target_kernel .get_base_kernel ():
354358 # There might be some TargetPointMultipliers hanging around.
@@ -377,8 +381,13 @@ def _rewrite_int_g_using_base_kernel(
377381 for _ in range (val ):
378382 knl = AxisSourceDerivative (d , knl )
379383 c *= - 1
380- result += int_g .copy (source_kernels = (knl ,), target_kernel = target_kernel ,
381- densities = (density * c ,), kernel_arguments = new_kernel_args )
384+ result += replace (
385+ int_g ,
386+ source_kernels = (knl ,),
387+ target_kernel = target_kernel ,
388+ densities = (density * c ,),
389+ kernel_arguments = new_kernel_args )
390+
382391 return result
383392
384393
@@ -432,7 +441,7 @@ def get_deriv_relation(
432441 res = []
433442 for knl in kernels :
434443 res .append (get_deriv_relation_kernel (knl , base_kernel ,
435- hashable_kernel_arguments = hashable_kernel_args (kernel_arguments ),
444+ hashable_kernel_arguments = sym . hashable_kernel_args (kernel_arguments ),
436445 tol = tol , order = order ))
437446 return res
438447
@@ -460,14 +469,14 @@ def get_deriv_relation_kernel(
460469 order = order ,
461470 hashable_kernel_arguments = hashable_kernel_arguments )
462471 dim = base_kernel .dim
463- sym_vec = sym .make_sym_vector ("d" , dim )
464- sympy_conv = sym .SympyToPymbolicMapper ()
472+ sym_vec = sp .make_sym_vector ("d" , dim )
473+ sympy_conv = sp .SympyToPymbolicMapper ()
465474
466475 expr = _get_sympy_kernel_expression (kernel .expression , kernel_arguments )
467476 vec = []
468477 for i in range (len (mis )):
469478 vec .append (evalf (expr .xreplace (dict (zip (sym_vec , rand [:, i ], strict = True )))))
470- vec = sym .Matrix (vec )
479+ vec = sp .Matrix (vec )
471480 result = []
472481 const = 0
473482 logger .debug ("%s = " , kernel )
@@ -493,8 +502,8 @@ def get_deriv_relation_kernel(
493502
494503@dataclass
495504class LUFactorization :
496- L : sym .Matrix
497- U : sym .Matrix
505+ L : sp .Matrix
506+ U : sp .Matrix
498507 perm : Sequence [tuple [int , int ]]
499508
500509
@@ -539,8 +548,8 @@ def _get_base_kernel_matrix_lu_factorization(
539548 rand : np .ndarray = rng .integers (1 , 10 ** 15 , size = (dim , len (mis ))).astype (object )
540549 for i in range (rand .shape [0 ]):
541550 for j in range (rand .shape [1 ]):
542- rand [i , j ] = sym .sympify (rand [i , j ])/ 10 ** 15
543- sym_vec = sym .make_sym_vector ("d" , dim )
551+ rand [i , j ] = sp .sympify (rand [i , j ])/ 10 ** 15
552+ sym_vec = sp .make_sym_vector ("d" , dim )
544553
545554 base_expr = _get_sympy_kernel_expression (base_kernel .expression ,
546555 dict (hashable_kernel_arguments ))
@@ -560,7 +569,7 @@ def _get_base_kernel_matrix_lu_factorization(
560569 row .append (1 )
561570 mat .append (row )
562571
563- sym_mat = sym .Matrix (mat )
572+ sym_mat = sp .Matrix (mat )
564573 failed = False
565574 try :
566575 L , U , perm = sym_mat .LUdecomposition ()
@@ -569,7 +578,7 @@ def _get_base_kernel_matrix_lu_factorization(
569578 # and sympy returns U with last row zero
570579 failed = True
571580
572- if not sym .USE_SYMENGINE and all (expr == 0 for expr in U [- 1 , :]):
581+ if not sp .USE_SYMENGINE and all (expr == 0 for expr in U [- 1 , :]):
573582 failed = True
574583
575584 if failed :
0 commit comments