Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the catalog generation and part of preprocessing #1

Open
wants to merge 7 commits into
base: desi
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions QFA/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

# Base config files
_C.BASE = ['']
_C.TYPE = "train"
_C.GPU = 0

#------------------
Expand Down Expand Up @@ -46,6 +47,7 @@
#------------------
_C.MODEL = CN()
_C.MODEL.NH = 8
_C.MODEL.TAU = 'becker'
_C.MODEL.RESUME = ''


Expand All @@ -61,6 +63,7 @@
_C.TRAIN.WINDOW_LENGTH_FOR_MU = 16



def _update_config_from_file(config, cfg_file):
config.defrost()
with open(cfg_file, 'r') as f:
Expand Down Expand Up @@ -131,6 +134,10 @@ def _check_args(name):
config.DATA.NPROCS = args.nprocs
if _check_args('validation'):
config.DATA.VALIDATION = args.validation
if _check_args('tau'):
config.MODEL.TAU = args.tau
if _check_args('type'):
config.TYPE = args.type


config.freeze()
Expand Down
84 changes: 61 additions & 23 deletions QFA/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import torch
import numpy as np
import pandas as pd
import warnings
from tqdm import tqdm
import multiprocessing
from .utils import smooth
from .utils import tau as default_tau
from typing import Tuple, Callable
from .utils import tau as taufunc
from typing import Tuple, Callable, List
from functools import partial
from yacs.config import CfgNode as CN


Expand All @@ -15,7 +17,7 @@

def _read_npz_file(path: str)->Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
"""
load spectra from npz files
load spectra from npz file
NOTE:
(1) all spectra should have same wavelength grid
(2) spectra are preprocessed as in the paper
Expand All @@ -25,21 +27,32 @@ def _read_npz_file(path: str)->Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
flux, error, z = file['flux'], file['error'], float(file['z'])
mask = (flux!=-999.)&(error!=-999.)
file.close()
return flux, error, mask, z
return flux, error, mask, z, path


def _read_from_catalog(flux, error, mask, zqso, catalog, data_dir, num, snr_min, snr_max, z_min, z_max, num_mask, nprocs):
catalog = pd.read_csv(catalog)
criteria = (catalog['snr']>=snr_min) & (catalog['snr']<=snr_max) & (catalog['z']>=z_min) & (catalog['z']<=z_max) & (catalog['num_mask']<=num_mask)
files = np.random.choice(catalog['file'][criteria].values, size=(num,), replace=(np.sum(criteria)<num))
paths = [os.path.join(data_dir, x) for x in files]
def _read_npz_files(flux: List[np.ndarray], error: List[np.ndarray], mask: List[np.ndarray], zqso: List[np.ndarray], pathlist: List[np.ndarray], paths: str, nprocs: int)->Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
"""
load spectra from npz files
"""
with multiprocessing.Pool(nprocs) as p:
data = p.map(_read_npz_file, paths)
for f, e, m, z in tqdm(data):
for f, e, m, z, p in tqdm(data):
flux.append(f)
error.append(e)
mask.append(m)
zqso.append(z)
pathlist.append(p)


def _read_from_catalog(flux, error, mask, zqso, pathlist, catalog, data_dir, num, snr_min, snr_max, z_min, z_max, num_mask, nprocs, output_dir, prefix='train'):
catalog = pd.read_csv(catalog)
criteria = (catalog['snr']>=snr_min) & (catalog['snr']<=snr_max) & (catalog['z']>=z_min) & (catalog['z']<=z_max) & (catalog['num_mask']<=num_mask)
files = np.random.choice(catalog['file'][criteria].values, size=(num,), replace=(np.sum(criteria)<num))
if not os.path.exists(output_dir):
os.mkdir(output_dir)
pd.Series(files).to_csv(os.path.join(output_dir, f'{prefix}-catalog.csv'), header=False, index=False)
paths = [os.path.join(data_dir, x) for x in files]
_read_npz_files(flux, error, mask, zqso, pathlist, paths, nprocs)


class Dataloader(object):
Expand All @@ -48,6 +61,7 @@ def __init__(self, config: CN):
self.wav_grid = 10**np.arange(np.log10(config.DATA.LAMMIN), np.log10(config.DATA.LAMMAX), config.DATA.LOGLAM_DELTA)
self.Nb = np.sum(self.wav_grid<_lya_peak)
self.Nr = len(self.wav_grid) - self.Nb
self.type = config.TYPE

self.batch_size = config.DATA.BATCH_SIZE

Expand All @@ -56,34 +70,45 @@ def __init__(self, config: CN):
self.mask = []
self.zabs = []
self.zqso = []
self.pathlist = []

if self.type == 'train':
print("=> Load Data...")
_read_from_catalog(self.flux, self.error, self.mask, self.zqso, self.pathlist, config.DATA.CATALOG,
config.DATA.DATA_DIR, config.DATA.DATA_NUM, config.DATA.SNR_MIN, config.DATA.SNR_MAX, config.DATA.Z_MIN,
config.DATA.Z_MAX, config.DATA.NUM_MASK, config.DATA.NPROCS, config.DATA.OUTPUT_DIR, 'train')

if os.path.exists(config.DATA.VALIDATION_CATALOG) and os.path.exists(config.DATA.VALIDATION_DIR) and config.DATA.VALIDATION:
print("=> Load Validation Data...")
_read_from_catalog(self.flux, self.error, self.mask, self.zqso, self.pathlist, config.DATA.VALIDATION_CATALOG,
config.DATA.VALIDATION_DIR, config.DATA.VALIDATION_NUM, config.DATA.SNR_MIN, config.DATA.SNR_MAX,
config.DATA.Z_MIN, config.DATA.Z_MAX, config.DATA.NUM_MASK, config.DATA.NPROCS, config.DATA.OUTPUT_DIR, 'validation')

print("=> Load Data...")
_read_from_catalog(self.flux, self.error, self.mask, self.zqso, config.DATA.CATALOG,
config.DATA.DATA_DIR, config.DATA.DATA_NUM, config.DATA.SNR_MIN, config.DATA.SNR_MAX, config.DATA.Z_MIN,
config.DATA.Z_MAX, config.DATA.NUM_MASK, config.DATA.NPROCS)
elif self.type == 'predict':
print("=> Load Data...")
paths = pd.read_csv(config.DATA.CATALOG).values.squeeze()
paths = list(map(lambda x: os.path.join(config.DATA.DATA_DIR, x), paths))
_read_npz_files(self.flux, self.error, self.mask, self.zqso, self.pathlist, paths, config.DATA.NPROCS)

if os.path.exists(config.DATA.VALIDATION_CATALOG) and os.path.exists(config.DATA.VALIDATION_DIR) and config.DATA.VALIDATION:
print("=> Load Validation Data...")
_read_from_catalog(self.flux, self.error, self.mask, self.zqso, config.DATA.VALIDATION_CATALOG,
config.DATA.VALIDATION_DIR, config.DATA.VALIDATION_NUM, config.DATA.SNR_MIN, config.DATA.SNR_MAX,
config.DATA.Z_MIN, config.DATA.Z_MAX, config.DATA.NUM_MASK, config.DATA.NPROCS)
else:
raise NotImplementedError("TYPE should be in ['train', 'test']!")


self.flux = np.array(self.flux)
self.error = np.array(self.error)
self.zqso = np.array(self.zqso)
self.mask = np.array(self.mask)
self.pathlist = np.array(self.pathlist)
self.zabs = (self.zqso + 1).reshape(-1, 1)*self.wav_grid[:self.Nb]/1215.67 - 1


self.cur = 0
self._device = None
self._tau = default_tau
self.validation_dir = None
self._tau = partial(taufunc, which=config.MODEL.TAU)
self.data_size = self.flux.shape[0]

s = np.hstack((np.exp(1*self._tau(self.zabs)), np.ones((self.data_size, self.Nr), dtype=float)))
self._mu = np.sum(self.flux*s, axis=0)/np.sum(self.flux!=0., axis=0)
self._mu = np.sum(self.flux*s*self.mask, axis=0)/np.sum(self.flux!=-999., axis=0)
self._mu = smooth(self._mu, window_len=config.TRAIN.WINDOW_LENGTH_FOR_MU)

def have_next_batch(self):
Expand All @@ -93,6 +118,7 @@ def have_next_batch(self):
Returns:
sig (bool): whether this dataloader have next batch
"""
if self.type == 'test': warnings.warn('dataloader is in test mode...')
return self.cur < self.data_size

def next_batch(self):
Expand All @@ -102,6 +128,7 @@ def next_batch(self):
Returns:
delta, error, redshift, mask (torch.tensor): batch data
"""
if self.type == 'test': warnings.warn('dataloader is in test mode...')
start = self.cur
end = self.cur + self.batch_size if self.cur + self.batch_size < self.data_size else self.data_size
self.cur = end
Expand All @@ -117,6 +144,7 @@ def sample(self):
Returns:
delta, error, redshift, mask (torch.tensor): sampled data
"""
if self.type == 'test': warnings.warn('dataloader is in test mode...')
sig = np.random.randint(0, self.data_size, size=(self.batch_size, ))
s = np.hstack((np.exp(-1.*self._tau(self.zabs[sig])), np.ones((self.batch_size, self.Nr), dtype=float)))
return torch.tensor(self.flux[sig]-self._mu*s, dtype=torch.tensor32).to(self._device),\
Expand All @@ -127,6 +155,7 @@ def rewind(self):
"""
shuffle all the data and reset the dataloader
"""
if self.type == 'test': warnings.warn('dataloader is in test mode...')
idx = np.arange(self.data_size)
np.random.shuffle(idx)
self.cur = 0
Expand All @@ -135,6 +164,7 @@ def rewind(self):
self.zqso = self.zqso[idx]
self.zabs = self.zabs[idx]
self.mask = self.mask[idx]
self.pathlist = self.pathlist[idx]

def set_tau(self, tau:Callable[[torch.tensor, ], torch.tensor])->None:
"""
Expand All @@ -148,6 +178,14 @@ def set_device(self, device: torch.device)->None:
"""
self._device = device

def __len__(self):
return len(self.flux)

def __getitem__(self, idx):
return torch.tensor(self.flux[idx], dtype=torch.float32).to(self._device),\
torch.tensor(self.error[idx], dtype=torch.float32).to(self._device), torch.tensor(self.zabs[idx], dtype=torch.float32).to(self._device), \
torch.tensor(self.mask[idx], dtype=bool).to(self._device), self.pathlist[idx]

@property
def mu(self):
return self._mu
return self._mu
18 changes: 12 additions & 6 deletions QFA/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
import os
import numpy as np
from .utils import MatrixInverse, MatrixLogDet, tauHI, omega_func
from .utils import tau as default_tau
from .utils import tau as taufunc
from functools import partial

from torch.nn import functional as F


log2pi = 1.8378770664093453
default_tau = partial(taufunc, which='becker')


class QFA(object):
Expand Down Expand Up @@ -129,9 +131,9 @@ def loglikelihood_and_gradient_for_single_spectra(self, delta: torch.Tensor, err
diag = Psi + omega + masked_error*masked_error
invSigma = MatrixInverse(F, diag, self.device)
logDet = MatrixLogDet(F, diag, self.device)
loglikelihood = 0.5*(masked_delta.T @ invSigma @ masked_delta + Npix * log2pi + logDet)
masked_delta = masked_delta.reshape((-1, 1))
partialSigma = 0.5*(invSigma-invSigma@masked_delta@masked_delta.T@invSigma)
masked_delta = masked_delta[:, None]
loglikelihood = 0.5*(masked_delta.mT @ invSigma @ masked_delta + Npix * log2pi + logDet)
partialSigma = 0.5*(invSigma-invSigma@masked_delta@masked_delta.mT@invSigma)
partialF = 2*diagA@partialSigma@diagA@F
diagPartialSigma = torch.diag(partialSigma)
partialPsi = A*diagPartialSigma*A
Expand Down Expand Up @@ -170,11 +172,12 @@ def prediction_for_single_spectra(self, flux: torch.Tensor, error: torch.Tensor,
diag = Psi + omega + masked_error*masked_error
invSigma = MatrixInverse(F, diag, self.device)
logDet = MatrixLogDet(F, diag, self.device)
loglikelihood = 0.5*(masked_delta.T @ invSigma @ masked_delta + Npix * log2pi + logDet)
masked_delta = masked_delta[:, None]
loglikelihood = 0.5*(masked_delta.mT @ invSigma @ masked_delta + Npix * log2pi + logDet)
Sigma_e = torch.diag(1./diag)
hcov = torch.linalg.inv(torch.eye(self.Nh, dtype=torch.float).to(self.device) + F.T@Sigma_e@F)
hmean = [email protected]@Sigma_e@masked_delta
return loglikelihood, hmean, hcov, self.F@hmean + self.mu, torch.diag(self.F@[email protected])**0.5
return loglikelihood, hmean, hcov, (self.F@hmean).squeeze() + self.mu, torch.diag(self.F@[email protected])**0.5


def train(self, optimizer, dataloader, n_epochs, output_dir="./result", save_interval=5, smooth_interval=5, quiet=False, logger=None):
Expand All @@ -193,6 +196,9 @@ def train(self, optimizer, dataloader, n_epochs, output_dir="./result", save_int
Returns:
None
"""
if not os.path.exists(output_dir):
os.mkdir(output_dir)
output_dir = os.path.join(output_dir, 'checkpoints')
if not os.path.exists(output_dir):
os.mkdir(output_dir)
self.mu = torch.tensor(dataloader.mu, dtype=torch.float32).to(self.device)
Expand Down
59 changes: 58 additions & 1 deletion QFA/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def omega_func(z: torch.Tensor, tau0: torch.float32, beta: torch.float32, c0: to
return root*root


def tau(z: torch.Tensor)->torch.Tensor:
def _tau_becker(z: torch.Tensor)->torch.Tensor:
"""
mean optical depth measured by Becker et al. 2012 [https://arxiv.org/abs/1208.2584]
----------------------------------------------------------------------
Expand All @@ -106,6 +106,63 @@ def tau(z: torch.Tensor)->torch.Tensor:
return tau0 * ((1+z)/(1+z0)) ** beta + C


def _tau_fg(z: torch.Tensor)->torch.Tensor:
"""
mean optical depth measured by Faucher Giguere et al. 2008 [https://iopscience.iop.org/article/10.1086/588648]
-------------------------------------------------------------------------
Args:
z (torch.Tensor (shape=(N, ), dtype=torch.float32)): redshift array

Returns:
effective optical depth: (torch.Tensor (shape=(N, ), dtype=torch.float32))
"""
tau0, beta = 0.0018, 3.92
return tau0 * (1+z) ** beta


def _tau_kamble(z: torch.Tensor)->torch.Tensor:
"""
mean optical depth measured by Kamble et al. 2020 [https://ui.adsabs.harvard.edu/abs/2020ApJ...892...70K/abstract]
------------------------------------------------------------------------------------------
Args:
z (torch.Tensor (shape=(N, ), dtype=torch.float32)): redshift array

Returns:
effective optical depth: (torch.Tensor (shape=(N, ), dtype=torch.float32))
"""
tau0, beta = 5.54*1e-3, 3.182
return tau0 * (1+z) ** beta


def _tau_mock(z: torch.Tensor)->torch.Tensor:
"""
mean optical depth from mock literature, Bautista et al. 2015 [https://iopscience.iop.org/article/10.1088/1475-7516/2015/05/060]
"""
return 0.2231435513142097*((1+z)/3.25)**3.2


def tau(z: torch.Tensor, which: Optional[str]='becker') -> torch.Tensor:
"""
mean optical depth function
---------------------------------------------------
Args:
z (torch.Tensor (shape=(N, ), dtype=torch.float32)): redshift array
which (str): which measurement to use ["becker", 'fg', 'kamble']
Returns:
effective optical depth: (torch.Tensor (shape=(N, ), dtype=torch.float32))
"""
if which == 'becker':
return _tau_becker(z)
elif which == 'fg':
return _tau_fg(z)
elif which == 'kamble':
return _tau_kamble(z)
elif which == 'mock':
return _tau_mock(z)
else:
raise NotImplementedError("currently available mean optical depth function: ['becker', 'fg', 'kamble']")


def smooth(s: np.ndarray, window_len: Optional[int]=32):
"""Smooth curve s with corresponding window length

Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

**An Unsupervised and Probabilistic Quasar Continuum Prediction Algorithm with Latent Factor Analysis**

![](https://img.shields.io/github/license/zechangsun/QFA)


Welcome! This is the source code for *Quasar Factor Analysis*. 😉
Expand Down
Loading