From 7363dc41f75c4eeab8ef4b4162cc877814afdb80 Mon Sep 17 00:00:00 2001 From: Anning Gao Date: Sun, 29 Jan 2023 11:40:29 +0800 Subject: [PATCH 1/6] Add the catalog generation and part of preprocession. --- catalog.ipynb | 179 +++++++++++++++++++++++++++++++++++++++++++++++ preprocess.ipynb | 147 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 326 insertions(+) create mode 100644 catalog.ipynb create mode 100644 preprocess.ipynb diff --git a/catalog.ipynb b/catalog.ipynb new file mode 100644 index 0000000..33a91d5 --- /dev/null +++ b/catalog.ipynb @@ -0,0 +1,179 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Generate the catalogs" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook generate the catalog of DESI data. Each catalog contains 10 column.\n", + "\n", + "* ``file``: describes which file the spectra is in.\n", + "* ``id``: describes the id of each spectra\n", + "* ``snr``: describes the signal-to-noise ratio of the spectra\n", + "* ``z_qso``: describes the redshift of the spectra\n", + "* ``LyBETA, LyALPHA, MgII1, CIV1, MgII2, CIV2``: describes whether the emission line is included in the spectra" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import numpy as np\n", + "import pandas as pd\n", + "from tqdm import tqdm\n", + "from dla_cnn.desi.DesiMock import DesiMock" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# prepare the wavelengths of some important emission lines\\\n", + "\n", + "LyALPHA = 1215.6701\n", + "LyBETA = 1025.7220\n", + "MgII1 = 1482.890\n", + "MgII2 = 1737.628\n", + "CIV1 = 1550\n", + "CIV2 = 1910\n", + "lams = np.array([LyBETA, LyALPHA, MgII1, CIV1, MgII2, CIV2])\n", + "names = ['LyBETA', 'LyALPHA', 'MgII1', 'CIV1', 'MgII2', 'CIV2']\n", + "lines = {}\n", + "for i, name in enumerate(names):\n", + " lines[name] = lams[i]\n", + "print(lines)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# prepare for the data path\n", + "\n", + "prefix = './desi-0.2-100/spectra-16/' # this need to be specialized\n", + "suffix = {}\n", + "for preid in os.listdir(prefix):\n", + " suffix[preid] = os.listdir(prefix+preid)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# generate a catalog (csv format) under each folder\n", + "\n", + "data = {}\n", + "for suffix1 in tqdm(suffix.keys()):\n", + " for suffix2 in tqdm(suffix[suffix1]):\n", + " path = prefix + suffix1 + '/' + suffix2 + '/'\n", + " if len(os.listdir(path)) == 3:\n", + " path_spectra = path + 'spectra-16-' + suffix2 +'.fits'\n", + " path_truth = path + 'truth-16-' + suffix2 +'.fits'\n", + " path_zbest = path + 'zbest-16-' + suffix2 +'.fits'\n", + " data = DesiMock()\n", + " data.read_fits_file(path_spectra, path_truth, path_zbest)\n", + " total = pd.DataFrame()\n", + " for id in data.data:\n", + " sline = data.get_sightline(id=id)\n", + " wav_max, wav_min = 10**np.max(sline.loglam - np.log10(1+sline.z_qso)), 10**np.min(sline.loglam - np.log10(1+sline.z_qso))\n", + " info = pd.DataFrame()\n", + " info['id'] = np.ones(1, dtype='i8') * int(id)\n", + " info['z_qso'] = np.ones(1) * sline.z_qso\n", + " info['snr'] = np.ones(1) * sline.s2n\n", + " for name in names:\n", + " info[name] = [lines[name] >= wav_min and lines[name] <= wav_max]\n", + " total = pd.concat([total, info])\n", + " total['file'] = np.ones(len(total), dtype='i8') * int(suffix2)\n", + " total = total[['file', 'id', 'z_qso', 'snr', 'LyBETA', 'LyALPHA', 'MgII1', 'CIV1', 'MgII2', 'CIV2']]\n", + " total.to_csv(prefix + suffix1 + '/' + suffix2 + '/catalog.csv', index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# delete all the catalog\n", + "\n", + "for suffix1 in suffix.keys():\n", + " for suffix2 in suffix[suffix1]:\n", + " path = prefix + suffix1 + '/' + suffix2 + '/'\n", + " files = os.listdir(path)\n", + " if len(files) == 4:\n", + " for file in files:\n", + " if '.csv' in file:\n", + " os.remove(path + file)\n", + "if 'catalog_total.csv' in os.listdir(prefix):\n", + " os.remove(prefix+'catalog_total.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# generate a total catalog\n", + "# this should be done AFTER the catalog of each folder has been generated\n", + "\n", + "catalog = pd.DataFrame()\n", + "for suffix1 in suffix.keys():\n", + " for suffix2 in suffix[suffix1]:\n", + " path = prefix + suffix1 + '/' + suffix2 + '/'\n", + " files = os.listdir(path)\n", + " if len(files) == 4:\n", + " for file in files:\n", + " if '.csv' in file:\n", + " this = pd.read_csv(path+file)\n", + " catalog = pd.concat([catalog, this])\n", + "\n", + "catalog.to_csv(prefix+'catalog_total.csv')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.15 (main, Nov 24 2022, 14:31:59) \n[GCC 11.2.0]" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "5822ad6f7f6b9f3b64cbf7e0583dcb95803179c963db37d0bbfcb06c41bbd518" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/preprocess.ipynb b/preprocess.ipynb new file mode 100644 index 0000000..72391cd --- /dev/null +++ b/preprocess.ipynb @@ -0,0 +1,147 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Preprocession of DESI spectrum data\n", + "\n", + "This notebook is aimed at preprocessing the DESI spectrum data. There are several changes we make to the original data:\n", + "\n", + "* Rebin the spectrum in the restframe.\n", + "* Clip the abnormal points from the spectrum. (3 sigma)\n", + "* Normalize.\n", + "\n", + "**This notebook is not completed**" + ] + }, + { + "cell_type": "code", + "execution_count": 114, + "metadata": {}, + "outputs": [], + "source": [ + "from dla_cnn.desi.DesiMock import DesiMock\n", + "import numpy as np\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "from astropy.io import fits\n", + "from scipy.interpolate import interp1d\n", + "from matplotlib.backends.backend_pdf import PdfPages\n", + "from tqdm import tqdm\n", + "from astropy.stats import sigma_clip\n", + "from scipy.optimize import curve_fit" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "metadata": {}, + "outputs": [], + "source": [ + "def get_spilt_point(data:DesiMock, id):\n", + " line_b = data.get_sightline(id=id, camera='b')\n", + " line_z = data.get_sightline(id=id, camera='z')\n", + " line_r = data.get_sightline(id=id, camera='r')\n", + " spilt_loglam_br = np.average([np.max(line_b.loglam), np.min(line_r.loglam)])\n", + " spilt_loglam_rz = np.average([np.min(line_z.loglam), np.max(line_r.loglam)])\n", + " return spilt_loglam_br, spilt_loglam_rz\n", + "\n", + "def get_between(array, max, min, maxif=False, minif=False):\n", + " if maxif:\n", + " if minif:\n", + " return np.intersect1d(np.where(array>=min)[0], np.where(array<=max)[0])\n", + " else:\n", + " return np.intersect1d(np.where(array>min)[0], np.where(array<=max)[0])\n", + " else:\n", + " if minif:\n", + " return np.intersect1d(np.where(array>=min)[0], np.where(arraymin)[0], np.where(array Date: Wed, 22 Feb 2023 17:48:37 +0800 Subject: [PATCH 2/6] update for more flexible configuration --- QFA/config.py | 3 +++ QFA/dataloader.py | 9 ++++++--- QFA/model.py | 11 ++++++++--- QFA/utils.py | 50 ++++++++++++++++++++++++++++++++++++++++++++++- 4 files changed, 66 insertions(+), 7 deletions(-) diff --git a/QFA/config.py b/QFA/config.py index 67550b4..28d7e20 100644 --- a/QFA/config.py +++ b/QFA/config.py @@ -46,6 +46,7 @@ #------------------ _C.MODEL = CN() _C.MODEL.NH = 8 +_C.MODEL.TAU = 'becker' _C.MODEL.RESUME = '' @@ -131,6 +132,8 @@ def _check_args(name): config.DATA.NPROCS = args.nprocs if _check_args('validation'): config.DATA.VALIDATION = args.validation + if _check_args('tau'): + config.MODEL.TAU = args.tau config.freeze() diff --git a/QFA/dataloader.py b/QFA/dataloader.py index b50e52d..cc63829 100644 --- a/QFA/dataloader.py +++ b/QFA/dataloader.py @@ -28,10 +28,13 @@ def _read_npz_file(path: str)->Tuple[np.ndarray, np.ndarray, np.ndarray, float]: return flux, error, mask, z -def _read_from_catalog(flux, error, mask, zqso, catalog, data_dir, num, snr_min, snr_max, z_min, z_max, num_mask, nprocs): +def _read_from_catalog(flux, error, mask, zqso, catalog, data_dir, num, snr_min, snr_max, z_min, z_max, num_mask, nprocs, output_dir, prefix='train'): catalog = pd.read_csv(catalog) criteria = (catalog['snr']>=snr_min) & (catalog['snr']<=snr_max) & (catalog['z']>=z_min) & (catalog['z']<=z_max) & (catalog['num_mask']<=num_mask) files = np.random.choice(catalog['file'][criteria].values, size=(num,), replace=(np.sum(criteria) Load Data...") _read_from_catalog(self.flux, self.error, self.mask, self.zqso, config.DATA.CATALOG, config.DATA.DATA_DIR, config.DATA.DATA_NUM, config.DATA.SNR_MIN, config.DATA.SNR_MAX, config.DATA.Z_MIN, - config.DATA.Z_MAX, config.DATA.NUM_MASK, config.DATA.NPROCS) + config.DATA.Z_MAX, config.DATA.NUM_MASK, config.DATA.NPROCS, config.DATA.OUTPUT_DIR, 'train') if os.path.exists(config.DATA.VALIDATION_CATALOG) and os.path.exists(config.DATA.VALIDATION_DIR) and config.DATA.VALIDATION: print("=> Load Validation Data...") _read_from_catalog(self.flux, self.error, self.mask, self.zqso, config.DATA.VALIDATION_CATALOG, config.DATA.VALIDATION_DIR, config.DATA.VALIDATION_NUM, config.DATA.SNR_MIN, config.DATA.SNR_MAX, - config.DATA.Z_MIN, config.DATA.Z_MAX, config.DATA.NUM_MASK, config.DATA.NPROCS) + config.DATA.Z_MIN, config.DATA.Z_MAX, config.DATA.NUM_MASK, config.DATA.NPROCS, config.DATA.OUTPUT_DIR, 'validation') self.flux = np.array(self.flux) diff --git a/QFA/model.py b/QFA/model.py index 0bebacf..c153681 100644 --- a/QFA/model.py +++ b/QFA/model.py @@ -11,12 +11,14 @@ import os import numpy as np from .utils import MatrixInverse, MatrixLogDet, tauHI, omega_func -from .utils import tau as default_tau +from .utils import tau as taufunc +from functools import partial from torch.nn import functional as F log2pi = 1.8378770664093453 +default_tau = partial(taufunc, which='becker') class QFA(object): @@ -129,8 +131,8 @@ def loglikelihood_and_gradient_for_single_spectra(self, delta: torch.Tensor, err diag = Psi + omega + masked_error*masked_error invSigma = MatrixInverse(F, diag, self.device) logDet = MatrixLogDet(F, diag, self.device) - loglikelihood = 0.5*(masked_delta.T @ invSigma @ masked_delta + Npix * log2pi + logDet) - masked_delta = masked_delta.reshape((-1, 1)) + masked_delta = masked_delta[:, None] + loglikelihood = 0.5*(masked_delta.mT @ invSigma @ masked_delta + Npix * log2pi + logDet) partialSigma = 0.5*(invSigma-invSigma@masked_delta@masked_delta.T@invSigma) partialF = 2*diagA@partialSigma@diagA@F diagPartialSigma = torch.diag(partialSigma) @@ -193,6 +195,9 @@ def train(self, optimizer, dataloader, n_epochs, output_dir="./result", save_int Returns: None """ + if not os.path.exists(output_dir): + os.mkdir(output_dir) + output_dir = os.path.join(output_dir, 'checkpoints') if not os.path.exists(output_dir): os.mkdir(output_dir) self.mu = torch.tensor(dataloader.mu, dtype=torch.float32).to(self.device) diff --git a/QFA/utils.py b/QFA/utils.py index 82513fe..98732d6 100644 --- a/QFA/utils.py +++ b/QFA/utils.py @@ -92,7 +92,7 @@ def omega_func(z: torch.Tensor, tau0: torch.float32, beta: torch.float32, c0: to return root*root -def tau(z: torch.Tensor)->torch.Tensor: +def _tau_becker(z: torch.Tensor)->torch.Tensor: """ mean optical depth measured by Becker et al. 2012 [https://arxiv.org/abs/1208.2584] ---------------------------------------------------------------------- @@ -106,6 +106,54 @@ def tau(z: torch.Tensor)->torch.Tensor: return tau0 * ((1+z)/(1+z0)) ** beta + C +def _tau_fg(z: torch.Tensor)->torch.Tensor: + """ + mean optical depth measured by Faucher Giguere et al. 2008 [https://iopscience.iop.org/article/10.1086/588648] + ------------------------------------------------------------------------- + Args: + z (torch.Tensor (shape=(N, ), dtype=torch.float32)): redshift array + + Returns: + effective optical depth: (torch.Tensor (shape=(N, ), dtype=torch.float32)) + """ + tau0, beta = 0.0018, 3.92 + return tau0 * (1+z) ** beta + + +def _tau_kamble(z: torch.Tensor)->torch.Tensor: + """ + mean optical depth measured by Kamble et al. 2020 [https://ui.adsabs.harvard.edu/abs/2020ApJ...892...70K/abstract] + ------------------------------------------------------------------------------------------ + Args: + z (torch.Tensor (shape=(N, ), dtype=torch.float32)): redshift array + + Returns: + effective optical depth: (torch.Tensor (shape=(N, ), dtype=torch.float32)) + """ + tau0, beta = 5.54*1e-3, 3.182 + return tau0 * (1+z) ** beta + + +def tau(z: torch.Tensor, which: Optional[str]='becker') -> torch.Tensor: + """ + mean optical depth function + --------------------------------------------------- + Args: + z (torch.Tensor (shape=(N, ), dtype=torch.float32)): redshift array + which (str): which measurement to use ["becker", 'fg', 'kamble'] + Returns: + effective optical depth: (torch.Tensor (shape=(N, ), dtype=torch.float32)) + """ + if which == 'becker': + return _tau_becker(z) + elif which == 'fg': + return _tau_fg(z) + elif which == 'kamble': + return _tau_kamble(z) + else: + raise NotImplementedError("currently available mean optical depth function: ['becker', 'fg', 'kamble']") + + def smooth(s: np.ndarray, window_len: Optional[int]=32): """Smooth curve s with corresponding window length From 16ade105b5da7621e43f97e107ff2280857d5a0d Mon Sep 17 00:00:00 2001 From: Zechang Sun Date: Tue, 28 Feb 2023 22:08:06 +0800 Subject: [PATCH 3/6] update for better user-interface --- QFA/config.py | 4 +++ QFA/dataloader.py | 81 ++++++++++++++++++++++++++++++++------------- QFA/model.py | 7 ++-- QFA/utils.py | 9 +++++ train.py => main.py | 69 ++++++++++++++++++++++++++++---------- 5 files changed, 127 insertions(+), 43 deletions(-) rename train.py => main.py (53%) diff --git a/QFA/config.py b/QFA/config.py index 28d7e20..fdac277 100644 --- a/QFA/config.py +++ b/QFA/config.py @@ -15,6 +15,7 @@ # Base config files _C.BASE = [''] +_C.TYPE = "train" _C.GPU = 0 #------------------ @@ -62,6 +63,7 @@ _C.TRAIN.WINDOW_LENGTH_FOR_MU = 16 + def _update_config_from_file(config, cfg_file): config.defrost() with open(cfg_file, 'r') as f: @@ -134,6 +136,8 @@ def _check_args(name): config.DATA.VALIDATION = args.validation if _check_args('tau'): config.MODEL.TAU = args.tau + if _check_args('type'): + config.TYPE = args.type config.freeze() diff --git a/QFA/dataloader.py b/QFA/dataloader.py index cc63829..ae7a32d 100644 --- a/QFA/dataloader.py +++ b/QFA/dataloader.py @@ -2,11 +2,13 @@ import torch import numpy as np import pandas as pd +import warnings from tqdm import tqdm import multiprocessing from .utils import smooth -from .utils import tau as default_tau -from typing import Tuple, Callable +from .utils import tau as taufunc +from typing import Tuple, Callable, List +from functools import partial from yacs.config import CfgNode as CN @@ -15,7 +17,7 @@ def _read_npz_file(path: str)->Tuple[np.ndarray, np.ndarray, np.ndarray, float]: """ - load spectra from npz files + load spectra from npz file NOTE: (1) all spectra should have same wavelength grid (2) spectra are preprocessed as in the paper @@ -25,10 +27,24 @@ def _read_npz_file(path: str)->Tuple[np.ndarray, np.ndarray, np.ndarray, float]: flux, error, z = file['flux'], file['error'], float(file['z']) mask = (flux!=-999.)&(error!=-999.) file.close() - return flux, error, mask, z + return flux, error, mask, z, path + +def _read_npz_files(flux: List[np.ndarray], error: List[np.ndarray], mask: List[np.ndarray], zqso: List[np.ndarray], pathlist: List[np.ndarray], paths: str, nprocs: int)->Tuple[np.ndarray, np.ndarray, np.ndarray, float]: + """ + load spectra from npz files + """ + with multiprocessing.Pool(nprocs) as p: + data = p.map(_read_npz_file, paths) + for f, e, m, z, p in tqdm(data): + flux.append(f) + error.append(e) + mask.append(m) + zqso.append(z) + pathlist.append(p) + -def _read_from_catalog(flux, error, mask, zqso, catalog, data_dir, num, snr_min, snr_max, z_min, z_max, num_mask, nprocs, output_dir, prefix='train'): +def _read_from_catalog(flux, error, mask, zqso, pathlist, catalog, data_dir, num, snr_min, snr_max, z_min, z_max, num_mask, nprocs, output_dir, prefix='train'): catalog = pd.read_csv(catalog) criteria = (catalog['snr']>=snr_min) & (catalog['snr']<=snr_max) & (catalog['z']>=z_min) & (catalog['z']<=z_max) & (catalog['num_mask']<=num_mask) files = np.random.choice(catalog['file'][criteria].values, size=(num,), replace=(np.sum(criteria) Load Data...") + _read_from_catalog(self.flux, self.error, self.mask, self.zqso, self.pathlist, config.DATA.CATALOG, + config.DATA.DATA_DIR, config.DATA.DATA_NUM, config.DATA.SNR_MIN, config.DATA.SNR_MAX, config.DATA.Z_MIN, + config.DATA.Z_MAX, config.DATA.NUM_MASK, config.DATA.NPROCS, config.DATA.OUTPUT_DIR, 'train') + + if os.path.exists(config.DATA.VALIDATION_CATALOG) and os.path.exists(config.DATA.VALIDATION_DIR) and config.DATA.VALIDATION: + print("=> Load Validation Data...") + _read_from_catalog(self.flux, self.error, self.mask, self.zqso, self.pathlist, config.DATA.VALIDATION_CATALOG, + config.DATA.VALIDATION_DIR, config.DATA.VALIDATION_NUM, config.DATA.SNR_MIN, config.DATA.SNR_MAX, + config.DATA.Z_MIN, config.DATA.Z_MAX, config.DATA.NUM_MASK, config.DATA.NPROCS, config.DATA.OUTPUT_DIR, 'validation') - print("=> Load Data...") - _read_from_catalog(self.flux, self.error, self.mask, self.zqso, config.DATA.CATALOG, - config.DATA.DATA_DIR, config.DATA.DATA_NUM, config.DATA.SNR_MIN, config.DATA.SNR_MAX, config.DATA.Z_MIN, - config.DATA.Z_MAX, config.DATA.NUM_MASK, config.DATA.NPROCS, config.DATA.OUTPUT_DIR, 'train') + elif self.type == 'predict': + print("=> Load Data...") + paths = pd.read_csv(config.DATA.CATALOG).values.squeeze() + paths = list(map(lambda x: os.path.join(config.DATA.DATA_DIR, x), paths)) + _read_npz_files(self.flux, self.error, self.mask, self.zqso, self.pathlist, paths, config.DATA.NPROCS) - if os.path.exists(config.DATA.VALIDATION_CATALOG) and os.path.exists(config.DATA.VALIDATION_DIR) and config.DATA.VALIDATION: - print("=> Load Validation Data...") - _read_from_catalog(self.flux, self.error, self.mask, self.zqso, config.DATA.VALIDATION_CATALOG, - config.DATA.VALIDATION_DIR, config.DATA.VALIDATION_NUM, config.DATA.SNR_MIN, config.DATA.SNR_MAX, - config.DATA.Z_MIN, config.DATA.Z_MAX, config.DATA.NUM_MASK, config.DATA.NPROCS, config.DATA.OUTPUT_DIR, 'validation') + else: + raise NotImplementedError("TYPE should be in ['train', 'test']!") self.flux = np.array(self.flux) self.error = np.array(self.error) self.zqso = np.array(self.zqso) self.mask = np.array(self.mask) + self.pathlist = np.array(self.pathlist) self.zabs = (self.zqso + 1).reshape(-1, 1)*self.wav_grid[:self.Nb]/1215.67 - 1 self.cur = 0 self._device = None - self._tau = default_tau - self.validation_dir = None + self._tau = partial(taufunc, which=config.MODEL.TAU) self.data_size = self.flux.shape[0] s = np.hstack((np.exp(1*self._tau(self.zabs)), np.ones((self.data_size, self.Nr), dtype=float))) @@ -96,6 +118,7 @@ def have_next_batch(self): Returns: sig (bool): whether this dataloader have next batch """ + if self.type == 'test': warnings.warn('dataloader is in test mode...') return self.cur < self.data_size def next_batch(self): @@ -105,6 +128,7 @@ def next_batch(self): Returns: delta, error, redshift, mask (torch.tensor): batch data """ + if self.type == 'test': warnings.warn('dataloader is in test mode...') start = self.cur end = self.cur + self.batch_size if self.cur + self.batch_size < self.data_size else self.data_size self.cur = end @@ -120,6 +144,7 @@ def sample(self): Returns: delta, error, redshift, mask (torch.tensor): sampled data """ + if self.type == 'test': warnings.warn('dataloader is in test mode...') sig = np.random.randint(0, self.data_size, size=(self.batch_size, )) s = np.hstack((np.exp(-1.*self._tau(self.zabs[sig])), np.ones((self.batch_size, self.Nr), dtype=float))) return torch.tensor(self.flux[sig]-self._mu*s, dtype=torch.tensor32).to(self._device),\ @@ -130,6 +155,7 @@ def rewind(self): """ shuffle all the data and reset the dataloader """ + if self.type == 'test': warnings.warn('dataloader is in test mode...') idx = np.arange(self.data_size) np.random.shuffle(idx) self.cur = 0 @@ -138,6 +164,7 @@ def rewind(self): self.zqso = self.zqso[idx] self.zabs = self.zabs[idx] self.mask = self.mask[idx] + self.pathlist = self.pathlist[idx] def set_tau(self, tau:Callable[[torch.tensor, ], torch.tensor])->None: """ @@ -151,6 +178,14 @@ def set_device(self, device: torch.device)->None: """ self._device = device + def __len__(self): + return len(self.flux) + + def __getitem__(self, idx): + return torch.tensor(self.flux[idx], dtype=torch.float32).to(self._device),\ + torch.tensor(self.error[idx], dtype=torch.float32).to(self._device), torch.tensor(self.zabs[idx], dtype=torch.float32).to(self._device), \ + torch.tensor(self.mask[idx], dtype=bool).to(self._device), self.pathlist[idx] + @property def mu(self): return self._mu \ No newline at end of file diff --git a/QFA/model.py b/QFA/model.py index c153681..ae0f0f1 100644 --- a/QFA/model.py +++ b/QFA/model.py @@ -133,7 +133,7 @@ def loglikelihood_and_gradient_for_single_spectra(self, delta: torch.Tensor, err logDet = MatrixLogDet(F, diag, self.device) masked_delta = masked_delta[:, None] loglikelihood = 0.5*(masked_delta.mT @ invSigma @ masked_delta + Npix * log2pi + logDet) - partialSigma = 0.5*(invSigma-invSigma@masked_delta@masked_delta.T@invSigma) + partialSigma = 0.5*(invSigma-invSigma@masked_delta@masked_delta.mT@invSigma) partialF = 2*diagA@partialSigma@diagA@F diagPartialSigma = torch.diag(partialSigma) partialPsi = A*diagPartialSigma*A @@ -172,11 +172,12 @@ def prediction_for_single_spectra(self, flux: torch.Tensor, error: torch.Tensor, diag = Psi + omega + masked_error*masked_error invSigma = MatrixInverse(F, diag, self.device) logDet = MatrixLogDet(F, diag, self.device) - loglikelihood = 0.5*(masked_delta.T @ invSigma @ masked_delta + Npix * log2pi + logDet) + masked_delta = masked_delta[:, None] + loglikelihood = 0.5*(masked_delta.mT @ invSigma @ masked_delta + Npix * log2pi + logDet) Sigma_e = torch.diag(1./diag) hcov = torch.linalg.inv(torch.eye(self.Nh, dtype=torch.float).to(self.device) + F.T@Sigma_e@F) hmean = hcov@F.T@Sigma_e@masked_delta - return loglikelihood, hmean, hcov, self.F@hmean + self.mu, torch.diag(self.F@hcov@self.F.T)**0.5 + return loglikelihood, hmean, hcov, (self.F@hmean).squeeze() + self.mu, torch.diag(self.F@hcov@self.F.T)**0.5 def train(self, optimizer, dataloader, n_epochs, output_dir="./result", save_interval=5, smooth_interval=5, quiet=False, logger=None): diff --git a/QFA/utils.py b/QFA/utils.py index 98732d6..0d6a185 100644 --- a/QFA/utils.py +++ b/QFA/utils.py @@ -134,6 +134,13 @@ def _tau_kamble(z: torch.Tensor)->torch.Tensor: return tau0 * (1+z) ** beta +def _tau_mock(z: torch.Tensor)->torch.Tensor: + """ + mean optical depth from mock literature, Bautista et al. 2015 [https://iopscience.iop.org/article/10.1088/1475-7516/2015/05/060] + """ + return 0.2231435513142097*((1+z)/3.25)**3.2 + + def tau(z: torch.Tensor, which: Optional[str]='becker') -> torch.Tensor: """ mean optical depth function @@ -150,6 +157,8 @@ def tau(z: torch.Tensor, which: Optional[str]='becker') -> torch.Tensor: return _tau_fg(z) elif which == 'kamble': return _tau_kamble(z) + elif which == 'mock': + return _tau_mock(z) else: raise NotImplementedError("currently available mean optical depth function: ['becker', 'fg', 'kamble']") diff --git a/train.py b/main.py similarity index 53% rename from train.py rename to main.py index e7d3837..5436196 100644 --- a/train.py +++ b/main.py @@ -1,22 +1,29 @@ import argparse +from functools import partial from QFA.model import QFA from QFA.dataloader import Dataloader +from QFA.utils import tau as taufunc from QFA.optimizer import Adam, step_scheduler import logging import os +import numpy as np import torch +import time +from tqdm import tqdm from QFA.config import get_config parser = argparse.ArgumentParser() parser.add_argument("--cfg", type=str, required=False, help='configuration file') parser.add_argument("--catalog", type=str, required=False, help="csv file which records the meta info for each spectra, see the example 'catalog.csv'") +parser.add_argument("--type", type=str, required=False, help='which mode to use [train or predict]...') parser.add_argument("--data_num", type=int, required=False, help="number of quasar spectra for model training") parser.add_argument("--validation_catalog", type=str, required=False, help="csv file which records the meta info for each spectra, see the example 'catalog.csv'") parser.add_argument("--validation_num", type=int, required=False, help="number of quasar spectra for validation set") parser.add_argument("--batch_size", type=int, required=False, help='batch size for model training') parser.add_argument("--n_epochs", type=int, required=False, help='number of training epochs') parser.add_argument("--Nh", type=int, required=False, help="number of hidden variables") +parser.add_argument("--tau", type=str, required=False, help='mean optical depth function') parser.add_argument("--learning_rate", type=float, required=False, help="model learning rate (suggestion: learning_rate should be lower than 1e-2)") parser.add_argument("--gpu", type=int, required=False, help="specify the GPU number for model training") parser.add_argument("--snr_min", type=float, required=False, help="the lowest signal-to-noise ratio in the training set") @@ -38,30 +45,58 @@ if __name__ == "__main__": + # save config + if not os.path.exists(config.DATA.OUTPUT_DIR): + os.mkdir(config.DATA.OUTPUT_DIR) + + with open(os.path.join(config.DATA.OUTPUT_DIR, 'config.yaml'), 'w') as f: + f.write(config.dump()) + # gpu settings os.environ['CUDA_VISIBLE_DEVICES'] = str(config.GPU) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + assert config.TYPE in ['train', 'predict'], "TYPE must be in ['train', 'predict']!" + #load data dataloader = Dataloader(config) dataloader.set_device(device) - # set up logger - logger = logging.getLogger(__name__) - logger.setLevel(level = logging.INFO) - handler = logging.FileHandler(os.path.join(config.DATA.OUTPUT_DIR, "log.txt")) - handler.setLevel(logging.INFO) - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - handler.setFormatter(formatter) - logger.addHandler(handler) - print("training...") + if config.TYPE == 'train': + # set up logger + logger = logging.getLogger(__name__) + logger.setLevel(level = logging.INFO) + handler = logging.FileHandler(os.path.join(config.DATA.OUTPUT_DIR, "log.txt")) + handler.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) + print("training...") - # training - model = QFA(dataloader.Nb, dataloader.Nr, config.MODEL.NH, device=device) - if os.path.exists(config.MODEL.RESUME): - print(f"=> Resume from {config.MODEL.RESUME}") + # training + model = QFA(dataloader.Nb, dataloader.Nr, config.MODEL.NH, device=device, tau=partial(taufunc, which=config.MODEL.TAU)) + if os.path.exists(config.MODEL.RESUME): + print(f"=> Resume from {config.MODEL.RESUME}") + model.load_from_npz(config.MODEL.RESUME) + scheduler = step_scheduler(config.TRAIN.DECAY_ALPHA, config.TRAIN.DECAY_STEP) + optimizer = Adam(params=model.parameters, learning_rate=config.TRAIN.LEARNING_RATE, device=device, scheduler=scheduler, weight_decay=config.TRAIN.WEIGHT_DECAY) + model.random_init_func() + model.train(optimizer, dataloader, config.TRAIN.NEPOCHS, config.DATA.OUTPUT_DIR, logger=logger) + elif config.TYPE == 'predict': + print(f"try to predict {len(dataloader)} spectra...") + model = QFA(dataloader.Nb, dataloader.Nr, config.MODEL.NH, device=device, tau=partial(taufunc, which=config.MODEL.TAU)) + print(f'=> Resume from {config.MODEL.RESUME}') model.load_from_npz(config.MODEL.RESUME) - scheduler = step_scheduler(config.TRAIN.DECAY_ALPHA, config.TRAIN.DECAY_STEP) - optimizer = Adam(params=model.parameters, learning_rate=config.TRAIN.LEARNING_RATE, device=device, scheduler=scheduler, weight_decay=config.TRAIN.WEIGHT_DECAY) - model.random_init_func() - model.train(optimizer, dataloader, config.TRAIN.NEPOCHS, config.DATA.OUTPUT_DIR, logger=logger) + ts = time.time() + output_dir = os.path.join(config.DATA.OUTPUT_DIR, 'predict') + if not os.path.exists(output_dir): + os.mkdir(output_dir) + for (f, e, z, m, p) in tqdm(dataloader): + ll, hmean, hcov, cont, uncertainty = model.prediction_for_single_spectra(f, e, z, m) + result_dict = {key: eval(f'{key}.cpu().detach().numpy()') for key in ['ll', 'hmean', 'hcov', 'cont', 'uncertainty']} + name = os.path.basename(p) + np.savez(os.path.join(output_dir, name), **result_dict) + te = time.time() + print(f'Finish predicting {len(dataloader)} spectra in {te-ts} seconds...') + else: + raise NotImplementedError(f"Mode {config.TYPE} hasn't been implemented!") From 2bf0fde60e2cb3670567f5c1c1cdef9e0328ef83 Mon Sep 17 00:00:00 2001 From: Zechang <60177057+ZechangSun@users.noreply.github.com> Date: Mon, 3 Apr 2023 13:09:15 +0800 Subject: [PATCH 4/6] Update dataloader.py fix bugs for mean vector estimation --- QFA/dataloader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/QFA/dataloader.py b/QFA/dataloader.py index ae7a32d..2c3b664 100644 --- a/QFA/dataloader.py +++ b/QFA/dataloader.py @@ -108,7 +108,7 @@ def __init__(self, config: CN): self.data_size = self.flux.shape[0] s = np.hstack((np.exp(1*self._tau(self.zabs)), np.ones((self.data_size, self.Nr), dtype=float))) - self._mu = np.sum(self.flux*s, axis=0)/np.sum(self.flux!=0., axis=0) + self._mu = np.sum(self.flux*s*self.mask, axis=0)/np.sum(self.flux!=-999., axis=0) self._mu = smooth(self._mu, window_len=config.TRAIN.WINDOW_LENGTH_FOR_MU) def have_next_batch(self): @@ -188,4 +188,4 @@ def __getitem__(self, idx): @property def mu(self): - return self._mu \ No newline at end of file + return self._mu From 0ce58af9b55eba0b9372839cdf579648e1f665da Mon Sep 17 00:00:00 2001 From: Zechang <60177057+ZechangSun@users.noreply.github.com> Date: Mon, 12 Jun 2023 11:37:10 +0800 Subject: [PATCH 5/6] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 79fb672..c3db19e 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ **An Unsupervised and Probabilistic Quasar Continuum Prediction Algorithm with Latent Factor Analysis** +https://img.shields.io/github/license/zechangsun/QFA Welcome! This is the source code for *Quasar Factor Analysis*. 😉 From 8bfdc99ea364df2c18f021f329830ed915b7a3f6 Mon Sep 17 00:00:00 2001 From: Zechang <60177057+ZechangSun@users.noreply.github.com> Date: Mon, 12 Jun 2023 11:40:01 +0800 Subject: [PATCH 6/6] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c3db19e..73ac934 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ **An Unsupervised and Probabilistic Quasar Continuum Prediction Algorithm with Latent Factor Analysis** -https://img.shields.io/github/license/zechangsun/QFA +![](https://img.shields.io/github/license/zechangsun/QFA) Welcome! This is the source code for *Quasar Factor Analysis*. 😉