diff --git a/src/toast/instrument.py b/src/toast/instrument.py index 19411b404..a0d7f6c06 100644 --- a/src/toast/instrument.py +++ b/src/toast/instrument.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023 by the parties listed in the AUTHORS file. +# Copyright (c) 2019-2024 by the parties listed in the AUTHORS file. # All rights reserved. Use of this source code is governed by # a BSD-style license that can be found in the LICENSE file. @@ -267,6 +267,7 @@ def _position_velocity(self, times): vel_z = np.zeros(n_sparse, np.float64) for i, t in enumerate(sparse_times): atime = astime.Time(t, format="unix") + # Get the satellite position and velocity in the equatorial frame (ICRS) p, v = coord.get_body_barycentric_posvel("earth", atime) # FIXME: apply translation from earth center to L2. pm = p.xyz.to_value(u.kilometer) diff --git a/src/toast/ops/mapmaker.py b/src/toast/ops/mapmaker.py index 4939fb910..2fa57c44c 100644 --- a/src/toast/ops/mapmaker.py +++ b/src/toast/ops/mapmaker.py @@ -1,4 +1,4 @@ -# Copyright (c) 2015-2023 by the parties listed in the AUTHORS file. +# Copyright (c) 2015-2024 by the parties listed in the AUTHORS file. # All rights reserved. Use of this source code is governed by # a BSD-style license that can be found in the LICENSE file. @@ -214,25 +214,125 @@ def __init__(self, **kwargs): super().__init__(**kwargs) @function_timer - def _exec(self, data, detectors=None, use_accel=None, **kwargs): + def _write_del(self, prod_key, prod_write, force, rootname): + """Write data object to file and delete it from cache""" log = Logger.get() - timer = Timer() - log_prefix = "MapMaker" - memreport = MemoryCounter() - if not self.report_memory: - memreport.enabled = False + # FIXME: This I/O technique assumes "known" types of pixel representations. + # Instead, we should associate read / write functions to a particular pixel + # class. - memreport.prefix = "Start of mapmaking" - memreport.apply(data, use_accel=use_accel) + if self.map_binning is not None and self.map_binning.enabled: + map_binning = self.map_binning + else: + map_binning = self.binning + + if hasattr(map_binning.pixel_pointing, "wcs"): + is_pix_wcs = True + else: + is_pix_wcs = False + is_hpix_nest = map_binning.pixel_pointing.nest + + wtimer = Timer() + wtimer.start() + product = prod_key.replace(f"{self.name}_", "") + if prod_write: + if is_pix_wcs: + fname = os.path.join(self.output_dir, f"{rootname}_{product}.fits") + if self.mc_mode and not force and os.path.isfile(fname): + log.info_rank(f"Skipping existing file: {fname}", comm=self._comm) + else: + write_wcs_fits(self._data[prod_key], fname) + else: + if self.write_hdf5: + # Non-standard HDF5 output + fname = os.path.join(self.output_dir, f"{rootname}_{product}.h5") + if self.mc_mode and not force and os.path.isfile(fname): + log.info_rank( + f"Skipping existing file: {fname}", comm=self._comm + ) + else: + write_healpix_hdf5( + self._data[prod_key], + fname, + nest=is_hpix_nest, + single_precision=True, + force_serial=self.write_hdf5_serial, + ) + else: + # Standard FITS output + fname = os.path.join(self.output_dir, f"{rootname}_{product}.fits") + if self.mc_mode and not force and os.path.isfile(fname): + log.info_rank( + f"Skipping existing file: {fname}", comm=self._comm + ) + else: + write_healpix_fits( + self._data[prod_key], + fname, + nest=is_hpix_nest, + report_memory=self.report_memory, + ) + log.info_rank(f"Wrote {fname} in", comm=self._comm, timer=wtimer) + + if not self.keep_final_products and not self.mc_mode: + if prod_key in self._data: + self._data[prod_key].clear() + del self._data[prod_key] + + self._memreport.prefix = f"After writing/deleting {prod_key}" + self._memreport.apply(self._data, use_accel=self._use_accel) + + return + + @function_timer + def _setup(self, data, detectors, use_accel): + """Set up convenience members used in the _exec() method""" + + self._log = Logger.get() + self._timer = Timer() + self._log_prefix = "MapMaker" + + self._mc_root = self.name + if self.mc_mode: + if self.mc_root is not None: + self._mc_root += f"_{self.mc_root}" + if self.mc_index is not None: + self._mc_root += f"_{self.mc_index:05d}" + + self._data = data + self._detectors = detectors + self._use_accel = use_accel + self._memreport = MemoryCounter() + if not self.report_memory: + self._memreport.enabled = False # The global communicator we are using (or None) - comm = data.comm.comm_world - rank = data.comm.world_rank - timer.start() + self._comm = data.comm.comm_world + self._rank = data.comm.world_rank + + # Data names of outputs + + self.hits_name = f"{self.name}_hits" + self.cov_name = f"{self.name}_cov" + self.invcov_name = f"{self.name}_invcov" + self.rcond_name = f"{self.name}_rcond" + self.det_flag_name = f"{self.name}_flags" + + self.clean_name = f"{self.name}_cleaned" + self.binmap_name = f"{self.name}_binmap" + self.map_name = f"{self.name}_map" + self.noiseweighted_map_name = f"{self.name}_noiseweighted_map" + + self._timer.start() + + return + + @function_timer + def _fit_templates(self): + """Solve for template amplitudes""" - # Solve for template amplitudes amplitudes_solve = SolveAmplitudes( name=self.name, det_data=self.det_data, @@ -253,47 +353,47 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): reset_pix_dist=self.reset_pix_dist, report_memory=self.report_memory, ) - amplitudes_solve.apply(data, detectors=detectors, use_accel=use_accel) - - log.info_rank( - f"{log_prefix} finished template amplitude solve in", - comm=comm, - timer=timer, + amplitudes_solve.apply( + self._data, detectors=self._detectors, use_accel=self._use_accel ) + template_amplitudes = amplitudes_solve.amplitudes - # Data names of outputs + self._log.info_rank( + f"{self._log_prefix} finished template amplitude solve in", + comm=self._comm, + timer=self._timer, + ) - self.hits_name = "{}_hits".format(self.name) - self.cov_name = "{}_cov".format(self.name) - self.invcov_name = "{}_invcov".format(self.name) - self.rcond_name = "{}_rcond".format(self.name) - self.det_flag_name = "{}_flags".format(self.name) + self._memreport.prefix = "After solving amplitudes" + self._memreport.apply(self._data, use_accel=self._use_accel) - self.clean_name = "{}_cleaned".format(self.name) - self.binmap_name = "{}_binmap".format(self.name) - self.map_name = "{}_map".format(self.name) - self.noiseweighted_map_name = "{}_noiseweighted_map".format(self.name) + return template_amplitudes - # Check map binning + @function_timer + def _prepare_binning(self): + """Set up the final map binning""" - map_binning = self.map_binning - if self.map_binning is None or not self.map_binning.enabled: + # Map binning operator + if self.map_binning is not None and self.map_binning.enabled: + map_binning = self.map_binning + else: # Use the same binning used in the solver. map_binning = self.binning map_binning.pre_process = None map_binning.covariance = self.cov_name + # Pixel distribution if self.reset_pix_dist: - if map_binning.pixel_dist in data: - del data[map_binning.pixel_dist] - if map_binning.covariance in data: + if map_binning.pixel_dist in self._data: + del self._data[map_binning.pixel_dist] + if map_binning.covariance in self._data: # Cannot trust earlier covariance - del data[map_binning.covariance] + del self._data[map_binning.covariance] - if map_binning.pixel_dist not in data: - log.info_rank( - f"{log_prefix} Caching pixel distribution", - comm=comm, + if map_binning.pixel_dist not in self._data: + self._log.info_rank( + f"{self._log_prefix} Caching pixel distribution", + comm=self._comm, ) pix_dist = BuildPixelDistribution( pixel_dist=map_binning.pixel_dist, @@ -301,68 +401,107 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): shared_flags=map_binning.shared_flags, shared_flag_mask=map_binning.shared_flag_mask, ) - pix_dist.apply(data) - log.info_rank( - f"{log_prefix} finished build of pixel distribution in", - comm=comm, - timer=timer, + pix_dist.apply(self._data) + self._log.info_rank( + f"{self._log_prefix} finished build of pixel distribution in", + comm=self._comm, + timer=self._timer, ) - if map_binning.covariance not in data: - # Construct the noise covariance, hits, and condition number - # mask for the final binned map. + self._memreport.prefix = "After pixel distribution" + self._memreport.apply(self._data, use_accel=self._use_accel) - log.info_rank( - f"{log_prefix} begin build of final binning covariance", - comm=comm, - ) + return map_binning - final_cov = CovarianceAndHits( - pixel_dist=map_binning.pixel_dist, - covariance=map_binning.covariance, - inverse_covariance=self.invcov_name, - hits=self.hits_name, - rcond=self.rcond_name, - det_mask=map_binning.det_mask, - det_flags=map_binning.det_flags, - det_flag_mask=map_binning.det_flag_mask, - det_data_units=map_binning.det_data_units, - shared_flags=map_binning.shared_flags, - shared_flag_mask=map_binning.shared_flag_mask, - pixel_pointing=map_binning.pixel_pointing, - stokes_weights=map_binning.stokes_weights, - noise_model=map_binning.noise_model, - rcond_threshold=self.map_rcond_threshold, - sync_type=map_binning.sync_type, - save_pointing=map_binning.full_pointing, - ) + @function_timer + def _build_pixel_covariance(self, map_binning): + """Accumulate hits and pixel covariance""" - final_cov.apply(data, detectors=detectors, use_accel=use_accel) + if map_binning.covariance in self._data and self.mc_mode: + # Covariance is already cached + return - log.info_rank( - f"{log_prefix} finished build of final covariance in", - comm=comm, - timer=timer, - ) + # Construct the noise covariance, hits, and condition number + # mask for the final binned map. - memreport.prefix = "After constructing final covariance and hits" - memreport.apply(data, use_accel=use_accel) + self._log.info_rank( + f"{self._log_prefix} begin build of final binning covariance", + comm=self._comm, + ) - if self.write_binmap: - map_binning.det_data = self.det_data - map_binning.binned = self.binmap_name - map_binning.noiseweighted = None - log.info_rank( - f"{log_prefix} begin map binning", - comm=comm, - ) - map_binning.apply(data, detectors=detectors, use_accel=use_accel) - log.info_rank( - f"{log_prefix} finished binning in", - comm=comm, - timer=timer, - ) + final_cov = CovarianceAndHits( + pixel_dist=map_binning.pixel_dist, + covariance=map_binning.covariance, + inverse_covariance=self.invcov_name, + hits=self.hits_name, + rcond=self.rcond_name, + det_mask=map_binning.det_mask, + det_flags=map_binning.det_flags, + det_flag_mask=map_binning.det_flag_mask, + det_data_units=map_binning.det_data_units, + shared_flags=map_binning.shared_flags, + shared_flag_mask=map_binning.shared_flag_mask, + pixel_pointing=map_binning.pixel_pointing, + stokes_weights=map_binning.stokes_weights, + noise_model=map_binning.noise_model, + rcond_threshold=self.map_rcond_threshold, + sync_type=map_binning.sync_type, + save_pointing=map_binning.full_pointing, + ) + final_cov.apply( + self._data, detectors=self._detectors, use_accel=self._use_accel + ) + + self._log.info_rank( + f"{self._log_prefix} finished build of final covariance in", + comm=self._comm, + timer=self._timer, + ) + + self._memreport.prefix = "After constructing final covariance and hits" + self._memreport.apply(self._data, use_accel=self._use_accel) + + # These data products are not needed later so they can be + # written out and purged + + self._write_del(self.hits_name, self.write_hits, False, self.name) + self._write_del(self.rcond_name, self.write_rcond, False, self.name) + self._write_del(self.invcov_name, self.write_invcov, False, self.name) + + return + + @function_timer + def _bin_and_write_raw_signal(self, map_binning): + """Optionally bin and save an undestriped map""" + + if not self.write_binmap: + return + + map_binning.det_data = self.det_data + map_binning.binned = self.binmap_name + map_binning.noiseweighted = None + self._log.info_rank( + f"{self._log_prefix} begin map binning", + comm=self._comm, + ) + map_binning.apply( + self._data, detectors=self._detectors, use_accel=self._use_accel + ) + self._log.info_rank( + f"{self._log_prefix} finished binning in", + comm=self._comm, + timer=self._timer, + ) + self._write_del(self.binmap_name, self.write_binmap, True, self._mc_root) + + self._memreport.prefix = "After binning final map" + self._memreport.apply(self._data, use_accel=self._use_accel) + + return + + @function_timer + def _clean_signal(self, template_amplitudes): if ( self.template_matrix is None or self.template_matrix.n_enabled_templates == 0 @@ -372,9 +511,9 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): else: # Apply (subtract) solved amplitudes. - log.info_rank( - f"{log_prefix} begin apply template amplitudes", - comm=comm, + self._log.info_rank( + f"{self._log_prefix} begin apply template amplitudes", + comm=self._comm, ) out_cleaned = self.clean_name @@ -385,18 +524,37 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): amplitudes_apply = ApplyAmplitudes( op="subtract", det_data=self.det_data, - amplitudes=amplitudes_solve.amplitudes, + amplitudes=template_amplitudes, template_matrix=self.template_matrix, output=out_cleaned, ) - amplitudes_apply.apply(data, detectors=detectors, use_accel=use_accel) + amplitudes_apply.apply( + self._data, detectors=self._detectors, use_accel=self._use_accel + ) + + if not self.keep_solver_products: + del self._data[template_amplitudes] - log.info_rank( - f"{log_prefix} finished apply template amplitudes in", - comm=comm, - timer=timer, + self._log.info_rank( + f"{self._log_prefix} finished apply template amplitudes in", + comm=self._comm, + timer=self._timer, ) + self._memreport.prefix = "After subtracting templates" + self._memreport.apply(self._data, use_accel=self._use_accel) + + return out_cleaned + + @function_timer + def _bin_cleaned_signal(self, map_binning, out_cleaned): + """Bin and save a map of the destriped signal""" + + self._log.info_rank( + f"{self._log_prefix} begin final map binning", + comm=self._comm, + ) + if out_cleaned is None: map_binning.det_data = self.det_data else: @@ -405,123 +563,102 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): map_binning.noiseweighted = self.noiseweighted_map_name map_binning.binned = self.map_name - log.info_rank( - f"{log_prefix} begin final map binning", - comm=comm, - ) - # Do the final binning - map_binning.apply(data, detectors=detectors, use_accel=use_accel) + map_binning.apply( + self._data, detectors=self._detectors, use_accel=self._use_accel + ) - log.info_rank( - f"{log_prefix} finished final binning in", - comm=comm, - timer=timer, + self._log.info_rank( + f"{self._log_prefix} finished final binning in", + comm=self._comm, + timer=self._timer, ) - memreport.prefix = "After binning final map" - memreport.apply(data, use_accel=use_accel) + self._memreport.prefix = "After binning final map" + self._memreport.apply(self._data, use_accel=self._use_accel) - # Write and delete the outputs + return - if not self.save_cleaned: - Delete( - detdata=[ - self.clean_name, - ] - ).apply(data, use_accel=use_accel) + @function_timer + def _purge_cleaned_tod(self): + """If the cleaned TOD is not being returned, purge it""" - # FIXME: This I/O technique assumes "known" types of pixel representations. - # Instead, we should associate read / write functions to a particular pixel - # class. + if self.save_cleaned: + return - is_pix_wcs = hasattr(map_binning.pixel_pointing, "wcs") - is_hpix_nest = None - if not is_pix_wcs: - is_hpix_nest = map_binning.pixel_pointing.nest + del_tod = Delete(detdata=[self.clean_name]) + del_tod.apply(self._data, use_accel=self._use_accel) - mc_root = self.name - if self.mc_mode: - if self.mc_root is not None: - mc_root += f"_{self.mc_root}" - if self.mc_index is not None: - mc_root += f"_{self.mc_index:05d}" + self._memreport.prefix = "After purging cleaned TOD" + self._memreport.apply(self._data, use_accel=self._use_accel) - write_del = list() - write_del.append((self.hits_name, self.write_hits, False, self.name)) - write_del.append((self.rcond_name, self.write_rcond, False, self.name)) - write_del.append( - (self.noiseweighted_map_name, self.write_noiseweighted_map, True, mc_root) - ) - write_del.append((self.binmap_name, self.write_binmap, True, mc_root)) - write_del.append((self.map_name, self.write_map, True, mc_root)) - write_del.append((self.invcov_name, self.write_invcov, False, self.name)) - write_del.append((self.cov_name, self.write_cov, False, self.name)) - wtimer = Timer() - wtimer.start() - for prod_key, prod_write, force, rootname in write_del: - product = prod_key.replace(f"{self.name}_", "") - if prod_write: - if is_pix_wcs: - fname = os.path.join(self.output_dir, f"{rootname}_{product}.fits") - if self.mc_mode and not force: - if os.path.isfile(fname): - log.info_rank(f"Skipping existing file: {fname}", comm=comm) - continue - write_wcs_fits(data[prod_key], fname) - else: - if self.write_hdf5: - # Non-standard HDF5 output - fname = os.path.join( - self.output_dir, f"{rootname}_{product}.h5" - ) - if self.mc_mode and not force: - if os.path.isfile(fname): - log.info_rank( - f"Skipping existing file: {fname}", comm=comm - ) - continue - write_healpix_hdf5( - data[prod_key], - fname, - nest=is_hpix_nest, - single_precision=True, - force_serial=self.write_hdf5_serial, - ) - else: - # Standard FITS output - fname = os.path.join( - self.output_dir, f"{rootname}_{product}.fits" - ) - if self.mc_mode and not force: - if os.path.isfile(fname): - log.info_rank( - f"Skipping existing file: {fname}", comm=comm - ) - continue - write_healpix_fits( - data[prod_key], - fname, - nest=is_hpix_nest, - report_memory=self.report_memory, - ) - log.info_rank(f"Wrote {fname} in", comm=comm, timer=wtimer) - if not self.keep_final_products and not self.mc_mode: - if prod_key in data: - data[prod_key].clear() - del data[prod_key] + return - memreport.prefix = f"After writing/deleting {prod_key}" - memreport.apply(data, use_accel=use_accel) + @function_timer + def _write_maps(self): + """Write and delete the outputs""" + + self._write_del( + self.noiseweighted_map_name, + self.write_noiseweighted_map, + True, + self._mc_root, + ) + self._write_del(self.map_name, self.write_map, True, self._mc_root) + self._write_del(self.cov_name, self.write_cov, False, self.name) - log.info_rank( - f"{log_prefix} finished output write in", - comm=comm, - timer=timer, + self._log.info_rank( + f"{self._log_prefix} finished output write in", + comm=self._comm, + timer=self._timer, ) - memreport.prefix = "End of mapmaking" - memreport.apply(data, use_accel=use_accel) + return + + @function_timer + def _closeout(self): + """Explicitly delete members used by the _exec() method""" + + del self._log + del self._timer + del self._log_prefix + del self._mc_root + del self._data + del self._detectors + del self._use_accel + del self._memreport + del self._comm + del self._rank + + return + + @function_timer + def _exec(self, data, detectors=None, use_accel=None, **kwargs): + self._setup(data, detectors, use_accel) + + self._memreport.prefix = "Start of mapmaking" + self._memreport.apply(self._data, use_accel=self._use_accel) + + template_amplitudes = self._fit_templates() + + map_binning = self._prepare_binning() + + self._build_pixel_covariance(map_binning) + + self._bin_and_write_raw_signal(map_binning) + + out_cleaned = self._clean_signal(template_amplitudes) + + self._bin_cleaned_signal(map_binning, out_cleaned) + + self._purge_cleaned_tod() # Potentially frees memory for writing maps + + self._write_maps() + + self._memreport.prefix = "End of mapmaking" + self._memreport.apply(self._data, use_accel=self._use_accel) + + self._closeout() return diff --git a/src/toast/ops/mapmaker_solve.py b/src/toast/ops/mapmaker_solve.py index 1fb27899a..894e8d90e 100644 --- a/src/toast/ops/mapmaker_solve.py +++ b/src/toast/ops/mapmaker_solve.py @@ -1,4 +1,4 @@ -# Copyright (c) 2015-2020 by the parties listed in the AUTHORS file. +# Copyright (c) 2015-2024 by the parties listed in the AUTHORS file. # All rights reserved. Use of this source code is governed by # a BSD-style license that can be found in the LICENSE file. @@ -112,15 +112,11 @@ def _exec(self, data, detectors=None, **kwargs): comm = data.comm.comm_world rank = data.comm.world_rank - # Check that the inputs are set - if self.det_data is None: - raise RuntimeError("You must set the det_data trait before calling exec()") - if self.binning is None: - raise RuntimeError("You must set the binning trait before calling exec()") - if self.template_matrix is None: - raise RuntimeError( - "You must set the template_matrix trait before calling exec()" - ) + # Check that input traits are set + for trait in ("det_data", "binning", "template_matrix"): + if getattr(self, trait) is None: + msg = f"You must set the '{trait}' trait before calling exec()" + raise RuntimeError(msg) # Make a binned map @@ -352,14 +348,10 @@ def _exec(self, data, detectors=None, **kwargs): rank = data.comm.world_rank # Check that input traits are set - if self.binning is None: - raise RuntimeError("You must set the binning trait before calling exec()") - if self.template_matrix is None: - raise RuntimeError( - "You must set the template_matrix trait before calling exec()" - ) - if self.out is None: - raise RuntimeError("You must set the 'out' trait before calling exec()") + for trait in ("binning", "template_matrix", "out"): + if getattr(self, trait) is None: + msg = f"You must set the '{trait}' trait before calling exec()" + raise RuntimeError(msg) # Clear temp detector data if it exists for ob in data.obs: @@ -571,7 +563,7 @@ def solve( rank = data.comm.world_rank if rhs_key not in data: - msg = "rhs_key '{}' does not exist in data".format(rhs_key) + msg = f"rhs_key '{rhs_key}' does not exist in data" log.error(msg) raise RuntimeError(msg) rhs = data[rhs_key] @@ -590,18 +582,10 @@ def solve( raise RuntimeError("starting guess must have same keys as RHS") for k, v in result.items(): if v.n_global != rhs[k].n_global: - msg = ( - "starting guess['{}'] has different n_global than rhs['{}']".format( - k, k - ) - ) + msg = f"starting guess['{k}'] has different n_global than rhs['{k}']" raise RuntimeError(msg) if v.n_local != rhs[k].n_local: - msg = ( - "starting guess['{}'] has different n_global than rhs['{}']".format( - k, k - ) - ) + msg = f"starting guess['{k}'] has different n_global than rhs['{k}']" raise RuntimeError(msg) # Solving A * x = b ... @@ -614,7 +598,7 @@ def solve( residual = None # The result of the LHS operator "q" - lhs_out_key = "{}_out".format(lhs_op.name) + lhs_out_key = f"{lhs_op.name}_out" if lhs_out_key in data: data[lhs_out_key].clear() del data[lhs_out_key] @@ -625,7 +609,7 @@ def solve( precond = None # The new proposed direction "d" - proposal_key = "{}_in".format(lhs_op.name) + proposal_key = f"{lhs_op.name}_in" if proposal_key in data: data[proposal_key].clear() del data[proposal_key] @@ -673,16 +657,9 @@ def solve( delta = proposal.dot(residual) delta_init = delta - if comm is not None: - comm.barrier() - timer.stop() - if rank == 0: - msg = "MapMaker initial residual = {}, {:0.2f} s".format( - sqsum_init, timer.seconds() - ) - log.info(msg) - timer.clear() - timer.start() + log.info_rank( + f"MapMaker initial residual = {sqsum_init} in", comm=comm, timer=timer + ) for iter in range(n_iter_max): if not np.isfinite(sqsum): @@ -720,26 +697,20 @@ def solve( sqsum = residual.dot(residual) # print(f"{comm.rank} LHS {iter}: sqsum = {sqsum}", flush=True) - if comm is not None: - comm.barrier() - timer.stop() - if rank == 0: - msg = "MapMaker iteration {:4d}, relative residual = {:0.6e}, {:0.2f} s".format( - iter, sqsum / sqsum_init, timer.seconds() - ) - log.info(msg) - timer.clear() - timer.start() + relative = sqsum / sqsum_init + log.info_rank( + f"MapMaker iteration {iter:4d}, relative residual = {relative:0.6e} in", + comm=comm, + timer=timer, + ) # Check for convergence - if (sqsum / sqsum_init) < convergence or sqsum < 1e-30: - timer.stop() - timer_full.stop() - if rank == 0: - msg = "MapMaker PCG converged after {:4d} iterations and {:0.2f} seconds".format( - iter, timer_full.seconds() - ) - log.info(msg) + if relative < convergence or sqsum < 1e-30: + log.info_rank( + f"MapMaker PCG converged after {iter:4d} iterations and", + comm=comm, + timer=timer_full, + ) break sqsum_best = min(sqsum, sqsum_best) @@ -747,13 +718,11 @@ def solve( # Check for stall / divergence if iter % 10 == 0 and iter >= n_iter_min: if last_best < sqsum_best * 2: - timer.stop() - timer_full.stop() - if rank == 0: - msg = "MapMaker PCG stalled after {:4d} iterations and {:0.2f} seconds".format( - iter, timer_full.seconds() - ) - log.info(msg) + log.info_rank( + f"MapMaker PCG stalled after {iter:4d} iterations and", + comm=comm, + timer=timer_full, + ) break last_best = sqsum_best @@ -774,3 +743,12 @@ def solve( # d = s + beta * d proposal *= beta proposal += precond + + # Delete the temporary objects + del temp + del proposal + del data[proposal_key] + del lhs_out + del data[lhs_out_key] + + return diff --git a/src/toast/ops/mapmaker_templates.py b/src/toast/ops/mapmaker_templates.py index 1883ad583..0fcfcd62d 100644 --- a/src/toast/ops/mapmaker_templates.py +++ b/src/toast/ops/mapmaker_templates.py @@ -1,4 +1,4 @@ -# Copyright (c) 2015-2020 by the parties listed in the AUTHORS file. +# Copyright (c) 2015-2024 by the parties listed in the AUTHORS file. # All rights reserved. Use of this source code is governed by # a BSD-style license that can be found in the LICENSE file. @@ -470,7 +470,8 @@ class SolveAmplitudes(Operator): mask = Unicode( None, allow_none=True, - help="Data key for pixel mask to use in solving. First bit of pixel values is tested", + help="Data key for pixel mask to use in solving. " + "First bit of pixel values is tested", ) binning = Instance( @@ -548,333 +549,423 @@ def __init__(self, **kwargs): super().__init__(**kwargs) @function_timer - def _exec(self, data, detectors=None, **kwargs): - log = Logger.get() - timer = Timer() - log_prefix = "SolveAmplitudes" + def _write_del(self, prod_key): + """Write and optionally delete map object""" - # Check if we have any templates - if ( - self.template_matrix is None - or self.template_matrix.n_enabled_templates == 0 - ): - return + # FIXME: This I/O technique assumes "known" types of pixel representations. + # Instead, we should associate read / write functions to a particular pixel + # class. - memreport = MemoryCounter() - if not self.report_memory: - memreport.enabled = False + is_pix_wcs = hasattr(self.binning.pixel_pointing, "wcs") + is_hpix_nest = None + if not is_pix_wcs: + is_hpix_nest = self.binning.pixel_pointing.nest + + if self.write_solver_products: + if is_pix_wcs: + fname = os.path.join(self.output_dir, f"{prod_key}.fits") + write_wcs_fits(self._data[prod_key], fname) + else: + if self.write_hdf5: + # Non-standard HDF5 output + fname = os.path.join(self.output_dir, f"{prod_key}.h5") + write_healpix_hdf5( + self._data[prod_key], + fname, + nest=is_hpix_nest, + single_precision=True, + force_serial=self.write_hdf5_serial, + ) + else: + # Standard FITS output + fname = os.path.join(self.output_dir, f"{prod_key}.fits") + write_healpix_fits( + self._data[prod_key], + fname, + nest=is_hpix_nest, + report_memory=self.report_memory, + ) - memreport.prefix = "Start of amplitude solve" - memreport.apply(data) + if not self.mc_mode and not self.keep_solver_products: + if prod_key in self._data: + self._data[prod_key].clear() + del self._data[prod_key] - # The global communicator we are using (or None) - comm = data.comm.comm_world - rank = data.comm.world_rank + self._memreport.prefix = f"After writing/deleting {prod_key}" + self._memreport.apply(self._data, use_accel=self._use_accel) - # Optionally destroy existing pixel distributions (useful if calling - # repeatedly with different data objects) - if self.reset_pix_dist: - if self.binning.pixel_dist in data: - del data[self.binning.pixel_dist] + return - memreport.prefix = "After resetting pixel distribution" - memreport.apply(data) + @function_timer + def _setup(self, data, detectors, use_accel): + """Set up convenience members used in the _exec() method""" + + self._log = Logger.get() + self._timer = Timer() + self._log_prefix = "SolveAmplitudes" + + self._data = data + self._detectors = detectors + self._use_accel = use_accel + self._memreport = MemoryCounter() + if not self.report_memory: + self._memreport.enabled = False + + # The global communicator we are using (or None) + self._comm = data.comm.comm_world + self._rank = data.comm.world_rank # Get the units used across the distributed data for our desired # input detector data - det_data_units = data.detector_units(self.det_data) + self._det_data_units = data.detector_units(self.det_data) # We use the input binning operator to define the flags that the user has # specified. We will save the name / bit mask for these and restore them later. # Then we will use the binning operator with our solver flags. These input # flags are combined to the first bit (== 1) of the solver flags. - save_det_flags = self.binning.det_flags - save_det_mask = self.binning.det_mask - save_det_flag_mask = self.binning.det_flag_mask - save_shared_flags = self.binning.shared_flags - save_shared_flag_mask = self.binning.shared_flag_mask - save_binned = self.binning.binned - save_covariance = self.binning.covariance + self._save_det_flags = self.binning.det_flags + self._save_det_mask = self.binning.det_mask + self._save_det_flag_mask = self.binning.det_flag_mask + self._save_shared_flags = self.binning.shared_flags + self._save_shared_flag_mask = self.binning.shared_flag_mask + self._save_binned = self.binning.binned + self._save_covariance = self.binning.covariance - save_tmpl_flags = self.template_matrix.det_flags - save_tmpl_mask = self.template_matrix.det_mask - save_tmpl_det_mask = self.template_matrix.det_flag_mask + self._save_tmpl_flags = self.template_matrix.det_flags + self._save_tmpl_mask = self.template_matrix.det_mask + self._save_tmpl_det_mask = self.template_matrix.det_flag_mask - # The pointing matrix used for the solve. The per-detector flags - # are normally reset when the binner is run, but here we set them - # explicitly since we will use these pointing matrix operators for - # setting up the solver flags below. - solve_pixels = self.binning.pixel_pointing - solve_weights = self.binning.stokes_weights - solve_pixels.detector_pointing.det_mask = save_det_mask - solve_pixels.detector_pointing.det_flag_mask = save_det_flag_mask - if hasattr(solve_weights, "detector_pointing"): - solve_weights.detector_pointing.det_mask = save_det_mask - solve_weights.detector_pointing.det_flag_mask = save_det_flag_mask + # Use the same data view as the pointing operator in binning + self._solve_view = self.binning.pixel_pointing.view # Output data products, prefixed with the name of the operator and optionally # the MC index. - mc_root = None if self.mc_mode and self.mc_index is not None: - mc_root = "{}_{:05d}".format(self.name, self.mc_index) + self._mc_root = "{self.name}_{self.mc_index:05d}" else: - mc_root = self.name + self._mc_root = self.name - self.solver_flags = "{}_solve_flags".format(self.name) - self.solver_hits_name = "{}_solve_hits".format(self.name) - self.solver_cov_name = "{}_solve_cov".format(self.name) - self.solver_rcond_name = "{}_solve_rcond".format(self.name) - self.solver_rcond_mask_name = "{}_solve_rcond_mask".format(self.name) - self.solver_rhs = "{}_solve_rhs".format(mc_root) - self.solver_bin = "{}_solve_bin".format(mc_root) + self.solver_flags = f"{self.name}_solve_flags" + self.solver_hits_name = f"{self.name}_solve_hits" + self.solver_cov_name = f"{self.name}_solve_cov" + self.solver_rcond_name = f"{self.name}_solve_rcond" + self.solver_rcond_mask_name = f"{self.name}_solve_rcond_mask" + self.solver_rhs = f"{self._mc_root}_solve_rhs" + self.solver_bin = f"{self._mc_root}_solve_bin" if self.amplitudes is None: - self.amplitudes = "{}_solve_amplitudes".format(mc_root) + self.amplitudes = f"{self._mc_root}_solve_amplitudes" + + return + + @function_timer + def _prepare_pixels(self): + """Optionally destroy existing pixel distributions (useful if calling + repeatedly with different data objects) + """ - timer.start() + if self.reset_pix_dist: + if self.binning.pixel_dist in self._data: + del self._data[self.binning.pixel_dist] - # Flagging. We create a new set of data flags for the solver that includes: - # - one bit for a bitwise OR of all detector / shared flags - # - one bit for any pixel mask, projected to TOD - # - one bit for any poorly conditioned pixels, projected to TOD + self._memreport.prefix = "After resetting pixel distribution" + self._memreport.apply(self._data) + + # The pointing matrix used for the solve. The per-detector flags + # are normally reset when the binner is run, but here we set them + # explicitly since we will use these pointing matrix operators for + # setting up the solver flags below. + solve_pixels = self.binning.pixel_pointing + solve_weights = self.binning.stokes_weights + solve_pixels.detector_pointing.det_mask = self._save_det_mask + solve_pixels.detector_pointing.det_flag_mask = self._save_det_flag_mask + if hasattr(solve_weights, "detector_pointing"): + solve_weights.detector_pointing.det_mask = self._save_det_mask + solve_weights.detector_pointing.det_flag_mask = self._save_det_flag_mask + + # Set up a pipeline to scan processing and condition number masks + self._scanner = ScanMask( + det_flags=self.solver_flags, + det_mask=self._save_det_mask, + det_flag_mask=self._save_det_flag_mask, + pixels=solve_pixels.pixels, + view=self._solve_view, + ) + scan_pipe = Pipeline( + detector_sets=["SINGLE"], operators=[solve_pixels, self._scanner] + ) + + return solve_pixels, solve_weights, scan_pipe + + @function_timer + def _prepare_flagging_ob(self, ob): + """Process a single observation, used by _prepare_flagging + + Copies and masks existing flags + """ + + # Get the detectors we are using for this observation + dets = ob.select_local_detectors(self._detectors, flagmask=self._save_det_mask) + if len(dets) == 0: + # Nothing to do for this observation + return if self.mc_mode: - # Verify that our flags exist - for ob in data.obs: - # Get the detectors we are using for this observation - dets = ob.select_local_detectors(detectors, flagmask=save_det_mask) - if len(dets) == 0: - # Nothing to do for this observation - continue - if self.solver_flags not in ob.detdata: - msg = "In MC mode, solver flags missing for observation {}".format( - ob.name - ) - log.error(msg) + # Shortcut, just verify that our flags exist + if self.solver_flags not in ob.detdata: + msg = f"In MC mode, solver flags missing for observation {ob.name}" + self._log.error(msg) + raise RuntimeError(msg) + det_check = set(ob.detdata[self.solver_flags].detectors) + for d in dets: + if d not in det_check: + msg = "In MC mode, solver flags missing for " + msg + f"observation {ob.name}, det {d}" + self._log.error(msg) raise RuntimeError(msg) - det_check = set(ob.detdata[self.solver_flags].detectors) - for d in dets: - if d not in det_check: - msg = "In MC mode, solver flags missing for observation {}, det {}".format( - ob.name, d + return + + # Create the new solver flags + exists = ob.detdata.ensure(self.solver_flags, dtype=np.uint8, detectors=dets) + + # The data views + views = ob.view[self._solve_view] + # For each view... + for vw in range(len(views)): + view_samples = None + if views[vw].start is None: + # There is one view of the whole obs + view_samples = ob.n_local_samples + else: + view_samples = views[vw].stop - views[vw].start + starting_flags = np.zeros(view_samples, dtype=np.uint8) + if self._save_shared_flags is not None: + starting_flags[:] = np.where( + ( + views.shared[self._save_shared_flags][vw] + & self._save_shared_flag_mask + ) + > 0, + 1, + 0, + ) + for d in dets: + views.detdata[self.solver_flags][vw][d, :] = starting_flags + if self._save_det_flags is not None: + views.detdata[self.solver_flags][vw][d, :] |= np.where( + ( + views.detdata[self._save_det_flags][vw][d] + & self._save_det_flag_mask ) - log.error(msg) - raise RuntimeError(msg) - log.info_rank(f"{log_prefix} MC mode, reusing flags for solver", comm=comm) + > 0, + 1, + 0, + ).astype(views.detdata[self.solver_flags][vw].dtype) + + return + + @function_timer + def _prepare_flagging(self, solve_pixels): + """Flagging. We create a new set of data flags for the solver that includes: + - one bit for a bitwise OR of all detector / shared flags + - one bit for any pixel mask, projected to TOD + - one bit for any poorly conditioned pixels, projected to TOD + """ + + if self.mc_mode: + msg = f"{self._log_prefix} begin verifying flags for solver" else: - log.info_rank(f"{log_prefix} begin building flags for solver", comm=comm) + msg = f"{self._log_prefix} begin building flags for solver" + self._log.info_rank(msg, comm=self._comm) - # Use the same data view as the pointing operator in binning - solve_view = self.binning.pixel_pointing.view + for ob in self._data.obs: + self._prepare_flagging_ob(ob) - for ob in data.obs: - # Get the detectors we are using for this observation - dets = ob.select_local_detectors(detectors, flagmask=save_det_mask) - if len(dets) == 0: - # Nothing to do for this observation - continue - # Create the new solver flags - exists = ob.detdata.ensure( - self.solver_flags, dtype=np.uint8, detectors=dets - ) - # The data views - views = ob.view[solve_view] - # For each view... - for vw in range(len(views)): - view_samples = None - if views[vw].start is None: - # There is one view of the whole obs - view_samples = ob.n_local_samples - else: - view_samples = views[vw].stop - views[vw].start - starting_flags = np.zeros(view_samples, dtype=np.uint8) - if save_shared_flags is not None: - starting_flags[:] = np.where( - ( - views.shared[save_shared_flags][vw] - & save_shared_flag_mask - ) - > 0, - 1, - 0, - ) - for d in dets: - views.detdata[self.solver_flags][vw][d, :] = starting_flags - if save_det_flags is not None: - views.detdata[self.solver_flags][vw][d, :] |= np.where( - ( - views.detdata[save_det_flags][vw][d] - & save_det_flag_mask - ) - > 0, - 1, - 0, - ).astype(views.detdata[self.solver_flags][vw].dtype) - - # Now scan any input mask to this same flag field. We use the second - # bit (== 2) for these mask flags. For the input mask bit we check the - # first bit of the pixel values. This is noted in the help string for - # the mask trait. Note that we explicitly expand the pointing once - # here and do not save it. Even if we are eventually saving the - # pointing, we want to do that later when building the covariance and - # the pixel distribution. - - scanner = ScanMask( - det_flags=self.solver_flags, - det_mask=save_det_mask, - det_flag_mask=save_det_flag_mask, - pixels=solve_pixels.pixels, - view=solve_view, + if self.mc_mode: + # Shortcut, just verified that our flags exist + self._log.info_rank( + f"{self._log_prefix} MC mode, reusing flags for solver", comm=comm ) + return - scanner.det_flags_value = 2 - scanner.mask_key = self.mask + # Now scan any input mask to this same flag field. We use the second + # bit (== 2) for these mask flags. For the input mask bit we check the + # first bit of the pixel values. This is noted in the help string for + # the mask trait. Note that we explicitly expand the pointing once + # here and do not save it. Even if we are eventually saving the + # pointing, we want to do that later when building the covariance and + # the pixel distribution. + + if self.mask is not None: + # We have a mask. Scan it. + self._scanner.det_flags_value = 2 + self._scanner.mask_key = self.mask + scan_pipe.apply(self._data, detectors=self._detectors) + + self._log.info_rank( + f"{self._log_prefix} finished flag building in", + comm=self._comm, + timer=self._timer, + ) - scan_pipe = Pipeline( - detector_sets=["SINGLE"], operators=[solve_pixels, scanner] - ) + self._memreport.prefix = "After building flags" + self._memreport.apply(self._data) - if self.mask is not None: - # We have a mask. Scan it. - scan_pipe.apply(data, detectors=detectors) + return - log.info_rank( - f"{log_prefix} finished flag building in", - comm=comm, - timer=timer, + def _count_cut_data(self): + """Collect and report statistics about cut data""" + local_total = 0 + local_cut = 0 + for ob in self._data.obs: + # Get the detectors we are using for this observation + dets = ob.select_local_detectors( + self._detectors, flagmask=self._save_det_mask ) + if len(dets) == 0: + # Nothing to do for this observation + continue + for vw in ob.view[self._solve_view].detdata[self.solver_flags]: + for d in dets: + local_total += len(vw[d]) + local_cut += np.count_nonzero(vw[d]) + + if self._comm is None: + total = local_total + cut = local_cut + else: + total = self._comm.allreduce(local_total, op=MPI.SUM) + cut = self._comm.allreduce(local_cut, op=MPI.SUM) - memreport.prefix = "After building flags" - memreport.apply(data) + frac = 100.0 * (cut / total) + msg = f"Solver flags cut {cut } / {total} = {frac:0.2f}% of samples" + self._log.info_rank(f"{self._log_prefix} {msg}", comm=self._comm) - # Now construct the noise covariance, hits, and condition number mask for - # the solver. + return + + @function_timer + def _get_pixel_covariance(self, solve_pixels, solve_weights): + """Construct the noise covariance, hits, and condition number map for + the solver. + """ if self.mc_mode: - # Verify that our covariance and other products exist. - if self.binning.pixel_dist not in data: - msg = f"MC mode, pixel distribution '{self.binning.pixel_dist}' does not exist" - log.error(msg) + # Shortcut, verify that our covariance and other products exist. + if self.binning.pixel_dist not in self._data: + msg = f"MC mode, pixel distribution " + msg += f"'{self.binning.pixel_dist}' does not exist" + self._log.error(msg) raise RuntimeError(msg) - if self.solver_cov_name not in data: + if self.solver_cov_name not in self._data: msg = f"MC mode, covariance '{self.solver_cov_name}' does not exist" - log.error(msg) + self._log.error(msg) raise RuntimeError(msg) - - log.info_rank( - f"{log_prefix} MC mode, reusing covariance for solver", - comm=comm, - ) - else: - log.info_rank( - f"{log_prefix} begin build of solver covariance", - comm=comm, + self._log.info_rank( + f"{self._log_prefix} MC mode, reusing covariance for solver", + comm=self._comm, ) + return - solver_cov = CovarianceAndHits( - pixel_dist=self.binning.pixel_dist, - covariance=self.solver_cov_name, - hits=self.solver_hits_name, - rcond=self.solver_rcond_name, - det_data_units=det_data_units, - det_mask=save_det_mask, - det_flags=self.solver_flags, - det_flag_mask=255, - pixel_pointing=solve_pixels, - stokes_weights=solve_weights, - noise_model=self.binning.noise_model, - rcond_threshold=self.solve_rcond_threshold, - sync_type=self.binning.sync_type, - save_pointing=self.binning.full_pointing, - ) + self._log.info_rank( + f"{self._log_prefix} begin build of solver covariance", + comm=self._comm, + ) - solver_cov.apply(data, detectors=detectors) + solver_cov = CovarianceAndHits( + pixel_dist=self.binning.pixel_dist, + covariance=self.solver_cov_name, + hits=self.solver_hits_name, + rcond=self.solver_rcond_name, + det_data_units=self._det_data_units, + det_mask=self._save_det_mask, + det_flags=self.solver_flags, + det_flag_mask=255, + pixel_pointing=solve_pixels, + stokes_weights=solve_weights, + noise_model=self.binning.noise_model, + rcond_threshold=self.solve_rcond_threshold, + sync_type=self.binning.sync_type, + save_pointing=self.binning.full_pointing, + ) - memreport.prefix = "After constructing covariance and hits" - memreport.apply(data) + solver_cov.apply(self._data, detectors=self._detectors) - data[self.solver_rcond_mask_name] = PixelData( - data[self.binning.pixel_dist], dtype=np.uint8, n_value=1 - ) - n_bad = np.count_nonzero( - data[self.solver_rcond_name].data < self.solve_rcond_threshold - ) - n_good = data[self.solver_rcond_name].data.size - n_bad - data[self.solver_rcond_mask_name].data[ - data[self.solver_rcond_name].data < self.solve_rcond_threshold - ] = 1 + self._memreport.prefix = "After constructing covariance and hits" + self._memreport.apply(self._data) - memreport.prefix = "After constructing rcond mask" - memreport.apply(data) + return - # Re-use our mask scanning pipeline, setting third bit (== 4) - scanner.det_flags_value = 4 - scanner.mask_key = self.solver_rcond_mask_name + @function_timer + def _get_rcond_mask(self, scan_pipe): + """Construct the noise covariance, hits, and condition number mask for + the solver. + """ - scan_pipe.apply(data, detectors=detectors) + if self.mc_mode: + # The flags are already cached + return - log.info_rank( - f"{log_prefix} finished build of solver covariance in", - comm=comm, - timer=timer, - ) + self._log.info_rank( + f"{self._log_prefix} begin build of rcond flags", + comm=self._comm, + ) - local_total = 0 - local_cut = 0 - for ob in data.obs: - # Get the detectors we are using for this observation - dets = ob.select_local_detectors(detectors, flagmask=save_det_mask) - if len(dets) == 0: - # Nothing to do for this observation - continue - for vw in ob.view[solve_view].detdata[self.solver_flags]: - for d in dets: - local_total += len(vw[d]) - local_cut += np.count_nonzero(vw[d]) - total = None - cut = None - msg = None - if comm is None: - total = local_total - cut = local_cut - msg = "Solver flags cut {} / {} = {:0.2f}% of samples".format( - cut, total, 100.0 * (cut / total) - ) - else: - total = comm.reduce(local_total, op=MPI.SUM, root=0) - cut = comm.reduce(local_cut, op=MPI.SUM, root=0) - if comm.rank == 0: - msg = "Solver flags cut {} / {} = {:0.2f}% of samples".format( - cut, total, 100.0 * (cut / total) - ) - log.info_rank( - f"{log_prefix} {msg}", - comm=comm, - ) + # Translate the rcond map into a mask + self._data[self.solver_rcond_mask_name] = PixelData( + self._data[self.binning.pixel_dist], dtype=np.uint8, n_value=1 + ) + rcond = self._data[self.solver_rcond_name].data + rcond_mask = self._data[self.solver_rcond_mask_name].data + bad = rcond < self.solve_rcond_threshold + n_bad = np.count_nonzero(bad) + n_good = rcond.size - n_bad + rcond_mask[bad] = 1 + + # No more need for the rcond map + self._write_del(self.solver_rcond_name) + + self._memreport.prefix = "After constructing rcond mask" + self._memreport.apply(self._data) + + # Re-use our mask scanning pipeline, setting third bit (== 4) + self._scanner.det_flags_value = 4 + self._scanner.mask_key = self.solver_rcond_mask_name + scan_pipe.apply(self._data, detectors=self._detectors) + + self._log.info_rank( + f"{self._log_prefix} finished build of solver covariance in", + comm=self._comm, + timer=self._timer, + ) - # Compute the RHS. Overwrite inputs, either the original or the copy. + self._count_cut_data() # Report statistics - log.info_rank( - f"{log_prefix} begin RHS calculation", - comm=comm, + return + + @function_timer + def _get_rhs(self): + """Compute the RHS. Overwrite inputs, either the original or the copy""" + + self._log.info_rank( + f"{self._log_prefix} begin RHS calculation", comm=self._comm ) # Initialize the template matrix self.template_matrix.det_data = self.det_data - self.template_matrix.det_data_units = det_data_units + self.template_matrix.det_data_units = self._det_data_units self.template_matrix.det_flags = self.solver_flags - self.template_matrix.det_mask = save_det_mask + self.template_matrix.det_mask = self._save_det_mask self.template_matrix.det_flag_mask = 255 self.template_matrix.view = self.binning.pixel_pointing.view - self.template_matrix.initialize(data) + self.template_matrix.initialize(self._data) # Set our binning operator to use only our new solver flags self.binning.shared_flag_mask = 0 self.binning.det_flags = self.solver_flags self.binning.det_flag_mask = 255 - self.binning.det_data_units = det_data_units + self.binning.det_data_units = self._det_data_units # Set the binning operator to output to temporary map. This will be # overwritten on each iteration of the solver. @@ -886,31 +977,37 @@ def _exec(self, data, detectors=None, **kwargs): rhs_calc = SolverRHS( name=f"{self.name}_rhs", det_data=self.det_data, - det_data_units=det_data_units, + det_data_units=self._det_data_units, binning=self.binning, template_matrix=self.template_matrix, ) - rhs_calc.apply(data, detectors=detectors) + rhs_calc.apply(self._data, detectors=self._detectors) - log.info_rank( - f"{log_prefix} finished RHS calculation in", - comm=comm, - timer=timer, + self._log.info_rank( + f"{self._log_prefix} finished RHS calculation in", + comm=self._comm, + timer=self._timer, ) - memreport.prefix = "After constructing RHS" - memreport.apply(data) + self._memreport.prefix = "After constructing RHS" + self._memreport.apply(self._data) + + return + + @function_timer + def _solve_amplitudes(self): + """Solve the destriping equation""" # Set up the LHS operator. - log.info_rank( - f"{log_prefix} begin PCG solver", - comm=comm, + self._log.info_rank( + f"{self._log_prefix} begin PCG solver", + comm=self._comm, ) lhs_calc = SolverLHS( - name="{}_lhs".format(self.name), - det_data_units=det_data_units, + name=f"{self.name}_lhs", + det_data_units=self._det_data_units, binning=self.binning, template_matrix=self.template_matrix, ) @@ -921,8 +1018,8 @@ def _exec(self, data, detectors=None, **kwargs): # Solve for amplitudes. solve( - data, - detectors, + self._data, + self._detectors, lhs_calc, self.solver_rhs, self.amplitudes, @@ -931,90 +1028,105 @@ def _exec(self, data, detectors=None, **kwargs): n_iter_max=self.iter_max, ) - log.info_rank( - f"{log_prefix} finished solver in", - comm=comm, - timer=timer, + self._log.info_rank( + f"{self._log_prefix} finished solver in", + comm=self._comm, + timer=self._timer, ) - memreport.prefix = "After solving for amplitudes" - memreport.apply(data) - - # FIXME: This I/O technique assumes "known" types of pixel representations. - # Instead, we should associate read / write functions to a particular pixel - # class. - - is_pix_wcs = hasattr(self.binning.pixel_pointing, "wcs") - is_hpix_nest = None - if not is_pix_wcs: - is_hpix_nest = self.binning.pixel_pointing.nest + self._memreport.prefix = "After solving for amplitudes" + self._memreport.apply(self._data) - write_del = [ - self.solver_hits_name, - self.solver_cov_name, - self.solver_rcond_name, - self.solver_rcond_mask_name, - self.solver_bin, - ] - for prod_key in write_del: - if self.write_solver_products: - if is_pix_wcs: - fname = os.path.join(self.output_dir, "{}.fits".format(prod_key)) - write_wcs_fits(data[prod_key], fname) - else: - if self.write_hdf5: - # Non-standard HDF5 output - fname = os.path.join(self.output_dir, "{}.h5".format(prod_key)) - write_healpix_hdf5( - data[prod_key], - fname, - nest=is_hpix_nest, - single_precision=True, - force_serial=self.write_hdf5_serial, - ) - else: - # Standard FITS output - fname = os.path.join( - self.output_dir, "{}.fits".format(prod_key) - ) - write_healpix_fits( - data[prod_key], - fname, - nest=is_hpix_nest, - report_memory=self.report_memory, - ) - if not self.mc_mode and not self.keep_solver_products: - if prod_key in data: - data[prod_key].clear() - del data[prod_key] + return - if not self.mc_mode and not self.keep_solver_products: - if self.solver_rhs in data: - data[self.solver_rhs].clear() - del data[self.solver_rhs] - for ob in data.obs: - del ob.detdata[self.solver_flags] + @function_timer + def _cleanup(self): + """Clean up convenience members for _exec()""" # Restore flag names and masks to binning operator, in case it is being used # for the final map making or for other external operations. - self.binning.det_flags = save_det_flags - self.binning.det_mask = save_det_mask - self.binning.det_flag_mask = save_det_flag_mask - self.binning.shared_flags = save_shared_flags - self.binning.shared_flag_mask = save_shared_flag_mask - self.binning.binned = save_binned - self.binning.covariance = save_covariance - - self.template_matrix.det_flags = save_tmpl_flags - self.template_matrix.det_flag_mask = save_tmpl_det_mask - self.template_matrix.det_mask = save_tmpl_mask + self.binning.det_flags = self._save_det_flags + self.binning.det_mask = self._save_det_mask + self.binning.det_flag_mask = self._save_det_flag_mask + self.binning.shared_flags = self._save_shared_flags + self.binning.shared_flag_mask = self._save_shared_flag_mask + self.binning.binned = self._save_binned + self.binning.covariance = self._save_covariance + + self.template_matrix.det_flags = self._save_tmpl_flags + self.template_matrix.det_flag_mask = self._save_tmpl_det_mask + self.template_matrix.det_mask = self._save_tmpl_mask # FIXME: this reset does not seem needed # if not self.mc_mode: # self.template_matrix.reset_templates() - memreport.prefix = "End of amplitude solve" - memreport.apply(data) + del self._solve_view + + # Delete members used by the _exec() method + del self._log + del self._timer + del self._log_prefix + + del self._data + del self._detectors + del self._use_accel + del self._memreport + + del self._comm + del self._rank + + del self._det_data_units + + del self._mc_root + + del self._scanner + + return + + @function_timer + def _exec(self, data, detectors=None, use_accel=None, **kwargs): + # Check if we have any templates + if ( + self.template_matrix is None + or self.template_matrix.n_enabled_templates == 0 + ): + return + + self._setup(data, detectors, use_accel) + + self._memreport.prefix = "Start of amplitude solve" + self._memreport.apply(self._data) + + solve_pixels, solve_weights, scan_pipe = self._prepare_pixels() + + self._timer.start() + + self._prepare_flagging(solve_pixels) + + self._get_pixel_covariance(solve_pixels, solve_weights) + self._write_del(self.solver_hits_name) + + self._get_rcond_mask(scan_pipe) + self._write_del(self.solver_rcond_mask_name) + + self._get_rhs() + self._solve_amplitudes() + + self._write_del(self.solver_cov_name) + self._write_del(self.solver_bin) + + if not self.mc_mode and not self.keep_solver_products: + if self.solver_rhs in self._data: + self._data[self.solver_rhs].clear() + del self._data[self.solver_rhs] + for ob in self._data.obs: + del ob.detdata[self.solver_flags] + + self._memreport.prefix = "End of amplitude solve" + self._memreport.apply(self._data) + + self._cleanup() return @@ -1022,7 +1134,7 @@ def _finalize(self, data, **kwargs): return def _requires(self): - # This operator requires everything that its sub-operators needs. + # This operator requires everything that its sub-operators need. req = self.binning.requires() if self.template_matrix is not None: req.update(self.template_matrix.requires()) @@ -1156,7 +1268,7 @@ def _finalize(self, data, **kwargs): return def _requires(self): - # This operator requires everything that its sub-operators needs. + # This operator requires everything that its sub-operators need. req = dict() req["global"] = [self.amplitudes] req["detdata"] = list() diff --git a/src/toast/ops/pointing_detector/pointing_detector.py b/src/toast/ops/pointing_detector/pointing_detector.py index 1f24ddab5..eaea7ecdb 100644 --- a/src/toast/ops/pointing_detector/pointing_detector.py +++ b/src/toast/ops/pointing_detector/pointing_detector.py @@ -105,6 +105,7 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): implementation, use_accel = self.select_kernels(use_accel=use_accel) coord_rot = None + bore_suffix = "" if self.coord_in is None: if self.coord_out is not None: msg = "Input and output coordinate systems should both be None or valid" @@ -115,19 +116,61 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): raise RuntimeError(msg) if self.coord_in == "C": if self.coord_out == "E": - coord_rot = qa.equ2ecl + coord_rot = qa.equ2ecl() + bore_suffix = "_C2E" elif self.coord_out == "G": - coord_rot = qa.equ2gal + coord_rot = qa.equ2gal() + bore_suffix = "_C2G" elif self.coord_in == "E": if self.coord_out == "G": - coord_rot = qa.ecl2gal + coord_rot = qa.ecl2gal() + bore_suffix = "_E2G" elif self.coord_out == "C": - coord_rot = qa.inv(qa.equ2ecl) + coord_rot = qa.inv(qa.equ2ecl()) + bore_suffix = "_E2C" elif self.coord_in == "G": if self.coord_out == "C": - coord_rot = qa.inv(qa.equ2gal) + coord_rot = qa.inv(qa.equ2gal()) + bore_suffix = "_G2C" if self.coord_out == "E": - coord_rot = qa.inv(qa.ecl2gal) + coord_rot = qa.inv(qa.ecl2gal()) + bore_suffix = "_G2E" + + # Ensure that we have boresight pointing in the required coordinate + # frame. We will potentially re-use this boresight pointing for every + # iteration of the amplitude solver, so it makes sense to compute and + # store this. + bore_name = self.boresight + if bore_suffix != "": + bore_name = f"{self.boresight}{bore_suffix}" + for ob in data.obs: + if bore_name not in ob.shared: + # Does not yet exist, create it + ob.shared.create_column( + bore_name, + ob.shared[self.boresight].shape, + ob.shared[self.boresight].dtype, + ) + # First process in each column computes the quaternions + bore_rot = None + if ob.comm_col_rank == 0: + bore_rot = qa.mult(coord_rot, ob.shared[self.boresight].data) + ob.shared[bore_name].set(bore_rot, fromrank=0) + + # Ensure that our boresight data is on the right device. In the case of + # no coordinate rotation, this would already be done by the outer pipeline. + for ob in data.obs: + if use_accel: + if not ob.shared.accel_in_use(bore_name): + # Not currently on the device + if not ob.shared.accel_exists(bore_name): + # Does not even exist yet on the device + ob.shared.accel_create(bore_name) + ob.shared.accel_update_device(bore_name) + else: + if ob.shared.accel_in_use(bore_name): + # Back to host + ob.shared.accel_update_host(bore_name) for ob in data.obs: # Get the detectors we are using for this observation @@ -171,11 +214,9 @@ def _exec(self, data, detectors=None, use_accel=None, **kwargs): comm=data.comm.comm_group, ) - # FIXME: handle coordinate transforms here too... - pointing_detector( fp_quats, - ob.shared[self.boresight].data, + ob.shared[bore_name].data, quat_indx, ob.detdata[self.quats].data, ob.intervals[self.view].data, diff --git a/src/toast/ops/sim_satellite.py b/src/toast/ops/sim_satellite.py index 56f8a3f0a..22115f4f0 100644 --- a/src/toast/ops/sim_satellite.py +++ b/src/toast/ops/sim_satellite.py @@ -1,4 +1,4 @@ -# Copyright (c) 2015-2020 by the parties listed in the AUTHORS file. +# Copyright (c) 2015-2023 by the parties listed in the AUTHORS file. # All rights reserved. Use of this source code is governed by # a BSD-style license that can be found in the LICENSE file. @@ -188,6 +188,9 @@ class SimSatellite(Operator): may have some gaps in between for cooler cycles or other events. The precession axis (anti-sun direction) is continuously slewed. + To be consistent with the ground simulation facilities, the satellite pointing + is expressed in the ICRS (equatorial) system by default. Detector pointing + expansion can rotate the output pointing to any other reference frame. """ # Class traits @@ -229,7 +232,8 @@ class SimSatellite(Operator): detset_key = Unicode( None, allow_none=True, - help="If specified, use this column of the focalplane detector_data to group detectors", + help="If specified, use this column of the focalplane " + "detector_data to group detectors", ) times = Unicode(defaults.times, help="Observation shared key for timestamps") @@ -248,6 +252,10 @@ class SimSatellite(Operator): defaults.boresight_radec, help="Observation shared key for boresight" ) + coord = Unicode( + "C", help="Coordinate system to use for pointing. One of ('C', 'E', 'G')" + ) + position = Unicode(defaults.position, help="Observation shared key for position") velocity = Unicode(defaults.velocity, help="Observation shared key for velocity") @@ -268,6 +276,14 @@ class SimSatellite(Operator): help="Observation detdata key for flags to initialize", ) + @traitlets.validate("coord") + def _check_coord(self, proposal): + check = proposal["value"] + if check is not None: + if check not in ["E", "C", "G"]: + raise traitlets.TraitError("coordinate system must be 'E', 'C', or 'G'") + return check + @traitlets.validate("telescope") def _check_telescope(self, proposal): tele = proposal["value"] @@ -336,9 +352,23 @@ def _check_hwp_step(self, proposal): def __init__(self, **kwargs): super().__init__(**kwargs) + def _get_coord_rot(self): + """ Get an optional coordinate rotation quaternion to return satellite + pointing and velocity in the user-specified frame + """ + if self.coord == "C": + coord_rot = None + elif self.coord == "E": + coord_rot = qa.equ2ecl() + elif self.coord == "G": + coord_rot = qa.equ2gal() + return coord_rot + + @function_timer def _exec(self, data, detectors=None, **kwargs): zaxis = np.array([0, 0, 1], dtype=np.float64) + coord_rot = self._get_coord_rot() log = Logger.get() if self.telescope is None: raise RuntimeError( @@ -395,7 +425,8 @@ def _exec(self, data, detectors=None, **kwargs): n_detset = len(detsets) if det_ranks > n_detset: if comm.group_rank == 0: - msg = f"Group {comm.group} has {comm.group_size} processes but {n_detset} detector sets." + msg = f"Group {comm.group} has {comm.group_size} " + msg += f"processes but {n_detset} detector sets." log.error(msg) raise RuntimeError(msg) @@ -507,6 +538,10 @@ def _exec(self, data, detectors=None, **kwargs): # Get the motion of the site for these times. position, velocity = site.position_velocity(stamps) + if coord_rot is not None: + # `site` always returns ICRS (celestial) position + position = qa.rotate(coord_rot, position) + velocity = qa.rotate(coord_rot, velocity) # Get the quaternions for the precession axis. For now, assume that # it simply points away from the solar system barycenter diff --git a/src/toast/qarray.py b/src/toast/qarray.py index e1d4ffcbb..b401d7c00 100644 --- a/src/toast/qarray.py +++ b/src/toast/qarray.py @@ -723,7 +723,7 @@ def equ2ecl(): ] ).reshape([3, 3]) if _equ2ecl is None: - _equ2ecl = from_rotmat(coordmat_J2000radec2ecl) + _equ2ecl = from_rotmat(_coordmat_J2000radec2ecl) return _equ2ecl diff --git a/src/toast/tests/template_periodic.py b/src/toast/tests/template_periodic.py index 93c6b5f30..acad6fa13 100644 --- a/src/toast/tests/template_periodic.py +++ b/src/toast/tests/template_periodic.py @@ -505,7 +505,7 @@ def test_satellite_hwp(self): write_map=True, write_cov=False, write_rcond=False, - keep_solver_products=False, + keep_solver_products=True, keep_final_products=False, save_cleaned=True, overwrite_cleaned=True, diff --git a/workflows/toast_sim_satellite.py b/workflows/toast_sim_satellite.py index 63036d832..59119754e 100644 --- a/workflows/toast_sim_satellite.py +++ b/workflows/toast_sim_satellite.py @@ -26,6 +26,7 @@ """ import argparse +import datetime import os import sys import traceback @@ -65,6 +66,20 @@ def parse_config(operators, templates, comm): help="The output directory", ) + parser.add_argument( + "--sample_rate", + required=False, + type=float, + help="Override focalplane sampling rate [Hz]", + ) + + parser.add_argument( + "--thinfp", + required=False, + type=int, + help="Only sample the provided focalplane pixels", + ) + # Build a config dictionary starting from the operator defaults, overriding with any # config files specified with the '--config' commandline option, followed by any # individually specified parameter overrides. @@ -91,13 +106,33 @@ def load_instrument_and_schedule(args, comm): # Load a generic focalplane file. NOTE: again, this is just using the # built-in Focalplane class. In a workflow for a specific experiment we would # have a custom class. - focalplane = toast.instrument.Focalplane() + log = toast.utils.Logger.get() + timer = toast.timing.Timer() + timer.start() + + if args.sample_rate is not None: + sample_rate = args.sample_rate * u.Hz + else: + sample_rate = None + focalplane = toast.instrument.Focalplane( + sample_rate=sample_rate, + thinfp=args.thinfp, + ) + with toast.io.H5File(args.focalplane, "r", comm=comm, force_serial=True) as f: focalplane.load_hdf5(f.handle, comm=comm) + log.info_rank("Loaded focalplane in", comm=comm, timer=timer) + log.info_rank(f"Focalplane: {str(focalplane)}", comm=comm) + mem = toast.utils.memreport(msg="(whole node)", comm=comm, silent=True) + log.info_rank(f"After loading focalplane: {mem}", comm) # Load the schedule file schedule = toast.schedule.SatelliteSchedule() schedule.read(args.schedule, comm=comm) + log.info_rank("Loaded schedule in", comm=comm, timer=timer) + log.info_rank(f"Schedule: {str(schedule)}", comm=comm) + mem = toast.utils.memreport(msg="(whole node)", comm=comm, silent=True) + log.info_rank(f"After loading focalplane: {mem}", comm) # Create a telescope for the simulation. Again, for a specific experiment we # would use custom classes for the site. @@ -211,12 +246,18 @@ def simulate_data(job, toast_comm, telescope, schedule): ops.sim_noise.apply(data) log.info_rank("Simulated detector noise in", comm=world_comm, timer=timer) + mem = toast.utils.memreport(msg="(whole node)", comm=world_comm, silent=True) + log.info_rank(f"After simulating data: {mem}", world_comm) + # Optionally write out the data if ops.save_hdf5.volume is None: ops.save_hdf5.volume = os.path.join(args.out_dir, "data") ops.save_hdf5.apply(data) log.info_rank("Saved HDF5 data in", comm=world_comm, timer=timer) + mem = toast.utils.memreport(msg="(whole node)", comm=world_comm, silent=True) + log.info_rank(f"After saving data: {mem}", world_comm) + return data @@ -269,10 +310,22 @@ def main(): log = toast.utils.Logger.get() gt = toast.timing.GlobalTimers.get() gt.start("toast_satellite_sim (total)") + timer0 = toast.timing.Timer() + timer0.start() # Get optional MPI parameters comm, procs, rank = toast.get_world() + if "OMP_NUM_THREADS" in os.environ: + nthread = os.environ["OMP_NUM_THREADS"] + else: + nthread = "unknown number of" + log.info_rank( + f"Executing workflow with {procs} MPI tasks, each with " + f"{nthread} OpenMP threads at {datetime.datetime.now()}", + comm, + ) + # The operators we want to configure from the command line or a parameter file. # We will use other operators, but these are the ones that the user can configure. # The "name" of each operator instance controls what the commandline and config @@ -335,6 +388,8 @@ def main(): out = os.path.join(args.out_dir, "timing") toast.timing.dump(alltimers, out) + log.info_rank("Workflow completed in", comm=comm, timer=timer0) + if __name__ == "__main__": world, procs, rank = toast.mpi.get_world()