From bee1e5f2659cf7a4dc3f9cc2e0ac9954178fc371 Mon Sep 17 00:00:00 2001 From: Avinash Anand <36325275+anand-avinash@users.noreply.github.com> Date: Wed, 27 Nov 2024 17:37:48 +0900 Subject: [PATCH] renamed comm_grid to MPI_COMM_GRID --- litebird_sim/__init__.py | 4 ++-- litebird_sim/mapmaking/destriper.py | 32 ++++++++++++++--------------- litebird_sim/mpi.py | 6 +++--- litebird_sim/observations.py | 4 ++-- litebird_sim/simulations.py | 14 +++++++------ 5 files changed, 31 insertions(+), 29 deletions(-) diff --git a/litebird_sim/__init__.py b/litebird_sim/__init__.py index ab720c05..44285c14 100644 --- a/litebird_sim/__init__.py +++ b/litebird_sim/__init__.py @@ -72,7 +72,7 @@ ) from .madam import save_simulation_for_madam from .mbs.mbs import Mbs, MbsParameters, MbsSavedMapInfo -from .mpi import MPI_COMM_WORLD, MPI_ENABLED, MPI_CONFIGURATION, comm_grid +from .mpi import MPI_COMM_WORLD, MPI_ENABLED, MPI_CONFIGURATION, MPI_COMM_GRID from .noise import ( add_white_noise, add_one_over_f_noise, @@ -218,7 +218,7 @@ def destripe_with_toast2(*args, **kwargs): "MPI_COMM_WORLD", "MPI_ENABLED", "MPI_CONFIGURATION", - "comm_grid", + "MPI_COMM_GRID", # observations.py "Observation", "TodDescription", diff --git a/litebird_sim/mapmaking/destriper.py b/litebird_sim/mapmaking/destriper.py index 16e023da..71e323af 100644 --- a/litebird_sim/mapmaking/destriper.py +++ b/litebird_sim/mapmaking/destriper.py @@ -20,7 +20,7 @@ from numba import njit, prange import healpy as hp -from litebird_sim.mpi import MPI_ENABLED, MPI_COMM_WORLD, comm_grid +from litebird_sim.mpi import MPI_ENABLED, MPI_COMM_WORLD, MPI_COMM_GRID from typing import Callable, Union, List, Optional, Tuple, Any, Dict from litebird_sim.hwp import HWP from litebird_sim.observations import Observation @@ -44,7 +44,7 @@ __DESTRIPER_RESULTS_FILE_NAME = "destriper_results.fits" -__BASELINES_FILE_NAME = f"baselines_mpi{comm_grid.COMM_OBS_GRID.rank:04d}.fits" +__BASELINES_FILE_NAME = f"baselines_mpi{MPI_COMM_GRID.COMM_OBS_GRID.rank:04d}.fits" def _split_items_into_n_segments(n: int, num_of_segments: int) -> List[int]: @@ -498,8 +498,8 @@ def _build_nobs_matrix( ) # Now we must accumulate the result of every MPI process - if MPI_ENABLED and comm_grid.COMM_OBS_GRID != comm_grid.COMM_NULL: - comm_grid.COMM_OBS_GRID.Allreduce( + if MPI_ENABLED and MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: + MPI_COMM_GRID.COMM_OBS_GRID.Allreduce( mpi4py.MPI.IN_PLACE, nobs_matrix, op=mpi4py.MPI.SUM ) @@ -748,10 +748,10 @@ def _compute_binned_map( ) if MPI_ENABLED: - comm_grid.COMM_OBS_GRID.Allreduce( + MPI_COMM_GRID.COMM_OBS_GRID.Allreduce( mpi4py.MPI.IN_PLACE, output_sky_map, op=mpi4py.MPI.SUM ) - comm_grid.COMM_OBS_GRID.Allreduce( + MPI_COMM_GRID.COMM_OBS_GRID.Allreduce( mpi4py.MPI.IN_PLACE, output_hit_map, op=mpi4py.MPI.SUM ) @@ -993,7 +993,7 @@ def _mpi_dot(a: List[npt.ArrayLike], b: List[npt.ArrayLike]) -> float: # the dot product local_result = sum([np.dot(x1.flatten(), x2.flatten()) for (x1, x2) in zip(a, b)]) if MPI_ENABLED: - return comm_grid.COMM_OBS_GRID.allreduce(local_result, op=mpi4py.MPI.SUM) + return MPI_COMM_GRID.COMM_OBS_GRID.allreduce(local_result, op=mpi4py.MPI.SUM) else: return local_result @@ -1010,7 +1010,7 @@ def _get_stopping_factor(residual: List[npt.ArrayLike]) -> float: """ local_result = np.max(np.abs(residual)) if MPI_ENABLED: - return comm_grid.COMM_OBS_GRID.allreduce(local_result, op=mpi4py.MPI.MAX) + return MPI_COMM_GRID.COMM_OBS_GRID.allreduce(local_result, op=mpi4py.MPI.MAX) else: return local_result @@ -1424,7 +1424,7 @@ def _run_destriper( bytes_in_temporary_buffers += mask.nbytes if MPI_ENABLED: - bytes_in_temporary_buffers = comm_grid.COMM_OBS_GRID.allreduce( + bytes_in_temporary_buffers = MPI_COMM_GRID.COMM_OBS_GRID.allreduce( bytes_in_temporary_buffers, op=mpi4py.MPI.SUM, ) @@ -1619,9 +1619,9 @@ def my_gui_callback( binned_map = np.empty((3, number_of_pixels)) hit_map = np.empty(number_of_pixels) - if comm_grid.COMM_OBS_GRID != comm_grid.COMM_NULL: + if MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: # perform the following operations when MPI is not being used - # OR when the comm_grid.COMM_OBS_GRID is not a NULL communicator + # OR when the MPI_COMM_GRID.COMM_OBS_GRID is not a NULL communicator if do_destriping: try: # This will fail if the parameter is a scalar @@ -1686,7 +1686,7 @@ def my_gui_callback( ) if MPI_ENABLED: - bytes_in_temporary_buffers = comm_grid.COMM_OBS_GRID.allreduce( + bytes_in_temporary_buffers = MPI_COMM_GRID.COMM_OBS_GRID.allreduce( bytes_in_temporary_buffers, op=mpi4py.MPI.SUM, ) @@ -2014,11 +2014,11 @@ def _save_baselines(results: DestriperResult, output_file: Path) -> None: primary_hdu = fits.PrimaryHDU() primary_hdu.header["MPIRANK"] = ( - comm_grid.COMM_OBS_GRID.rank, + MPI_COMM_GRID.COMM_OBS_GRID.rank, "The rank of the MPI process that wrote this file", ) primary_hdu.header["MPISIZE"] = ( - comm_grid.COMM_OBS_GRID.size, + MPI_COMM_GRID.COMM_OBS_GRID.size, "The number of MPI processes used in the computation", ) @@ -2234,11 +2234,11 @@ def load_destriper_results( baselines_file_name = folder / __BASELINES_FILE_NAME with fits.open(baselines_file_name) as inpf: - assert comm_grid.COMM_OBS_GRID.rank == inpf[0].header["MPIRANK"], ( + assert MPI_COMM_GRID.COMM_OBS_GRID.rank == inpf[0].header["MPIRANK"], ( "You must call load_destriper_results using the " "same MPI layout that was used for save_destriper_results " ) - assert comm_grid.COMM_OBS_GRID.size == inpf[0].header["MPISIZE"], ( + assert MPI_COMM_GRID.COMM_OBS_GRID.size == inpf[0].header["MPISIZE"], ( "You must call load_destriper_results using the " "same MPI layout that was used for save_destriper_results" ) diff --git a/litebird_sim/mpi.py b/litebird_sim/mpi.py index d6c31c4d..80bb2dbc 100644 --- a/litebird_sim/mpi.py +++ b/litebird_sim/mpi.py @@ -61,7 +61,7 @@ def _set_null_comm(self, comm_null): #: that defines the member variables `rank = 0` and `size = 1`. MPI_COMM_WORLD = _SerialMpiCommunicator() -comm_grid = _GridCommClass() +MPI_COMM_GRID = _GridCommClass() #: `True` if MPI should be used by the application. The value of this #: variable is set according to the following rules: @@ -90,8 +90,8 @@ def _set_null_comm(self, comm_null): from mpi4py import MPI MPI_COMM_WORLD = MPI.COMM_WORLD - comm_grid._set_comm_obs_grid(comm_obs_grid=MPI.COMM_WORLD) - comm_grid._set_null_comm(comm_null=MPI.COMM_NULL) + MPI_COMM_GRID._set_comm_obs_grid(comm_obs_grid=MPI.COMM_WORLD) + MPI_COMM_GRID._set_null_comm(comm_null=MPI.COMM_NULL) MPI_ENABLED = True MPI_CONFIGURATION = mpi4py.get_config() except ImportError: diff --git a/litebird_sim/observations.py b/litebird_sim/observations.py index 38582ef1..fa5ab634 100644 --- a/litebird_sim/observations.py +++ b/litebird_sim/observations.py @@ -13,7 +13,7 @@ from .coordinates import DEFAULT_TIME_SCALE from .distribute import distribute_evenly, distribute_detector_blocks from .detectors import DetectorInfo -from .mpi import comm_grid +from .mpi import MPI_COMM_GRID @dataclass @@ -175,7 +175,7 @@ def __init__( matrix_color = MPI.UNDEFINED comm_obs_grid = comm.Split(matrix_color, comm.rank) - comm_grid._set_comm_obs_grid(comm_obs_grid=comm_obs_grid) + MPI_COMM_GRID._set_comm_obs_grid(comm_obs_grid=comm_obs_grid) self.tod_list = tods for cur_tod in self.tod_list: diff --git a/litebird_sim/simulations.py b/litebird_sim/simulations.py index 44148fe0..51dde0cf 100644 --- a/litebird_sim/simulations.py +++ b/litebird_sim/simulations.py @@ -44,7 +44,7 @@ DestriperResult, destriper_log_callback, ) -from .mpi import MPI_ENABLED, MPI_COMM_WORLD, comm_grid +from .mpi import MPI_ENABLED, MPI_COMM_WORLD, MPI_COMM_GRID from .noise import add_noise_to_observations from .observations import Observation, TodDescription from .pointings_in_obs import prepare_pointings, precompute_pointings @@ -1221,8 +1221,8 @@ def set_scanning_strategy( num_of_obs = len(self.observations) if append_to_report and MPI_ENABLED: - if comm_grid.COMM_OBS_GRID != comm_grid.COMM_NULL: - num_of_obs = comm_grid.COMM_OBS_GRID.allreduce(num_of_obs) + if MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: + num_of_obs = MPI_COMM_GRID.COMM_OBS_GRID.allreduce(num_of_obs) if append_to_report and MPI_COMM_WORLD.rank == 0: template_file_path = get_template_file_path("report_quaternions.md") @@ -1319,9 +1319,11 @@ def prepare_pointings( memory_occupation = pointing_provider.bore2ecliptic_quats.quats.nbytes num_of_obs = len(self.observations) if append_to_report and MPI_ENABLED: - if comm_grid.COMM_OBS_GRID != comm_grid.COMM_NULL: - memory_occupation = comm_grid.COMM_OBS_GRID.allreduce(memory_occupation) - num_of_obs = comm_grid.COMM_OBS_GRID.allreduce(num_of_obs) + if MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL: + memory_occupation = MPI_COMM_GRID.COMM_OBS_GRID.allreduce( + memory_occupation + ) + num_of_obs = MPI_COMM_GRID.COMM_OBS_GRID.allreduce(num_of_obs) if append_to_report and MPI_COMM_WORLD.rank == 0: template_file_path = get_template_file_path("report_pointings.md")