Skip to content

Commit 2dccf70

Browse files
committed
fix porting
1 parent e30a86e commit 2dccf70

File tree

4 files changed

+19
-14
lines changed

4 files changed

+19
-14
lines changed

pytential/qbx/__init__.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -930,19 +930,21 @@ def _flat_centers(dofdesc, qbx_forced_limit):
930930
p2p = self.get_p2p(actx, insn.target_kernels, insn.source_kernels)
931931
lpot_applier_on_tgt_subset = self.get_lpot_applier_on_tgt_subset(
932932
insn.target_kernels, insn.source_kernels)
933+
else:
934+
p2p = lpot_applier_on_tgt_subset = None
933935

934936
for (target_name, qbx_forced_limit), outputs in other_outputs.items():
935937
target_discr = bound_expr.places.get_discretization(
936938
target_name.geometry, target_name.discr_stage)
937939
flat_target_nodes = _flat_nodes(target_name)
938940

939941
# FIXME: (Somewhat wastefully) compute P2P for all targets
940-
_, output_for_each_kernel = p2p(
941-
actx,
942-
targets=flat_target_nodes,
943-
sources=flat_source_nodes,
944-
strength=flat_strengths,
945-
**flat_kernel_args)
942+
assert p2p is not None
943+
output_for_each_kernel = p2p(actx,
944+
targets=flat_target_nodes,
945+
sources=flat_source_nodes,
946+
strength=flat_strengths,
947+
**flat_kernel_args)
946948

947949
target_discrs_and_qbx_sides = ((target_discr, qbx_forced_limit),)
948950
geo_data = self.qbx_fmm_geometry_data(
@@ -980,6 +982,7 @@ def _flat_centers(dofdesc, qbx_forced_limit):
980982
tgt_subset_kwargs[f"result_{i}"] = res_i
981983

982984
if qbx_tgt_count:
985+
assert lpot_applier_on_tgt_subset is not None
983986
lpot_applier_on_tgt_subset(
984987
actx,
985988
targets=flat_target_nodes,

pytential/qbx/fmmlib.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
"""
2525

2626
import logging
27-
from typing import TYPE_CHECKING
27+
from typing import TYPE_CHECKING, Any
2828

2929
import numpy as np
3030

@@ -66,7 +66,7 @@ def __init__(self, actx: PyOpenCLArrayContext, *,
6666
# {{{ digest target_kernels
6767

6868
ifgrad = False
69-
outputs = []
69+
outputs: list[tuple[Any, ...]] = []
7070
source_deriv_names = []
7171
k_names = []
7272

pytential/qbx/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ class TreeWithQBXMetadata(Tree):
258258
box_to_qbx_target_lists: Array
259259

260260
qbx_element_to_source_starts: Array
261-
qbx_element_to_center_starts: Array
261+
qbx_element_to_center_starts: Array | None
262262

263263
qbx_user_source_slice: slice
264264
qbx_user_center_slice: slice
@@ -399,16 +399,16 @@ def _make_centers(discr):
399399
del box_to_class
400400

401401
# Compute element => source relation
402-
qbx_element_to_source_starts = cast("cl_array.Array",
403-
actx.np.zeros(nelements + 1, tree.particle_id_dtype))
402+
qbx_element_to_source_starts = actx.np.zeros(nelements + 1, tree.particle_id_dtype)
403+
404404
el_offset = 0
405405
node_nr_base = 0
406406
for group in density_discr.groups:
407407
group_element_starts = np.arange(
408408
node_nr_base, node_nr_base + group.ndofs, group.nunit_dofs,
409409
dtype=tree.particle_id_dtype)
410-
qbx_element_to_source_starts[el_offset:el_offset + group.nelements] = \
411-
actx.from_numpy(group_element_starts)
410+
qbx_element_to_source_starts[el_offset:el_offset + group.nelements] = (
411+
actx.from_numpy(group_element_starts))
412412

413413
node_nr_base += group.ndofs
414414
el_offset += group.nelements

pytential/unregularized.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from meshmode.dof_array import DOFArray
3838
from pytools import memoize_method
3939

40+
from pytential.array_context import dataclass_array_container
4041
from pytential.source import LayerPotentialSourceBase
4142

4243

@@ -337,7 +338,7 @@ def copy_targets_kernel(self):
337338
knl = lp.tag_array_axes(knl, "targets", "stride:auto, stride:1")
338339
knl = lp.tag_inames(knl, {"dim": "ilp"})
339340

340-
return knl.executor(self.cl_context)
341+
return knl.executor(self.array_context.context)
341342

342343
@property
343344
@memoize_method
@@ -352,6 +353,7 @@ def build_traversal(self):
352353
return FMMTraversalBuilder(self.array_context)
353354

354355

356+
@dataclass_array_container
355357
@dataclass(frozen=True)
356358
class _TargetInfo:
357359
"""

0 commit comments

Comments
 (0)