|
1 | | -from __future__ import division |
| 1 | +from __future__ import print_function, division, absolute_import |
2 | 2 |
|
3 | 3 | from astropy import modeling |
4 | 4 | from astropy.modeling import models, fitting |
| 5 | +from astropy.nddata import StdDevUncertainty |
5 | 6 | from ..spectra import Spectrum1D |
6 | 7 |
|
7 | 8 | import numpy as np |
| 9 | +from scipy import interpolate |
8 | 10 |
|
9 | 11 | import logging |
| 12 | +import warnings |
10 | 13 |
|
11 | | -__all__ = ['fit_continuum_generic'] |
| 14 | +__all__ = ['fit_continuum_generic', 'fit_continuum_linetools'] |
12 | 15 |
|
13 | 16 | def fit_continuum_generic(spectrum, |
14 | 17 | model=None, fitter=None, |
@@ -124,10 +127,320 @@ def fit_continuum_generic(spectrum, |
124 | 127 | good[sigma_difference > sigma_upper] = False |
125 | 128 | good[sigma_difference < -sigma_lower] = False |
126 | 129 |
|
127 | | - ## Update model iteratively: it is initialized at the previous fit's values |
128 | | - #model = new_model |
129 | | - |
130 | 130 | model = new_model |
131 | 131 | if full_output: |
132 | 132 | return model, good |
133 | 133 | return model |
| 134 | + |
| 135 | +def fit_continuum_linetools(spec, edges=None, ax=None, debug=False, kind="QSO", **kwargs): |
| 136 | + """ |
| 137 | + A direct port of the linetools continuum normalization algorithm by X Prochaska |
| 138 | + https://github.com/linetools/linetools/blob/master/linetools/analysis/continuum.py |
| 139 | + |
| 140 | + The only changes are switching to Scipy's Akima1D interpolator and changing the relevant syntax |
| 141 | + """ |
| 142 | + assert kind in ["QSO"], kind |
| 143 | + if not isinstance(spec, Spectrum1D): |
| 144 | + raise ValueError('The spectrum parameter must be a Spectrum1D object') |
| 145 | + |
| 146 | + ### To start, we define all the functions here to avoid namespace bloat, but this can be fixed later |
| 147 | + ### The goal is to have the same algorithm but with flexible wavelength chunks for other object types |
| 148 | + |
| 149 | + def make_chunks_qso(wa, redshift, divmult=1, forest_divmult=1, debug=False): |
| 150 | + """ Generate a series of wavelength chunks for use by |
| 151 | + prepare_knots, assuming a QSO spectrum. |
| 152 | + """ |
| 153 | + |
| 154 | + cond = np.isnan(wa) |
| 155 | + if np.any(cond): |
| 156 | + warnings.warn('Some wavelengths are NaN, ignoring these pixels.') |
| 157 | + wa = wa[~cond] |
| 158 | + assert len(wa) > 0 |
| 159 | + |
| 160 | + zp1 = 1 + redshift |
| 161 | + div = np.rec.fromrecords([(200. , 500. , 25), |
| 162 | + (500. , 800. , 25), |
| 163 | + (800. , 1190., 25), |
| 164 | + (1190., 1213., 4), |
| 165 | + (1213., 1230., 6), |
| 166 | + (1230., 1263., 6), |
| 167 | + (1263., 1290., 5), |
| 168 | + (1290., 1340., 5), |
| 169 | + (1340., 1370., 2), |
| 170 | + (1370., 1410., 5), |
| 171 | + (1410., 1515., 5), |
| 172 | + (1515., 1600., 15), |
| 173 | + (1600., 1800., 8), |
| 174 | + (1800., 1900., 5), |
| 175 | + (1900., 1940., 5), |
| 176 | + (1940., 2240., 15), |
| 177 | + (2240., 3000., 25), |
| 178 | + (3000., 6000., 80), |
| 179 | + (6000., 20000., 100), |
| 180 | + ], names=str('left,right,num')) |
| 181 | + |
| 182 | + div.num[2:] = np.ceil(div.num[2:] * divmult) |
| 183 | + div.num[:2] = np.ceil(div.num[:2] * forest_divmult) |
| 184 | + div.left *= zp1 |
| 185 | + div.right *= zp1 |
| 186 | + if debug: |
| 187 | + print(div.tolist()) |
| 188 | + temp = [np.linspace(left, right, n+1)[:-1] for left,right,n in div] |
| 189 | + edges = np.concatenate(temp) |
| 190 | + |
| 191 | + i0,i1,i2 = edges.searchsorted([wa[0], 1210*zp1, wa[-1]]) |
| 192 | + if debug: |
| 193 | + print(i0,i1,i2) |
| 194 | + return edges[i0:i2] |
| 195 | + |
| 196 | + def update_knots(knots, indices, fl, masked): |
| 197 | + """ Calculate the y position of each knot. |
| 198 | + |
| 199 | + Updates `knots` inplace. |
| 200 | + |
| 201 | + Parameters |
| 202 | + ---------- |
| 203 | + knots: list of [xpos, ypos, bool] with length N |
| 204 | + bool says whether the knot should kept unchanged. |
| 205 | + indices: list of (i0,i1) index pairs |
| 206 | + The start and end indices into fl and masked of each |
| 207 | + spectrum chunk (xpos of each knot are the chunk centres). |
| 208 | + fl, masked: arrays shape (M,) |
| 209 | + The flux, and boolean arrays showing which pixels are |
| 210 | + masked. |
| 211 | + """ |
| 212 | + |
| 213 | + iy, iflag = 1, 2 |
| 214 | + for iknot,(i1,i2) in enumerate(indices): |
| 215 | + if knots[iknot][iflag]: |
| 216 | + continue |
| 217 | + |
| 218 | + f0 = fl[i1:i2] |
| 219 | + m0 = masked[i1:i2] |
| 220 | + f1 = f0[~m0] |
| 221 | + knots[iknot][iy] = np.median(f1) |
| 222 | + |
| 223 | + def linear_co(wa, knots): |
| 224 | + """linear interpolation through the spline knots. |
| 225 | + |
| 226 | + Add extra points on either end to give |
| 227 | + a nice slope at the end points.""" |
| 228 | + wavc, mfl = list(zip(*knots))[:2] |
| 229 | + extwavc = ([wavc[0] - (wavc[1] - wavc[0])] + list(wavc) + |
| 230 | + [wavc[-1] + (wavc[-1] - wavc[-2])]) |
| 231 | + extmfl = ([mfl[0] - (mfl[1] - mfl[0])] + list(mfl) + |
| 232 | + [mfl[-1] + (mfl[-1] - mfl[-2])]) |
| 233 | + co = np.interp(wa, extwavc, extmfl) |
| 234 | + return co |
| 235 | + |
| 236 | + def Akima_co(wa, knots): |
| 237 | + """Akima interpolation through the spline knots.""" |
| 238 | + x,y,_ = zip(*knots) |
| 239 | + spl = interpolate.Akima1DInterpolator(x, y) |
| 240 | + return spl(wa) |
| 241 | + |
| 242 | + def remove_bad_knots(knots, indices, masked, fl, er, debug=False): |
| 243 | + """ Remove knots in chunks without any good pixels. Modifies |
| 244 | + inplace.""" |
| 245 | + idelknot = [] |
| 246 | + for iknot,(i,j) in enumerate(indices): |
| 247 | + if np.all(masked[i:j]) or np.median(fl[i:j]) <= 2*np.median(er[i:j]): |
| 248 | + if debug: |
| 249 | + print('Deleting knot', iknot, 'near {:.1f} Angstroms'.format( |
| 250 | + knots[iknot][0])) |
| 251 | + idelknot.append(iknot) |
| 252 | + |
| 253 | + for i in reversed(idelknot): |
| 254 | + del knots[i] |
| 255 | + del indices[i] |
| 256 | + |
| 257 | + def chisq_chunk(model, fl, er, masked, indices, knots, chithresh=1.5): |
| 258 | + """ Calc chisq per chunk, update knots flags inplace if chisq is |
| 259 | + acceptable. """ |
| 260 | + chisq = [] |
| 261 | + FLAG = 2 |
| 262 | + for iknot,(i1,i2) in enumerate(indices): |
| 263 | + if knots[iknot][FLAG]: |
| 264 | + continue |
| 265 | + |
| 266 | + f0 = fl[i1:i2] |
| 267 | + e0 = er[i1:i2] |
| 268 | + m0 = masked[i1:i2] |
| 269 | + f1 = f0[~m0] |
| 270 | + e1 = e0[~m0] |
| 271 | + mod0 = model[i1:i2] |
| 272 | + mod1 = mod0[~m0] |
| 273 | + resid = (mod1 - f1) / e1 |
| 274 | + chisq = np.sum(resid*resid) |
| 275 | + rchisq = chisq / len(f1) |
| 276 | + if rchisq < chithresh: |
| 277 | + #print (good reduced chisq in knot', iknot) |
| 278 | + knots[iknot][FLAG] = True |
| 279 | + |
| 280 | + def prepare_knots(wa, fl, er, edges, ax=None, debug=False): |
| 281 | + """ Make initial knots for the continuum estimation. |
| 282 | + |
| 283 | + Parameters |
| 284 | + ---------- |
| 285 | + wa, fl, er : arrays |
| 286 | + Wavelength, flux, error. |
| 287 | + edges : The edges of the wavelength chunks. Splines knots are to be |
| 288 | + places at the centre of these chunks. |
| 289 | + ax : Matplotlib Axes |
| 290 | + If not None, use to plot debugging info. |
| 291 | + |
| 292 | + Returns |
| 293 | + ------- |
| 294 | + knots, indices, masked |
| 295 | + * knots: A list of [x, y, flag] lists giving the x and y position |
| 296 | + of each knot. |
| 297 | + * indices: A list of tuples (i,j) giving the start and end index |
| 298 | + of each chunk. |
| 299 | + * masked: An array the same shape as wa. |
| 300 | + """ |
| 301 | + indices = wa.searchsorted(edges) |
| 302 | + indices = [(i0,i1) for i0,i1 in zip(indices[:-1],indices[1:])] |
| 303 | + wavc = [0.5*(w1 + w2) for w1,w2 in zip(edges[:-1],edges[1:])] |
| 304 | + |
| 305 | + knots = [[wavc[i], 0, False] for i in range(len(wavc))] |
| 306 | + |
| 307 | + masked = np.zeros(len(wa), bool) |
| 308 | + masked[~(er > 0)] = True |
| 309 | + |
| 310 | + # remove bad knots |
| 311 | + remove_bad_knots(knots, indices, masked, fl, er, debug=debug) |
| 312 | + |
| 313 | + if ax is not None: |
| 314 | + yedge = np.interp(edges, wa, fl) |
| 315 | + ax.vlines(edges, 0, yedge + 100, color='c', zorder=10) |
| 316 | + |
| 317 | + # set the knot flux values |
| 318 | + update_knots(knots, indices, fl, masked) |
| 319 | + |
| 320 | + if ax is not None: |
| 321 | + x,y = list(zip(*knots))[:2] |
| 322 | + ax.plot(x, y, 'o', mfc='none', mec='c', ms=10, mew=1, zorder=10) |
| 323 | + |
| 324 | + return knots, indices, masked |
| 325 | + |
| 326 | + |
| 327 | + def unmask(masked, indices, wa, fl, er, minpix=3): |
| 328 | + """ Forces each chunk to use at least minpix pixels. |
| 329 | + |
| 330 | + Sometimes all pixels can become masked in a chunk. We don't want |
| 331 | + this! This forces there to be at least minpix pixels used in each |
| 332 | + chunk. |
| 333 | + """ |
| 334 | + for iknot,(i,j) in enumerate(indices): |
| 335 | + #print(iknot, wa[i], wa[j], (~masked[i:j]).sum()) |
| 336 | + if np.sum(~masked[i:j]) < minpix: |
| 337 | + #print('unmasking pixels') |
| 338 | + # need to unmask minpix |
| 339 | + f0 = fl[i:j] |
| 340 | + e0 = er[i:j] |
| 341 | + ind = np.arange(i,j) |
| 342 | + f1 = f0[e0 > 0] |
| 343 | + isort = np.argsort(f1) |
| 344 | + ind1 = ind[e0 > 0][isort[-minpix:]] |
| 345 | + # print(wa[i], wa[j]) |
| 346 | + # print(wa[ind1]) |
| 347 | + masked[ind1] = False |
| 348 | + |
| 349 | + |
| 350 | + def estimate_continuum(s, knots, indices, masked, ax=None, maxiter=1000, |
| 351 | + nsig=1.5, debug=False): |
| 352 | + """ Iterate to estimate the continuum. |
| 353 | + """ |
| 354 | + count = 0 |
| 355 | + while True: |
| 356 | + if debug: |
| 357 | + print('iteration', count) |
| 358 | + update_knots(knots, indices, s.fl, masked) |
| 359 | + model = linear_co(s.wa, knots) |
| 360 | + model_a = Akima_co(s.wa, knots) |
| 361 | + chisq_chunk(model_a, s.fl, s.er, masked, |
| 362 | + indices, knots, chithresh=1) |
| 363 | + flags = list(zip(*knots))[-1] |
| 364 | + if np.all(flags): |
| 365 | + if debug: |
| 366 | + print('All regions have satisfactory fit, stopping') |
| 367 | + break |
| 368 | + # remove outliers |
| 369 | + c0 = ~masked |
| 370 | + resid = (model - s.fl) / s.er |
| 371 | + oldmasked = masked.copy() |
| 372 | + masked[(resid > nsig) & ~masked] = True |
| 373 | + unmask(masked, indices, s.wa, s.fl, s.er) |
| 374 | + if np.all(oldmasked == masked): |
| 375 | + if debug: |
| 376 | + print('No further points masked, stopping') |
| 377 | + break |
| 378 | + if count > maxiter: |
| 379 | + raise RuntimeError('Exceeded maximum iterations') |
| 380 | + |
| 381 | + count +=1 |
| 382 | + |
| 383 | + co = Akima_co(s.wa, knots) |
| 384 | + c0 = co <= 0 |
| 385 | + co[c0] = 0 |
| 386 | + |
| 387 | + if ax is not None: |
| 388 | + ax.plot(s.wa, linear_co(s.wa, knots), color='0.7', lw=2) |
| 389 | + ax.plot(s.wa, co, 'k', lw=2, zorder=10) |
| 390 | + x,y = list(zip(*knots))[:2] |
| 391 | + ax.plot(x, y, 'o', mfc='none', mec='k', ms=10, mew=1, zorder=10) |
| 392 | + |
| 393 | + return co |
| 394 | + |
| 395 | + ### Here starts the actual fitting |
| 396 | + ## Pull uncertainty from spectrum |
| 397 | + ## TODO this is very hacky right now |
| 398 | + if not hasattr(spec, "uncertainty"): |
| 399 | + logging.info("No uncertainty, assuming all are equal (continuum will probably fail)") |
| 400 | + error = np.ones(len(spec.wavelength.value)) |
| 401 | + else: |
| 402 | + if isinstance(spec.uncertainty, StdDevUncertainty): |
| 403 | + error = spec.uncertainty.array |
| 404 | + else: |
| 405 | + raise ValueError("Could not understand uncertainty type: {}".format( |
| 406 | + spec.uncertainty)) |
| 407 | + |
| 408 | + s = np.rec.fromarrays([spec.wavelength.value, |
| 409 | + spec.flux.value, |
| 410 | + error], names=["wa","fl","er"]) |
| 411 | + |
| 412 | + if edges is not None: |
| 413 | + edges = list(edges) |
| 414 | + elif kind.upper() == 'QSO': |
| 415 | + if 'redshift' in kwargs: |
| 416 | + z = kwargs['redshift'] |
| 417 | + elif 'redshift' in spec.meta: |
| 418 | + z = spec.meta['redshift'] |
| 419 | + else: |
| 420 | + raise RuntimeError( |
| 421 | + "I need the emission redshift for kind='qso'; please\ |
| 422 | + provide redshift using `redshift` keyword.") |
| 423 | + |
| 424 | + divmult = kwargs.get('divmult', 2) |
| 425 | + forest_divmult = kwargs.get('forest_divmult', 2) |
| 426 | + edges = make_chunks_qso( |
| 427 | + s.wa, z, debug=debug, divmult=divmult, |
| 428 | + forest_divmult=forest_divmult) |
| 429 | + |
| 430 | + |
| 431 | + if ax is not None: |
| 432 | + ax.plot(s.wa, s.fl, '-', color='0.4', drawstyle='steps-mid') |
| 433 | + ax.plot(s.wa, s.er, 'g') |
| 434 | + |
| 435 | + knots, indices, masked = prepare_knots(s.wa, s.fl, s.er, edges, |
| 436 | + ax=ax, debug=debug) |
| 437 | + |
| 438 | + # Note this modifies knots and masked inplace |
| 439 | + co = estimate_continuum(s, knots, indices, masked, ax=ax, debug=debug) |
| 440 | + |
| 441 | + if ax is not None: |
| 442 | + ax.plot(s.wa[~masked], s.fl[~masked], '.y') |
| 443 | + ymax = np.percentile(s.fl[~np.isnan(s.fl)], 95) |
| 444 | + ax.set_ylim(-0.02*ymax, 1.1*ymax) |
| 445 | + |
| 446 | + return co, [k[:2] for k in knots] |
0 commit comments