Skip to content

Commit 790cab9

Browse files
tylerflexmomchil-flex
authored andcommitted
add conductivity support in adjoint plugin
1 parent e8059ad commit 790cab9

File tree

3 files changed

+34
-8
lines changed

3 files changed

+34
-8
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
66
## [Unreleased]
77

88
### Added
9+
- Support for differentiating with respect to `JaxMedium.conductivity`.
910

1011
- Validating that every surface (unless excluded in ``exclude_surfaces``) of a 3D ``SurfaceIntegrationMonitor`` (flux monitor or field projection monitor) is not completely outside the simulation domain.
1112

tests/test_plugins/test_adjoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def make_sim(
201201

202202
# JaxBox
203203
jax_box1 = JaxBox(size=size, center=(1, 0, 2))
204-
jax_med1 = JaxMedium(permittivity=permittivity)
204+
jax_med1 = JaxMedium(permittivity=permittivity, conductivity=permittivity * 0.1)
205205
jax_struct1 = JaxStructure(geometry=jax_box1, medium=jax_med1)
206206

207207
jax_box2 = JaxBox(size=size, center=(-1, 0, -3))

tidy3d/plugins/adjoint/components/medium.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import pydantic as pd
88
import numpy as np
9-
import jax.numpy as jnp
109
from jax.tree_util import register_pytree_node_class
1110
import xarray as xr
1211

@@ -17,6 +16,7 @@
1716
from ....components.data.dataset import PermittivityDataset
1817
from ....components.data.data_array import ScalarFieldDataArray
1918
from ....exceptions import SetupError
19+
from ....constants import CONDUCTIVITY
2020

2121
from .base import JaxObject
2222
from .types import JaxFloat, validate_jax_float
@@ -92,7 +92,7 @@ def e_mult_volume(
9292
e_fwd = grad_data_fwd.field_components[field]
9393
e_adj = grad_data_adj.field_components[field]
9494

95-
e_dotted = (e_fwd * e_adj).real
95+
e_dotted = e_fwd * e_adj
9696

9797
inside_mask = self.make_inside_mask(vol_coords=vol_coords, inside_fn=inside_fn)
9898

@@ -103,7 +103,7 @@ def e_mult_volume(
103103
}
104104
interp_kwargs = {key: value for key, value in vol_coords.items() if key not in isel_kwargs}
105105

106-
fields_eval = e_dotted.isel(f=0, **isel_kwargs).interp(**interp_kwargs, assume_sorted=True)
106+
fields_eval = e_dotted.isel(**isel_kwargs).interp(**interp_kwargs, assume_sorted=True)
107107
inside_mask = inside_mask.isel(**isel_kwargs)
108108

109109
return inside_mask * d_vol * fields_eval
@@ -148,12 +148,22 @@ class JaxMedium(Medium, AbstractJaxMedium):
148148
jax_field=True,
149149
)
150150

151+
conductivity: JaxFloat = pd.Field(
152+
0.0,
153+
title="Conductivity",
154+
description="Electric conductivity. Defined such that the imaginary part of the complex "
155+
"permittivity at angular frequency omega is given by conductivity/omega.",
156+
units=CONDUCTIVITY,
157+
jax_field=True,
158+
)
159+
151160
@pd.validator("conductivity", always=True)
152161
def _passivity_validation(cls, val, values):
153162
"""Override of inherited validator."""
154163
return val
155164

156165
_sanitize_permittivity = validate_jax_float("permittivity")
166+
_sanitize_conductivity = validate_jax_float("conductivity")
157167

158168
def to_medium(self) -> Medium:
159169
"""Convert :class:`.JaxMedium` instance to :class:`.Medium`"""
@@ -180,8 +190,17 @@ def store_vjp(
180190
inside_fn=inside_fn,
181191
)
182192

183-
vjp_permittivty = jnp.sum(d_eps_map.values)
184-
return self.copy(update=dict(permittivity=vjp_permittivty))
193+
vjp_eps_complex = np.sum(d_eps_map.values)
194+
195+
freq = d_eps_map.coords["f"][0]
196+
vjp_eps, vjp_sigma = self.eps_complex_to_eps_sigma(vjp_eps_complex, freq)
197+
198+
return self.copy(
199+
update=dict(
200+
permittivity=vjp_eps,
201+
conductivity=vjp_sigma,
202+
)
203+
)
185204

186205

187206
@register_pytree_node_class
@@ -255,8 +274,14 @@ def store_vjp(
255274
inside_fn=inside_fn,
256275
)
257276

258-
vjp_ii = jnp.sum(e_mult_dim.real.values)
259-
vjp_fields[component_name] = JaxMedium(permittivity=vjp_ii)
277+
vjp_eps_complex_ii = np.sum(e_mult_dim.values)
278+
freq = e_mult_dim.coords["f"][0]
279+
vjp_eps_ii, vjp_sigma_ii = self.eps_complex_to_eps_sigma(vjp_eps_complex_ii, freq)
280+
281+
vjp_fields[component_name] = JaxMedium(
282+
permittivity=vjp_eps_ii,
283+
conductivity=vjp_sigma_ii,
284+
)
260285

261286
return self.copy(update=vjp_fields)
262287

0 commit comments

Comments
 (0)