diff --git a/QFA/config.py b/QFA/config.py index 28d7e20..fdac277 100644 --- a/QFA/config.py +++ b/QFA/config.py @@ -15,6 +15,7 @@ # Base config files _C.BASE = [''] +_C.TYPE = "train" _C.GPU = 0 #------------------ @@ -62,6 +63,7 @@ _C.TRAIN.WINDOW_LENGTH_FOR_MU = 16 + def _update_config_from_file(config, cfg_file): config.defrost() with open(cfg_file, 'r') as f: @@ -134,6 +136,8 @@ def _check_args(name): config.DATA.VALIDATION = args.validation if _check_args('tau'): config.MODEL.TAU = args.tau + if _check_args('type'): + config.TYPE = args.type config.freeze() diff --git a/QFA/dataloader.py b/QFA/dataloader.py index cc63829..ae7a32d 100644 --- a/QFA/dataloader.py +++ b/QFA/dataloader.py @@ -2,11 +2,13 @@ import torch import numpy as np import pandas as pd +import warnings from tqdm import tqdm import multiprocessing from .utils import smooth -from .utils import tau as default_tau -from typing import Tuple, Callable +from .utils import tau as taufunc +from typing import Tuple, Callable, List +from functools import partial from yacs.config import CfgNode as CN @@ -15,7 +17,7 @@ def _read_npz_file(path: str)->Tuple[np.ndarray, np.ndarray, np.ndarray, float]: """ - load spectra from npz files + load spectra from npz file NOTE: (1) all spectra should have same wavelength grid (2) spectra are preprocessed as in the paper @@ -25,10 +27,24 @@ def _read_npz_file(path: str)->Tuple[np.ndarray, np.ndarray, np.ndarray, float]: flux, error, z = file['flux'], file['error'], float(file['z']) mask = (flux!=-999.)&(error!=-999.) file.close() - return flux, error, mask, z + return flux, error, mask, z, path + +def _read_npz_files(flux: List[np.ndarray], error: List[np.ndarray], mask: List[np.ndarray], zqso: List[np.ndarray], pathlist: List[np.ndarray], paths: str, nprocs: int)->Tuple[np.ndarray, np.ndarray, np.ndarray, float]: + """ + load spectra from npz files + """ + with multiprocessing.Pool(nprocs) as p: + data = p.map(_read_npz_file, paths) + for f, e, m, z, p in tqdm(data): + flux.append(f) + error.append(e) + mask.append(m) + zqso.append(z) + pathlist.append(p) + -def _read_from_catalog(flux, error, mask, zqso, catalog, data_dir, num, snr_min, snr_max, z_min, z_max, num_mask, nprocs, output_dir, prefix='train'): +def _read_from_catalog(flux, error, mask, zqso, pathlist, catalog, data_dir, num, snr_min, snr_max, z_min, z_max, num_mask, nprocs, output_dir, prefix='train'): catalog = pd.read_csv(catalog) criteria = (catalog['snr']>=snr_min) & (catalog['snr']<=snr_max) & (catalog['z']>=z_min) & (catalog['z']<=z_max) & (catalog['num_mask']<=num_mask) files = np.random.choice(catalog['file'][criteria].values, size=(num,), replace=(np.sum(criteria) Load Data...") + _read_from_catalog(self.flux, self.error, self.mask, self.zqso, self.pathlist, config.DATA.CATALOG, + config.DATA.DATA_DIR, config.DATA.DATA_NUM, config.DATA.SNR_MIN, config.DATA.SNR_MAX, config.DATA.Z_MIN, + config.DATA.Z_MAX, config.DATA.NUM_MASK, config.DATA.NPROCS, config.DATA.OUTPUT_DIR, 'train') + + if os.path.exists(config.DATA.VALIDATION_CATALOG) and os.path.exists(config.DATA.VALIDATION_DIR) and config.DATA.VALIDATION: + print("=> Load Validation Data...") + _read_from_catalog(self.flux, self.error, self.mask, self.zqso, self.pathlist, config.DATA.VALIDATION_CATALOG, + config.DATA.VALIDATION_DIR, config.DATA.VALIDATION_NUM, config.DATA.SNR_MIN, config.DATA.SNR_MAX, + config.DATA.Z_MIN, config.DATA.Z_MAX, config.DATA.NUM_MASK, config.DATA.NPROCS, config.DATA.OUTPUT_DIR, 'validation') - print("=> Load Data...") - _read_from_catalog(self.flux, self.error, self.mask, self.zqso, config.DATA.CATALOG, - config.DATA.DATA_DIR, config.DATA.DATA_NUM, config.DATA.SNR_MIN, config.DATA.SNR_MAX, config.DATA.Z_MIN, - config.DATA.Z_MAX, config.DATA.NUM_MASK, config.DATA.NPROCS, config.DATA.OUTPUT_DIR, 'train') + elif self.type == 'predict': + print("=> Load Data...") + paths = pd.read_csv(config.DATA.CATALOG).values.squeeze() + paths = list(map(lambda x: os.path.join(config.DATA.DATA_DIR, x), paths)) + _read_npz_files(self.flux, self.error, self.mask, self.zqso, self.pathlist, paths, config.DATA.NPROCS) - if os.path.exists(config.DATA.VALIDATION_CATALOG) and os.path.exists(config.DATA.VALIDATION_DIR) and config.DATA.VALIDATION: - print("=> Load Validation Data...") - _read_from_catalog(self.flux, self.error, self.mask, self.zqso, config.DATA.VALIDATION_CATALOG, - config.DATA.VALIDATION_DIR, config.DATA.VALIDATION_NUM, config.DATA.SNR_MIN, config.DATA.SNR_MAX, - config.DATA.Z_MIN, config.DATA.Z_MAX, config.DATA.NUM_MASK, config.DATA.NPROCS, config.DATA.OUTPUT_DIR, 'validation') + else: + raise NotImplementedError("TYPE should be in ['train', 'test']!") self.flux = np.array(self.flux) self.error = np.array(self.error) self.zqso = np.array(self.zqso) self.mask = np.array(self.mask) + self.pathlist = np.array(self.pathlist) self.zabs = (self.zqso + 1).reshape(-1, 1)*self.wav_grid[:self.Nb]/1215.67 - 1 self.cur = 0 self._device = None - self._tau = default_tau - self.validation_dir = None + self._tau = partial(taufunc, which=config.MODEL.TAU) self.data_size = self.flux.shape[0] s = np.hstack((np.exp(1*self._tau(self.zabs)), np.ones((self.data_size, self.Nr), dtype=float))) @@ -96,6 +118,7 @@ def have_next_batch(self): Returns: sig (bool): whether this dataloader have next batch """ + if self.type == 'test': warnings.warn('dataloader is in test mode...') return self.cur < self.data_size def next_batch(self): @@ -105,6 +128,7 @@ def next_batch(self): Returns: delta, error, redshift, mask (torch.tensor): batch data """ + if self.type == 'test': warnings.warn('dataloader is in test mode...') start = self.cur end = self.cur + self.batch_size if self.cur + self.batch_size < self.data_size else self.data_size self.cur = end @@ -120,6 +144,7 @@ def sample(self): Returns: delta, error, redshift, mask (torch.tensor): sampled data """ + if self.type == 'test': warnings.warn('dataloader is in test mode...') sig = np.random.randint(0, self.data_size, size=(self.batch_size, )) s = np.hstack((np.exp(-1.*self._tau(self.zabs[sig])), np.ones((self.batch_size, self.Nr), dtype=float))) return torch.tensor(self.flux[sig]-self._mu*s, dtype=torch.tensor32).to(self._device),\ @@ -130,6 +155,7 @@ def rewind(self): """ shuffle all the data and reset the dataloader """ + if self.type == 'test': warnings.warn('dataloader is in test mode...') idx = np.arange(self.data_size) np.random.shuffle(idx) self.cur = 0 @@ -138,6 +164,7 @@ def rewind(self): self.zqso = self.zqso[idx] self.zabs = self.zabs[idx] self.mask = self.mask[idx] + self.pathlist = self.pathlist[idx] def set_tau(self, tau:Callable[[torch.tensor, ], torch.tensor])->None: """ @@ -151,6 +178,14 @@ def set_device(self, device: torch.device)->None: """ self._device = device + def __len__(self): + return len(self.flux) + + def __getitem__(self, idx): + return torch.tensor(self.flux[idx], dtype=torch.float32).to(self._device),\ + torch.tensor(self.error[idx], dtype=torch.float32).to(self._device), torch.tensor(self.zabs[idx], dtype=torch.float32).to(self._device), \ + torch.tensor(self.mask[idx], dtype=bool).to(self._device), self.pathlist[idx] + @property def mu(self): return self._mu \ No newline at end of file diff --git a/QFA/model.py b/QFA/model.py index c153681..ae0f0f1 100644 --- a/QFA/model.py +++ b/QFA/model.py @@ -133,7 +133,7 @@ def loglikelihood_and_gradient_for_single_spectra(self, delta: torch.Tensor, err logDet = MatrixLogDet(F, diag, self.device) masked_delta = masked_delta[:, None] loglikelihood = 0.5*(masked_delta.mT @ invSigma @ masked_delta + Npix * log2pi + logDet) - partialSigma = 0.5*(invSigma-invSigma@masked_delta@masked_delta.T@invSigma) + partialSigma = 0.5*(invSigma-invSigma@masked_delta@masked_delta.mT@invSigma) partialF = 2*diagA@partialSigma@diagA@F diagPartialSigma = torch.diag(partialSigma) partialPsi = A*diagPartialSigma*A @@ -172,11 +172,12 @@ def prediction_for_single_spectra(self, flux: torch.Tensor, error: torch.Tensor, diag = Psi + omega + masked_error*masked_error invSigma = MatrixInverse(F, diag, self.device) logDet = MatrixLogDet(F, diag, self.device) - loglikelihood = 0.5*(masked_delta.T @ invSigma @ masked_delta + Npix * log2pi + logDet) + masked_delta = masked_delta[:, None] + loglikelihood = 0.5*(masked_delta.mT @ invSigma @ masked_delta + Npix * log2pi + logDet) Sigma_e = torch.diag(1./diag) hcov = torch.linalg.inv(torch.eye(self.Nh, dtype=torch.float).to(self.device) + F.T@Sigma_e@F) hmean = hcov@F.T@Sigma_e@masked_delta - return loglikelihood, hmean, hcov, self.F@hmean + self.mu, torch.diag(self.F@hcov@self.F.T)**0.5 + return loglikelihood, hmean, hcov, (self.F@hmean).squeeze() + self.mu, torch.diag(self.F@hcov@self.F.T)**0.5 def train(self, optimizer, dataloader, n_epochs, output_dir="./result", save_interval=5, smooth_interval=5, quiet=False, logger=None): diff --git a/QFA/utils.py b/QFA/utils.py index 98732d6..0d6a185 100644 --- a/QFA/utils.py +++ b/QFA/utils.py @@ -134,6 +134,13 @@ def _tau_kamble(z: torch.Tensor)->torch.Tensor: return tau0 * (1+z) ** beta +def _tau_mock(z: torch.Tensor)->torch.Tensor: + """ + mean optical depth from mock literature, Bautista et al. 2015 [https://iopscience.iop.org/article/10.1088/1475-7516/2015/05/060] + """ + return 0.2231435513142097*((1+z)/3.25)**3.2 + + def tau(z: torch.Tensor, which: Optional[str]='becker') -> torch.Tensor: """ mean optical depth function @@ -150,6 +157,8 @@ def tau(z: torch.Tensor, which: Optional[str]='becker') -> torch.Tensor: return _tau_fg(z) elif which == 'kamble': return _tau_kamble(z) + elif which == 'mock': + return _tau_mock(z) else: raise NotImplementedError("currently available mean optical depth function: ['becker', 'fg', 'kamble']") diff --git a/train.py b/main.py similarity index 53% rename from train.py rename to main.py index e7d3837..5436196 100644 --- a/train.py +++ b/main.py @@ -1,22 +1,29 @@ import argparse +from functools import partial from QFA.model import QFA from QFA.dataloader import Dataloader +from QFA.utils import tau as taufunc from QFA.optimizer import Adam, step_scheduler import logging import os +import numpy as np import torch +import time +from tqdm import tqdm from QFA.config import get_config parser = argparse.ArgumentParser() parser.add_argument("--cfg", type=str, required=False, help='configuration file') parser.add_argument("--catalog", type=str, required=False, help="csv file which records the meta info for each spectra, see the example 'catalog.csv'") +parser.add_argument("--type", type=str, required=False, help='which mode to use [train or predict]...') parser.add_argument("--data_num", type=int, required=False, help="number of quasar spectra for model training") parser.add_argument("--validation_catalog", type=str, required=False, help="csv file which records the meta info for each spectra, see the example 'catalog.csv'") parser.add_argument("--validation_num", type=int, required=False, help="number of quasar spectra for validation set") parser.add_argument("--batch_size", type=int, required=False, help='batch size for model training') parser.add_argument("--n_epochs", type=int, required=False, help='number of training epochs') parser.add_argument("--Nh", type=int, required=False, help="number of hidden variables") +parser.add_argument("--tau", type=str, required=False, help='mean optical depth function') parser.add_argument("--learning_rate", type=float, required=False, help="model learning rate (suggestion: learning_rate should be lower than 1e-2)") parser.add_argument("--gpu", type=int, required=False, help="specify the GPU number for model training") parser.add_argument("--snr_min", type=float, required=False, help="the lowest signal-to-noise ratio in the training set") @@ -38,30 +45,58 @@ if __name__ == "__main__": + # save config + if not os.path.exists(config.DATA.OUTPUT_DIR): + os.mkdir(config.DATA.OUTPUT_DIR) + + with open(os.path.join(config.DATA.OUTPUT_DIR, 'config.yaml'), 'w') as f: + f.write(config.dump()) + # gpu settings os.environ['CUDA_VISIBLE_DEVICES'] = str(config.GPU) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + assert config.TYPE in ['train', 'predict'], "TYPE must be in ['train', 'predict']!" + #load data dataloader = Dataloader(config) dataloader.set_device(device) - # set up logger - logger = logging.getLogger(__name__) - logger.setLevel(level = logging.INFO) - handler = logging.FileHandler(os.path.join(config.DATA.OUTPUT_DIR, "log.txt")) - handler.setLevel(logging.INFO) - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - handler.setFormatter(formatter) - logger.addHandler(handler) - print("training...") + if config.TYPE == 'train': + # set up logger + logger = logging.getLogger(__name__) + logger.setLevel(level = logging.INFO) + handler = logging.FileHandler(os.path.join(config.DATA.OUTPUT_DIR, "log.txt")) + handler.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + print("training...") - # training - model = QFA(dataloader.Nb, dataloader.Nr, config.MODEL.NH, device=device) - if os.path.exists(config.MODEL.RESUME): - print(f"=> Resume from {config.MODEL.RESUME}") + # training + model = QFA(dataloader.Nb, dataloader.Nr, config.MODEL.NH, device=device, tau=partial(taufunc, which=config.MODEL.TAU)) + if os.path.exists(config.MODEL.RESUME): + print(f"=> Resume from {config.MODEL.RESUME}") + model.load_from_npz(config.MODEL.RESUME) + scheduler = step_scheduler(config.TRAIN.DECAY_ALPHA, config.TRAIN.DECAY_STEP) + optimizer = Adam(params=model.parameters, learning_rate=config.TRAIN.LEARNING_RATE, device=device, scheduler=scheduler, weight_decay=config.TRAIN.WEIGHT_DECAY) + model.random_init_func() + model.train(optimizer, dataloader, config.TRAIN.NEPOCHS, config.DATA.OUTPUT_DIR, logger=logger) + elif config.TYPE == 'predict': + print(f"try to predict {len(dataloader)} spectra...") + model = QFA(dataloader.Nb, dataloader.Nr, config.MODEL.NH, device=device, tau=partial(taufunc, which=config.MODEL.TAU)) + print(f'=> Resume from {config.MODEL.RESUME}') model.load_from_npz(config.MODEL.RESUME) - scheduler = step_scheduler(config.TRAIN.DECAY_ALPHA, config.TRAIN.DECAY_STEP) - optimizer = Adam(params=model.parameters, learning_rate=config.TRAIN.LEARNING_RATE, device=device, scheduler=scheduler, weight_decay=config.TRAIN.WEIGHT_DECAY) - model.random_init_func() - model.train(optimizer, dataloader, config.TRAIN.NEPOCHS, config.DATA.OUTPUT_DIR, logger=logger) + ts = time.time() + output_dir = os.path.join(config.DATA.OUTPUT_DIR, 'predict') + if not os.path.exists(output_dir): + os.mkdir(output_dir) + for (f, e, z, m, p) in tqdm(dataloader): + ll, hmean, hcov, cont, uncertainty = model.prediction_for_single_spectra(f, e, z, m) + result_dict = {key: eval(f'{key}.cpu().detach().numpy()') for key in ['ll', 'hmean', 'hcov', 'cont', 'uncertainty']} + name = os.path.basename(p) + np.savez(os.path.join(output_dir, name), **result_dict) + te = time.time() + print(f'Finish predicting {len(dataloader)} spectra in {te-ts} seconds...') + else: + raise NotImplementedError(f"Mode {config.TYPE} hasn't been implemented!")