Skip to content

Commit

Permalink
update for better user-interface
Browse files Browse the repository at this point in the history
  • Loading branch information
ZechangSun committed Feb 28, 2023
1 parent ead0ec4 commit 16ade10
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 43 deletions.
4 changes: 4 additions & 0 deletions QFA/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

# Base config files
_C.BASE = ['']
_C.TYPE = "train"
_C.GPU = 0

#------------------
Expand Down Expand Up @@ -62,6 +63,7 @@
_C.TRAIN.WINDOW_LENGTH_FOR_MU = 16



def _update_config_from_file(config, cfg_file):
config.defrost()
with open(cfg_file, 'r') as f:
Expand Down Expand Up @@ -134,6 +136,8 @@ def _check_args(name):
config.DATA.VALIDATION = args.validation
if _check_args('tau'):
config.MODEL.TAU = args.tau
if _check_args('type'):
config.TYPE = args.type


config.freeze()
Expand Down
81 changes: 58 additions & 23 deletions QFA/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import torch
import numpy as np
import pandas as pd
import warnings
from tqdm import tqdm
import multiprocessing
from .utils import smooth
from .utils import tau as default_tau
from typing import Tuple, Callable
from .utils import tau as taufunc
from typing import Tuple, Callable, List
from functools import partial
from yacs.config import CfgNode as CN


Expand All @@ -15,7 +17,7 @@

def _read_npz_file(path: str)->Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
"""
load spectra from npz files
load spectra from npz file
NOTE:
(1) all spectra should have same wavelength grid
(2) spectra are preprocessed as in the paper
Expand All @@ -25,24 +27,32 @@ def _read_npz_file(path: str)->Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
flux, error, z = file['flux'], file['error'], float(file['z'])
mask = (flux!=-999.)&(error!=-999.)
file.close()
return flux, error, mask, z
return flux, error, mask, z, path


def _read_npz_files(flux: List[np.ndarray], error: List[np.ndarray], mask: List[np.ndarray], zqso: List[np.ndarray], pathlist: List[np.ndarray], paths: str, nprocs: int)->Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
"""
load spectra from npz files
"""
with multiprocessing.Pool(nprocs) as p:
data = p.map(_read_npz_file, paths)
for f, e, m, z, p in tqdm(data):
flux.append(f)
error.append(e)
mask.append(m)
zqso.append(z)
pathlist.append(p)


def _read_from_catalog(flux, error, mask, zqso, catalog, data_dir, num, snr_min, snr_max, z_min, z_max, num_mask, nprocs, output_dir, prefix='train'):
def _read_from_catalog(flux, error, mask, zqso, pathlist, catalog, data_dir, num, snr_min, snr_max, z_min, z_max, num_mask, nprocs, output_dir, prefix='train'):
catalog = pd.read_csv(catalog)
criteria = (catalog['snr']>=snr_min) & (catalog['snr']<=snr_max) & (catalog['z']>=z_min) & (catalog['z']<=z_max) & (catalog['num_mask']<=num_mask)
files = np.random.choice(catalog['file'][criteria].values, size=(num,), replace=(np.sum(criteria)<num))
if not os.path.exists(output_dir):
os.mkdir(output_dir)
pd.Series(files).to_csv(os.path.join(output_dir, f'{prefix}-catalog.csv'), header=False, index=False)
paths = [os.path.join(data_dir, x) for x in files]
with multiprocessing.Pool(nprocs) as p:
data = p.map(_read_npz_file, paths)
for f, e, m, z in tqdm(data):
flux.append(f)
error.append(e)
mask.append(m)
zqso.append(z)
_read_npz_files(flux, error, mask, zqso, pathlist, paths, nprocs)


class Dataloader(object):
Expand All @@ -51,6 +61,7 @@ def __init__(self, config: CN):
self.wav_grid = 10**np.arange(np.log10(config.DATA.LAMMIN), np.log10(config.DATA.LAMMAX), config.DATA.LOGLAM_DELTA)
self.Nb = np.sum(self.wav_grid<_lya_peak)
self.Nr = len(self.wav_grid) - self.Nb
self.type = config.TYPE

self.batch_size = config.DATA.BATCH_SIZE

Expand All @@ -59,30 +70,41 @@ def __init__(self, config: CN):
self.mask = []
self.zabs = []
self.zqso = []
self.pathlist = []

if self.type == 'train':
print("=> Load Data...")
_read_from_catalog(self.flux, self.error, self.mask, self.zqso, self.pathlist, config.DATA.CATALOG,
config.DATA.DATA_DIR, config.DATA.DATA_NUM, config.DATA.SNR_MIN, config.DATA.SNR_MAX, config.DATA.Z_MIN,
config.DATA.Z_MAX, config.DATA.NUM_MASK, config.DATA.NPROCS, config.DATA.OUTPUT_DIR, 'train')

if os.path.exists(config.DATA.VALIDATION_CATALOG) and os.path.exists(config.DATA.VALIDATION_DIR) and config.DATA.VALIDATION:
print("=> Load Validation Data...")
_read_from_catalog(self.flux, self.error, self.mask, self.zqso, self.pathlist, config.DATA.VALIDATION_CATALOG,
config.DATA.VALIDATION_DIR, config.DATA.VALIDATION_NUM, config.DATA.SNR_MIN, config.DATA.SNR_MAX,
config.DATA.Z_MIN, config.DATA.Z_MAX, config.DATA.NUM_MASK, config.DATA.NPROCS, config.DATA.OUTPUT_DIR, 'validation')

print("=> Load Data...")
_read_from_catalog(self.flux, self.error, self.mask, self.zqso, config.DATA.CATALOG,
config.DATA.DATA_DIR, config.DATA.DATA_NUM, config.DATA.SNR_MIN, config.DATA.SNR_MAX, config.DATA.Z_MIN,
config.DATA.Z_MAX, config.DATA.NUM_MASK, config.DATA.NPROCS, config.DATA.OUTPUT_DIR, 'train')
elif self.type == 'predict':
print("=> Load Data...")
paths = pd.read_csv(config.DATA.CATALOG).values.squeeze()
paths = list(map(lambda x: os.path.join(config.DATA.DATA_DIR, x), paths))
_read_npz_files(self.flux, self.error, self.mask, self.zqso, self.pathlist, paths, config.DATA.NPROCS)

if os.path.exists(config.DATA.VALIDATION_CATALOG) and os.path.exists(config.DATA.VALIDATION_DIR) and config.DATA.VALIDATION:
print("=> Load Validation Data...")
_read_from_catalog(self.flux, self.error, self.mask, self.zqso, config.DATA.VALIDATION_CATALOG,
config.DATA.VALIDATION_DIR, config.DATA.VALIDATION_NUM, config.DATA.SNR_MIN, config.DATA.SNR_MAX,
config.DATA.Z_MIN, config.DATA.Z_MAX, config.DATA.NUM_MASK, config.DATA.NPROCS, config.DATA.OUTPUT_DIR, 'validation')
else:
raise NotImplementedError("TYPE should be in ['train', 'test']!")


self.flux = np.array(self.flux)
self.error = np.array(self.error)
self.zqso = np.array(self.zqso)
self.mask = np.array(self.mask)
self.pathlist = np.array(self.pathlist)
self.zabs = (self.zqso + 1).reshape(-1, 1)*self.wav_grid[:self.Nb]/1215.67 - 1


self.cur = 0
self._device = None
self._tau = default_tau
self.validation_dir = None
self._tau = partial(taufunc, which=config.MODEL.TAU)
self.data_size = self.flux.shape[0]

s = np.hstack((np.exp(1*self._tau(self.zabs)), np.ones((self.data_size, self.Nr), dtype=float)))
Expand All @@ -96,6 +118,7 @@ def have_next_batch(self):
Returns:
sig (bool): whether this dataloader have next batch
"""
if self.type == 'test': warnings.warn('dataloader is in test mode...')
return self.cur < self.data_size

def next_batch(self):
Expand All @@ -105,6 +128,7 @@ def next_batch(self):
Returns:
delta, error, redshift, mask (torch.tensor): batch data
"""
if self.type == 'test': warnings.warn('dataloader is in test mode...')
start = self.cur
end = self.cur + self.batch_size if self.cur + self.batch_size < self.data_size else self.data_size
self.cur = end
Expand All @@ -120,6 +144,7 @@ def sample(self):
Returns:
delta, error, redshift, mask (torch.tensor): sampled data
"""
if self.type == 'test': warnings.warn('dataloader is in test mode...')
sig = np.random.randint(0, self.data_size, size=(self.batch_size, ))
s = np.hstack((np.exp(-1.*self._tau(self.zabs[sig])), np.ones((self.batch_size, self.Nr), dtype=float)))
return torch.tensor(self.flux[sig]-self._mu*s, dtype=torch.tensor32).to(self._device),\
Expand All @@ -130,6 +155,7 @@ def rewind(self):
"""
shuffle all the data and reset the dataloader
"""
if self.type == 'test': warnings.warn('dataloader is in test mode...')
idx = np.arange(self.data_size)
np.random.shuffle(idx)
self.cur = 0
Expand All @@ -138,6 +164,7 @@ def rewind(self):
self.zqso = self.zqso[idx]
self.zabs = self.zabs[idx]
self.mask = self.mask[idx]
self.pathlist = self.pathlist[idx]

def set_tau(self, tau:Callable[[torch.tensor, ], torch.tensor])->None:
"""
Expand All @@ -151,6 +178,14 @@ def set_device(self, device: torch.device)->None:
"""
self._device = device

def __len__(self):
return len(self.flux)

def __getitem__(self, idx):
return torch.tensor(self.flux[idx], dtype=torch.float32).to(self._device),\
torch.tensor(self.error[idx], dtype=torch.float32).to(self._device), torch.tensor(self.zabs[idx], dtype=torch.float32).to(self._device), \
torch.tensor(self.mask[idx], dtype=bool).to(self._device), self.pathlist[idx]

@property
def mu(self):
return self._mu
7 changes: 4 additions & 3 deletions QFA/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def loglikelihood_and_gradient_for_single_spectra(self, delta: torch.Tensor, err
logDet = MatrixLogDet(F, diag, self.device)
masked_delta = masked_delta[:, None]
loglikelihood = 0.5*(masked_delta.mT @ invSigma @ masked_delta + Npix * log2pi + logDet)
partialSigma = 0.5*(invSigma-invSigma@masked_delta@masked_delta.T@invSigma)
partialSigma = 0.5*(invSigma-invSigma@masked_delta@masked_delta.mT@invSigma)
partialF = 2*diagA@partialSigma@diagA@F
diagPartialSigma = torch.diag(partialSigma)
partialPsi = A*diagPartialSigma*A
Expand Down Expand Up @@ -172,11 +172,12 @@ def prediction_for_single_spectra(self, flux: torch.Tensor, error: torch.Tensor,
diag = Psi + omega + masked_error*masked_error
invSigma = MatrixInverse(F, diag, self.device)
logDet = MatrixLogDet(F, diag, self.device)
loglikelihood = 0.5*(masked_delta.T @ invSigma @ masked_delta + Npix * log2pi + logDet)
masked_delta = masked_delta[:, None]
loglikelihood = 0.5*(masked_delta.mT @ invSigma @ masked_delta + Npix * log2pi + logDet)
Sigma_e = torch.diag(1./diag)
hcov = torch.linalg.inv(torch.eye(self.Nh, dtype=torch.float).to(self.device) + F.T@Sigma_e@F)
hmean = hcov@F.T@Sigma_e@masked_delta
return loglikelihood, hmean, hcov, self.F@hmean + self.mu, torch.diag(self.F@hcov@self.F.T)**0.5
return loglikelihood, hmean, hcov, (self.F@hmean).squeeze() + self.mu, torch.diag(self.F@hcov@self.F.T)**0.5


def train(self, optimizer, dataloader, n_epochs, output_dir="./result", save_interval=5, smooth_interval=5, quiet=False, logger=None):
Expand Down
9 changes: 9 additions & 0 deletions QFA/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,13 @@ def _tau_kamble(z: torch.Tensor)->torch.Tensor:
return tau0 * (1+z) ** beta


def _tau_mock(z: torch.Tensor)->torch.Tensor:
"""
mean optical depth from mock literature, Bautista et al. 2015 [https://iopscience.iop.org/article/10.1088/1475-7516/2015/05/060]
"""
return 0.2231435513142097*((1+z)/3.25)**3.2


def tau(z: torch.Tensor, which: Optional[str]='becker') -> torch.Tensor:
"""
mean optical depth function
Expand All @@ -150,6 +157,8 @@ def tau(z: torch.Tensor, which: Optional[str]='becker') -> torch.Tensor:
return _tau_fg(z)
elif which == 'kamble':
return _tau_kamble(z)
elif which == 'mock':
return _tau_mock(z)
else:
raise NotImplementedError("currently available mean optical depth function: ['becker', 'fg', 'kamble']")

Expand Down
69 changes: 52 additions & 17 deletions train.py → main.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,29 @@
import argparse
from functools import partial
from QFA.model import QFA
from QFA.dataloader import Dataloader
from QFA.utils import tau as taufunc
from QFA.optimizer import Adam, step_scheduler
import logging
import os
import numpy as np
import torch
import time
from tqdm import tqdm
from QFA.config import get_config


parser = argparse.ArgumentParser()
parser.add_argument("--cfg", type=str, required=False, help='configuration file')
parser.add_argument("--catalog", type=str, required=False, help="csv file which records the meta info for each spectra, see the example 'catalog.csv'")
parser.add_argument("--type", type=str, required=False, help='which mode to use [train or predict]...')
parser.add_argument("--data_num", type=int, required=False, help="number of quasar spectra for model training")
parser.add_argument("--validation_catalog", type=str, required=False, help="csv file which records the meta info for each spectra, see the example 'catalog.csv'")
parser.add_argument("--validation_num", type=int, required=False, help="number of quasar spectra for validation set")
parser.add_argument("--batch_size", type=int, required=False, help='batch size for model training')
parser.add_argument("--n_epochs", type=int, required=False, help='number of training epochs')
parser.add_argument("--Nh", type=int, required=False, help="number of hidden variables")
parser.add_argument("--tau", type=str, required=False, help='mean optical depth function')
parser.add_argument("--learning_rate", type=float, required=False, help="model learning rate (suggestion: learning_rate should be lower than 1e-2)")
parser.add_argument("--gpu", type=int, required=False, help="specify the GPU number for model training")
parser.add_argument("--snr_min", type=float, required=False, help="the lowest signal-to-noise ratio in the training set")
Expand All @@ -38,30 +45,58 @@


if __name__ == "__main__":
# save config
if not os.path.exists(config.DATA.OUTPUT_DIR):
os.mkdir(config.DATA.OUTPUT_DIR)

with open(os.path.join(config.DATA.OUTPUT_DIR, 'config.yaml'), 'w') as f:
f.write(config.dump())

# gpu settings
os.environ['CUDA_VISIBLE_DEVICES'] = str(config.GPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

assert config.TYPE in ['train', 'predict'], "TYPE must be in ['train', 'predict']!"

#load data
dataloader = Dataloader(config)
dataloader.set_device(device)

# set up logger
logger = logging.getLogger(__name__)
logger.setLevel(level = logging.INFO)
handler = logging.FileHandler(os.path.join(config.DATA.OUTPUT_DIR, "log.txt"))
handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
print("training...")
if config.TYPE == 'train':
# set up logger
logger = logging.getLogger(__name__)
logger.setLevel(level = logging.INFO)
handler = logging.FileHandler(os.path.join(config.DATA.OUTPUT_DIR, "log.txt"))
handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
print("training...")

# training
model = QFA(dataloader.Nb, dataloader.Nr, config.MODEL.NH, device=device)
if os.path.exists(config.MODEL.RESUME):
print(f"=> Resume from {config.MODEL.RESUME}")
# training
model = QFA(dataloader.Nb, dataloader.Nr, config.MODEL.NH, device=device, tau=partial(taufunc, which=config.MODEL.TAU))
if os.path.exists(config.MODEL.RESUME):
print(f"=> Resume from {config.MODEL.RESUME}")
model.load_from_npz(config.MODEL.RESUME)
scheduler = step_scheduler(config.TRAIN.DECAY_ALPHA, config.TRAIN.DECAY_STEP)
optimizer = Adam(params=model.parameters, learning_rate=config.TRAIN.LEARNING_RATE, device=device, scheduler=scheduler, weight_decay=config.TRAIN.WEIGHT_DECAY)
model.random_init_func()
model.train(optimizer, dataloader, config.TRAIN.NEPOCHS, config.DATA.OUTPUT_DIR, logger=logger)
elif config.TYPE == 'predict':
print(f"try to predict {len(dataloader)} spectra...")
model = QFA(dataloader.Nb, dataloader.Nr, config.MODEL.NH, device=device, tau=partial(taufunc, which=config.MODEL.TAU))
print(f'=> Resume from {config.MODEL.RESUME}')
model.load_from_npz(config.MODEL.RESUME)
scheduler = step_scheduler(config.TRAIN.DECAY_ALPHA, config.TRAIN.DECAY_STEP)
optimizer = Adam(params=model.parameters, learning_rate=config.TRAIN.LEARNING_RATE, device=device, scheduler=scheduler, weight_decay=config.TRAIN.WEIGHT_DECAY)
model.random_init_func()
model.train(optimizer, dataloader, config.TRAIN.NEPOCHS, config.DATA.OUTPUT_DIR, logger=logger)
ts = time.time()
output_dir = os.path.join(config.DATA.OUTPUT_DIR, 'predict')
if not os.path.exists(output_dir):
os.mkdir(output_dir)
for (f, e, z, m, p) in tqdm(dataloader):
ll, hmean, hcov, cont, uncertainty = model.prediction_for_single_spectra(f, e, z, m)
result_dict = {key: eval(f'{key}.cpu().detach().numpy()') for key in ['ll', 'hmean', 'hcov', 'cont', 'uncertainty']}
name = os.path.basename(p)
np.savez(os.path.join(output_dir, name), **result_dict)
te = time.time()
print(f'Finish predicting {len(dataloader)} spectra in {te-ts} seconds...')
else:
raise NotImplementedError(f"Mode {config.TYPE} hasn't been implemented!")

0 comments on commit 16ade10

Please sign in to comment.