2222
2323import logging
2424import warnings
25+ from collections .abc import Mapping , Sequence
2526from dataclasses import dataclass
26- from typing import Any , List , Mapping , Optional , Sequence , Tuple , Union
27+ from typing import Any
2728
2829import numpy as np
2930import pymbolic
3435from sumpy .kernel import (AxisSourceDerivative , AxisTargetDerivative ,
3536 DirectionalSourceDerivative , ExpressionKernel ,
3637 Kernel , KernelWrapper , TargetPointMultiplier )
38+ from pymbolic .typing import ExpressionT , ArithmeticExpressionT
3739
3840import pytential
3941from pytential .symbolic .mappers import IdentityMapper
4042from pytential .symbolic .primitives import (DEFAULT_SOURCE , IntG ,
4143 NodeCoordinateComponent ,
4244 hashable_kernel_args )
43- from pytential .symbolic .typing import ExpressionT
4445from pytential .utils import chop , solve_from_lu
4546
4647logger = logging .getLogger (__name__ )
@@ -72,7 +73,7 @@ class RewriteFailedError(RuntimeError):
7273
7374def rewrite_using_base_kernel (
7475 exprs : Sequence [ExpressionT ],
75- base_kernel : Kernel = _NO_ARG_SENTINEL ) -> List [ExpressionT ]:
76+ base_kernel : Kernel = _NO_ARG_SENTINEL ) -> list [ExpressionT ]:
7677 """
7778 Rewrites a list of expressions with :class:`~pytential.symbolic.primitives.IntG`
7879 objects using *base_kernel*.
@@ -130,7 +131,7 @@ def map_int_g(self, expr):
130131 self .base_kernel ) for new_int_g in new_int_gs )
131132
132133
133- def _get_sympy_kernel_expression (expr : ExpressionT ,
134+ def _get_sympy_kernel_expression (expr : ArithmeticExpressionT ,
134135 kernel_arguments : Mapping [str , Any ]) -> sym .Basic :
135136 """Convert a :mod:`pymbolic` expression to :mod:`sympy` expression
136137 after substituting kernel arguments.
@@ -148,22 +149,22 @@ def _get_sympy_kernel_expression(expr: ExpressionT,
148149
149150
150151def _monom_to_expr (monom : Sequence [int ],
151- variables : Sequence [Union [ sym .Basic , ExpressionT ]]) \
152- -> Union [ sym .Basic , ExpressionT ] :
152+ variables : Sequence [sym .Basic | ArithmeticExpressionT ]
153+ ) -> sym .Basic | ArithmeticExpressionT :
153154 """Convert a monomial to an expression using given variables.
154155
155156 For example, ``[3, 2, 1]`` with variables ``[x, y, z]`` is converted to
156157 ``x^3 y^2 z``.
157158 """
158- prod : ExpressionT = 1
159+ prod : ArithmeticExpressionT = 1
159160 for i , nrepeats in enumerate (monom ):
160161 for _ in range (nrepeats ):
161162 prod *= variables [i ]
162163
163164 return prod
164165
165166
166- def convert_target_transformation_to_source (int_g : IntG ) -> List [IntG ]:
167+ def convert_target_transformation_to_source (int_g : IntG ) -> list [IntG ]:
167168 r"""Convert an ``IntG`` with :class:`~sumpy.kernel.AxisTargetDerivative`
168169 or :class:`~sumpy.kernel.TargetPointMultiplier` to a list
169170 of ``IntG``\ s without them and only source dependent transformations.
@@ -189,7 +190,7 @@ def convert_target_transformation_to_source(int_g: IntG) -> List[IntG]:
189190 ds = sympy .symbols (f"d0:{ knl .dim } " )
190191 sources = sympy .symbols (f"y0:{ knl .dim } " )
191192 # instead of just x, we use x = (d + y)
192- targets = [d + source for d , source in zip (ds , sources )]
193+ targets = [d + source for d , source in zip (ds , sources , strict = True )]
193194 orig_expr = sympy .Function ("f" )(* ds ) # pylint: disable=not-callable
194195 expr = orig_expr
195196 found = False
@@ -272,7 +273,7 @@ def convert_target_transformation_to_source(int_g: IntG) -> List[IntG]:
272273
273274
274275def _multiply_int_g (int_g : IntG , expr_multiplier : sym .Basic ,
275- density_multiplier : ExpressionT ) -> List [IntG ]:
276+ density_multiplier : ArithmeticExpressionT ) -> list [IntG ]:
276277 """Multiply the expression in ``IntG`` with the *expr_multiplier*
277278 which is a symbolic (:mod:`sympy` or :mod:`symengine`) expression and
278279 multiply the densities with *density_multiplier* which is a :mod:`pymbolic`
@@ -293,7 +294,7 @@ def _multiply_int_g(int_g: IntG, expr_multiplier: sym.Basic,
293294 return [int_g .copy (densities = tuple (density * density_multiplier
294295 for density in int_g .densities ))]
295296
296- for knl , density in zip (int_g .source_kernels , int_g .densities ):
297+ for knl , density in zip (int_g .source_kernels , int_g .densities , strict = True ):
297298 if expr_multiplier == 1 :
298299 new_knl = knl .get_base_kernel ()
299300 else :
@@ -310,12 +311,12 @@ def _multiply_int_g(int_g: IntG, expr_multiplier: sym.Basic,
310311
311312
312313def rewrite_int_g_using_base_kernel (
313- int_g : IntG , base_kernel : ExpressionKernel ) -> ExpressionT :
314+ int_g : IntG , base_kernel : ExpressionKernel ) -> ArithmeticExpressionT :
314315 r"""Rewrite an ``IntG`` to an expression with ``IntG``\ s having the
315316 base kernel *base_kernel*.
316317 """
317- result : ExpressionT = 0
318- for knl , density in zip (int_g .source_kernels , int_g .densities ):
318+ result : ArithmeticExpressionT = 0
319+ for knl , density in zip (int_g .source_kernels , int_g .densities , strict = True ):
319320 result += _rewrite_int_g_using_base_kernel (
320321 int_g .copy (source_kernels = (knl ,), densities = (density ,)),
321322 base_kernel )
@@ -324,14 +325,14 @@ def rewrite_int_g_using_base_kernel(
324325
325326
326327def _rewrite_int_g_using_base_kernel (
327- int_g : IntG , base_kernel : ExpressionKernel ) -> ExpressionT :
328+ int_g : IntG , base_kernel : ExpressionKernel ) -> ArithmeticExpressionT :
328329 r"""Rewrites an ``IntG`` with only one source kernel to an expression with
329330 ``IntG``\ s having the base kernel *base_kernel*.
330331 """
331332 target_kernel = int_g .target_kernel .replace_base_kernel (base_kernel )
332333 dim = target_kernel .dim
333334
334- result = 0
335+ result : ArithmeticExpressionT = 0
335336
336337 density , = int_g .densities
337338 source_kernel , = int_g .source_kernels
@@ -360,7 +361,7 @@ def _rewrite_int_g_using_base_kernel(
360361 knl = source_kernel
361362 while isinstance (knl , KernelWrapper ):
362363 if not isinstance (knl ,
363- ( AxisSourceDerivative , DirectionalSourceDerivative ) ):
364+ AxisSourceDerivative | DirectionalSourceDerivative ):
364365 return int_g
365366 knl = knl .inner_kernel
366367 const = 0
@@ -393,18 +394,19 @@ class DerivRelation:
393394 .. autoattribute:: linear_combination
394395 """
395396
396- const : ExpressionT
397+ const : ArithmeticExpressionT
397398 """A constant to add to the combination."""
398- linear_combination : Sequence [Tuple [ Tuple [int , ...], ExpressionT ]]
399+ linear_combination : Sequence [tuple [ tuple [int , ...], ArithmeticExpressionT ]]
399400 """A list of pairs ``(mi, coeffs)``."""
400401
401402
402- def get_deriv_relation (kernels : Sequence [ExpressionKernel ],
403+ def get_deriv_relation (
404+ kernels : Sequence [ExpressionKernel ],
403405 base_kernel : ExpressionKernel ,
404406 kernel_arguments : Mapping [str , Any ],
405407 tol : float = 1e-10 ,
406- order : Optional [ int ] = None ) \
407- -> List [DerivRelation ]:
408+ order : int | None = None ,
409+ ) -> list [DerivRelation ]:
408410 r"""
409411 Given a sequence of *kernels*, a *base_kernel* and an *order*, this
410412 gives a relation between the *base_kernel* and each of the *kernels*.
@@ -436,12 +438,13 @@ def get_deriv_relation(kernels: Sequence[ExpressionKernel],
436438
437439
438440@memoize_on_first_arg
439- def get_deriv_relation_kernel (kernel : ExpressionKernel ,
441+ def get_deriv_relation_kernel (
442+ kernel : ExpressionKernel ,
440443 base_kernel : ExpressionKernel ,
441- hashable_kernel_arguments : Tuple [ Tuple [str , Any ], ...],
444+ hashable_kernel_arguments : tuple [ tuple [str , Any ], ...],
442445 tol : float = 1e-10 ,
443- order : Optional [ int ] = None ) \
444- -> DerivRelation :
446+ order : int | None = None ,
447+ ) -> DerivRelation :
445448 """Takes a *kernel* and a base_kernel* as input and re-writes the
446449 *kernel* as a linear combination of derivatives of *base_kernel* up-to
447450 order *order* and a constant.
@@ -463,7 +466,7 @@ def get_deriv_relation_kernel(kernel: ExpressionKernel,
463466 expr = _get_sympy_kernel_expression (kernel .expression , kernel_arguments )
464467 vec = []
465468 for i in range (len (mis )):
466- vec .append (evalf (expr .xreplace (dict (zip (sym_vec , rand [:, i ])))))
469+ vec .append (evalf (expr .xreplace (dict (zip (sym_vec , rand [:, i ], strict = True )))))
467470 vec = sym .Matrix (vec )
468471 result = []
469472 const = 0
@@ -492,15 +495,16 @@ def get_deriv_relation_kernel(kernel: ExpressionKernel,
492495class LUFactorization :
493496 L : sym .Matrix
494497 U : sym .Matrix
495- perm : Sequence [Tuple [int , int ]]
498+ perm : Sequence [tuple [int , int ]]
496499
497500
498501@memoize_on_first_arg
499502def _get_base_kernel_matrix_lu_factorization (
500503 base_kernel : ExpressionKernel ,
501- hashable_kernel_arguments : Tuple [Tuple [str , Any ], ...],
502- order : Optional [int ] = None , retries : int = 3 ) \
503- -> Tuple [LUFactorization , np .ndarray , List [Tuple [int , ...]]]:
504+ hashable_kernel_arguments : tuple [tuple [str , Any ], ...],
505+ order : int | None = None ,
506+ retries : int = 3 ,
507+ ) -> tuple [LUFactorization , np .ndarray , list [tuple [int , ...]]]:
504508 """
505509 Takes a *base_kernel* and samples the kernel and its derivatives upto
506510 order *order*.
@@ -550,7 +554,7 @@ def _get_base_kernel_matrix_lu_factorization(
550554 if nderivs == 0 :
551555 continue
552556 expr = expr .diff (sym_vec [var_idx ], nderivs )
553- replace_dict = dict (zip (sym_vec , rand [:, rand_vec_idx ]))
557+ replace_dict = dict (zip (sym_vec , rand [:, rand_vec_idx ], strict = True ))
554558 eval_expr = evalf (expr .xreplace (replace_dict ))
555559 row .append (eval_expr )
556560 row .append (1 )
0 commit comments