Skip to content

Commit 4661805

Browse files
committed
Added direct port of linetools continuum
1 parent 6e26d79 commit 4661805

File tree

1 file changed

+318
-5
lines changed

1 file changed

+318
-5
lines changed

specutils/manipulation/continuum.py

Lines changed: 318 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1-
from __future__ import division
1+
from __future__ import print_function, division, absolute_import
22

33
from astropy import modeling
44
from astropy.modeling import models, fitting
5+
from astropy.nddata import StdDevUncertainty
56
from ..spectra import Spectrum1D
67

78
import numpy as np
9+
from scipy import interpolate
810

911
import logging
12+
import warnings
1013

11-
__all__ = ['fit_continuum_generic']
14+
__all__ = ['fit_continuum_generic', 'fit_continuum_linetools']
1215

1316
def fit_continuum_generic(spectrum,
1417
model=None, fitter=None,
@@ -124,10 +127,320 @@ def fit_continuum_generic(spectrum,
124127
good[sigma_difference > sigma_upper] = False
125128
good[sigma_difference < -sigma_lower] = False
126129

127-
## Update model iteratively: it is initialized at the previous fit's values
128-
#model = new_model
129-
130130
model = new_model
131131
if full_output:
132132
return model, good
133133
return model
134+
135+
def fit_continuum_linetools(spec, edges=None, ax=None, debug=False, kind="QSO", **kwargs):
136+
"""
137+
A direct port of the linetools continuum normalization algorithm by X Prochaska
138+
https://github.com/linetools/linetools/blob/master/linetools/analysis/continuum.py
139+
140+
The only changes are switching to Scipy's Akima1D interpolator and changing the relevant syntax
141+
"""
142+
assert kind in ["QSO"], kind
143+
if not isinstance(spec, Spectrum1D):
144+
raise ValueError('The spectrum parameter must be a Spectrum1D object')
145+
146+
### To start, we define all the functions here to avoid namespace bloat, but this can be fixed later
147+
### The goal is to have the same algorithm but with flexible wavelength chunks for other object types
148+
149+
def make_chunks_qso(wa, redshift, divmult=1, forest_divmult=1, debug=False):
150+
""" Generate a series of wavelength chunks for use by
151+
prepare_knots, assuming a QSO spectrum.
152+
"""
153+
154+
cond = np.isnan(wa)
155+
if np.any(cond):
156+
warnings.warn('Some wavelengths are NaN, ignoring these pixels.')
157+
wa = wa[~cond]
158+
assert len(wa) > 0
159+
160+
zp1 = 1 + redshift
161+
div = np.rec.fromrecords([(200. , 500. , 25),
162+
(500. , 800. , 25),
163+
(800. , 1190., 25),
164+
(1190., 1213., 4),
165+
(1213., 1230., 6),
166+
(1230., 1263., 6),
167+
(1263., 1290., 5),
168+
(1290., 1340., 5),
169+
(1340., 1370., 2),
170+
(1370., 1410., 5),
171+
(1410., 1515., 5),
172+
(1515., 1600., 15),
173+
(1600., 1800., 8),
174+
(1800., 1900., 5),
175+
(1900., 1940., 5),
176+
(1940., 2240., 15),
177+
(2240., 3000., 25),
178+
(3000., 6000., 80),
179+
(6000., 20000., 100),
180+
], names=str('left,right,num'))
181+
182+
div.num[2:] = np.ceil(div.num[2:] * divmult)
183+
div.num[:2] = np.ceil(div.num[:2] * forest_divmult)
184+
div.left *= zp1
185+
div.right *= zp1
186+
if debug:
187+
print(div.tolist())
188+
temp = [np.linspace(left, right, n+1)[:-1] for left,right,n in div]
189+
edges = np.concatenate(temp)
190+
191+
i0,i1,i2 = edges.searchsorted([wa[0], 1210*zp1, wa[-1]])
192+
if debug:
193+
print(i0,i1,i2)
194+
return edges[i0:i2]
195+
196+
def update_knots(knots, indices, fl, masked):
197+
""" Calculate the y position of each knot.
198+
199+
Updates `knots` inplace.
200+
201+
Parameters
202+
----------
203+
knots: list of [xpos, ypos, bool] with length N
204+
bool says whether the knot should kept unchanged.
205+
indices: list of (i0,i1) index pairs
206+
The start and end indices into fl and masked of each
207+
spectrum chunk (xpos of each knot are the chunk centres).
208+
fl, masked: arrays shape (M,)
209+
The flux, and boolean arrays showing which pixels are
210+
masked.
211+
"""
212+
213+
iy, iflag = 1, 2
214+
for iknot,(i1,i2) in enumerate(indices):
215+
if knots[iknot][iflag]:
216+
continue
217+
218+
f0 = fl[i1:i2]
219+
m0 = masked[i1:i2]
220+
f1 = f0[~m0]
221+
knots[iknot][iy] = np.median(f1)
222+
223+
def linear_co(wa, knots):
224+
"""linear interpolation through the spline knots.
225+
226+
Add extra points on either end to give
227+
a nice slope at the end points."""
228+
wavc, mfl = list(zip(*knots))[:2]
229+
extwavc = ([wavc[0] - (wavc[1] - wavc[0])] + list(wavc) +
230+
[wavc[-1] + (wavc[-1] - wavc[-2])])
231+
extmfl = ([mfl[0] - (mfl[1] - mfl[0])] + list(mfl) +
232+
[mfl[-1] + (mfl[-1] - mfl[-2])])
233+
co = np.interp(wa, extwavc, extmfl)
234+
return co
235+
236+
def Akima_co(wa, knots):
237+
"""Akima interpolation through the spline knots."""
238+
x,y,_ = zip(*knots)
239+
spl = interpolate.Akima1DInterpolator(x, y)
240+
return spl(wa)
241+
242+
def remove_bad_knots(knots, indices, masked, fl, er, debug=False):
243+
""" Remove knots in chunks without any good pixels. Modifies
244+
inplace."""
245+
idelknot = []
246+
for iknot,(i,j) in enumerate(indices):
247+
if np.all(masked[i:j]) or np.median(fl[i:j]) <= 2*np.median(er[i:j]):
248+
if debug:
249+
print('Deleting knot', iknot, 'near {:.1f} Angstroms'.format(
250+
knots[iknot][0]))
251+
idelknot.append(iknot)
252+
253+
for i in reversed(idelknot):
254+
del knots[i]
255+
del indices[i]
256+
257+
def chisq_chunk(model, fl, er, masked, indices, knots, chithresh=1.5):
258+
""" Calc chisq per chunk, update knots flags inplace if chisq is
259+
acceptable. """
260+
chisq = []
261+
FLAG = 2
262+
for iknot,(i1,i2) in enumerate(indices):
263+
if knots[iknot][FLAG]:
264+
continue
265+
266+
f0 = fl[i1:i2]
267+
e0 = er[i1:i2]
268+
m0 = masked[i1:i2]
269+
f1 = f0[~m0]
270+
e1 = e0[~m0]
271+
mod0 = model[i1:i2]
272+
mod1 = mod0[~m0]
273+
resid = (mod1 - f1) / e1
274+
chisq = np.sum(resid*resid)
275+
rchisq = chisq / len(f1)
276+
if rchisq < chithresh:
277+
#print (good reduced chisq in knot', iknot)
278+
knots[iknot][FLAG] = True
279+
280+
def prepare_knots(wa, fl, er, edges, ax=None, debug=False):
281+
""" Make initial knots for the continuum estimation.
282+
283+
Parameters
284+
----------
285+
wa, fl, er : arrays
286+
Wavelength, flux, error.
287+
edges : The edges of the wavelength chunks. Splines knots are to be
288+
places at the centre of these chunks.
289+
ax : Matplotlib Axes
290+
If not None, use to plot debugging info.
291+
292+
Returns
293+
-------
294+
knots, indices, masked
295+
* knots: A list of [x, y, flag] lists giving the x and y position
296+
of each knot.
297+
* indices: A list of tuples (i,j) giving the start and end index
298+
of each chunk.
299+
* masked: An array the same shape as wa.
300+
"""
301+
indices = wa.searchsorted(edges)
302+
indices = [(i0,i1) for i0,i1 in zip(indices[:-1],indices[1:])]
303+
wavc = [0.5*(w1 + w2) for w1,w2 in zip(edges[:-1],edges[1:])]
304+
305+
knots = [[wavc[i], 0, False] for i in range(len(wavc))]
306+
307+
masked = np.zeros(len(wa), bool)
308+
masked[~(er > 0)] = True
309+
310+
# remove bad knots
311+
remove_bad_knots(knots, indices, masked, fl, er, debug=debug)
312+
313+
if ax is not None:
314+
yedge = np.interp(edges, wa, fl)
315+
ax.vlines(edges, 0, yedge + 100, color='c', zorder=10)
316+
317+
# set the knot flux values
318+
update_knots(knots, indices, fl, masked)
319+
320+
if ax is not None:
321+
x,y = list(zip(*knots))[:2]
322+
ax.plot(x, y, 'o', mfc='none', mec='c', ms=10, mew=1, zorder=10)
323+
324+
return knots, indices, masked
325+
326+
327+
def unmask(masked, indices, wa, fl, er, minpix=3):
328+
""" Forces each chunk to use at least minpix pixels.
329+
330+
Sometimes all pixels can become masked in a chunk. We don't want
331+
this! This forces there to be at least minpix pixels used in each
332+
chunk.
333+
"""
334+
for iknot,(i,j) in enumerate(indices):
335+
#print(iknot, wa[i], wa[j], (~masked[i:j]).sum())
336+
if np.sum(~masked[i:j]) < minpix:
337+
#print('unmasking pixels')
338+
# need to unmask minpix
339+
f0 = fl[i:j]
340+
e0 = er[i:j]
341+
ind = np.arange(i,j)
342+
f1 = f0[e0 > 0]
343+
isort = np.argsort(f1)
344+
ind1 = ind[e0 > 0][isort[-minpix:]]
345+
# print(wa[i], wa[j])
346+
# print(wa[ind1])
347+
masked[ind1] = False
348+
349+
350+
def estimate_continuum(s, knots, indices, masked, ax=None, maxiter=1000,
351+
nsig=1.5, debug=False):
352+
""" Iterate to estimate the continuum.
353+
"""
354+
count = 0
355+
while True:
356+
if debug:
357+
print('iteration', count)
358+
update_knots(knots, indices, s.fl, masked)
359+
model = linear_co(s.wa, knots)
360+
model_a = Akima_co(s.wa, knots)
361+
chisq_chunk(model_a, s.fl, s.er, masked,
362+
indices, knots, chithresh=1)
363+
flags = list(zip(*knots))[-1]
364+
if np.all(flags):
365+
if debug:
366+
print('All regions have satisfactory fit, stopping')
367+
break
368+
# remove outliers
369+
c0 = ~masked
370+
resid = (model - s.fl) / s.er
371+
oldmasked = masked.copy()
372+
masked[(resid > nsig) & ~masked] = True
373+
unmask(masked, indices, s.wa, s.fl, s.er)
374+
if np.all(oldmasked == masked):
375+
if debug:
376+
print('No further points masked, stopping')
377+
break
378+
if count > maxiter:
379+
raise RuntimeError('Exceeded maximum iterations')
380+
381+
count +=1
382+
383+
co = Akima_co(s.wa, knots)
384+
c0 = co <= 0
385+
co[c0] = 0
386+
387+
if ax is not None:
388+
ax.plot(s.wa, linear_co(s.wa, knots), color='0.7', lw=2)
389+
ax.plot(s.wa, co, 'k', lw=2, zorder=10)
390+
x,y = list(zip(*knots))[:2]
391+
ax.plot(x, y, 'o', mfc='none', mec='k', ms=10, mew=1, zorder=10)
392+
393+
return co
394+
395+
### Here starts the actual fitting
396+
## Pull uncertainty from spectrum
397+
## TODO this is very hacky right now
398+
if not hasattr(spec, "uncertainty"):
399+
logging.info("No uncertainty, assuming all are equal (continuum will probably fail)")
400+
error = np.ones(len(spec.wavelength.value))
401+
else:
402+
if isinstance(spec.uncertainty, StdDevUncertainty):
403+
error = spec.uncertainty.array
404+
else:
405+
raise ValueError("Could not understand uncertainty type: {}".format(
406+
spec.uncertainty))
407+
408+
s = np.rec.fromarrays([spec.wavelength.value,
409+
spec.flux.value,
410+
error], names=["wa","fl","er"])
411+
412+
if edges is not None:
413+
edges = list(edges)
414+
elif kind.upper() == 'QSO':
415+
if 'redshift' in kwargs:
416+
z = kwargs['redshift']
417+
elif 'redshift' in spec.meta:
418+
z = spec.meta['redshift']
419+
else:
420+
raise RuntimeError(
421+
"I need the emission redshift for kind='qso'; please\
422+
provide redshift using `redshift` keyword.")
423+
424+
divmult = kwargs.get('divmult', 2)
425+
forest_divmult = kwargs.get('forest_divmult', 2)
426+
edges = make_chunks_qso(
427+
s.wa, z, debug=debug, divmult=divmult,
428+
forest_divmult=forest_divmult)
429+
430+
431+
if ax is not None:
432+
ax.plot(s.wa, s.fl, '-', color='0.4', drawstyle='steps-mid')
433+
ax.plot(s.wa, s.er, 'g')
434+
435+
knots, indices, masked = prepare_knots(s.wa, s.fl, s.er, edges,
436+
ax=ax, debug=debug)
437+
438+
# Note this modifies knots and masked inplace
439+
co = estimate_continuum(s, knots, indices, masked, ax=ax, debug=debug)
440+
441+
if ax is not None:
442+
ax.plot(s.wa[~masked], s.fl[~masked], '.y')
443+
ymax = np.percentile(s.fl[~np.isnan(s.fl)], 95)
444+
ax.set_ylim(-0.02*ymax, 1.1*ymax)
445+
446+
return co, [k[:2] for k in knots]

0 commit comments

Comments
 (0)