Skip to content

Commit 6a3ffc7

Browse files
committed
fix new ruff issues
1 parent 0194a67 commit 6a3ffc7

File tree

8 files changed

+78
-97
lines changed

8 files changed

+78
-97
lines changed

pytential/qbx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,7 @@ def get_flat_strengths_from_densities(
971971

972972
density_dofarrays = [evaluate(density) for density in densities]
973973
for i, ary in enumerate(density_dofarrays):
974-
if not isinstance(ary, (DOFArray, Number)):
974+
if not isinstance(ary, DOFArray | Number):
975975
raise ValueError(
976976
f"DOFArray expected for density '{densities[i]}', "
977977
f"{type(ary)} received instead")

pytential/symbolic/elasticity.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from functools import cached_property
3030

3131
import numpy as np
32+
from pymbolic.typing import ArithmeticExpressionT
3233
from sumpy.kernel import (AxisSourceDerivative, AxisTargetDerivative,
3334
BiharmonicKernel, ElasticityKernel, Kernel,
3435
LaplaceKernel, StokesletKernel, StressletKernel,
@@ -37,7 +38,6 @@
3738

3839
from pytential import sym
3940
from pytential.symbolic.pde.system_utils import rewrite_using_base_kernel
40-
from pytential.symbolic.typing import ExpressionT
4141

4242
__doc__ = """
4343
.. autoclass:: Method
@@ -95,9 +95,9 @@ class ElasticityWrapperBase(ABC):
9595

9696
dim: int
9797
"""Ambient dimension of the representation."""
98-
mu: ExpressionT
98+
mu: ArithmeticExpressionT
9999
r"""Expression or value for the shear modulus :math:`\mu`."""
100-
nu: ExpressionT
100+
nu: ArithmeticExpressionT
101101
r"""Expression or value for Poisson's ratio :math:`\nu`."""
102102

103103
@abstractmethod
@@ -159,9 +159,9 @@ class ElasticityDoubleLayerWrapperBase(ABC):
159159

160160
dim: int
161161
"""Ambient dimension of the representation."""
162-
mu: ExpressionT
162+
mu: ArithmeticExpressionT
163163
r"""Expression or value for the shear modulus :math:`\mu`."""
164-
nu: ExpressionT
164+
nu: ArithmeticExpressionT
165165
r"""Expression or value for Poisson's ration :math:`\nu`."""
166166

167167
@abstractmethod
@@ -227,8 +227,8 @@ def _create_int_g(knl, deriv_dirs, density, **kwargs):
227227
@dataclass
228228
class _ElasticityWrapperNaiveOrBiharmonic:
229229
dim: int
230-
mu: ExpressionT
231-
nu: ExpressionT
230+
mu: ArithmeticExpressionT
231+
nu: ArithmeticExpressionT
232232
base_kernel: Kernel
233233

234234
def __post_init__(self):
@@ -315,8 +315,8 @@ def __init__(self, dim, mu, nu):
315315
@dataclass
316316
class _ElasticityDoubleLayerWrapperNaiveOrBiharmonic:
317317
dim: int
318-
mu: ExpressionT
319-
nu: ExpressionT
318+
mu: ArithmeticExpressionT
319+
nu: ArithmeticExpressionT
320320
base_kernel: Kernel
321321

322322
def __post_init__(self):
@@ -373,16 +373,20 @@ def _get_int_g(self, idx, density_sym, dir_vec_sym, qbx_forced_limit,
373373
coeffs[-1] = 0
374374

375375
result = 0
376-
for kernel_idx, dir_vec_idx, coeff, extra_deriv_dirs in \
377-
zip(kernel_indices, dir_vec_indices, coeffs,
378-
extra_deriv_dirs_vec):
376+
for kernel_idx, dir_vec_idx, coeff, extra_deriv_dirs in zip(
377+
kernel_indices,
378+
dir_vec_indices,
379+
coeffs,
380+
extra_deriv_dirs_vec, strict=True):
379381
if coeff == 0:
380382
continue
383+
381384
knl = self.kernel_dict[kernel_idx]
382-
result += _create_int_g(knl, tuple(deriv_dirs) + tuple(extra_deriv_dirs),
385+
result += coeff * _create_int_g(
386+
knl, tuple(deriv_dirs) + tuple(extra_deriv_dirs),
383387
density=density_sym*dir_vec_sym[dir_vec_idx],
384-
qbx_forced_limit=qbx_forced_limit, mu=self.mu, nu=self.nu) * \
385-
coeff
388+
qbx_forced_limit=qbx_forced_limit, mu=self.mu, nu=self.nu)
389+
386390
return result/(2*(1 - nu))
387391

388392
def apply(self, density_vec_sym, dir_vec_sym, qbx_forced_limit,
@@ -470,8 +474,8 @@ class Method(Enum):
470474

471475
def make_elasticity_wrapper(
472476
dim: int,
473-
mu: ExpressionT = _MU_SYM_DEFAULT,
474-
nu: ExpressionT = _NU_SYM_DEFAULT,
477+
mu: ArithmeticExpressionT = _MU_SYM_DEFAULT,
478+
nu: ArithmeticExpressionT = _NU_SYM_DEFAULT,
475479
method: Method = Method.Naive) -> ElasticityWrapperBase:
476480
"""Creates an appropriate :class:`ElasticityWrapperBase` object.
477481
@@ -498,8 +502,8 @@ def make_elasticity_wrapper(
498502

499503
def make_elasticity_double_layer_wrapper(
500504
dim: int,
501-
mu: ExpressionT = _MU_SYM_DEFAULT,
502-
nu: ExpressionT = _NU_SYM_DEFAULT,
505+
mu: ArithmeticExpressionT = _MU_SYM_DEFAULT,
506+
nu: ArithmeticExpressionT = _NU_SYM_DEFAULT,
503507
method: Method = Method.Naive) -> ElasticityDoubleLayerWrapperBase:
504508
"""Creates an appropriate :class:`ElasticityDoubleLayerWrapperBase` object.
505509

pytential/symbolic/pde/system_utils.py

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222

2323
import logging
2424
import warnings
25+
from collections.abc import Mapping, Sequence
2526
from dataclasses import dataclass
26-
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union
27+
from typing import Any
2728

2829
import numpy as np
2930
import pymbolic
@@ -34,13 +35,13 @@
3435
from sumpy.kernel import (AxisSourceDerivative, AxisTargetDerivative,
3536
DirectionalSourceDerivative, ExpressionKernel,
3637
Kernel, KernelWrapper, TargetPointMultiplier)
38+
from pymbolic.typing import ExpressionT, ArithmeticExpressionT
3739

3840
import pytential
3941
from pytential.symbolic.mappers import IdentityMapper
4042
from pytential.symbolic.primitives import (DEFAULT_SOURCE, IntG,
4143
NodeCoordinateComponent,
4244
hashable_kernel_args)
43-
from pytential.symbolic.typing import ExpressionT
4445
from pytential.utils import chop, solve_from_lu
4546

4647
logger = logging.getLogger(__name__)
@@ -72,7 +73,7 @@ class RewriteFailedError(RuntimeError):
7273

7374
def rewrite_using_base_kernel(
7475
exprs: Sequence[ExpressionT],
75-
base_kernel: Kernel = _NO_ARG_SENTINEL) -> List[ExpressionT]:
76+
base_kernel: Kernel = _NO_ARG_SENTINEL) -> list[ExpressionT]:
7677
"""
7778
Rewrites a list of expressions with :class:`~pytential.symbolic.primitives.IntG`
7879
objects using *base_kernel*.
@@ -130,7 +131,7 @@ def map_int_g(self, expr):
130131
self.base_kernel) for new_int_g in new_int_gs)
131132

132133

133-
def _get_sympy_kernel_expression(expr: ExpressionT,
134+
def _get_sympy_kernel_expression(expr: ArithmeticExpressionT,
134135
kernel_arguments: Mapping[str, Any]) -> sym.Basic:
135136
"""Convert a :mod:`pymbolic` expression to :mod:`sympy` expression
136137
after substituting kernel arguments.
@@ -148,22 +149,22 @@ def _get_sympy_kernel_expression(expr: ExpressionT,
148149

149150

150151
def _monom_to_expr(monom: Sequence[int],
151-
variables: Sequence[Union[sym.Basic, ExpressionT]]) \
152-
-> Union[sym.Basic, ExpressionT]:
152+
variables: Sequence[sym.Basic | ArithmeticExpressionT]
153+
) -> sym.Basic | ArithmeticExpressionT:
153154
"""Convert a monomial to an expression using given variables.
154155
155156
For example, ``[3, 2, 1]`` with variables ``[x, y, z]`` is converted to
156157
``x^3 y^2 z``.
157158
"""
158-
prod: ExpressionT = 1
159+
prod: ArithmeticExpressionT = 1
159160
for i, nrepeats in enumerate(monom):
160161
for _ in range(nrepeats):
161162
prod *= variables[i]
162163

163164
return prod
164165

165166

166-
def convert_target_transformation_to_source(int_g: IntG) -> List[IntG]:
167+
def convert_target_transformation_to_source(int_g: IntG) -> list[IntG]:
167168
r"""Convert an ``IntG`` with :class:`~sumpy.kernel.AxisTargetDerivative`
168169
or :class:`~sumpy.kernel.TargetPointMultiplier` to a list
169170
of ``IntG``\ s without them and only source dependent transformations.
@@ -189,7 +190,7 @@ def convert_target_transformation_to_source(int_g: IntG) -> List[IntG]:
189190
ds = sympy.symbols(f"d0:{knl.dim}")
190191
sources = sympy.symbols(f"y0:{knl.dim}")
191192
# instead of just x, we use x = (d + y)
192-
targets = [d + source for d, source in zip(ds, sources)]
193+
targets = [d + source for d, source in zip(ds, sources, strict=True)]
193194
orig_expr = sympy.Function("f")(*ds) # pylint: disable=not-callable
194195
expr = orig_expr
195196
found = False
@@ -272,7 +273,7 @@ def convert_target_transformation_to_source(int_g: IntG) -> List[IntG]:
272273

273274

274275
def _multiply_int_g(int_g: IntG, expr_multiplier: sym.Basic,
275-
density_multiplier: ExpressionT) -> List[IntG]:
276+
density_multiplier: ArithmeticExpressionT) -> list[IntG]:
276277
"""Multiply the expression in ``IntG`` with the *expr_multiplier*
277278
which is a symbolic (:mod:`sympy` or :mod:`symengine`) expression and
278279
multiply the densities with *density_multiplier* which is a :mod:`pymbolic`
@@ -293,7 +294,7 @@ def _multiply_int_g(int_g: IntG, expr_multiplier: sym.Basic,
293294
return [int_g.copy(densities=tuple(density*density_multiplier
294295
for density in int_g.densities))]
295296

296-
for knl, density in zip(int_g.source_kernels, int_g.densities):
297+
for knl, density in zip(int_g.source_kernels, int_g.densities, strict=True):
297298
if expr_multiplier == 1:
298299
new_knl = knl.get_base_kernel()
299300
else:
@@ -310,12 +311,12 @@ def _multiply_int_g(int_g: IntG, expr_multiplier: sym.Basic,
310311

311312

312313
def rewrite_int_g_using_base_kernel(
313-
int_g: IntG, base_kernel: ExpressionKernel) -> ExpressionT:
314+
int_g: IntG, base_kernel: ExpressionKernel) -> ArithmeticExpressionT:
314315
r"""Rewrite an ``IntG`` to an expression with ``IntG``\ s having the
315316
base kernel *base_kernel*.
316317
"""
317-
result: ExpressionT = 0
318-
for knl, density in zip(int_g.source_kernels, int_g.densities):
318+
result: ArithmeticExpressionT = 0
319+
for knl, density in zip(int_g.source_kernels, int_g.densities, strict=True):
319320
result += _rewrite_int_g_using_base_kernel(
320321
int_g.copy(source_kernels=(knl,), densities=(density,)),
321322
base_kernel)
@@ -324,14 +325,14 @@ def rewrite_int_g_using_base_kernel(
324325

325326

326327
def _rewrite_int_g_using_base_kernel(
327-
int_g: IntG, base_kernel: ExpressionKernel) -> ExpressionT:
328+
int_g: IntG, base_kernel: ExpressionKernel) -> ArithmeticExpressionT:
328329
r"""Rewrites an ``IntG`` with only one source kernel to an expression with
329330
``IntG``\ s having the base kernel *base_kernel*.
330331
"""
331332
target_kernel = int_g.target_kernel.replace_base_kernel(base_kernel)
332333
dim = target_kernel.dim
333334

334-
result = 0
335+
result: ArithmeticExpressionT = 0
335336

336337
density, = int_g.densities
337338
source_kernel, = int_g.source_kernels
@@ -360,7 +361,7 @@ def _rewrite_int_g_using_base_kernel(
360361
knl = source_kernel
361362
while isinstance(knl, KernelWrapper):
362363
if not isinstance(knl,
363-
(AxisSourceDerivative, DirectionalSourceDerivative)):
364+
AxisSourceDerivative | DirectionalSourceDerivative):
364365
return int_g
365366
knl = knl.inner_kernel
366367
const = 0
@@ -393,18 +394,19 @@ class DerivRelation:
393394
.. autoattribute:: linear_combination
394395
"""
395396

396-
const: ExpressionT
397+
const: ArithmeticExpressionT
397398
"""A constant to add to the combination."""
398-
linear_combination: Sequence[Tuple[Tuple[int, ...], ExpressionT]]
399+
linear_combination: Sequence[tuple[tuple[int, ...], ArithmeticExpressionT]]
399400
"""A list of pairs ``(mi, coeffs)``."""
400401

401402

402-
def get_deriv_relation(kernels: Sequence[ExpressionKernel],
403+
def get_deriv_relation(
404+
kernels: Sequence[ExpressionKernel],
403405
base_kernel: ExpressionKernel,
404406
kernel_arguments: Mapping[str, Any],
405407
tol: float = 1e-10,
406-
order: Optional[int] = None) \
407-
-> List[DerivRelation]:
408+
order: int | None = None,
409+
) -> list[DerivRelation]:
408410
r"""
409411
Given a sequence of *kernels*, a *base_kernel* and an *order*, this
410412
gives a relation between the *base_kernel* and each of the *kernels*.
@@ -436,12 +438,13 @@ def get_deriv_relation(kernels: Sequence[ExpressionKernel],
436438

437439

438440
@memoize_on_first_arg
439-
def get_deriv_relation_kernel(kernel: ExpressionKernel,
441+
def get_deriv_relation_kernel(
442+
kernel: ExpressionKernel,
440443
base_kernel: ExpressionKernel,
441-
hashable_kernel_arguments: Tuple[Tuple[str, Any], ...],
444+
hashable_kernel_arguments: tuple[tuple[str, Any], ...],
442445
tol: float = 1e-10,
443-
order: Optional[int] = None) \
444-
-> DerivRelation:
446+
order: int | None = None,
447+
) -> DerivRelation:
445448
"""Takes a *kernel* and a base_kernel* as input and re-writes the
446449
*kernel* as a linear combination of derivatives of *base_kernel* up-to
447450
order *order* and a constant.
@@ -463,7 +466,7 @@ def get_deriv_relation_kernel(kernel: ExpressionKernel,
463466
expr = _get_sympy_kernel_expression(kernel.expression, kernel_arguments)
464467
vec = []
465468
for i in range(len(mis)):
466-
vec.append(evalf(expr.xreplace(dict(zip(sym_vec, rand[:, i])))))
469+
vec.append(evalf(expr.xreplace(dict(zip(sym_vec, rand[:, i], strict=True)))))
467470
vec = sym.Matrix(vec)
468471
result = []
469472
const = 0
@@ -492,15 +495,16 @@ def get_deriv_relation_kernel(kernel: ExpressionKernel,
492495
class LUFactorization:
493496
L: sym.Matrix
494497
U: sym.Matrix
495-
perm: Sequence[Tuple[int, int]]
498+
perm: Sequence[tuple[int, int]]
496499

497500

498501
@memoize_on_first_arg
499502
def _get_base_kernel_matrix_lu_factorization(
500503
base_kernel: ExpressionKernel,
501-
hashable_kernel_arguments: Tuple[Tuple[str, Any], ...],
502-
order: Optional[int] = None, retries: int = 3) \
503-
-> Tuple[LUFactorization, np.ndarray, List[Tuple[int, ...]]]:
504+
hashable_kernel_arguments: tuple[tuple[str, Any], ...],
505+
order: int | None = None,
506+
retries: int = 3,
507+
) -> tuple[LUFactorization, np.ndarray, list[tuple[int, ...]]]:
504508
"""
505509
Takes a *base_kernel* and samples the kernel and its derivatives upto
506510
order *order*.
@@ -550,7 +554,7 @@ def _get_base_kernel_matrix_lu_factorization(
550554
if nderivs == 0:
551555
continue
552556
expr = expr.diff(sym_vec[var_idx], nderivs)
553-
replace_dict = dict(zip(sym_vec, rand[:, rand_vec_idx]))
557+
replace_dict = dict(zip(sym_vec, rand[:, rand_vec_idx], strict=True))
554558
eval_expr = evalf(expr.xreplace(replace_dict))
555559
row.append(eval_expr)
556560
row.append(1)

pytential/symbolic/primitives.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,9 @@ class NodeCoordinateComponent(DiscretizationProperty):
520520
"""The axis index this node coordinate represents, i.e. 0 for $x$, etc."""
521521

522522
# FIXME: this is added for backwards compatibility with pre-dataclass expressions
523-
def __init__(self, ambient_axis: int, dofdesc: DOFDescriptorLike) -> None:
523+
def __init__(self,
524+
ambient_axis: int,
525+
dofdesc: DOFDescriptorLike | None = None) -> None:
524526
object.__setattr__(self, "ambient_axis", ambient_axis)
525527
super().__init__(dofdesc) # type: ignore[arg-type]
526528

0 commit comments

Comments
 (0)