-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathobjective.py
428 lines (333 loc) · 14 KB
/
objective.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
#!/usr/bin/env python
# coding: utf-8
import os
import tempfile
import gc
import logging
from shutil import copytree, move
import numpy as np
from numpy.linalg import norm as vecnorm
from simnibs import cond
from simnibs.msh import mesh_io
from simnibs.simulation.fem import tms_coil
from fieldopt import geolib
logging.basicConfig(
format='[%(levelname)s - %(name)s.%(funcName)5s() ] %(message)s',
level=logging.INFO)
logger = logging.getLogger(__name__)
class FieldFunc():
'''
This class provides an interface in which the details related
to simulation and score extraction is abstracted away for
any general Optimization package
Layout
Properties:
1. Mesh file to perform simulation on
2. Store details surrounding the input -->
surface transformations (use file paths?)
'''
FIELD_ENTITY = (3, 2)
PIAL_ENTITY = (2, 1002)
FIELDS = ['E', 'e', 'J', 'j']
def __init__(self,
mesh_file,
initial_centroid,
tet_weights,
field_dir,
coil,
span=35,
local_span=8,
distance=1,
didt=1e6,
cpus=1,
solver_options=None):
'''
Standard constructor
Arguments:
mesh_file Path to FEM model
initial_centroid Initial point to grow sampling region
tet_weights Weighting scores for each tetrahedron
(1D array ordered by node ID)
field_dir Directory to perform simulation
experiments in
coil TMS coil file (either dA/dt volume or
coil geometry)
span Radius of points to include in
sampling surface
local_span Radius of points to include in
construction of local geometry
for normal and curvature estimation
distance Distance from coil to head surface
didt Intensity of stimulation
cpus Number of cpus to use for simulation
'''
self.mesh = mesh_file
self.tw = tet_weights
self.field_dir = field_dir
self.coil = coil
self.didt = didt
logger.info(f"Configured to use {cpus} cpus...")
self.cpus = cpus
self.geo_radius = local_span
self.distance = distance
self.solver_opt = solver_options
logger.info('Loading in coordinate data from mesh file...')
self.nodes, self.coords, _ = geolib.load_gmsh_nodes(self.mesh, (2, 5))
_, _, trigs = geolib.load_gmsh_elems(self.mesh, (2, 5))
self.trigs = np.array(trigs).reshape(-1, 3)
logger.info('Successfully pulled in node and element data!')
# Construct basis of sampling space using centroid
logger.info('Constructing initial sampling surface...')
C, iR, bounds = self._initialize(initial_centroid, span)
self.C = C
self.iR = iR
self.bounds = bounds
logger.info('Successfully constructed initial sampling surface')
# Store single read in memory, this will prevent GC issues
logger.info('Caching mesh file on instance construction...')
self.cached_mesh = mesh_io.read_msh(mesh_file)
self.cached_mesh.fix_surface_labels()
logger.info('Successfully cached mesh file')
logger.info('Storing standard conductivity values...')
condlist = [c.value for c in cond.standard_cond()]
self.cond = cond.cond2elmdata(self.cached_mesh, condlist)
logger.info('Successfully stored conductivity values...')
# Control for coil file-type and SimNIBS changing convention
if self.coil.endswith('.ccd'):
self.normflip = 1
else:
self.normflip = -1
def __repr__(self):
'''
print(FieldFunc)
'''
print('Mesh:', self.mesh)
print('Coil:', self.coil)
print('Field Directory:', self.field_dir)
return ''
def get_bounds(self):
return self.bounds
def _initialize(self, centroid, span):
'''
Construct quadratic basis and rotation at centroid point
to use for sampling
'''
v = geolib.closest_point2surf(centroid, self.coords)
C, R, iR = self._construct_local_quadric(v, 0.75 * span)
# Calculate neighbours, rotate to flatten on XY plane
neighbours_ind = np.where(vecnorm(self.coords - v, axis=1) < span)
neighbours = self.coords[neighbours_ind]
r_neighbours = geolib.affine(R, neighbours)
minarr = np.min(r_neighbours, axis=0)
maxarr = np.max(r_neighbours, axis=0)
bounds = np.c_[minarr.T, maxarr.T]
return C, iR, bounds
def _construct_local_quadric(self, p, tol=1e-3):
'''
Given a single point construct a local quadric
surface on a given mesh
'''
# Get local neighbourhood
neighbours_ind = np.where(vecnorm(self.coords - p, axis=1) < tol)
neighbours = self.coords[neighbours_ind]
# Calculate normals
normals = geolib.get_normals(self.nodes[neighbours_ind], self.nodes,
self.coords, self.trigs)
# Usage average of normals for alignment
n = normals / vecnorm(normals)
# Make transformation matrix
z = np.array([0, 0, 1])
R = np.eye(4)
R[:3, :3] = geolib.rotate_vec2vec(n, z)
T = np.eye(4)
T[:3, 3] = -p
affine = R @ T
# Create inverse rotation
iR = R
iR[:3, :3] = iR[:3, :3].T
# Create inverse translation
iT = T
iT[:3, 3] = -T[:3, 3]
i_affine = iT @ iR
# Perform quadratic fitting
r_neighbours = geolib.affine(affine, neighbours)
C = geolib.quad_fit(r_neighbours[:, :2], r_neighbours[:, 2])
return C, affine, i_affine
def _construct_sample(self, x, y):
'''
Given a sampling point, estimate local geometry to
get accurate normals/curvatures
'''
pp = geolib.map_param_2_surf(x, y, self.C)[np.newaxis, :]
p = geolib.affine(self.iR, pp)
v = geolib.closest_point2surf(p, self.coords)
C, _, iR = self._construct_local_quadric(v, self.geo_radius)
_, _, n = geolib.compute_principal_dir(0, 0, C)
# Map normal to coordinate space
n_r = iR[:3, :3] @ n
n_r = n_r / vecnorm(n_r)
# Push sample out by set distance
sample = v + (n_r * self.distance)
return sample, iR, C, n
def _transform_input(self, x, y, theta):
'''
Generates a coil orientation matrix given inputs from a
quadratic surface sampling domain
'''
sample, R, C, _ = self._construct_sample(x, y)
preaff_rot, preaff_norm = geolib.map_rot_2_surf(0, 0, theta, C)
rot = R[:3, :3] @ preaff_rot
n = R[:3, :3] @ preaff_norm
o_matrix = geolib.define_coil_orientation(sample, rot,
self.normflip * n)
return o_matrix
def _get_simulation_outnames(self, num_sims, sim_dir):
simu_name = os.path.join(sim_dir, 'TMS_{}'.format(1))
coil_name = os.path.splitext(os.path.basename(self.coil))[0]
fn_simu = [
"{0}-{1:0=4d}_{2}_".format(simu_name, i + 1, coil_name)
for i in range(num_sims)
]
output_names = [f + 'scalar.msh' for f in fn_simu]
geo_names = [f + 'coil_pos.geo' for f in fn_simu]
return output_names, geo_names
def _run_simulation(self, matsimnibs, sim_dir):
if not isinstance(matsimnibs, list):
matsimnibs = [matsimnibs]
# Construct standard inputs
logger.info('Constructing inputs for simulation...')
logger.info(f'Using didt={self.didt}')
didt_list = [self.didt] * len(matsimnibs)
output_names, geo_names = self._get_simulation_outnames(
len(matsimnibs), sim_dir)
logger.info('Starting SimNIBS simulations...')
tms_coil(self.cached_mesh,
self.cond,
self.coil,
self.FIELDS,
matsimnibs,
didt_list,
output_names,
geo_names,
n_workers=self.cpus,
solver_options=self.solver_opt)
logger.info('Successfully completed simulations!')
return sorted(output_names)
def run_simulation(self, coord, out_sim, out_geo):
'''
Given a quadratic surface input (x,y) and a rotational
interpolation angle (theta) run a simulation and
save the resulting output into <out_dir>
Arguments:
[(x,y,theta),...] A iterable of iterable (x,y,theta)
Returns:
sim_file Path to simulation file
in output directory
score Score
'''
logger.info('Transforming inputs...')
matsimnibs = self._transform_input(*coord)
# Run simulation in a temporary directory
with tempfile.TemporaryDirectory(dir=self.field_dir) as sim_dir:
logger.info('Running simulation')
sim_file = self._run_simulation(matsimnibs, sim_dir)[0]
logger.info('Calculating Score...')
scores = self._calculate_score(sim_file)
logger.info('Successfully pulled scores!')
sim_names, geo_names = self._get_simulation_outnames(1, sim_dir)
# Transfer files to destination
logger.info('Transferring files to destination')
move(geo_names[0], out_geo)
move(sim_names[0], out_sim)
logger.info('Succesfully transferred files')
return scores, matsimnibs
def _get_tet_ids(self, entity):
'''
Pull list of element IDs for a given gmsh entity
'''
tet_ids = np.where(self.cached_mesh.elm.tag1 == entity[1])
return tet_ids
def _calculate_score(self, sim_file):
'''
Given a simulation output file, compute the score
Volumes are integrated into the score function to deal with
score inflation due to a larger number of elements near
more complex geometry.
'''
logger.info('Loading gmsh elements from {}...'.format(sim_file))
tet_ids = self._get_tet_ids(self.FIELD_ENTITY)
logger.info('Pulling field values from {}...'.format(sim_file))
normE = get_field_subset(sim_file, tet_ids)
logger.info('Successfully pulled field values!')
neg_ind = np.where(normE < 0)
normE[neg_ind] = 0
vols = self.cached_mesh.elements_volumes_and_areas().value[tet_ids]
scores = self.tw * normE * vols
return scores.sum()
def evaluate(self, input_list, out_dir=None):
'''
Given a quadratic surface input (x,y) and rotational
interpolation angle (theta) compute the resulting field score
over a region of interest
Arguments:
[(x,y,theta),...] A iterable of iterable (x,y,theta)
Returns:
scores An array of scores in order of inputs
'''
with tempfile.TemporaryDirectory(dir=self.field_dir) as sim_dir:
logger.info('Transforming inputs...')
matsimnibs = [
self._transform_input(x, y, t) for x, y, t in input_list
]
logger.info('Running simulations...')
sim_files = self._run_simulation(matsimnibs, sim_dir)
logger.info('Calculating Scores...')
scores = np.array([self._calculate_score(s) for s in sim_files])
logger.info('Successfully pulled scores!')
if out_dir is not None:
logger.info('Storing outputs!')
copytree(sim_dir, out_dir, dirs_exist_ok=True)
logger.info(f'Successfully stored outputs in {out_dir}')
return scores
def get_coil2cortex_distance(self, input_coord):
'''
Given an input sampling coordinate on the parameteric
mesh calculate the distance from the coil to target triangle
on the cortical surface mesh.
This function wraps:
`simnibs.Msh.intercept_ray`
The proposed ray is drawn from the coil centre and is drawn
with large magnitude in the direction of the coil normal
Arguments:
input_coord : a 1 dimensional 3-iterable containing
x, y, and rotation
Returns:
d : The distance between the proposed coil location
and the target ROI on the cortex
'''
# Compute the coil affine matrix
coil_affine = self._transform_input(*input_coord)
n = coil_affine[:3, 2]
p0 = coil_affine[:3, 3]
p1 = p0 + (n * 200)
# Calculate interception with pial surface
pial_mesh = self.cached_mesh.crop_msh(self.PIAL_ENTITY[1],
self.PIAL_ENTITY[0])
_, pos = pial_mesh.intercept_ray(p0, p1)
# Compute distance
return np.linalg.norm(p0 - pos)
def get_field_subset(field_msh, tag_list):
'''
From a .msh file outputted from running a TMS field simulation
extract the field magnitude values of elements provided for in tag_list
field_msh -- Path to .msh file result from TMS simulation
tag_list -- List of element tags to use as subset
Output:
normE -- List of electric field norms (magnitudes)
subsetted according to tag_list
'''
msh = mesh_io.read_msh(field_msh)
norm_E = msh.elmdata[1].value
del msh
gc.collect()
return norm_E[tag_list]