Skip to content

Commit

Permalink
implement MPI queueing
Browse files Browse the repository at this point in the history
  • Loading branch information
segasai committed Nov 19, 2024
1 parent 0bf7759 commit c5d9e32
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 42 deletions.
131 changes: 89 additions & 42 deletions py/rvspecfit/desi/desi_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1357,6 +1357,11 @@ def submit(self, f, *args, **kw):
return FakeFuture(f(*args, **kw))


def get_mpi_rank():
from mpi4py import MPI
return MPI.COMM_WORLD.Get_rank()


def proc_many(files,
output_dir,
output_tab_prefix,
Expand Down Expand Up @@ -1662,6 +1667,11 @@ def main(args):
action='store_true',
default=False)

parser.add_argument('--mpi',
help='Use MPI for communication',
action='store_true',
default=False)

parser.add_argument(
'--no_ccf_continuum_normalize',
help='Do not normalize by the continuum when doing CCF',
Expand Down Expand Up @@ -1694,11 +1704,29 @@ def main(args):

log_level = args.log_level

if args.log is not None:
logging.basicConfig(filename=args.log, level=log_level)
log_filename = args.log
process_status_file = args.process_status_file
if log_filename is not None:
if args.mpi:
if re.match('.*%.*', log_filename) is None:
raise RuntimeError(
'If using MPI and log files you need to '
'allow for the MPI rank to placed in the filename using %d'
)
log_filename = log_filename % (get_mpi_rank())
logging.basicConfig(filename=log_filename, level=log_level)
else:
logging.basicConfig(level=log_level)

if process_status_file is not None:
if args.mpi:
if re.match('.*%.*', process_status_file) is None:
raise RuntimeError(
'If using MPI and log files you need to '
'allow for the MPI rank to placed in the filename using %d'
)
process_status_file = process_status_file % (get_mpi_rank())

input_files = args.input_files
input_file_from = args.input_file_from
output_dir, output_tab_prefix, output_mod_prefix = (args.output_dir,
Expand Down Expand Up @@ -1735,19 +1763,6 @@ def main(args):
'Unknown param_init value; only known ones are CCF and bruteforce')
ccf_continuum_normalize = args.ccf_continuum_normalize

if input_files != [] and input_file_from is not None:
raise RuntimeError(
'''You can only specify --input_files OR --input_file_from options
but not both of them simulatenously''')
elif input_file_from is None and input_files == []:
parser.print_help()
raise RuntimeError('You need to specify the spectra you want to fit')
if input_files == []:
input_files = None
files = FileQueue(file_list=input_files,
file_from=input_file_from,
queue=queue_file)

fit_targetid = None
if targetid_file_from is not None and targetid is not None:
raise RuntimeError(
Expand All @@ -1770,33 +1785,65 @@ def main(args):
if args.overwrite is not None:
logging.warning('overwrite keyword is meaningless now')

proc_many(
files,
output_dir,
output_tab_prefix,
output_mod_prefix,
figure_dir=figure_dir,
figure_prefix=args.figure_prefix,
nthreads=nthreads,
config_fname=config_fname,
fit_targetid=fit_targetid,
objtypes=objtypes,
doplot=doplot,
subdirs=args.subdirs,
minsn=minsn,
process_status_file=args.process_status_file,
expid_range=(minexpid, maxexpid),
skipexisting=args.skipexisting,
fitarm=fitarm,
cmdline=cmdline,
zbest_select=zbest_select,
zbest_include=zbest_include,
ccf_continuum_normalize=ccf_continuum_normalize,
use_resolution_matrix=args.resolution_matrix,
ccf_init=ccf_init,
npoly=npoly,
throw_exceptions=args.throw_exceptions,
)
# Dealing with input
if args.mpi:
rank = get_mpi_rank()
else:
rank = 0
if (not args.mpi) or (args.mpi and rank == 0):
# if root process in mpi mode or
# or anything non-mpi
if input_files != [] and input_file_from is not None:
raise RuntimeError('You can only specify --input_files OR '
'--input_file_from options '
'but not both of them simultaneously')
elif input_file_from is None and input_files == []:
parser.print_help()
raise RuntimeError(
'You need to specify the spectra you want to fit')
if input_files == []:
input_files = None
else:
# only in the case of mpi and rank>=1
input_files = None
if not args.mpi:
files = FileQueue(file_list=input_files,
file_from=input_file_from,
queue=queue_file)
else:
files = utils.MPIFileQueue(file_list=input_files)

if (not args.mpi) or (args.mpi and rank != 0):
# anything but mpi and rank=0 should go here
proc_many(
files,
output_dir,
output_tab_prefix,
output_mod_prefix,
figure_dir=figure_dir,
figure_prefix=args.figure_prefix,
nthreads=nthreads,
config_fname=config_fname,
fit_targetid=fit_targetid,
objtypes=objtypes,
doplot=doplot,
subdirs=args.subdirs,
minsn=minsn,
process_status_file=process_status_file,
expid_range=(minexpid, maxexpid),
skipexisting=args.skipexisting,
fitarm=fitarm,
cmdline=cmdline,
zbest_select=zbest_select,
zbest_include=zbest_include,
ccf_continuum_normalize=ccf_continuum_normalize,
use_resolution_matrix=args.resolution_matrix,
ccf_init=ccf_init,
npoly=npoly,
throw_exceptions=args.throw_exceptions,
)
else:
files.distribute_files()


if __name__ == '__main__':
Expand Down
68 changes: 68 additions & 0 deletions py/rvspecfit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,71 @@ def freezeDict(d):
return tuple(d)
else:
return d


class MPIFileQueue:

def __init__(self, file_list=None):
from mpi4py import MPI
self.MPI = MPI
self.comm = MPI.COMM_WORLD
self.rank = self.comm.Get_rank()
self.size = self.comm.Get_size()
self.file_list = file_list if self.rank == 0 else None
self.SEND_STATE = 1
self.STOP_STATE = 2
self.timeout = 3600
self.REQUEST_CMD = 'file'

def distribute_files(self):
if self.rank != 0:
# noop
return
# Only rank 0 manages the file distribution
index = 0
num_files = len(self.file_list)
mode = self.SEND_STATE
# We first iterate over file_list
# then over .size to send stop messages

while True:
# Receive request for next file
status = self.MPI.Status()

self.comm.probe(source=self.MPI.ANY_SOURCE,
tag=self.MPI.ANY_TAG,
status=status)
request = self.comm.recv(source=status.source,
tag=self.MPI.ANY_TAG)
if request == self.REQUEST_CMD and mode == self.SEND_STATE:
if index < num_files:
self.comm.send(self.file_list[index], dest=status.source)
index += 1
if index == num_files:
# we sent out the last file
# now we plan to send the termination command to
# every rank > 0
index = 1
mode = self.STOP_STATE
elif request == self.REQUEST_CMD and mode == self.STOP_STATE:
self.comm.send(None,
dest=status.source) # Send a termination signal
index += 1
if index == self.size:
break
else:
raise RuntimeError('Unsupported message')

def __next__(self):
if self.rank == 0:
# rank 0 does not work. he is the boss
raise StopIteration
# Other ranks request and receive files
self.comm.send(self.REQUEST_CMD, dest=0)
file_name = self.comm.recv(source=0, tag=self.MPI.ANY_TAG)
if file_name is None:
raise StopIteration # No more files, terminate
return file_name

def __iter__(self):
return self

0 comments on commit c5d9e32

Please sign in to comment.