Skip to content

Commit 16ade10

Browse files
committed
update for better user-interface
1 parent ead0ec4 commit 16ade10

File tree

5 files changed

+127
-43
lines changed

5 files changed

+127
-43
lines changed

QFA/config.py

+4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
# Base config files
1717
_C.BASE = ['']
18+
_C.TYPE = "train"
1819
_C.GPU = 0
1920

2021
#------------------
@@ -62,6 +63,7 @@
6263
_C.TRAIN.WINDOW_LENGTH_FOR_MU = 16
6364

6465

66+
6567
def _update_config_from_file(config, cfg_file):
6668
config.defrost()
6769
with open(cfg_file, 'r') as f:
@@ -134,6 +136,8 @@ def _check_args(name):
134136
config.DATA.VALIDATION = args.validation
135137
if _check_args('tau'):
136138
config.MODEL.TAU = args.tau
139+
if _check_args('type'):
140+
config.TYPE = args.type
137141

138142

139143
config.freeze()

QFA/dataloader.py

+58-23
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
import torch
33
import numpy as np
44
import pandas as pd
5+
import warnings
56
from tqdm import tqdm
67
import multiprocessing
78
from .utils import smooth
8-
from .utils import tau as default_tau
9-
from typing import Tuple, Callable
9+
from .utils import tau as taufunc
10+
from typing import Tuple, Callable, List
11+
from functools import partial
1012
from yacs.config import CfgNode as CN
1113

1214

@@ -15,7 +17,7 @@
1517

1618
def _read_npz_file(path: str)->Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
1719
"""
18-
load spectra from npz files
20+
load spectra from npz file
1921
NOTE:
2022
(1) all spectra should have same wavelength grid
2123
(2) spectra are preprocessed as in the paper
@@ -25,24 +27,32 @@ def _read_npz_file(path: str)->Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
2527
flux, error, z = file['flux'], file['error'], float(file['z'])
2628
mask = (flux!=-999.)&(error!=-999.)
2729
file.close()
28-
return flux, error, mask, z
30+
return flux, error, mask, z, path
31+
2932

33+
def _read_npz_files(flux: List[np.ndarray], error: List[np.ndarray], mask: List[np.ndarray], zqso: List[np.ndarray], pathlist: List[np.ndarray], paths: str, nprocs: int)->Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
34+
"""
35+
load spectra from npz files
36+
"""
37+
with multiprocessing.Pool(nprocs) as p:
38+
data = p.map(_read_npz_file, paths)
39+
for f, e, m, z, p in tqdm(data):
40+
flux.append(f)
41+
error.append(e)
42+
mask.append(m)
43+
zqso.append(z)
44+
pathlist.append(p)
45+
3046

31-
def _read_from_catalog(flux, error, mask, zqso, catalog, data_dir, num, snr_min, snr_max, z_min, z_max, num_mask, nprocs, output_dir, prefix='train'):
47+
def _read_from_catalog(flux, error, mask, zqso, pathlist, catalog, data_dir, num, snr_min, snr_max, z_min, z_max, num_mask, nprocs, output_dir, prefix='train'):
3248
catalog = pd.read_csv(catalog)
3349
criteria = (catalog['snr']>=snr_min) & (catalog['snr']<=snr_max) & (catalog['z']>=z_min) & (catalog['z']<=z_max) & (catalog['num_mask']<=num_mask)
3450
files = np.random.choice(catalog['file'][criteria].values, size=(num,), replace=(np.sum(criteria)<num))
3551
if not os.path.exists(output_dir):
3652
os.mkdir(output_dir)
3753
pd.Series(files).to_csv(os.path.join(output_dir, f'{prefix}-catalog.csv'), header=False, index=False)
3854
paths = [os.path.join(data_dir, x) for x in files]
39-
with multiprocessing.Pool(nprocs) as p:
40-
data = p.map(_read_npz_file, paths)
41-
for f, e, m, z in tqdm(data):
42-
flux.append(f)
43-
error.append(e)
44-
mask.append(m)
45-
zqso.append(z)
55+
_read_npz_files(flux, error, mask, zqso, pathlist, paths, nprocs)
4656

4757

4858
class Dataloader(object):
@@ -51,6 +61,7 @@ def __init__(self, config: CN):
5161
self.wav_grid = 10**np.arange(np.log10(config.DATA.LAMMIN), np.log10(config.DATA.LAMMAX), config.DATA.LOGLAM_DELTA)
5262
self.Nb = np.sum(self.wav_grid<_lya_peak)
5363
self.Nr = len(self.wav_grid) - self.Nb
64+
self.type = config.TYPE
5465

5566
self.batch_size = config.DATA.BATCH_SIZE
5667

@@ -59,30 +70,41 @@ def __init__(self, config: CN):
5970
self.mask = []
6071
self.zabs = []
6172
self.zqso = []
73+
self.pathlist = []
74+
75+
if self.type == 'train':
76+
print("=> Load Data...")
77+
_read_from_catalog(self.flux, self.error, self.mask, self.zqso, self.pathlist, config.DATA.CATALOG,
78+
config.DATA.DATA_DIR, config.DATA.DATA_NUM, config.DATA.SNR_MIN, config.DATA.SNR_MAX, config.DATA.Z_MIN,
79+
config.DATA.Z_MAX, config.DATA.NUM_MASK, config.DATA.NPROCS, config.DATA.OUTPUT_DIR, 'train')
80+
81+
if os.path.exists(config.DATA.VALIDATION_CATALOG) and os.path.exists(config.DATA.VALIDATION_DIR) and config.DATA.VALIDATION:
82+
print("=> Load Validation Data...")
83+
_read_from_catalog(self.flux, self.error, self.mask, self.zqso, self.pathlist, config.DATA.VALIDATION_CATALOG,
84+
config.DATA.VALIDATION_DIR, config.DATA.VALIDATION_NUM, config.DATA.SNR_MIN, config.DATA.SNR_MAX,
85+
config.DATA.Z_MIN, config.DATA.Z_MAX, config.DATA.NUM_MASK, config.DATA.NPROCS, config.DATA.OUTPUT_DIR, 'validation')
6286

63-
print("=> Load Data...")
64-
_read_from_catalog(self.flux, self.error, self.mask, self.zqso, config.DATA.CATALOG,
65-
config.DATA.DATA_DIR, config.DATA.DATA_NUM, config.DATA.SNR_MIN, config.DATA.SNR_MAX, config.DATA.Z_MIN,
66-
config.DATA.Z_MAX, config.DATA.NUM_MASK, config.DATA.NPROCS, config.DATA.OUTPUT_DIR, 'train')
87+
elif self.type == 'predict':
88+
print("=> Load Data...")
89+
paths = pd.read_csv(config.DATA.CATALOG).values.squeeze()
90+
paths = list(map(lambda x: os.path.join(config.DATA.DATA_DIR, x), paths))
91+
_read_npz_files(self.flux, self.error, self.mask, self.zqso, self.pathlist, paths, config.DATA.NPROCS)
6792

68-
if os.path.exists(config.DATA.VALIDATION_CATALOG) and os.path.exists(config.DATA.VALIDATION_DIR) and config.DATA.VALIDATION:
69-
print("=> Load Validation Data...")
70-
_read_from_catalog(self.flux, self.error, self.mask, self.zqso, config.DATA.VALIDATION_CATALOG,
71-
config.DATA.VALIDATION_DIR, config.DATA.VALIDATION_NUM, config.DATA.SNR_MIN, config.DATA.SNR_MAX,
72-
config.DATA.Z_MIN, config.DATA.Z_MAX, config.DATA.NUM_MASK, config.DATA.NPROCS, config.DATA.OUTPUT_DIR, 'validation')
93+
else:
94+
raise NotImplementedError("TYPE should be in ['train', 'test']!")
7395

7496

7597
self.flux = np.array(self.flux)
7698
self.error = np.array(self.error)
7799
self.zqso = np.array(self.zqso)
78100
self.mask = np.array(self.mask)
101+
self.pathlist = np.array(self.pathlist)
79102
self.zabs = (self.zqso + 1).reshape(-1, 1)*self.wav_grid[:self.Nb]/1215.67 - 1
80103

81104

82105
self.cur = 0
83106
self._device = None
84-
self._tau = default_tau
85-
self.validation_dir = None
107+
self._tau = partial(taufunc, which=config.MODEL.TAU)
86108
self.data_size = self.flux.shape[0]
87109

88110
s = np.hstack((np.exp(1*self._tau(self.zabs)), np.ones((self.data_size, self.Nr), dtype=float)))
@@ -96,6 +118,7 @@ def have_next_batch(self):
96118
Returns:
97119
sig (bool): whether this dataloader have next batch
98120
"""
121+
if self.type == 'test': warnings.warn('dataloader is in test mode...')
99122
return self.cur < self.data_size
100123

101124
def next_batch(self):
@@ -105,6 +128,7 @@ def next_batch(self):
105128
Returns:
106129
delta, error, redshift, mask (torch.tensor): batch data
107130
"""
131+
if self.type == 'test': warnings.warn('dataloader is in test mode...')
108132
start = self.cur
109133
end = self.cur + self.batch_size if self.cur + self.batch_size < self.data_size else self.data_size
110134
self.cur = end
@@ -120,6 +144,7 @@ def sample(self):
120144
Returns:
121145
delta, error, redshift, mask (torch.tensor): sampled data
122146
"""
147+
if self.type == 'test': warnings.warn('dataloader is in test mode...')
123148
sig = np.random.randint(0, self.data_size, size=(self.batch_size, ))
124149
s = np.hstack((np.exp(-1.*self._tau(self.zabs[sig])), np.ones((self.batch_size, self.Nr), dtype=float)))
125150
return torch.tensor(self.flux[sig]-self._mu*s, dtype=torch.tensor32).to(self._device),\
@@ -130,6 +155,7 @@ def rewind(self):
130155
"""
131156
shuffle all the data and reset the dataloader
132157
"""
158+
if self.type == 'test': warnings.warn('dataloader is in test mode...')
133159
idx = np.arange(self.data_size)
134160
np.random.shuffle(idx)
135161
self.cur = 0
@@ -138,6 +164,7 @@ def rewind(self):
138164
self.zqso = self.zqso[idx]
139165
self.zabs = self.zabs[idx]
140166
self.mask = self.mask[idx]
167+
self.pathlist = self.pathlist[idx]
141168

142169
def set_tau(self, tau:Callable[[torch.tensor, ], torch.tensor])->None:
143170
"""
@@ -151,6 +178,14 @@ def set_device(self, device: torch.device)->None:
151178
"""
152179
self._device = device
153180

181+
def __len__(self):
182+
return len(self.flux)
183+
184+
def __getitem__(self, idx):
185+
return torch.tensor(self.flux[idx], dtype=torch.float32).to(self._device),\
186+
torch.tensor(self.error[idx], dtype=torch.float32).to(self._device), torch.tensor(self.zabs[idx], dtype=torch.float32).to(self._device), \
187+
torch.tensor(self.mask[idx], dtype=bool).to(self._device), self.pathlist[idx]
188+
154189
@property
155190
def mu(self):
156191
return self._mu

QFA/model.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def loglikelihood_and_gradient_for_single_spectra(self, delta: torch.Tensor, err
133133
logDet = MatrixLogDet(F, diag, self.device)
134134
masked_delta = masked_delta[:, None]
135135
loglikelihood = 0.5*(masked_delta.mT @ invSigma @ masked_delta + Npix * log2pi + logDet)
136-
partialSigma = 0.5*(invSigma-invSigma@masked_delta@masked_delta.T@invSigma)
136+
partialSigma = 0.5*(invSigma-invSigma@masked_delta@masked_delta.mT@invSigma)
137137
partialF = 2*diagA@partialSigma@diagA@F
138138
diagPartialSigma = torch.diag(partialSigma)
139139
partialPsi = A*diagPartialSigma*A
@@ -172,11 +172,12 @@ def prediction_for_single_spectra(self, flux: torch.Tensor, error: torch.Tensor,
172172
diag = Psi + omega + masked_error*masked_error
173173
invSigma = MatrixInverse(F, diag, self.device)
174174
logDet = MatrixLogDet(F, diag, self.device)
175-
loglikelihood = 0.5*(masked_delta.T @ invSigma @ masked_delta + Npix * log2pi + logDet)
175+
masked_delta = masked_delta[:, None]
176+
loglikelihood = 0.5*(masked_delta.mT @ invSigma @ masked_delta + Npix * log2pi + logDet)
176177
Sigma_e = torch.diag(1./diag)
177178
hcov = torch.linalg.inv(torch.eye(self.Nh, dtype=torch.float).to(self.device) + F.T@Sigma_e@F)
178179
hmean = hcov@F.T@Sigma_e@masked_delta
179-
return loglikelihood, hmean, hcov, self.F@hmean + self.mu, torch.diag(self.F@hcov@self.F.T)**0.5
180+
return loglikelihood, hmean, hcov, (self.F@hmean).squeeze() + self.mu, torch.diag(self.F@hcov@self.F.T)**0.5
180181

181182

182183
def train(self, optimizer, dataloader, n_epochs, output_dir="./result", save_interval=5, smooth_interval=5, quiet=False, logger=None):

QFA/utils.py

+9
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,13 @@ def _tau_kamble(z: torch.Tensor)->torch.Tensor:
134134
return tau0 * (1+z) ** beta
135135

136136

137+
def _tau_mock(z: torch.Tensor)->torch.Tensor:
138+
"""
139+
mean optical depth from mock literature, Bautista et al. 2015 [https://iopscience.iop.org/article/10.1088/1475-7516/2015/05/060]
140+
"""
141+
return 0.2231435513142097*((1+z)/3.25)**3.2
142+
143+
137144
def tau(z: torch.Tensor, which: Optional[str]='becker') -> torch.Tensor:
138145
"""
139146
mean optical depth function
@@ -150,6 +157,8 @@ def tau(z: torch.Tensor, which: Optional[str]='becker') -> torch.Tensor:
150157
return _tau_fg(z)
151158
elif which == 'kamble':
152159
return _tau_kamble(z)
160+
elif which == 'mock':
161+
return _tau_mock(z)
153162
else:
154163
raise NotImplementedError("currently available mean optical depth function: ['becker', 'fg', 'kamble']")
155164

train.py main.py

+52-17
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,29 @@
11
import argparse
2+
from functools import partial
23
from QFA.model import QFA
34
from QFA.dataloader import Dataloader
5+
from QFA.utils import tau as taufunc
46
from QFA.optimizer import Adam, step_scheduler
57
import logging
68
import os
9+
import numpy as np
710
import torch
11+
import time
12+
from tqdm import tqdm
813
from QFA.config import get_config
914

1015

1116
parser = argparse.ArgumentParser()
1217
parser.add_argument("--cfg", type=str, required=False, help='configuration file')
1318
parser.add_argument("--catalog", type=str, required=False, help="csv file which records the meta info for each spectra, see the example 'catalog.csv'")
19+
parser.add_argument("--type", type=str, required=False, help='which mode to use [train or predict]...')
1420
parser.add_argument("--data_num", type=int, required=False, help="number of quasar spectra for model training")
1521
parser.add_argument("--validation_catalog", type=str, required=False, help="csv file which records the meta info for each spectra, see the example 'catalog.csv'")
1622
parser.add_argument("--validation_num", type=int, required=False, help="number of quasar spectra for validation set")
1723
parser.add_argument("--batch_size", type=int, required=False, help='batch size for model training')
1824
parser.add_argument("--n_epochs", type=int, required=False, help='number of training epochs')
1925
parser.add_argument("--Nh", type=int, required=False, help="number of hidden variables")
26+
parser.add_argument("--tau", type=str, required=False, help='mean optical depth function')
2027
parser.add_argument("--learning_rate", type=float, required=False, help="model learning rate (suggestion: learning_rate should be lower than 1e-2)")
2128
parser.add_argument("--gpu", type=int, required=False, help="specify the GPU number for model training")
2229
parser.add_argument("--snr_min", type=float, required=False, help="the lowest signal-to-noise ratio in the training set")
@@ -38,30 +45,58 @@
3845

3946

4047
if __name__ == "__main__":
48+
# save config
49+
if not os.path.exists(config.DATA.OUTPUT_DIR):
50+
os.mkdir(config.DATA.OUTPUT_DIR)
51+
52+
with open(os.path.join(config.DATA.OUTPUT_DIR, 'config.yaml'), 'w') as f:
53+
f.write(config.dump())
54+
4155
# gpu settings
4256
os.environ['CUDA_VISIBLE_DEVICES'] = str(config.GPU)
4357
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
4458

59+
assert config.TYPE in ['train', 'predict'], "TYPE must be in ['train', 'predict']!"
60+
4561
#load data
4662
dataloader = Dataloader(config)
4763
dataloader.set_device(device)
4864

49-
# set up logger
50-
logger = logging.getLogger(__name__)
51-
logger.setLevel(level = logging.INFO)
52-
handler = logging.FileHandler(os.path.join(config.DATA.OUTPUT_DIR, "log.txt"))
53-
handler.setLevel(logging.INFO)
54-
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
55-
handler.setFormatter(formatter)
56-
logger.addHandler(handler)
57-
print("training...")
65+
if config.TYPE == 'train':
66+
# set up logger
67+
logger = logging.getLogger(__name__)
68+
logger.setLevel(level = logging.INFO)
69+
handler = logging.FileHandler(os.path.join(config.DATA.OUTPUT_DIR, "log.txt"))
70+
handler.setLevel(logging.INFO)
71+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
72+
handler.setFormatter(formatter)
73+
logger.addHandler(handler)
74+
print("training...")
5875

59-
# training
60-
model = QFA(dataloader.Nb, dataloader.Nr, config.MODEL.NH, device=device)
61-
if os.path.exists(config.MODEL.RESUME):
62-
print(f"=> Resume from {config.MODEL.RESUME}")
76+
# training
77+
model = QFA(dataloader.Nb, dataloader.Nr, config.MODEL.NH, device=device, tau=partial(taufunc, which=config.MODEL.TAU))
78+
if os.path.exists(config.MODEL.RESUME):
79+
print(f"=> Resume from {config.MODEL.RESUME}")
80+
model.load_from_npz(config.MODEL.RESUME)
81+
scheduler = step_scheduler(config.TRAIN.DECAY_ALPHA, config.TRAIN.DECAY_STEP)
82+
optimizer = Adam(params=model.parameters, learning_rate=config.TRAIN.LEARNING_RATE, device=device, scheduler=scheduler, weight_decay=config.TRAIN.WEIGHT_DECAY)
83+
model.random_init_func()
84+
model.train(optimizer, dataloader, config.TRAIN.NEPOCHS, config.DATA.OUTPUT_DIR, logger=logger)
85+
elif config.TYPE == 'predict':
86+
print(f"try to predict {len(dataloader)} spectra...")
87+
model = QFA(dataloader.Nb, dataloader.Nr, config.MODEL.NH, device=device, tau=partial(taufunc, which=config.MODEL.TAU))
88+
print(f'=> Resume from {config.MODEL.RESUME}')
6389
model.load_from_npz(config.MODEL.RESUME)
64-
scheduler = step_scheduler(config.TRAIN.DECAY_ALPHA, config.TRAIN.DECAY_STEP)
65-
optimizer = Adam(params=model.parameters, learning_rate=config.TRAIN.LEARNING_RATE, device=device, scheduler=scheduler, weight_decay=config.TRAIN.WEIGHT_DECAY)
66-
model.random_init_func()
67-
model.train(optimizer, dataloader, config.TRAIN.NEPOCHS, config.DATA.OUTPUT_DIR, logger=logger)
90+
ts = time.time()
91+
output_dir = os.path.join(config.DATA.OUTPUT_DIR, 'predict')
92+
if not os.path.exists(output_dir):
93+
os.mkdir(output_dir)
94+
for (f, e, z, m, p) in tqdm(dataloader):
95+
ll, hmean, hcov, cont, uncertainty = model.prediction_for_single_spectra(f, e, z, m)
96+
result_dict = {key: eval(f'{key}.cpu().detach().numpy()') for key in ['ll', 'hmean', 'hcov', 'cont', 'uncertainty']}
97+
name = os.path.basename(p)
98+
np.savez(os.path.join(output_dir, name), **result_dict)
99+
te = time.time()
100+
print(f'Finish predicting {len(dataloader)} spectra in {te-ts} seconds...')
101+
else:
102+
raise NotImplementedError(f"Mode {config.TYPE} hasn't been implemented!")

0 commit comments

Comments
 (0)