-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Finish most of the preprocess #2
Open
AnningGao
wants to merge
21
commits into
ZechangSun:main
Choose a base branch
from
AnningGao:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 6 commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
f263296
add the dla_cnn folder
AnningGao 058bedb
add the preprocess.py
AnningGao cad7d08
add catalog.py
AnningGao 6b92a32
Add error processing
AnningGao 30a8dbf
add the rebin function and the comments for each function.
AnningGao 901704c
Merge remote-tracking branch 'upstream/main' into main
AnningGao 7892ec0
Delete the dla_cnn but keep the necessary .py
AnningGao 969224d
optimize the organization of files
AnningGao c88d06f
Optimize the functions again.
AnningGao 3fef38d
optimize the clip() function.
AnningGao d644b74
improve the way of loading wavelengths
AnningGao dbb1deb
Add the DLA-masking function.
AnningGao 9a3ce5e
Add the modified normalize funciton
AnningGao 8018aca
Add the fuction that can calculate bolometric luminosity.
AnningGao 570a594
Merge branch 'ZechangSun:main' into main
AnningGao 581b966
In rebin(), change the definition of step.
AnningGao 7757174
Change the definition of the parameter loglam_start in rebin()
AnningGao 7195012
Fix the bug in clip()
AnningGao 1ea9bca
Improve clip() in preprocess
AnningGao 873ffef
Merge branch 'ZechangSun:main' into main
AnningGao 0fae41a
Merge branch 'ZechangSun:main' into main
AnningGao File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import os | ||
import numpy as np | ||
import pandas as pd | ||
from tqdm import tqdm | ||
from dla_cnn.desi.DesiMock import DesiMock | ||
|
||
# Prepare the wavelengths of some important emission lines | ||
# Wavelengths here may be WRONG. Please check them before using. | ||
LyALPHA = 1215.6701 | ||
LyBETA = 1025.7220 | ||
MgII1 = 1482.890 | ||
MgII2 = 1737.628 | ||
CIV1 = 1550 | ||
CIV2 = 1910 | ||
lams = np.array([LyBETA, LyALPHA, MgII1, CIV1, MgII2, CIV2]) | ||
names = ['LyBETA', 'LyALPHA', 'MgII1', 'CIV1', 'MgII2', 'CIV2'] | ||
lines = {} | ||
for i, name in enumerate(names): | ||
lines[name] = lams[i] | ||
|
||
# prepare for the data path | ||
def generate_suffix(prefix): | ||
suffix = {} | ||
for preid in os.listdir(prefix): | ||
suffix[preid] = os.listdir(prefix+preid) | ||
return suffix | ||
|
||
|
||
def generate_seperated_catalog(prefix): | ||
|
||
# prefix = './desi-0.2-100/spectra-16/' # this need to be specialized | ||
suffix = generate_suffix(prefix=prefix) | ||
# generate a catalog (csv format) under each folder | ||
|
||
data = {} | ||
for suffix1 in tqdm(suffix.keys()): | ||
for suffix2 in tqdm(suffix[suffix1]): | ||
path = prefix + suffix1 + '/' + suffix2 + '/' | ||
if len(os.listdir(path)) == 3: | ||
path_spectra = path + 'spectra-16-' + suffix2 +'.fits' | ||
path_truth = path + 'truth-16-' + suffix2 +'.fits' | ||
path_zbest = path + 'zbest-16-' + suffix2 +'.fits' | ||
data = DesiMock() | ||
data.read_fits_file(path_spectra, path_truth, path_zbest) | ||
total = pd.DataFrame() | ||
for id in data.data: | ||
sline = data.get_sightline(id=id) | ||
wav_max, wav_min = 10**np.max(sline.loglam - np.log10(1+sline.z_qso)), 10**np.min(sline.loglam - np.log10(1+sline.z_qso)) | ||
info = pd.DataFrame() | ||
info['id'] = np.ones(1, dtype='i8') * int(id) | ||
info['z_qso'] = np.ones(1) * sline.z_qso | ||
info['snr'] = np.ones(1) * sline.s2n | ||
for name in names: | ||
info[name] = [lines[name] >= wav_min and lines[name] <= wav_max] | ||
total = pd.concat([total, info]) | ||
total['file'] = np.ones(len(total), dtype='i8') * int(suffix2) | ||
total = total[['file', 'id', 'z_qso', 'snr', 'LyBETA', 'LyALPHA', 'MgII1', 'CIV1', 'MgII2', 'CIV2']] | ||
total.to_csv(prefix + suffix1 + '/' + suffix2 + '/catalog.csv', index=False) | ||
|
||
# delete all the catalog | ||
def delete_all_calalog(prefix): | ||
suffix = generate_suffix(prefix=prefix) | ||
for suffix1 in suffix.keys(): | ||
for suffix2 in suffix[suffix1]: | ||
path = prefix + suffix1 + '/' + suffix2 + '/' | ||
files = os.listdir(path) | ||
if len(files) == 4: | ||
for file in files: | ||
if '.csv' in file: | ||
os.remove(path + file) | ||
if 'catalog_total.csv' in os.listdir(prefix): | ||
os.remove(prefix+'catalog_total.csv') | ||
|
||
# generate a total catalog | ||
# this should be done AFTER the catalog of each folder has been generated | ||
def generate_total_catalog(prefix): | ||
suffix = generate_suffix(prefix=prefix) | ||
catalog = pd.DataFrame() | ||
for suffix1 in suffix.keys(): | ||
for suffix2 in suffix[suffix1]: | ||
path = prefix + suffix1 + '/' + suffix2 + '/' | ||
files = os.listdir(path) | ||
if len(files) == 4: | ||
for file in files: | ||
if '.csv' in file: | ||
this = pd.read_csv(path+file) | ||
catalog = pd.concat([catalog, this]) | ||
|
||
catalog.to_csv(prefix+'catalog_total.csv') |
Submodule dla_cnn
added at
06ea0d
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,302 @@ | ||
from dla_cnn.desi.DesiMock import DesiMock | ||
import numpy as np | ||
import pandas as pd | ||
import matplotlib.pyplot as plt | ||
from astropy.io import fits | ||
from scipy.interpolate import interp1d | ||
from astropy.stats import sigma_clip | ||
from scipy.optimize import curve_fit | ||
|
||
LyALPHA = 1215.6701 | ||
LyBETA = 1025.7220 | ||
MgII1 = 1482.890 | ||
MgII2 = 1737.628 | ||
CIV1 = 1550 | ||
CIV2 = 1910 | ||
|
||
def overlap(sightline, data:DesiMock, id): | ||
''' | ||
Deal with the overlapping area between different cameras so that the result will not contain the overlaps. | ||
|
||
--- | ||
|
||
### Parameters | ||
`sightline`: the spectra that is waiting to be rebinned. It must have been clipped by `clip()`. | ||
`data` and `id`: the original dataset from which the sightline is extracted by `DesiMock().get_sightline()` and the id of this sightline. | ||
''' | ||
|
||
|
||
def get_spilt_point(data:DesiMock, id): | ||
line_b = data.get_sightline(id=id, camera='b') | ||
line_z = data.get_sightline(id=id, camera='z') | ||
line_r = data.get_sightline(id=id, camera='r') | ||
spilt_loglam_br = np.average([np.max(line_b.loglam), np.min(line_r.loglam)]) | ||
spilt_loglam_rz = np.average([np.min(line_z.loglam), np.max(line_r.loglam)]) | ||
return spilt_loglam_br, spilt_loglam_rz | ||
|
||
def get_between(array, max, min, maxif=False, minif=False): | ||
if maxif: | ||
if minif: | ||
return np.intersect1d(np.where(array>=min)[0], np.where(array<=max)[0]) | ||
else: | ||
return np.intersect1d(np.where(array>min)[0], np.where(array<=max)[0]) | ||
else: | ||
if minif: | ||
return np.intersect1d(np.where(array>=min)[0], np.where(array<max)[0]) | ||
else: | ||
return np.intersect1d(np.where(array>min)[0], np.where(array<max)[0]) | ||
spilt_loglam_br, spilt_loglam_rz = get_spilt_point(data, id) | ||
line_r = data.get_sightline(id=id, camera='r') | ||
line_b = data.get_sightline(id=id, camera='b') | ||
line_z = data.get_sightline(id=id, camera='z') | ||
|
||
loglam_r = line_r.loglam[0:np.where(line_r.loglam == np.max(line_r.loglam))[0][0]] | ||
indice_r = get_between(loglam_r, max=spilt_loglam_rz, min=spilt_loglam_br) | ||
indice_b = get_between(line_b.loglam, max=spilt_loglam_br, min=0, maxif=True) | ||
indice_z = get_between(line_z.loglam, max=np.Infinity, min=spilt_loglam_rz, minif=True) | ||
|
||
loglam_r, loglam_b, loglam_z = loglam_r[indice_r], line_b.loglam[indice_b], line_z.loglam[indice_z] | ||
flux_r, flux_b, flux_z = line_r.flux[indice_r], line_b.flux[indice_b], line_z.flux[indice_z] | ||
error_r, error_b, error_z = line_r.error[indice_r], line_b.error[indice_b], line_z.error[indice_z] | ||
sightline.loglam = np.concatenate((loglam_b, loglam_r, loglam_z)) | ||
sightline.flux = np.concatenate((flux_b, flux_r, flux_z)) | ||
sightline.error = np.concatenate((error_b, error_r, error_z)) | ||
|
||
|
||
# def clip(sightline, unit, plot=False): | ||
# wavs = 10**sightline.loglam | ||
# flux = sightline.flux | ||
# zero_point = np.where(wavs / (1+sightline.z_qso) >= LyALPHA)[0][0] | ||
# i = 0 | ||
|
||
# wavs_new = wavs[0:zero_point] | ||
# flux_new = flux[0:zero_point] | ||
|
||
# if plot: | ||
# sigmaup, sigmadown = np.zeros(zero_point), np.zeros(zero_point) | ||
|
||
# judge = True | ||
# while judge: | ||
# start = zero_point + i * unit | ||
# end = zero_point + (i+1) * unit | ||
# if end >= len(wavs): | ||
# end = len(wavs) - 1 | ||
# judge = False | ||
# if start == end: | ||
# break | ||
# subwavs = wavs[start:end] | ||
# subflux = flux[start:end] | ||
# mask = np.invert(sigma_clip(subflux, sigma=3).mask) | ||
# flux_cliped = subflux[mask] | ||
# wavs_cliped = subwavs[mask] | ||
# wavs_new = np.concatenate((wavs_new, wavs_cliped)) | ||
# flux_new = np.concatenate((flux_new, flux_cliped)) | ||
# if plot: | ||
# sigma = np.std(subflux) | ||
# mean = np.average(subflux) | ||
# sigmaup = np.concatenate((sigmaup, np.ones_like(wavs_cliped)*(mean+3*sigma))) | ||
# sigmadown = np.concatenate((sigmadown, np.ones_like(wavs_cliped)*(mean-3*sigma))) | ||
# i = i + 1 | ||
|
||
# sightline.loglam_cliped = np.log10(wavs_new) | ||
# sightline.flux_cliped = flux_new | ||
|
||
# if plot: | ||
# plt.plot(wavs, flux) | ||
# plt.plot(wavs_new[zero_point:], sigmaup[zero_point:]) | ||
# plt.plot(wavs_new[zero_point:], sigmadown[zero_point:]) | ||
# plt.axvline(LyALPHA*(1+sightline.z_qso), linestyle='--') | ||
|
||
def clip(sightline, unit_default=100, slope=2e-3, ratio=0.5, plot=False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, I don't think it's necessary to add |
||
''' | ||
Clip the abnormal points in the spectra. | ||
|
||
--- | ||
|
||
### Parameters | ||
`sightline`: the spectra that is waiting to be rebinned. It must have been clipped by `clip()`. | ||
`unit_default`: the default length of each bin that is used to conduct sigmaclip. | ||
`slope`: the critical value that decides whether a smaller bin will be applied. If the fit slope of current bin exceeds this value, a smaller bin will be used. | ||
`ratio`: how small the smaller bin will be compared with the default bin length. | ||
`plot`: if true, this function can generate a plot that shows every bin's clipping upper and lower limit as well as the original spectrum that has not been clipped, which can clearly show which points are clipped. Ofter used in jupyter notebook. | ||
''' | ||
|
||
def line_fit(xdata, ydata): | ||
|
||
def linear(x, *args): | ||
return args[0] * x + args[1] | ||
|
||
popt, pcov = curve_fit(f=linear, xdata=xdata, ydata=ydata, p0=(1e-3, 0)) | ||
return popt | ||
|
||
wavs = 10**sightline.loglam | ||
flux = sightline.flux | ||
error = sightline.error | ||
zero_point = np.where(wavs / (1+sightline.z_qso) >= LyALPHA)[0][0] | ||
sightline.points_num = len(wavs) | ||
|
||
wavs_new = wavs[0:zero_point] | ||
flux_new = flux[0:zero_point] | ||
error_new = error[0:zero_point] | ||
|
||
if plot: | ||
sigmaup, sigmadown = np.zeros(zero_point), np.zeros(zero_point) | ||
|
||
unit = unit_default | ||
judge, start, end = True, zero_point, zero_point + unit | ||
while judge: | ||
|
||
if end >= len(wavs): | ||
end = len(wavs) - 1 | ||
judge = False | ||
if start == end: | ||
break | ||
subwavs, subflux, suberror = wavs[start:end], flux[start:end], error[start:end] | ||
if end - start >= 3: | ||
slope_fit = line_fit(subwavs, subflux)[0] | ||
|
||
if np.abs(slope_fit) >= slope: | ||
unit = int(unit_default*ratio) | ||
end = start + unit | ||
if end >= len(wavs): | ||
end = len(wavs) - 1 | ||
judge = False | ||
if start == end: | ||
break | ||
subwavs, subflux, suberror = wavs[start:end], flux[start:end], error[start:end] | ||
|
||
elif np.abs(slope_fit) < slope and unit != unit_default: | ||
unit = unit_default | ||
end = start + unit | ||
if end >= len(wavs): | ||
end = len(wavs) - 1 | ||
judge = False | ||
if start == end: | ||
break | ||
subwavs, subflux, suberror = wavs[start:end], flux[start:end], error[start:end] | ||
|
||
mask = np.invert(sigma_clip(subflux, sigma=3).mask) | ||
flux_cliped = subflux[mask] | ||
wavs_cliped = subwavs[mask] | ||
error_cliped = suberror[mask] | ||
else: | ||
flux_cliped, wavs_cliped, error_cliped = subflux, subwavs, suberror | ||
wavs_new = np.concatenate((wavs_new, wavs_cliped)) | ||
flux_new = np.concatenate((flux_new, flux_cliped)) | ||
error_new = np.concatenate((error_new, error_cliped)) | ||
start = start + unit | ||
end = end + unit | ||
if plot: | ||
sigma = np.std(subflux) | ||
mean = np.average(subflux) | ||
sigmaup = np.concatenate((sigmaup, np.ones_like(wavs_cliped)*(mean+3*sigma))) | ||
sigmadown = np.concatenate((sigmadown, np.ones_like(wavs_cliped)*(mean-3*sigma))) | ||
|
||
sightline.loglam_cliped = np.log10(wavs_new) | ||
sightline.flux_cliped = flux_new | ||
sightline.error_cliped = error_new | ||
|
||
if plot: | ||
plt.plot(wavs, flux) | ||
plt.plot(wavs_new[zero_point:], sigmaup[zero_point:]) | ||
plt.plot(wavs_new[zero_point:], sigmadown[zero_point:]) | ||
plt.axvline(LyALPHA*(1+sightline.z_qso), linestyle='--') | ||
plt.show() | ||
|
||
def get_dlnlambda(sightline): | ||
''' | ||
Generate the step length of restframe grid used in `rebin()`. | ||
|
||
---- | ||
|
||
### Attention | ||
For the mock data, this function generate the same value for all the spectrum. I am not sure whether this characristic will remain the same for the actual data. | ||
''' | ||
wavelength = 10**sightline.loglam_cliped | ||
pixels_number = sightline.points_num | ||
max_wavelength = wavelength[-1] | ||
min_wavelength = wavelength[0] | ||
dlnlambda = np.log(max_wavelength/min_wavelength)/pixels_number | ||
return dlnlambda | ||
|
||
def rebin(sightline, loglam_start, dlnlambda, max_index:int=int(1e6)): | ||
''' | ||
Rebin to the same restframe grid. | ||
|
||
-------- | ||
|
||
### Parameters: | ||
`sightline`: the spectra that is waiting to be rebinned. It must have been clipped by `clip()`. | ||
`loglam_start`: the start point of this restframe grid. Usually it is the start RESTFRAME wavelength of the spectra whose redshift is the largest. | ||
`dlnlambda`: the step length of this restframe grid. It can be derived with `get_dlnlambda()`. | ||
`max_index`: because different spectra has different range of wavelength in restframe, so it is necessary to make the restframe grid large enough to contain all of these spectrum. This parameter is the size of this grid, which is usually very big. You can change the default value if you think it is too big. | ||
''' | ||
def get_between(array, max, min): | ||
if max >= min: | ||
if max >= np.min(array) and min <= np.max(array): | ||
return np.intersect1d(np.where(array>=min)[0], np.where(array<=max)[0]) | ||
else: | ||
raise ValueError('min~max out of range') | ||
else: | ||
raise ValueError('max < min, will return nothing') | ||
|
||
wavelength = 10**sightline.loglam_cliped / (1+sightline.z_qso) | ||
flux = sightline.flux_cliped | ||
error = sightline.error_cliped | ||
|
||
max_wavelength = wavelength[-1] | ||
min_wavelength = wavelength[0] | ||
new_wavelength_total = 10**loglam_start * np.exp(dlnlambda * np.arange(max_index)) | ||
indices = get_between(new_wavelength_total, max_wavelength, min_wavelength) | ||
new_wavelength = new_wavelength_total[indices] | ||
|
||
# 以下抄了学长的代码hhh | ||
npix = len(wavelength) | ||
wvh = (wavelength + np.roll(wavelength, -1)) / 2. | ||
wvh[npix - 1] = wavelength[npix - 1] + \ | ||
(wavelength[npix - 1] - wavelength[npix - 2]) / 2. | ||
dwv = wvh - np.roll(wvh, 1) | ||
dwv[0] = 2 * (wvh[0] - wavelength[0]) | ||
med_dwv = np.median(dwv) | ||
|
||
cumsum = np.cumsum(flux * dwv) | ||
cumvar = np.cumsum(error * dwv, dtype=np.float64) | ||
|
||
fcum = interp1d(wvh, cumsum,bounds_error=False) | ||
fvar = interp1d(wvh, cumvar,bounds_error=False) | ||
|
||
nnew = len(new_wavelength) | ||
nwvh = (new_wavelength + np.roll(new_wavelength, -1)) / 2. | ||
nwvh[nnew - 1] = new_wavelength[nnew - 1] + \ | ||
(new_wavelength[nnew - 1] - new_wavelength[nnew - 2]) / 2. | ||
|
||
bwv = np.zeros(nnew + 1) | ||
bwv[0] = new_wavelength[0] - (new_wavelength[1] - new_wavelength[0]) / 2. | ||
bwv[1:] = nwvh | ||
|
||
newcum = fcum(bwv) | ||
newvar = fvar(bwv) | ||
|
||
new_fx = (np.roll(newcum, -1) - newcum)[:-1] | ||
new_var = (np.roll(newvar, -1) - newvar)[:-1] | ||
|
||
# Normalize (preserve counts and flambda) | ||
new_dwv = bwv - np.roll(bwv, 1) | ||
new_fx = new_fx / new_dwv[1:] | ||
# Preserve S/N (crudely) | ||
med_newdwv = np.median(new_dwv) | ||
new_var = new_var / (med_newdwv/med_dwv) / new_dwv[1:] | ||
|
||
left = 0 | ||
while np.isnan(new_fx[left])|np.isnan(new_var[left]): | ||
left = left+1 | ||
right = len(new_fx) | ||
while np.isnan(new_fx[right-1])|np.isnan(new_var[right-1]): | ||
right = right-1 | ||
|
||
test = np.sum((np.isnan(new_fx[left:right]))|(np.isnan(new_var[left:right]))) | ||
assert test==0, 'Missing value in this spectra!' | ||
|
||
sightline.loglam_rebin_restframe = np.log10(new_wavelength[left:right]) | ||
sightline.flux_rebin_restframe = new_fx[left:right] | ||
sightline.error_rebin_restframe = new_var[left:right] |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_split_point
andget_between
outside theoverlap
functionget_split_point
andget_between
function