Skip to content

Commit

Permalink
renamed comm_grid to MPI_COMM_GRID
Browse files Browse the repository at this point in the history
  • Loading branch information
anand-avinash committed Nov 27, 2024
1 parent 719ac8e commit bee1e5f
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 29 deletions.
4 changes: 2 additions & 2 deletions litebird_sim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
)
from .madam import save_simulation_for_madam
from .mbs.mbs import Mbs, MbsParameters, MbsSavedMapInfo
from .mpi import MPI_COMM_WORLD, MPI_ENABLED, MPI_CONFIGURATION, comm_grid
from .mpi import MPI_COMM_WORLD, MPI_ENABLED, MPI_CONFIGURATION, MPI_COMM_GRID
from .noise import (
add_white_noise,
add_one_over_f_noise,
Expand Down Expand Up @@ -218,7 +218,7 @@ def destripe_with_toast2(*args, **kwargs):
"MPI_COMM_WORLD",
"MPI_ENABLED",
"MPI_CONFIGURATION",
"comm_grid",
"MPI_COMM_GRID",
# observations.py
"Observation",
"TodDescription",
Expand Down
32 changes: 16 additions & 16 deletions litebird_sim/mapmaking/destriper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from numba import njit, prange
import healpy as hp

from litebird_sim.mpi import MPI_ENABLED, MPI_COMM_WORLD, comm_grid
from litebird_sim.mpi import MPI_ENABLED, MPI_COMM_WORLD, MPI_COMM_GRID
from typing import Callable, Union, List, Optional, Tuple, Any, Dict
from litebird_sim.hwp import HWP
from litebird_sim.observations import Observation
Expand All @@ -44,7 +44,7 @@


__DESTRIPER_RESULTS_FILE_NAME = "destriper_results.fits"
__BASELINES_FILE_NAME = f"baselines_mpi{comm_grid.COMM_OBS_GRID.rank:04d}.fits"
__BASELINES_FILE_NAME = f"baselines_mpi{MPI_COMM_GRID.COMM_OBS_GRID.rank:04d}.fits"


def _split_items_into_n_segments(n: int, num_of_segments: int) -> List[int]:
Expand Down Expand Up @@ -498,8 +498,8 @@ def _build_nobs_matrix(
)

# Now we must accumulate the result of every MPI process
if MPI_ENABLED and comm_grid.COMM_OBS_GRID != comm_grid.COMM_NULL:
comm_grid.COMM_OBS_GRID.Allreduce(
if MPI_ENABLED and MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL:
MPI_COMM_GRID.COMM_OBS_GRID.Allreduce(
mpi4py.MPI.IN_PLACE, nobs_matrix, op=mpi4py.MPI.SUM
)

Expand Down Expand Up @@ -748,10 +748,10 @@ def _compute_binned_map(
)

if MPI_ENABLED:
comm_grid.COMM_OBS_GRID.Allreduce(
MPI_COMM_GRID.COMM_OBS_GRID.Allreduce(
mpi4py.MPI.IN_PLACE, output_sky_map, op=mpi4py.MPI.SUM
)
comm_grid.COMM_OBS_GRID.Allreduce(
MPI_COMM_GRID.COMM_OBS_GRID.Allreduce(
mpi4py.MPI.IN_PLACE, output_hit_map, op=mpi4py.MPI.SUM
)

Expand Down Expand Up @@ -993,7 +993,7 @@ def _mpi_dot(a: List[npt.ArrayLike], b: List[npt.ArrayLike]) -> float:
# the dot product
local_result = sum([np.dot(x1.flatten(), x2.flatten()) for (x1, x2) in zip(a, b)])
if MPI_ENABLED:
return comm_grid.COMM_OBS_GRID.allreduce(local_result, op=mpi4py.MPI.SUM)
return MPI_COMM_GRID.COMM_OBS_GRID.allreduce(local_result, op=mpi4py.MPI.SUM)
else:
return local_result

Expand All @@ -1010,7 +1010,7 @@ def _get_stopping_factor(residual: List[npt.ArrayLike]) -> float:
"""
local_result = np.max(np.abs(residual))
if MPI_ENABLED:
return comm_grid.COMM_OBS_GRID.allreduce(local_result, op=mpi4py.MPI.MAX)
return MPI_COMM_GRID.COMM_OBS_GRID.allreduce(local_result, op=mpi4py.MPI.MAX)
else:
return local_result

Expand Down Expand Up @@ -1424,7 +1424,7 @@ def _run_destriper(
bytes_in_temporary_buffers += mask.nbytes

if MPI_ENABLED:
bytes_in_temporary_buffers = comm_grid.COMM_OBS_GRID.allreduce(
bytes_in_temporary_buffers = MPI_COMM_GRID.COMM_OBS_GRID.allreduce(
bytes_in_temporary_buffers,
op=mpi4py.MPI.SUM,
)
Expand Down Expand Up @@ -1619,9 +1619,9 @@ def my_gui_callback(
binned_map = np.empty((3, number_of_pixels))
hit_map = np.empty(number_of_pixels)

if comm_grid.COMM_OBS_GRID != comm_grid.COMM_NULL:
if MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL:
# perform the following operations when MPI is not being used
# OR when the comm_grid.COMM_OBS_GRID is not a NULL communicator
# OR when the MPI_COMM_GRID.COMM_OBS_GRID is not a NULL communicator
if do_destriping:
try:
# This will fail if the parameter is a scalar
Expand Down Expand Up @@ -1686,7 +1686,7 @@ def my_gui_callback(
)

if MPI_ENABLED:
bytes_in_temporary_buffers = comm_grid.COMM_OBS_GRID.allreduce(
bytes_in_temporary_buffers = MPI_COMM_GRID.COMM_OBS_GRID.allreduce(
bytes_in_temporary_buffers,
op=mpi4py.MPI.SUM,
)
Expand Down Expand Up @@ -2014,11 +2014,11 @@ def _save_baselines(results: DestriperResult, output_file: Path) -> None:

primary_hdu = fits.PrimaryHDU()
primary_hdu.header["MPIRANK"] = (
comm_grid.COMM_OBS_GRID.rank,
MPI_COMM_GRID.COMM_OBS_GRID.rank,
"The rank of the MPI process that wrote this file",
)
primary_hdu.header["MPISIZE"] = (
comm_grid.COMM_OBS_GRID.size,
MPI_COMM_GRID.COMM_OBS_GRID.size,
"The number of MPI processes used in the computation",
)

Expand Down Expand Up @@ -2234,11 +2234,11 @@ def load_destriper_results(
baselines_file_name = folder / __BASELINES_FILE_NAME

with fits.open(baselines_file_name) as inpf:
assert comm_grid.COMM_OBS_GRID.rank == inpf[0].header["MPIRANK"], (
assert MPI_COMM_GRID.COMM_OBS_GRID.rank == inpf[0].header["MPIRANK"], (
"You must call load_destriper_results using the "
"same MPI layout that was used for save_destriper_results "
)
assert comm_grid.COMM_OBS_GRID.size == inpf[0].header["MPISIZE"], (
assert MPI_COMM_GRID.COMM_OBS_GRID.size == inpf[0].header["MPISIZE"], (
"You must call load_destriper_results using the "
"same MPI layout that was used for save_destriper_results"
)
Expand Down
6 changes: 3 additions & 3 deletions litebird_sim/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _set_null_comm(self, comm_null):
#: that defines the member variables `rank = 0` and `size = 1`.
MPI_COMM_WORLD = _SerialMpiCommunicator()

comm_grid = _GridCommClass()
MPI_COMM_GRID = _GridCommClass()

#: `True` if MPI should be used by the application. The value of this
#: variable is set according to the following rules:
Expand Down Expand Up @@ -90,8 +90,8 @@ def _set_null_comm(self, comm_null):
from mpi4py import MPI

MPI_COMM_WORLD = MPI.COMM_WORLD
comm_grid._set_comm_obs_grid(comm_obs_grid=MPI.COMM_WORLD)
comm_grid._set_null_comm(comm_null=MPI.COMM_NULL)
MPI_COMM_GRID._set_comm_obs_grid(comm_obs_grid=MPI.COMM_WORLD)
MPI_COMM_GRID._set_null_comm(comm_null=MPI.COMM_NULL)
MPI_ENABLED = True
MPI_CONFIGURATION = mpi4py.get_config()
except ImportError:
Expand Down
4 changes: 2 additions & 2 deletions litebird_sim/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .coordinates import DEFAULT_TIME_SCALE
from .distribute import distribute_evenly, distribute_detector_blocks
from .detectors import DetectorInfo
from .mpi import comm_grid
from .mpi import MPI_COMM_GRID


@dataclass
Expand Down Expand Up @@ -175,7 +175,7 @@ def __init__(
matrix_color = MPI.UNDEFINED

comm_obs_grid = comm.Split(matrix_color, comm.rank)
comm_grid._set_comm_obs_grid(comm_obs_grid=comm_obs_grid)
MPI_COMM_GRID._set_comm_obs_grid(comm_obs_grid=comm_obs_grid)

self.tod_list = tods
for cur_tod in self.tod_list:
Expand Down
14 changes: 8 additions & 6 deletions litebird_sim/simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
DestriperResult,
destriper_log_callback,
)
from .mpi import MPI_ENABLED, MPI_COMM_WORLD, comm_grid
from .mpi import MPI_ENABLED, MPI_COMM_WORLD, MPI_COMM_GRID
from .noise import add_noise_to_observations
from .observations import Observation, TodDescription
from .pointings_in_obs import prepare_pointings, precompute_pointings
Expand Down Expand Up @@ -1221,8 +1221,8 @@ def set_scanning_strategy(

num_of_obs = len(self.observations)
if append_to_report and MPI_ENABLED:
if comm_grid.COMM_OBS_GRID != comm_grid.COMM_NULL:
num_of_obs = comm_grid.COMM_OBS_GRID.allreduce(num_of_obs)
if MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL:
num_of_obs = MPI_COMM_GRID.COMM_OBS_GRID.allreduce(num_of_obs)

if append_to_report and MPI_COMM_WORLD.rank == 0:
template_file_path = get_template_file_path("report_quaternions.md")
Expand Down Expand Up @@ -1319,9 +1319,11 @@ def prepare_pointings(
memory_occupation = pointing_provider.bore2ecliptic_quats.quats.nbytes
num_of_obs = len(self.observations)
if append_to_report and MPI_ENABLED:
if comm_grid.COMM_OBS_GRID != comm_grid.COMM_NULL:
memory_occupation = comm_grid.COMM_OBS_GRID.allreduce(memory_occupation)
num_of_obs = comm_grid.COMM_OBS_GRID.allreduce(num_of_obs)
if MPI_COMM_GRID.COMM_OBS_GRID != MPI_COMM_GRID.COMM_NULL:
memory_occupation = MPI_COMM_GRID.COMM_OBS_GRID.allreduce(
memory_occupation
)
num_of_obs = MPI_COMM_GRID.COMM_OBS_GRID.allreduce(num_of_obs)

if append_to_report and MPI_COMM_WORLD.rank == 0:
template_file_path = get_template_file_path("report_pointings.md")
Expand Down

0 comments on commit bee1e5f

Please sign in to comment.