Skip to content

Commit 37777f5

Browse files
committed
covariance in tabular fits
1 parent cdf5720 commit 37777f5

File tree

5 files changed

+144
-36
lines changed

5 files changed

+144
-36
lines changed

specutils/io/default_loaders/tabular_fits.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ def tabular_fits_loader(file_obj, column_mapping=None, hdu=1, store_data_header=
4747
4848
Parameters
4949
----------
50-
file_obj : str, file-like, or :class:`~astropy.io.fits.HDUList`
51-
FITS file name, object (provided from name by Astropy I/O Registry),
52-
or HDU list (as resulting from `~astropy.io.fits.open`).
50+
file_obj : str, file-like, or HDUList
51+
FITS file name, object (provided from name by Astropy I/O Registry),
52+
or HDUList (as resulting from astropy.io.fits.open()).
5353
hdu : int
5454
The HDU of the fits file (default: 1st extension) to read from
5555
store_data_header : bool
@@ -73,8 +73,8 @@ def tabular_fits_loader(file_obj, column_mapping=None, hdu=1, store_data_header=
7373
7474
Returns
7575
-------
76-
data : :class:`Spectrum1D`
77-
The spectrum that is represented by the data in the input table.
76+
data : Spectrum1D
77+
The spectrum that is represented by the data in this table.
7878
"""
7979
# Parse the wcs information. The wcs will be passed to the column finding
8080
# routines to search for spectral axis information in the file.
@@ -86,6 +86,12 @@ def tabular_fits_loader(file_obj, column_mapping=None, hdu=1, store_data_header=
8686
else:
8787
tab.meta = hdulist[0].header
8888

89+
# Determine if there is a correlation matrix
90+
correl = None
91+
if 'CORREL' in [h.name for h in hdulist]:
92+
correl = Table.read(hdulist['CORREL'])
93+
correl.meta = hdulist['CORREL'].header
94+
8995
# Minimal checks for wcs consistency with table data -
9096
# assume 1D spectral axis (having shape (0, NAXIS1),
9197
# or alternatively compare against shape of 1st column.
@@ -96,9 +102,9 @@ def tabular_fits_loader(file_obj, column_mapping=None, hdu=1, store_data_header=
96102
# If no column mapping is given, attempt to parse the file using
97103
# unit information
98104
if column_mapping is None:
99-
return generic_spectrum_from_table(tab, wcs=wcs)
105+
return generic_spectrum_from_table(tab, wcs=wcs, correl=correl, **kwargs)
100106

101-
return spectrum_from_column_mapping(tab, column_mapping, wcs=wcs)
107+
return spectrum_from_column_mapping(tab, column_mapping, wcs=wcs, correl=correl)
102108

103109

104110
@custom_writer("tabular-fits")
@@ -173,7 +179,7 @@ def tabular_fits_writer(spectrum, file_name, hdu=1, update_header=False, store_d
173179
if spectrum.uncertainty is not None:
174180
if isinstance(spectrum.uncertainty, Covariance):
175181
var, correl = spectrum.uncertainty.to_tables()
176-
columns.append(np.sqrt(var))
182+
columns.append(np.sqrt(var) * funit)
177183
colnames.append("uncertainty")
178184
else:
179185
try:

specutils/io/parsing_utils.py

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import contextlib
77

88
from astropy.io import fits
9-
from astropy.nddata import StdDevUncertainty
9+
from astropy.nddata import StdDevUncertainty, Covariance
1010
from astropy.utils.exceptions import AstropyUserWarning
1111
import astropy.units as u
1212
import warnings
@@ -52,7 +52,7 @@ def read_fileobj_or_hdulist(*args, **kwargs):
5252
hdulist.close()
5353

5454

55-
def spectrum_from_column_mapping(table, column_mapping, wcs=None, verbose=False):
55+
def spectrum_from_column_mapping(table, column_mapping, wcs=None, correl=None, verbose=False):
5656
"""
5757
Given a table and a mapping of the table column names to attributes
5858
on the Spectrum1D object, parse the information into a Spectrum1D.
@@ -76,6 +76,11 @@ def spectrum_from_column_mapping(table, column_mapping, wcs=None, verbose=False)
7676
wcs : :class:`~astropy.wcs.WCS` or :class:`gwcs.WCS`
7777
WCS object passed to the Spectrum1D initializer.
7878
79+
correl : `astropy.table.Table`, optional
80+
Table with correlation matrix for the uncertainties in coordinate
81+
format; see `~astropy.nddata.Covariance`. If None, uncertainties are
82+
assumed to be uncorrelated.
83+
7984
verbose : bool
8085
Print extra info.
8186
@@ -131,14 +136,28 @@ def spectrum_from_column_mapping(table, column_mapping, wcs=None, verbose=False)
131136
spec_kwargs.setdefault(kwarg_name, kwarg_val)
132137

133138
# Ensure that the uncertainties are a subclass of NDUncertainty
139+
if spec_kwargs.get('uncertainty') is None and correl is not None:
140+
warnings.warn('Unable to parse uncertainty from provided table. Ignoring provided '
141+
'correlation matrix data.')
142+
correl = None
134143
if spec_kwargs.get('uncertainty') is not None:
135-
spec_kwargs['uncertainty'] = StdDevUncertainty(
136-
spec_kwargs.get('uncertainty'))
144+
if correl is not None:
145+
err = spec_kwargs.get('uncertainty')
146+
try:
147+
spec_kwargs['uncertainty'] = Covariance.from_tables(err**2, correl, quiet=True)
148+
except (ValueError, TypeError):
149+
warnings.warn('Unable to parse correlation table into a Covariance object. '
150+
'Ignoring correlation matrix data.')
151+
correl = None
152+
# NOTE: This is not an `else` or `elif` block in order to catch the
153+
# change to correl=None when handling the excemption above.
154+
if correl is None:
155+
spec_kwargs['uncertainty'] = StdDevUncertainty(spec_kwargs.get('uncertainty'))
137156

138157
return Spectrum1D(**spec_kwargs, wcs=wcs, meta={'header': table.meta})
139158

140159

141-
def generic_spectrum_from_table(table, wcs=None):
160+
def generic_spectrum_from_table(table, wcs=None, correl=None, **kwargs):
142161
"""
143162
Load spectrum from an Astropy table into a Spectrum1D object.
144163
Uses the following logic to figure out which column is which:
@@ -155,12 +174,15 @@ def generic_spectrum_from_table(table, wcs=None):
155174
156175
Parameters
157176
----------
158-
table : :class:`~astropy.table.Table`
159-
Table containing a column of ``flux``, and optionally ``spectral_axis``
160-
and ``uncertainty`` as defined above.
161-
wcs : :class:`~astropy.wcs.WCS`
177+
table : `astropy.table.QTable`
178+
Tabulated data with units
179+
wcs : :class:`~astropy.wcs.WCS`, optional
162180
A FITS WCS object. If this is present, the machinery will fall back
163-
and default to using the ``wcs`` to find the dispersion information.
181+
to using the wcs to find the dispersion information.
182+
correl : `astropy.table.Table`, optional
183+
Table with correlation matrix for the uncertainties in coordinate
184+
format; see `~astropy.nddata.Covariance`. If None, uncertainties are
185+
assumed to be uncorrelated.
164186
165187
Returns
166188
-------
@@ -275,12 +297,27 @@ def _find_spectral_column(table, columns_to_search, spectral_axis):
275297
if table[err_column].ndim > 1:
276298
err = table[err_column].T
277299
elif flux.ndim > 1: # Repeat uncertainties over all flux columns
300+
if correl is not None:
301+
warnings.warn('When applying correlated errors, the dimensionality of the error '
302+
'array must match the dimensionality of the flux array. Ignoring '
303+
'correlated errors.')
304+
correl = None
278305
err = np.tile(table[err_column], flux.shape[0], 1)
279306
else:
280307
err = table[err_column]
281-
err = StdDevUncertainty(err.to(err.unit))
282-
if np.min(table[err_column]) <= 0.:
283-
warnings.warn("Standard Deviation has values of 0 or less", AstropyUserWarning)
308+
if correl is not None:
309+
try:
310+
err = Covariance.from_tables(err**2, correl, quiet=True)
311+
except (ValueError, TypeError):
312+
warnings.warn('Unable to parse correlation table into a Covariance object. '
313+
'Ignoring correlation matrix data.')
314+
correl = None
315+
# NOTE: This is not an `else` or `elif` block in order to catch the
316+
# change to correl=None when handling the excemption above.
317+
if correl is None:
318+
err = StdDevUncertainty(err.to(err.unit))
319+
if np.min(table[err_column]) <= 0.:
320+
warnings.warn("Standard Deviation has values of 0 or less", AstropyUserWarning)
284321
else:
285322
err = None
286323

specutils/manipulation/extract_spectral_region.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numpy as np
66

77
from astropy import units as u
8+
from astropy.nddata import Covariance
89
from ..spectra import Spectrum1D, SpectralRegion
910

1011
__all__ = ['extract_region', 'extract_bounding_spectral_region', 'spectral_slab']
@@ -189,6 +190,9 @@ def _get_joined_value(sps, key, unique_inds=None):
189190
uncert = sps[0].uncertainty
190191
if uncert is None:
191192
return None
193+
if isinstance(uncert, Covariance):
194+
raise NotImplementedError("Cannot yet combine spectral regions with "
195+
"covariant uncertainties.")
192196
uncert._array = np.concatenate([sp.uncertainty._array for sp in sps])
193197
return uncert[unique_inds] if unique_inds is not None else uncert
194198
elif key in concat_keys or key == 'spectral_axis':

specutils/spectra/spectrum1d.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from astropy import units as u
66
from astropy.utils.decorators import lazyproperty
77
from astropy.utils.decorators import deprecated
8-
from astropy.nddata import NDUncertainty, NDIOMixin, NDArithmeticMixin
8+
from astropy.nddata import NDUncertainty, NDIOMixin, NDArithmeticMixin, Covariance
99

1010
from .spectral_axis import SpectralAxis
1111
from .spectrum_mixin import OneDSpectrumMixin
@@ -36,8 +36,10 @@ class Spectrum1D(OneDSpectrumMixin, NDCube, NDIOMixin, NDArithmeticMixin):
3636
Parameters
3737
----------
3838
flux : `~astropy.units.Quantity`
39-
The flux data for this spectrum. This can be a simple `~astropy.units.Quantity`,
40-
or an existing `~Spectrum1D` or `~ndcube.NDCube` object.
39+
The flux data for this spectrum. This can be a simple
40+
`~astropy.units.Quantity`, or an existing `~Spectrum1D` or
41+
`~ndcube.NDCube` object. If an `~ndcube.NDCube` object, all other
42+
arguments are ignored.
4143
spectral_axis : `~astropy.units.Quantity` or `~specutils.SpectralAxis`
4244
Dispersion information with the same shape as the last (or only)
4345
dimension of flux, or one greater than the last dimension of flux
@@ -60,8 +62,10 @@ class Spectrum1D(OneDSpectrumMixin, NDCube, NDIOMixin, NDArithmeticMixin):
6062
values represent edges of the wavelength bin, or centers of the bin.
6163
uncertainty : `~astropy.nddata.NDUncertainty`
6264
Contains uncertainty information along with propagation rules for
63-
spectrum arithmetic. Can take a unit, but if none is given, will use
64-
the unit defined in the flux.
65+
spectrum arithmetic. Can take a unit, but if none is given, will use the
66+
unit defined in the flux. Note that functionality is limited for
67+
`~astropy.nddata.Covariance` instances, particularly for
68+
multidimensional data.
6569
mask : `~numpy.ndarray`-like
6670
Array where values in the flux to be masked are those that
6771
``astype(bool)`` converts to True. (For example, integer arrays are not
@@ -211,7 +215,11 @@ def __init__(self, flux=None, spectral_axis=None, wcs=None,
211215
len(kwargs["mask"].shape)-temp_axes[0]-1, -1)
212216
if "uncertainty" in kwargs:
213217
if kwargs["uncertainty"] is not None:
214-
if isinstance(kwargs["uncertainty"], NDUncertainty):
218+
if isinstance(kwargs["uncertainty"], Covariance):
219+
raise NotImplementedError("Cannot yet handle covariance for "
220+
"multidimensional Spectrum1D objects "
221+
"with a WCS coordinate system.")
222+
elif isinstance(kwargs["uncertainty"], NDUncertainty):
215223
# Account for Astropy uncertainty types
216224
unc_len = len(kwargs["uncertainty"].array.shape)
217225
temp_unc = np.swapaxes(kwargs["uncertainty"].array,
@@ -302,10 +310,13 @@ def __init__(self, flux=None, spectral_axis=None, wcs=None,
302310
raise ValueError('Spectral axis must be strictly increasing or decreasing.')
303311

304312
if hasattr(self, 'uncertainty') and self.uncertainty is not None:
305-
if not flux.shape == self.uncertainty.array.shape:
306-
raise ValueError(
307-
"Flux axis ({}) and uncertainty ({}) shapes must be the "
308-
"same.".format(flux.shape, self.uncertainty.array.shape))
313+
if isinstance(self.uncertainty, Covariance):
314+
uncertainty_shape = self.uncertainty.data_shape
315+
else:
316+
uncertainty_shape = self.uncertainty.array.shape
317+
if not flux.shape == uncertainty_shape:
318+
raise ValueError(f"Flux axis ({flux.shape}) and uncertainty ({uncertainty_shape}) "
319+
"shapes must be the same.")
309320

310321
def __getitem__(self, item):
311322
"""
@@ -322,7 +333,6 @@ def __getitem__(self, item):
322333
The first case is handled by the parent class, while the second is
323334
handled here.
324335
"""
325-
326336
if self.flux.ndim > 1 or (isinstance(item, tuple) and item[0] is Ellipsis):
327337
if isinstance(item, tuple):
328338
if len(item) == len(self.flux.shape) or item[0] is Ellipsis:
@@ -371,11 +381,17 @@ def __getitem__(self, item):
371381
else:
372382
new_meta = deepcopy(self.meta)
373383

384+
if isinstance(self.uncertainty, Covariance):
385+
new_unc = self.uncertainty.sub_matrix(item)
386+
elif self.uncertainty is not None:
387+
new_unc = self.uncertainty[item]
388+
else:
389+
new_unc = None
390+
374391
return self._copy(
375392
flux=self.flux[item],
376393
spectral_axis=self.spectral_axis[spec_item],
377-
uncertainty=self.uncertainty[item]
378-
if self.uncertainty is not None else None,
394+
uncertainty=new_unc,
379395
mask=self.mask[item] if self.mask is not None else None,
380396
meta=new_meta, wcs=None)
381397

@@ -748,8 +764,13 @@ def __str__(self):
748764

749765
# Add information about uncertainties if available
750766
if self.uncertainty:
767+
_arr = (
768+
self.uncertainty.toarray()
769+
if isinstance(self.uncertainty, Covariance)
770+
else self.uncertainty.array
771+
)
751772
result += (f'\nUncertainty={type(self.uncertainty).__name__} '
752-
f'({np.array2string(self.uncertainty.array, threshold=8)}'
773+
f'({np.array2string(_arr, threshold=8)}'
753774
f' {self.uncertainty.unit})')
754775

755776
return result

specutils/tests/test_loaders.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,15 @@
66
import pytest
77
import astropy.units as u
88
import numpy as np
9+
from scipy import sparse
910
from astropy.io import fits
1011
from astropy.io.fits.verify import VerifyWarning
1112
from astropy.table import Table
1213
from astropy.units import UnitsWarning
1314
from astropy.wcs import FITSFixedWarning, WCS
1415
from astropy.io.registry import IORegistryError
1516
from astropy.modeling import models
16-
from astropy.nddata import StdDevUncertainty, InverseVariance, VarianceUncertainty
17+
from astropy.nddata import StdDevUncertainty, InverseVariance, VarianceUncertainty, Covariance
1718
from astropy.tests.helper import quantity_allclose
1819
from astropy.utils.exceptions import AstropyUserWarning
1920

@@ -584,6 +585,45 @@ def test_tabular_fits_writer(tmp_path, spectral_axis):
584585
spectrum.uncertainty.quantity)
585586

586587

588+
def test_tabular_fits_cov_io(tmp_path):
589+
# Create a fake spectrum
590+
wave = np.arange(4500., 5500., 1.) * u.AA
591+
flux = np.full(len(wave), 10) * u.Jy
592+
# And covariance
593+
cov_diags = [
594+
np.ones(1000, dtype=float),
595+
np.full(1000-1, 0.5, dtype=float),
596+
np.full(1000-2, 0.2, dtype=float),
597+
]
598+
cov = Covariance(sparse.diags(cov_diags, [0, 1, 2]), unit=u.Jy**2)
599+
600+
# Create the Spectrum1D
601+
spectrum = Spectrum1D(flux=flux, spectral_axis=wave, uncertainty=cov)
602+
# Simple test of ingestion
603+
assert np.array_equal(spectrum.uncertainty.toarray(), cov.toarray()), \
604+
'Covariance not included correctly'
605+
# Write it
606+
tmpfile = str(tmp_path / '_tst.fits')
607+
spectrum.write(tmpfile, format='tabular-fits')
608+
609+
# Check the output file
610+
with fits.open(tmpfile) as hdu:
611+
assert len(hdu) == 3, 'Should contain 3 extensions, primary, DATA, CORREL'
612+
assert np.array_equal([h.name for h in hdu], ['PRIMARY', 'DATA', 'CORREL']), \
613+
'Extension names are wrong'
614+
assert all([h.__class__.__name__ == 'BinTableHDU' for h in hdu[1:]]), \
615+
'Data extensions should both be BinTableHDU'
616+
assert len(hdu['CORREL'].data) == cov.nnz, 'Number of non-zero cov elements mismatch'
617+
618+
# Read it
619+
_spectrum = Spectrum1D.read(tmpfile)
620+
assert spectrum.flux.unit == _spectrum.flux.unit
621+
assert spectrum.spectral_axis.unit == _spectrum.spectral_axis.unit
622+
assert quantity_allclose(spectrum.spectral_axis, _spectrum.spectral_axis)
623+
assert quantity_allclose(spectrum.flux, _spectrum.flux)
624+
assert np.array_equal(spectrum.uncertainty.toarray(), _spectrum.uncertainty.toarray())
625+
626+
587627
@pytest.mark.parametrize("ndim", range(1, 4))
588628
@pytest.mark.parametrize("spectral_axis",
589629
['wavelength', 'frequency', 'energy', 'wavenumber'])

0 commit comments

Comments
 (0)