Skip to content

Commit e1e1972

Browse files
tylerflexmomchil-flex
authored andcommitted
fix adjoint plugin with complex-valued permittivity inputs
1 parent 790cab9 commit e1e1972

File tree

3 files changed

+14
-2
lines changed

3 files changed

+14
-2
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
### Added
99
- Support for differentiating with respect to `JaxMedium.conductivity`.
10-
1110
- Validating that every surface (unless excluded in ``exclude_surfaces``) of a 3D ``SurfaceIntegrationMonitor`` (flux monitor or field projection monitor) is not completely outside the simulation domain.
1211

1312
### Changed
@@ -26,6 +25,7 @@ that the fields match exactly except for a ``pi`` phase shift. This interpretati
2625
- Cleaner display of `ArrayLike` in docs.
2726
- `ArrayLike` validation properly fails with `None` or `nan` contents.
2827
- Apply finite grid correction to the fields when calculating the Poynting vector from 2D monitors.
28+
- `JaxCustomMedium` properly handles complex-valued permittivity.
2929

3030
## [2.3.0] - 2023-6-30
3131

tests/test_plugins/test_adjoint.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,11 @@ def make_sim(
227227

228228
jax_box_custom = JaxBox(size=size, center=(1, 0, 2))
229229
values = base_eps_val + np.random.random((Nx, Ny, Nz, 1))
230+
231+
# adding this line breaks things without enforcing that the vjp for custom medium is complex
232+
values = (1 + 1j) * values
233+
values = values + (1 + 1j) * values / 0.5
234+
230235
eps_ii = JaxDataArray(values=values, coords=coords)
231236
field_components = {f"eps_{dim}{dim}": eps_ii for dim in "xyz"}
232237
jax_eps_dataset = JaxPermittivityDataset(**field_components)

tidy3d/plugins/adjoint/components/medium.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,14 @@ def store_vjp(
523523

524524
# reshape values to the expected vjp shape to be more safe
525525
vjp_shape = tuple(len(coord) for _, coord in coords.items())
526-
vjp_values = e_dotted.real.values.reshape(vjp_shape)
526+
527+
# make sure this has the same dtype as the original
528+
dtype_orig = np.array(orig_data_array.values).dtype
529+
530+
vjp_values = e_dotted.values.reshape(vjp_shape)
531+
if dtype_orig.kind == "f":
532+
vjp_values = vjp_values.real
533+
vjp_values = vjp_values.astype(dtype_orig)
527534

528535
# construct a DataArray storing the vjp
529536
vjp_data_array = JaxDataArray(values=vjp_values, coords=coords)

0 commit comments

Comments
 (0)