Skip to content

Commit 8ab8a78

Browse files
authored
Merge pull request #351 from keflavich/memmaped_coadd
Generalize reproject_and_coadd for N-dimensional data, and add option to specify blank pixel value and progress bar
2 parents 9ba352c + f5e3016 commit 8ab8a78

File tree

6 files changed

+198
-125
lines changed

6 files changed

+198
-125
lines changed

reproject/array_utils.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
__all__ = ["map_coordinates"]
3+
__all__ = ["map_coordinates", "sample_array_edges"]
44

55

66
def map_coordinates(image, coords, **kwargs):
@@ -35,3 +35,22 @@ def map_coordinates(image, coords, **kwargs):
3535
values[reset] = kwargs.get("cval", 0.0)
3636

3737
return values
38+
39+
40+
def sample_array_edges(shape, *, n_samples):
41+
# Given an N-dimensional array shape, sample each edge of the array using
42+
# the requested number of samples (which will include vertices). To do this
43+
# we iterate through the dimensions and for each one we sample the points
44+
# in that dimension and iterate over the combination of other vertices.
45+
# Returns an array with dimensions (N, n_samples)
46+
all_positions = []
47+
ndim = len(shape)
48+
shape = np.array(shape)
49+
for idim in range(ndim):
50+
for vertex in range(2**ndim):
51+
positions = -0.5 + shape * ((vertex & (2 ** np.arange(ndim))) > 0).astype(int)
52+
positions = np.broadcast_to(positions, (n_samples, ndim)).copy()
53+
positions[:, idim] = np.linspace(-0.5, shape[idim] - 0.5, n_samples)
54+
all_positions.append(positions)
55+
positions = np.unique(np.vstack(all_positions), axis=0).T
56+
return positions

reproject/mosaicking/coadd.py

+100-66
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,18 @@
44
from astropy.wcs import WCS
55
from astropy.wcs.wcsapi import SlicedLowLevelWCS
66

7+
from ..array_utils import sample_array_edges
78
from ..utils import parse_input_data, parse_input_weights, parse_output_projection
89
from .background import determine_offset_matrix, solve_corrections_sgd
910
from .subset_array import ReprojectedArraySubset
1011

1112
__all__ = ["reproject_and_coadd"]
1213

1314

15+
def _noop(iterable):
16+
return iterable
17+
18+
1419
def reproject_and_coadd(
1520
input_data,
1621
output_projection,
@@ -24,14 +29,15 @@ def reproject_and_coadd(
2429
background_reference=None,
2530
output_array=None,
2631
output_footprint=None,
32+
block_sizes=None,
33+
progress_bar=None,
34+
blank_pixel_value=0,
2735
**kwargs,
2836
):
2937
"""
30-
Given a set of input images, reproject and co-add these to a single
38+
Given a set of input data, reproject and co-add these to a single
3139
final image.
3240
33-
This currently only works with 2-d images with celestial WCS.
34-
3541
Parameters
3642
----------
3743
input_data : iterable
@@ -77,7 +83,7 @@ def reproject_and_coadd(
7783
`~astropy.io.fits.HDUList` instance, specifies the HDU to use.
7884
reproject_function : callable
7985
The function to use for the reprojection.
80-
combine_function : { 'mean', 'sum', 'median', 'first', 'last', 'min', 'max' }
86+
combine_function : { 'mean', 'sum', 'first', 'last', 'min', 'max' }
8187
The type of function to use for combining the values into the final
8288
image. For 'first' and 'last', respectively, the reprojected images are
8389
simply overlaid on top of each other. With respect to the order of the
@@ -92,11 +98,22 @@ def reproject_and_coadd(
9298
output_array : array or None
9399
The final output array. Specify this if you already have an
94100
appropriately-shaped array to store the data in. Must match shape
95-
specified with ``shape_out`` or derived from the output projection.
101+
specified with `shape_out` or derived from the output
102+
projection.
96103
output_footprint : array or None
97104
The final output footprint array. Specify this if you already have an
98105
appropriately-shaped array to store the data in. Must match shape
99-
specified with ``shape_out`` or derived from the output projection.
106+
specified with `shape_out` or derived from the output projection.
107+
block_sizes : list of tuples or None
108+
The block size to use for each dataset. Could also be a single tuple
109+
if you want the sample block size for all data sets.
110+
progress_bar : callable, optional
111+
If specified, use this as a progress_bar to track loop iterations over
112+
data sets.
113+
blank_pixel_value : float, optional
114+
Value to use for areas of the resulting mosaic that do not have input
115+
data.
116+
100117
**kwargs
101118
Keyword arguments to be passed to the reprojection function.
102119
@@ -116,34 +133,49 @@ def reproject_and_coadd(
116133

117134
# Validate inputs
118135

119-
if combine_function not in ("mean", "sum", "median", "first", "last", "min", "max"):
120-
raise ValueError("combine_function should be one of mean/sum/median/first/last/min/max")
136+
if combine_function not in ("mean", "sum", "first", "last", "min", "max"):
137+
raise ValueError("combine_function should be one of mean/sum/first/last/min/max")
121138

122139
if reproject_function is None:
123140
raise ValueError(
124141
"reprojection function should be specified with the reproject_function argument"
125142
)
126143

144+
if progress_bar is None:
145+
progress_bar = _noop
146+
127147
# Parse the output projection to avoid having to do it for each
128148

129149
wcs_out, shape_out = parse_output_projection(output_projection, shape_out=shape_out)
130150

131-
if output_array is not None and output_array.shape != shape_out:
151+
if output_array is None:
152+
output_array = np.zeros(shape_out)
153+
elif output_array.shape != shape_out:
132154
raise ValueError(
133155
"If you specify an output array, it must have a shape matching "
134156
f"the output shape {shape_out}"
135157
)
136-
if output_footprint is not None and output_footprint.shape != shape_out:
158+
159+
if output_footprint is None:
160+
output_footprint = np.zeros(shape_out)
161+
elif output_footprint.shape != shape_out:
137162
raise ValueError(
138163
"If you specify an output footprint array, it must have a shape matching "
139164
f"the output shape {shape_out}"
140165
)
141166

167+
# Define 'on-the-fly' mode: in the case where we don't need to match
168+
# the backgrounds and we are combining with 'mean' or 'sum', we don't
169+
# have to keep track of the intermediate arrays and can just modify
170+
# the output array on-the-fly
171+
on_the_fly = not match_background and combine_function in ("mean", "sum")
172+
142173
# Start off by reprojecting individual images to the final projection
143174

144-
arrays = []
175+
if not on_the_fly:
176+
arrays = []
145177

146-
for idata in range(len(input_data)):
178+
for idata in progress_bar(range(len(input_data))):
147179
# We need to pre-parse the data here since we need to figure out how to
148180
# optimize/minimize the size of each output tile (see below).
149181
array_in, wcs_in = parse_input_data(input_data[idata], hdu_in=hdu_in)
@@ -166,42 +198,48 @@ def reproject_and_coadd(
166198
# significant distortion (when the edges of the input image become
167199
# convex in the output projection), and transforming every edge pixel,
168200
# which provides a lot of redundant information.
169-
ny, nx = array_in.shape
170-
n_per_edge = 11
171-
xs = np.linspace(-0.5, nx - 0.5, n_per_edge)
172-
ys = np.linspace(-0.5, ny - 0.5, n_per_edge)
173-
xs = np.concatenate((xs, np.full(n_per_edge, xs[-1]), xs, np.full(n_per_edge, xs[0])))
174-
ys = np.concatenate((np.full(n_per_edge, ys[0]), ys, np.full(n_per_edge, ys[-1]), ys))
175-
xc_out, yc_out = wcs_out.world_to_pixel(wcs_in.pixel_to_world(xs, ys))
201+
202+
edges = sample_array_edges(array_in.shape, n_samples=11)[::-1]
203+
edges_out = wcs_out.world_to_pixel(wcs_in.pixel_to_world(*edges))[::-1]
176204

177205
# Determine the cutout parameters
178206

179207
# In some cases, images might not have valid coordinates in the corners,
180208
# such as all-sky images or full solar disk views. In this case we skip
181209
# this step and just use the full output WCS for reprojection.
182210

183-
if np.any(np.isnan(xc_out)) or np.any(np.isnan(yc_out)):
184-
imin = 0
185-
imax = shape_out[1]
186-
jmin = 0
187-
jmax = shape_out[0]
188-
else:
189-
imin = max(0, int(np.floor(xc_out.min() + 0.5)))
190-
imax = min(shape_out[1], int(np.ceil(xc_out.max() + 0.5)))
191-
jmin = max(0, int(np.floor(yc_out.min() + 0.5)))
192-
jmax = min(shape_out[0], int(np.ceil(yc_out.max() + 0.5)))
211+
ndim_out = len(shape_out)
193212

194-
if imax < imin or jmax < jmin:
213+
skip_data = False
214+
if np.any(np.isnan(edges_out)):
215+
bounds = list(zip([0] * ndim_out, shape_out))
216+
else:
217+
bounds = []
218+
for idim in range(ndim_out):
219+
imin = max(0, int(np.floor(edges_out[idim].min() + 0.5)))
220+
imax = min(shape_out[idim], int(np.ceil(edges_out[idim].max() + 0.5)))
221+
bounds.append((imin, imax))
222+
if imax < imin:
223+
skip_data = True
224+
break
225+
226+
if skip_data:
195227
continue
196228

229+
slice_out = tuple([slice(imin, imax) for (imin, imax) in bounds])
230+
197231
if isinstance(wcs_out, WCS):
198-
wcs_out_indiv = wcs_out[jmin:jmax, imin:imax]
232+
wcs_out_indiv = wcs_out[slice_out]
199233
else:
200-
wcs_out_indiv = SlicedLowLevelWCS(
201-
wcs_out.low_level_wcs, (slice(jmin, jmax), slice(imin, imax))
202-
)
234+
wcs_out_indiv = SlicedLowLevelWCS(wcs_out.low_level_wcs, slice_out)
203235

204-
shape_out_indiv = (jmax - jmin, imax - imin)
236+
shape_out_indiv = [imax - imin for (imin, imax) in bounds]
237+
238+
if block_sizes is not None:
239+
if len(block_sizes) == len(input_data) and len(block_sizes[idata]) == len(shape_out):
240+
kwargs["block_size"] = block_sizes[idata]
241+
else:
242+
kwargs["block_size"] = block_sizes
205243

206244
# TODO: optimize handling of weights by making reprojection functions
207245
# able to handle weights, and make the footprint become the combined
@@ -235,12 +273,20 @@ def reproject_and_coadd(
235273
weights[reset] = 0.0
236274
footprint *= weights
237275

238-
array = ReprojectedArraySubset(array, footprint, imin, imax, jmin, jmax)
276+
array = ReprojectedArraySubset(array, footprint, bounds)
239277

240278
# TODO: make sure we gracefully handle the case where the
241279
# output image is empty (due e.g. to no overlap).
242280

243-
arrays.append(array)
281+
if on_the_fly:
282+
# By default, values outside of the footprint are set to NaN
283+
# but we set these to 0 here to avoid getting NaNs in the
284+
# means/sums.
285+
array.array[array.footprint == 0] = 0
286+
output_array[array.view_in_original_array] += array.array * array.footprint
287+
output_footprint[array.view_in_original_array] += array.footprint
288+
else:
289+
arrays.append(array)
244290

245291
# If requested, try and match the backgrounds.
246292
if match_background and len(arrays) > 1:
@@ -251,37 +297,32 @@ def reproject_and_coadd(
251297
for array, correction in zip(arrays, corrections, strict=True):
252298
array.array -= correction
253299

254-
# At this point, the images are now ready to be co-added.
255-
256-
if output_array is None:
257-
output_array = np.zeros(shape_out)
258-
if output_footprint is None:
259-
output_footprint = np.zeros(shape_out)
260-
261-
if combine_function == "min":
262-
output_array[...] = np.inf
263-
elif combine_function == "max":
264-
output_array[...] = -np.inf
265-
266300
if combine_function in ("mean", "sum"):
267-
for array in arrays:
268-
# By default, values outside of the footprint are set to NaN
269-
# but we set these to 0 here to avoid getting NaNs in the
270-
# means/sums.
271-
array.array[array.footprint == 0] = 0
301+
if match_background:
302+
# if we're not matching the background, this part has already been done
303+
for array in arrays:
304+
# By default, values outside of the footprint are set to NaN
305+
# but we set these to 0 here to avoid getting NaNs in the
306+
# means/sums.
307+
array.array[array.footprint == 0] = 0
272308

273-
output_array[array.view_in_original_array] += array.array * array.footprint
274-
output_footprint[array.view_in_original_array] += array.footprint
309+
output_array[array.view_in_original_array] += array.array * array.footprint
310+
output_footprint[array.view_in_original_array] += array.footprint
275311

276312
if combine_function == "mean":
277313
with np.errstate(invalid="ignore"):
278314
output_array /= output_footprint
279-
output_array[output_footprint == 0] = 0
315+
output_array[output_footprint == 0] = blank_pixel_value
280316

281317
elif combine_function in ("first", "last", "min", "max"):
318+
if combine_function == "min":
319+
output_array[...] = np.inf
320+
elif combine_function == "max":
321+
output_array[...] = -np.inf
322+
282323
for array in arrays:
283324
if combine_function == "first":
284-
mask = (output_footprint[array.view_in_original_array] == 0) & (array.footprint > 0)
325+
mask = output_footprint[array.view_in_original_array] == 0
285326
elif combine_function == "last":
286327
mask = array.footprint > 0
287328
elif combine_function == "min":
@@ -300,13 +341,6 @@ def reproject_and_coadd(
300341
mask, array.array, output_array[array.view_in_original_array]
301342
)
302343

303-
elif combine_function == "median":
304-
# Here we need to operate in chunks since we could otherwise run
305-
# into memory issues
306-
307-
raise NotImplementedError("combine_function='median' is not yet implemented")
308-
309-
if combine_function in ("min", "max"):
310-
output_array[output_footprint == 0] = 0.0
344+
output_array[output_footprint == 0] = blank_pixel_value
311345

312346
return output_array, output_footprint

0 commit comments

Comments
 (0)