Skip to content

Commit 2fc1e9a

Browse files
committed
Splitting some mesher functions, updating some docstrings
1 parent 1586523 commit 2fc1e9a

File tree

1 file changed

+131
-61
lines changed

1 file changed

+131
-61
lines changed

tidy3d/components/grid/mesher.py

Lines changed: 131 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from scipy.optimize import root_scalar
99

1010
from ..base import Tidy3dBaseModel
11-
from ..types import Axis
11+
from ..types import Axis, Array
1212
from ..structure import Structure
1313
from ...log import SetupError, ValidationError
1414
from ...constants import C_0, fp_eps
@@ -24,32 +24,32 @@ def parse_structures(
2424
structures: List[Structure],
2525
wavelength: pd.PositiveFloat,
2626
min_steps_per_wvl: pd.NonNegativeInt,
27-
) -> Tuple[np.ndarray, np.ndarray]:
27+
) -> Tuple[Array[float], Array[float]]:
2828
"""Calculate the positions of all bounding box interfaces along a given axis."""
2929

3030
@abstractmethod
3131
def make_grid_multiple_intervals(
3232
self,
33-
max_dl_list: np.ndarray,
34-
len_interval_list: np.ndarray,
33+
max_dl_list: Array[float],
34+
len_interval_list: Array[float],
3535
max_scale: float,
3636
is_periodic: bool,
37-
) -> List[np.ndarray]:
37+
) -> List[Array[float]]:
3838
"""Create grid steps in multiple connecting intervals."""
3939

4040

4141
class GradedMesher(Mesher):
4242
"""Implements automatic nonuniform meshing with a set minimum steps per wavelength and
4343
a graded mesh expanding from higher- to lower-resolution regions."""
4444

45-
# pylint:disable=too-many-statements,too-many-locals,too-many-branches
45+
# pylint:disable=too-many-statements,too-many-locals
4646
def parse_structures(
4747
self,
4848
axis: Axis,
4949
structures: List[Structure],
5050
wavelength: pd.PositiveFloat,
5151
min_steps_per_wvl: pd.NonNegativeInt,
52-
) -> Tuple[np.ndarray, np.ndarray]:
52+
) -> Tuple[Array[float], Array[float]]:
5353
"""Calculate the positions of all bounding box interfaces along a given axis.
5454
In this implementation, in most cases the complexity should be O(len(structures)**2),
5555
although the worst-case complexity may approach O(len(structures)**3).
@@ -59,7 +59,7 @@ def parse_structures(
5959
----------
6060
axis : Axis
6161
Axis index along which to operate.
62-
structures : List[Structures]
62+
structures : List[Structure]
6363
List of structures, with the simulation structure being the first item.
6464
wavelength : pd.PositiveFloat
6565
Wavelength to use for the step size and for dispersive media epsilon.
@@ -84,27 +84,15 @@ def parse_structures(
8484
sim_bmin, sim_bmax = structures[0].geometry.bounds
8585
domain_bounds = np.array([sim_bmin[axis], sim_bmax[axis]])
8686

87-
# Required minimum steps in every material
88-
medium_steps = []
89-
for structure in structures:
90-
n, k = structure.medium.eps_complex_to_nk(structure.medium.eps_model(C_0 / wavelength))
91-
index = max(abs(n), abs(k))
92-
medium_steps.append(wavelength / index / min_steps_per_wvl)
93-
medium_steps = np.array(medium_steps)
87+
# Required maximum steps in every structure
88+
structure_steps = self.structure_steps(structures, wavelength, min_steps_per_wvl)
9489

9590
# If empty simulation, return
9691
if len(structures) == 1:
97-
return (domain_bounds, medium_steps)
92+
return (domain_bounds, structure_steps)
9893

9994
# Bounding boxes with the meshing axis rotated to z
100-
struct_bbox = []
101-
for structure in structures:
102-
# Get 3D bounding box and rotate axes
103-
bmin, bmax = structure.geometry.bounds
104-
bmin_ax, bmin_plane = structure.geometry.pop_axis(bmin, axis=axis)
105-
bmax_ax, bmax_plane = structure.geometry.pop_axis(bmax, axis=axis)
106-
bounds = np.array([list(bmin_plane) + [bmin_ax], list(bmax_plane) + [bmax_ax]])
107-
struct_bbox.append(bounds)
95+
struct_bbox = self.rotate_structure_bounds(structures, axis)
10896

10997
# Array of coordinates of all intervals; add the simulation domain bounds already
11098
interval_coords = np.array(domain_bounds)
@@ -114,20 +102,10 @@ def parse_structures(
114102
struct_contains = [] # will have len equal to len(structures)
115103

116104
for struct_ind in range(len(structures) - 1, 0, -1):
117-
structure = structures[struct_ind]
118105
bbox = struct_bbox[struct_ind]
119106

120107
# indexes of structures that the current structure contains in 2D
121-
struct_contains_inds = []
122-
for ind, bounds in enumerate(struct_bbox[: struct_ind - 1]):
123-
if (
124-
bbox[0, 0] <= bounds[0, 0]
125-
and bbox[0, 1] <= bounds[0, 1]
126-
and bbox[1, 0] >= bounds[1, 0]
127-
and bbox[1, 1] >= bounds[1, 1]
128-
):
129-
struct_contains_inds.append(ind)
130-
struct_contains.append(struct_contains_inds)
108+
struct_contains.append(self.contains_2d(bbox, struct_bbox[:struct_ind]))
131109

132110
# Figure out where to place the bounding box coordinates of current structure
133111
indsmin = np.argwhere(bbox[0, 2] < interval_coords)
@@ -169,7 +147,7 @@ def parse_structures(
169147
interval_structs = [interval_structs[int(i)] for i in in_domain if i < b_array.size - 1]
170148

171149
# Remove intervals that are smaller than the absolute smallest min_step
172-
min_step = np.amin(medium_steps)
150+
min_step = np.amin(structure_steps)
173151
coords_filter = [interval_coords[0]]
174152
structs_filter = []
175153
for coord_ind, coord in enumerate(interval_coords[1:]):
@@ -197,18 +175,80 @@ def parse_structures(
197175
struct_list = [ind for ind in struct_list if ind not in contains]
198176

199177
# Define the max step as the minimum over all medium steps of media in this interval
200-
max_step = np.amin(medium_steps[struct_list_filter])
178+
max_step = np.amin(structure_steps[struct_list_filter])
201179
max_steps.append(float(max_step))
202180

203181
return interval_coords, np.array(max_steps)
204182

205183
@staticmethod
206-
def is_contained(bbox0, bbox_list):
207-
"""Return True if bbox0 is contained in any of the bbox_list, or False otherwise.
208-
It can be much faster to write out the conditions one by one than to use e.g. np.all.
184+
def structure_steps(
185+
structures: List[Structure], wavelength: float, min_steps_per_wvl: float
186+
) -> Array[float]:
187+
"""Get the minimum mesh required in each structure.
188+
189+
Parameters
190+
----------
191+
structures : List[Structure]
192+
List of structures, with the simulation structure being the first item.
193+
wavelength : float
194+
Wavelength to use for the step size and for dispersive media epsilon.
195+
min_steps_per_wvl : float
196+
Minimum requested steps per wavelength.
197+
"""
198+
min_steps = []
199+
for structure in structures:
200+
n, k = structure.medium.eps_complex_to_nk(structure.medium.eps_model(C_0 / wavelength))
201+
index = max(abs(n), abs(k))
202+
min_steps.append(wavelength / index / min_steps_per_wvl)
203+
return np.array(min_steps)
204+
205+
@staticmethod
206+
def rotate_structure_bounds(structures: List[Structure], axis: Axis) -> List[Array[float]]:
207+
"""Get sturcture bounding boxes with a given ``axis`` rotated to z.
208+
209+
Parameters
210+
----------
211+
structures : List[Structure]
212+
List of structures, with the simulation structure being the first item.
213+
axis : Axis
214+
Axis index to place last.
215+
216+
Returns
217+
-------
218+
List[Array[float]]
219+
A list of the bounding boxes of shape ``(2, 3)`` for each structure, with the bounds
220+
along ``axis`` being ``(:, 2)``.
221+
"""
222+
struct_bbox = []
223+
for structure in structures:
224+
# Get 3D bounding box and rotate axes
225+
bmin, bmax = structure.geometry.bounds
226+
bmin_ax, bmin_plane = structure.geometry.pop_axis(bmin, axis=axis)
227+
bmax_ax, bmax_plane = structure.geometry.pop_axis(bmax, axis=axis)
228+
bounds = np.array([list(bmin_plane) + [bmin_ax], list(bmax_plane) + [bmax_ax]])
229+
struct_bbox.append(bounds)
230+
return struct_bbox
231+
232+
@staticmethod
233+
def is_contained(bbox0: Array[float], bbox_list: List[Array[float]]) -> bool:
234+
"""Check if a bounding box is contained in any of a list of bounding boxes.
235+
236+
Parameters
237+
----------
238+
bbox0 : Array[float]
239+
Bounding box to check.
240+
bbox_list : List[Array[float]]
241+
List of bounding boxes to check if they contain ``bbox0``.
242+
243+
Returns
244+
-------
245+
contained : bool
246+
``True`` if ``bbox0`` is contained in any of the boxes in the list.
209247
"""
210248
contained = False
211249
for bounds in bbox_list:
250+
# It can be much faster to write out the conditions one by one than e.g. to use np.all
251+
# on the bottom values and np.all on the top values
212252
if all(
213253
[
214254
bbox0[0, 0] >= bounds[0, 0],
@@ -220,25 +260,55 @@ def is_contained(bbox0, bbox_list):
220260
]
221261
):
222262
contained = True
223-
224263
return contained
225264

265+
@staticmethod
266+
def contains_2d(bbox0: Array[float], bbox_list: List[Array[float]]) -> List[int]:
267+
"""Check if a bounding box contains along the first two dimensions any of a list of
268+
bounding boxes.
269+
270+
Parameters
271+
----------
272+
bbox0 : Array[float]
273+
Bounding box to check.
274+
bbox_list : List[Array[float]]
275+
List of bounding boxes to check if they are contained in ``bbox0``.
276+
277+
Returns
278+
-------
279+
List[int]
280+
A list with all the indexes into the ``bbox_list`` that are contained in ``bbox0``
281+
along the first two dimensions.
282+
"""
283+
struct_contains_inds = []
284+
for ind, bounds in enumerate(bbox_list):
285+
if all(
286+
[
287+
bbox0[0, 0] <= bounds[0, 0],
288+
bbox0[1, 0] >= bounds[1, 0],
289+
bbox0[0, 1] <= bounds[0, 1],
290+
bbox0[1, 1] >= bounds[1, 1],
291+
]
292+
):
293+
struct_contains_inds.append(ind)
294+
return struct_contains_inds
295+
226296
def make_grid_multiple_intervals( # pylint:disable=too-many-locals
227297
self,
228-
max_dl_list: np.ndarray,
229-
len_interval_list: np.ndarray,
298+
max_dl_list: Array[float],
299+
len_interval_list: Array[float],
230300
max_scale: float,
231301
is_periodic: bool,
232-
) -> List[np.ndarray]:
302+
) -> List[Array[float]]:
233303
"""Create grid steps in multiple connecting intervals of length specified by
234304
``len_interval_list``. The maximal allowed step size in each interval is given by
235305
``max_dl_list``. The maximum ratio between neighboring steps is bounded by ``max_scale``.
236306
237307
Parameters
238308
----------
239-
max_dl_list : np.ndarray
309+
max_dl_list : Array[float]
240310
Maximal allowed step size of each interval.
241-
len_interval_list : np.ndarray
311+
len_interval_list : Array[float]
242312
A list of interval lengths
243313
max_scale : float
244314
Maximal ratio between consecutive steps.
@@ -247,7 +317,7 @@ def make_grid_multiple_intervals( # pylint:disable=too-many-locals
247317
248318
Returns
249319
-------
250-
List[np.ndarray]
320+
List[Array[float]]
251321
A list of of step sizes in each interval.
252322
"""
253323

@@ -322,19 +392,19 @@ def make_grid_multiple_intervals( # pylint:disable=too-many-locals
322392

323393
def grid_multiple_interval_analy_refinement(
324394
self,
325-
max_dl_list: np.ndarray,
326-
len_interval_list: np.ndarray,
395+
max_dl_list: Array[float],
396+
len_interval_list: Array[float],
327397
max_scale: float,
328398
is_periodic: bool,
329-
) -> Tuple[np.ndarray, np.ndarray]:
399+
) -> Tuple[Array[float], Array[float]]:
330400
"""Analytical refinement for multiple intervals. "analytical" meaning we allow
331401
non-integar step sizes, so that we don't consider snapping here.
332402
333403
Parameters
334404
----------
335-
max_dl_list : np.ndarray
405+
max_dl_list : Array[float]
336406
Maximal allowed step size of each interval.
337-
len_interval_list : np.ndarray
407+
len_interval_list : Array[float]
338408
A list of interval lengths
339409
max_scale : float
340410
Maximal ratio between consecutive steps.
@@ -343,7 +413,7 @@ def grid_multiple_interval_analy_refinement(
343413
344414
Returns
345415
-------
346-
Tuple[np.ndarray, np.ndarray]
416+
Tuple[Array[float], Array[float]]
347417
left and right step sizes of each interval.
348418
"""
349419

@@ -411,7 +481,7 @@ def make_grid_in_interval(
411481
max_dl: float,
412482
max_scale: float,
413483
len_interval: float,
414-
) -> np.ndarray:
484+
) -> Array[float]:
415485
"""Create a set of grid steps in an interval of length ``len_interval``,
416486
with first step no larger than ``max_scale * left_neighbor_dl`` and last step no larger than
417487
``max_scale * right_neighbor_dl``, with maximum ratio ``max_scale`` between
@@ -432,7 +502,7 @@ def make_grid_in_interval(
432502
433503
Returns
434504
-------
435-
np.ndarray
505+
Array[float]
436506
A list of step sizes in the interval.
437507
"""
438508

@@ -517,7 +587,7 @@ def grid_grow_plateau_decrease_in_interval(
517587
max_dl: float,
518588
max_scale: float,
519589
len_interval: float,
520-
) -> np.ndarray:
590+
) -> Array[float]:
521591
"""In an interval, grid grows, plateau, and decrease, resembling Lambda letter but
522592
with plateau in the connection part..
523593
@@ -586,7 +656,7 @@ def grid_grow_decrease_in_interval(
586656
right_dl: float,
587657
max_scale: float,
588658
len_interval: float,
589-
) -> np.ndarray:
659+
) -> Array[float]:
590660
"""In an interval, grid grows, and decrease, resembling Lambda letter.
591661
592662
Parameters
@@ -674,7 +744,7 @@ def grid_grow_plateau_in_interval(
674744
large_dl: float,
675745
max_scale: float,
676746
len_interval: float,
677-
) -> np.ndarray:
747+
) -> Array[float]:
678748
"""In an interval, grid grows, then plateau.
679749
680750
Parameters
@@ -690,7 +760,7 @@ def grid_grow_plateau_in_interval(
690760
691761
Returns
692762
-------
693-
np.ndarray
763+
Array[float]
694764
A list of step sizes in the interval, in ascending order.
695765
"""
696766
# steps for scaling
@@ -728,7 +798,7 @@ def grid_grow_in_interval(
728798
small_dl: float,
729799
max_scale: float,
730800
len_interval: float,
731-
) -> np.ndarray:
801+
) -> Array[float]:
732802
"""Mesh simply grows in an interval.
733803
734804
Parameters
@@ -742,7 +812,7 @@ def grid_grow_in_interval(
742812
743813
Returns
744814
-------
745-
np.ndarray
815+
Array[float]
746816
A list of step sizes in the interval, in ascending order.
747817
"""
748818

0 commit comments

Comments
 (0)