diff --git a/CHANGES.rst b/CHANGES.rst index 793b13105..a20aa28d1 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -6,9 +6,15 @@ New Features - Added ``CompoundSpectralRegion`` class to enable combining ``SpectralRegion`` with operators. [#1282] +- Using ``Spectrum.shift_spectrum_to()`` will now also update the WCS when applying the redshift or radial + velocity if a FITS WCS is present. A GWCS will be replaced with the original stored in an + ``_original_wcs`` attribute. [#1287] + Bug Fixes ^^^^^^^^^ +- Doing arithmetic with ``Spectrum`` objects no longer improperly redshifts the spectral axis in some cases. [#1287] + Other Changes and Additions ^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/specutils/spectra/spectrum.py b/specutils/spectra/spectrum.py index 01d73dffd..e10d11cec 100644 --- a/specutils/spectra/spectrum.py +++ b/specutils/spectra/spectrum.py @@ -7,6 +7,7 @@ from astropy.utils.decorators import lazyproperty from astropy.utils.decorators import deprecated from astropy.nddata import NDUncertainty, NDIOMixin, NDArithmeticMixin, NDDataArray +from astropy.wcs import WCS from gwcs.wcs import WCS as GWCS from .spectral_axis import SpectralAxis @@ -802,11 +803,16 @@ def shift_spectrum_to(self, *, redshift=None, radial_velocity=None): raise ValueError( "Only one of redshift or radial_velocity can be used." ) + + old_redshift = self.redshift + if redshift is not None: - new_spec_coord = self.spectral_axis.with_radial_velocity_shift( + # with_radial_velocity_shift(redshift) looks wrong but astropy SpectralCoord handles + # redshift input to that method + new_spectral_axis = self.spectral_axis.with_radial_velocity_shift( -self.spectral_axis.radial_velocity ).with_radial_velocity_shift(redshift) - self._spectral_axis = new_spec_coord + self._spectral_axis = new_spectral_axis elif radial_velocity is not None: if radial_velocity is not None: if not radial_velocity.unit.is_equivalent(u.km/u.s): @@ -816,9 +822,32 @@ def shift_spectrum_to(self, *, redshift=None, radial_velocity=None): -self.spectral_axis.radial_velocity ).with_radial_velocity_shift(radial_velocity) self._spectral_axis = new_spectral_axis + redshift = radial_velocity.to(u.Unit(''), u.doppler_redshift()) else: raise ValueError("One of redshift or radial_velocity must be set.") + # Also store an updated WCS if we can update it. + if isinstance(self.wcs, WCS): + wcs_spectral_index = self.wcs.wcs.spec + 1 + h = self.wcs.to_header() + spec_ctype = h[f'CTYPE{wcs_spectral_index}'] + z_factor = (1 + redshift) / (1 + old_redshift) + if spec_ctype[0:4] != 'WAVE': + # Frequency, wavenumber and energy all invert this factor + z_factor = 1 / z_factor + h[f'CRVAL{wcs_spectral_index}'] *= z_factor + h[f'PC{wcs_spectral_index}_{wcs_spectral_index}'] *= z_factor + # WCS doesn't allow updating, but you can set it to None and then assign a new value + self.wcs = None + self.wcs = WCS(h) + else: + # I don't know how to update a GWCS cleanly so for now, we replace it and store the + # old one to retain any spatial information in the original + self._original_wcs = self.wcs + self.wcs = None + self.wcs = gwcs_from_array(new_spectral_axis, self.flux.shape, + spectral_axis_index=self.spectral_axis_index) + def with_spectral_axis_last(self): """ Convenience method to return a new copy of the Spectrum with the spectral axis last. @@ -827,12 +856,6 @@ def with_spectral_axis_last(self): mask=self.mask, uncertainty=self.uncertainty, redshift=self.redshift, move_spectral_axis="last") - def _return_with_redshift(self, result): - # We need actual spectral units to shift - if result.spectral_axis.unit not in ('', 'pix', 'pixels'): - result.shift_spectrum_to(redshift=self.redshift) - return result - def _check_input(self, other, force_quantity=False): # NDArithmetic mixin will try to turn other into a Spectrum, which will fail # sometimes because of not specifiying the spectral axis index @@ -856,12 +879,13 @@ def _do_flux_arithmetic(self, other, arith_func): func = getattr(operand1, arith_func) new_flux = func(other) - return self._return_with_redshift(Spectrum(new_flux.data*new_flux.unit, - wcs=self.wcs, - meta=self.meta, - uncertainty=new_flux.uncertainty, - mask = new_flux.mask, - spectral_axis_index=self.spectral_axis_index)) + return Spectrum(new_flux.data*new_flux.unit, + wcs=self.wcs, + meta=self.meta, + uncertainty=new_flux.uncertainty, + mask = new_flux.mask, + spectral_axis_index=self.spectral_axis_index, + redshift = self.redshift) def __add__(self, other): other = self._check_input(other, force_quantity=True) diff --git a/specutils/tests/test_arithmetic.py b/specutils/tests/test_arithmetic.py index cac1574c0..236dd9e14 100644 --- a/specutils/tests/test_arithmetic.py +++ b/specutils/tests/test_arithmetic.py @@ -135,6 +135,16 @@ def test_with_constants(simulated_spectra): assert_quantity_allclose(r_sub_result.flux, l_sub_result.flux) +def test_arithmetic_with_redshift(): + spec1 = Spectrum(flux=np.ones(20) * u.Jy, + spectral_axis=np.arange(1, 21) * u.nm, + redshift=1) + spec2 = spec1 * 2 + + assert_quantity_allclose(spec2.spectral_axis, spec1.spectral_axis) + assert_quantity_allclose(spec2.flux, 2*u.Jy) + + def test_arithmetic_after_shift(simulated_spectra): spec = simulated_spectra.s1_um_mJy_e1 spec.shift_spectrum_to(redshift = 1)