2
2
import torch
3
3
import numpy as np
4
4
import pandas as pd
5
+ import warnings
5
6
from tqdm import tqdm
6
7
import multiprocessing
7
8
from .utils import smooth
8
- from .utils import tau as default_tau
9
- from typing import Tuple , Callable
9
+ from .utils import tau as taufunc
10
+ from typing import Tuple , Callable , List
11
+ from functools import partial
10
12
from yacs .config import CfgNode as CN
11
13
12
14
15
17
16
18
def _read_npz_file (path : str )-> Tuple [np .ndarray , np .ndarray , np .ndarray , float ]:
17
19
"""
18
- load spectra from npz files
20
+ load spectra from npz file
19
21
NOTE:
20
22
(1) all spectra should have same wavelength grid
21
23
(2) spectra are preprocessed as in the paper
@@ -25,24 +27,32 @@ def _read_npz_file(path: str)->Tuple[np.ndarray, np.ndarray, np.ndarray, float]:
25
27
flux , error , z = file ['flux' ], file ['error' ], float (file ['z' ])
26
28
mask = (flux != - 999. )& (error != - 999. )
27
29
file .close ()
28
- return flux , error , mask , z
30
+ return flux , error , mask , z , path
31
+
29
32
33
+ def _read_npz_files (flux : List [np .ndarray ], error : List [np .ndarray ], mask : List [np .ndarray ], zqso : List [np .ndarray ], pathlist : List [np .ndarray ], paths : str , nprocs : int )-> Tuple [np .ndarray , np .ndarray , np .ndarray , float ]:
34
+ """
35
+ load spectra from npz files
36
+ """
37
+ with multiprocessing .Pool (nprocs ) as p :
38
+ data = p .map (_read_npz_file , paths )
39
+ for f , e , m , z , p in tqdm (data ):
40
+ flux .append (f )
41
+ error .append (e )
42
+ mask .append (m )
43
+ zqso .append (z )
44
+ pathlist .append (p )
45
+
30
46
31
- def _read_from_catalog (flux , error , mask , zqso , catalog , data_dir , num , snr_min , snr_max , z_min , z_max , num_mask , nprocs , output_dir , prefix = 'train' ):
47
+ def _read_from_catalog (flux , error , mask , zqso , pathlist , catalog , data_dir , num , snr_min , snr_max , z_min , z_max , num_mask , nprocs , output_dir , prefix = 'train' ):
32
48
catalog = pd .read_csv (catalog )
33
49
criteria = (catalog ['snr' ]>= snr_min ) & (catalog ['snr' ]<= snr_max ) & (catalog ['z' ]>= z_min ) & (catalog ['z' ]<= z_max ) & (catalog ['num_mask' ]<= num_mask )
34
50
files = np .random .choice (catalog ['file' ][criteria ].values , size = (num ,), replace = (np .sum (criteria )< num ))
35
51
if not os .path .exists (output_dir ):
36
52
os .mkdir (output_dir )
37
53
pd .Series (files ).to_csv (os .path .join (output_dir , f'{ prefix } -catalog.csv' ), header = False , index = False )
38
54
paths = [os .path .join (data_dir , x ) for x in files ]
39
- with multiprocessing .Pool (nprocs ) as p :
40
- data = p .map (_read_npz_file , paths )
41
- for f , e , m , z in tqdm (data ):
42
- flux .append (f )
43
- error .append (e )
44
- mask .append (m )
45
- zqso .append (z )
55
+ _read_npz_files (flux , error , mask , zqso , pathlist , paths , nprocs )
46
56
47
57
48
58
class Dataloader (object ):
@@ -51,6 +61,7 @@ def __init__(self, config: CN):
51
61
self .wav_grid = 10 ** np .arange (np .log10 (config .DATA .LAMMIN ), np .log10 (config .DATA .LAMMAX ), config .DATA .LOGLAM_DELTA )
52
62
self .Nb = np .sum (self .wav_grid < _lya_peak )
53
63
self .Nr = len (self .wav_grid ) - self .Nb
64
+ self .type = config .TYPE
54
65
55
66
self .batch_size = config .DATA .BATCH_SIZE
56
67
@@ -59,30 +70,41 @@ def __init__(self, config: CN):
59
70
self .mask = []
60
71
self .zabs = []
61
72
self .zqso = []
73
+ self .pathlist = []
74
+
75
+ if self .type == 'train' :
76
+ print ("=> Load Data..." )
77
+ _read_from_catalog (self .flux , self .error , self .mask , self .zqso , self .pathlist , config .DATA .CATALOG ,
78
+ config .DATA .DATA_DIR , config .DATA .DATA_NUM , config .DATA .SNR_MIN , config .DATA .SNR_MAX , config .DATA .Z_MIN ,
79
+ config .DATA .Z_MAX , config .DATA .NUM_MASK , config .DATA .NPROCS , config .DATA .OUTPUT_DIR , 'train' )
80
+
81
+ if os .path .exists (config .DATA .VALIDATION_CATALOG ) and os .path .exists (config .DATA .VALIDATION_DIR ) and config .DATA .VALIDATION :
82
+ print ("=> Load Validation Data..." )
83
+ _read_from_catalog (self .flux , self .error , self .mask , self .zqso , self .pathlist , config .DATA .VALIDATION_CATALOG ,
84
+ config .DATA .VALIDATION_DIR , config .DATA .VALIDATION_NUM , config .DATA .SNR_MIN , config .DATA .SNR_MAX ,
85
+ config .DATA .Z_MIN , config .DATA .Z_MAX , config .DATA .NUM_MASK , config .DATA .NPROCS , config .DATA .OUTPUT_DIR , 'validation' )
62
86
63
- print ("=> Load Data..." )
64
- _read_from_catalog (self .flux , self .error , self .mask , self .zqso , config .DATA .CATALOG ,
65
- config .DATA .DATA_DIR , config .DATA .DATA_NUM , config .DATA .SNR_MIN , config .DATA .SNR_MAX , config .DATA .Z_MIN ,
66
- config .DATA .Z_MAX , config .DATA .NUM_MASK , config .DATA .NPROCS , config .DATA .OUTPUT_DIR , 'train' )
87
+ elif self .type == 'predict' :
88
+ print ("=> Load Data..." )
89
+ paths = pd .read_csv (config .DATA .CATALOG ).values .squeeze ()
90
+ paths = list (map (lambda x : os .path .join (config .DATA .DATA_DIR , x ), paths ))
91
+ _read_npz_files (self .flux , self .error , self .mask , self .zqso , self .pathlist , paths , config .DATA .NPROCS )
67
92
68
- if os .path .exists (config .DATA .VALIDATION_CATALOG ) and os .path .exists (config .DATA .VALIDATION_DIR ) and config .DATA .VALIDATION :
69
- print ("=> Load Validation Data..." )
70
- _read_from_catalog (self .flux , self .error , self .mask , self .zqso , config .DATA .VALIDATION_CATALOG ,
71
- config .DATA .VALIDATION_DIR , config .DATA .VALIDATION_NUM , config .DATA .SNR_MIN , config .DATA .SNR_MAX ,
72
- config .DATA .Z_MIN , config .DATA .Z_MAX , config .DATA .NUM_MASK , config .DATA .NPROCS , config .DATA .OUTPUT_DIR , 'validation' )
93
+ else :
94
+ raise NotImplementedError ("TYPE should be in ['train', 'test']!" )
73
95
74
96
75
97
self .flux = np .array (self .flux )
76
98
self .error = np .array (self .error )
77
99
self .zqso = np .array (self .zqso )
78
100
self .mask = np .array (self .mask )
101
+ self .pathlist = np .array (self .pathlist )
79
102
self .zabs = (self .zqso + 1 ).reshape (- 1 , 1 )* self .wav_grid [:self .Nb ]/ 1215.67 - 1
80
103
81
104
82
105
self .cur = 0
83
106
self ._device = None
84
- self ._tau = default_tau
85
- self .validation_dir = None
107
+ self ._tau = partial (taufunc , which = config .MODEL .TAU )
86
108
self .data_size = self .flux .shape [0 ]
87
109
88
110
s = np .hstack ((np .exp (1 * self ._tau (self .zabs )), np .ones ((self .data_size , self .Nr ), dtype = float )))
@@ -96,6 +118,7 @@ def have_next_batch(self):
96
118
Returns:
97
119
sig (bool): whether this dataloader have next batch
98
120
"""
121
+ if self .type == 'test' : warnings .warn ('dataloader is in test mode...' )
99
122
return self .cur < self .data_size
100
123
101
124
def next_batch (self ):
@@ -105,6 +128,7 @@ def next_batch(self):
105
128
Returns:
106
129
delta, error, redshift, mask (torch.tensor): batch data
107
130
"""
131
+ if self .type == 'test' : warnings .warn ('dataloader is in test mode...' )
108
132
start = self .cur
109
133
end = self .cur + self .batch_size if self .cur + self .batch_size < self .data_size else self .data_size
110
134
self .cur = end
@@ -120,6 +144,7 @@ def sample(self):
120
144
Returns:
121
145
delta, error, redshift, mask (torch.tensor): sampled data
122
146
"""
147
+ if self .type == 'test' : warnings .warn ('dataloader is in test mode...' )
123
148
sig = np .random .randint (0 , self .data_size , size = (self .batch_size , ))
124
149
s = np .hstack ((np .exp (- 1. * self ._tau (self .zabs [sig ])), np .ones ((self .batch_size , self .Nr ), dtype = float )))
125
150
return torch .tensor (self .flux [sig ]- self ._mu * s , dtype = torch .tensor32 ).to (self ._device ),\
@@ -130,6 +155,7 @@ def rewind(self):
130
155
"""
131
156
shuffle all the data and reset the dataloader
132
157
"""
158
+ if self .type == 'test' : warnings .warn ('dataloader is in test mode...' )
133
159
idx = np .arange (self .data_size )
134
160
np .random .shuffle (idx )
135
161
self .cur = 0
@@ -138,6 +164,7 @@ def rewind(self):
138
164
self .zqso = self .zqso [idx ]
139
165
self .zabs = self .zabs [idx ]
140
166
self .mask = self .mask [idx ]
167
+ self .pathlist = self .pathlist [idx ]
141
168
142
169
def set_tau (self , tau :Callable [[torch .tensor , ], torch .tensor ])-> None :
143
170
"""
@@ -151,6 +178,14 @@ def set_device(self, device: torch.device)->None:
151
178
"""
152
179
self ._device = device
153
180
181
+ def __len__ (self ):
182
+ return len (self .flux )
183
+
184
+ def __getitem__ (self , idx ):
185
+ return torch .tensor (self .flux [idx ], dtype = torch .float32 ).to (self ._device ),\
186
+ torch .tensor (self .error [idx ], dtype = torch .float32 ).to (self ._device ), torch .tensor (self .zabs [idx ], dtype = torch .float32 ).to (self ._device ), \
187
+ torch .tensor (self .mask [idx ], dtype = bool ).to (self ._device ), self .pathlist [idx ]
188
+
154
189
@property
155
190
def mu (self ):
156
191
return self ._mu
0 commit comments