diff --git a/py/rvspecfit/desi/desi_fit.py b/py/rvspecfit/desi/desi_fit.py index 392738a..7977932 100644 --- a/py/rvspecfit/desi/desi_fit.py +++ b/py/rvspecfit/desi/desi_fit.py @@ -1357,6 +1357,11 @@ def submit(self, f, *args, **kw): return FakeFuture(f(*args, **kw)) +def get_mpi_rank(): + from mpi4py import MPI + return MPI.COMM_WORLD.Get_rank() + + def proc_many(files, output_dir, output_tab_prefix, @@ -1662,6 +1667,11 @@ def main(args): action='store_true', default=False) + parser.add_argument('--mpi', + help='Use MPI for communication', + action='store_true', + default=False) + parser.add_argument( '--no_ccf_continuum_normalize', help='Do not normalize by the continuum when doing CCF', @@ -1694,11 +1704,29 @@ def main(args): log_level = args.log_level - if args.log is not None: - logging.basicConfig(filename=args.log, level=log_level) + log_filename = args.log + process_status_file = args.process_status_file + if log_filename is not None: + if args.mpi: + if re.match('.*%.*', log_filename) is None: + raise RuntimeError( + 'If using MPI and log files you need to ' + 'allow for the MPI rank to placed in the filename using %d' + ) + log_filename = log_filename % (get_mpi_rank()) + logging.basicConfig(filename=log_filename, level=log_level) else: logging.basicConfig(level=log_level) + if process_status_file is not None: + if args.mpi: + if re.match('.*%.*', process_status_file) is None: + raise RuntimeError( + 'If using MPI and log files you need to ' + 'allow for the MPI rank to placed in the filename using %d' + ) + process_status_file = process_status_file % (get_mpi_rank()) + input_files = args.input_files input_file_from = args.input_file_from output_dir, output_tab_prefix, output_mod_prefix = (args.output_dir, @@ -1735,19 +1763,6 @@ def main(args): 'Unknown param_init value; only known ones are CCF and bruteforce') ccf_continuum_normalize = args.ccf_continuum_normalize - if input_files != [] and input_file_from is not None: - raise RuntimeError( - '''You can only specify --input_files OR --input_file_from options - but not both of them simulatenously''') - elif input_file_from is None and input_files == []: - parser.print_help() - raise RuntimeError('You need to specify the spectra you want to fit') - if input_files == []: - input_files = None - files = FileQueue(file_list=input_files, - file_from=input_file_from, - queue=queue_file) - fit_targetid = None if targetid_file_from is not None and targetid is not None: raise RuntimeError( @@ -1770,33 +1785,65 @@ def main(args): if args.overwrite is not None: logging.warning('overwrite keyword is meaningless now') - proc_many( - files, - output_dir, - output_tab_prefix, - output_mod_prefix, - figure_dir=figure_dir, - figure_prefix=args.figure_prefix, - nthreads=nthreads, - config_fname=config_fname, - fit_targetid=fit_targetid, - objtypes=objtypes, - doplot=doplot, - subdirs=args.subdirs, - minsn=minsn, - process_status_file=args.process_status_file, - expid_range=(minexpid, maxexpid), - skipexisting=args.skipexisting, - fitarm=fitarm, - cmdline=cmdline, - zbest_select=zbest_select, - zbest_include=zbest_include, - ccf_continuum_normalize=ccf_continuum_normalize, - use_resolution_matrix=args.resolution_matrix, - ccf_init=ccf_init, - npoly=npoly, - throw_exceptions=args.throw_exceptions, - ) + # Dealing with input + if args.mpi: + rank = get_mpi_rank() + else: + rank = 0 + if (not args.mpi) or (args.mpi and rank == 0): + # if root process in mpi mode or + # or anything non-mpi + if input_files != [] and input_file_from is not None: + raise RuntimeError('You can only specify --input_files OR ' + '--input_file_from options ' + 'but not both of them simultaneously') + elif input_file_from is None and input_files == []: + parser.print_help() + raise RuntimeError( + 'You need to specify the spectra you want to fit') + if input_files == []: + input_files = None + else: + # only in the case of mpi and rank>=1 + input_files = None + if not args.mpi: + files = FileQueue(file_list=input_files, + file_from=input_file_from, + queue=queue_file) + else: + files = utils.MPIFileQueue(file_list=input_files) + + if (not args.mpi) or (args.mpi and rank != 0): + # anything but mpi and rank=0 should go here + proc_many( + files, + output_dir, + output_tab_prefix, + output_mod_prefix, + figure_dir=figure_dir, + figure_prefix=args.figure_prefix, + nthreads=nthreads, + config_fname=config_fname, + fit_targetid=fit_targetid, + objtypes=objtypes, + doplot=doplot, + subdirs=args.subdirs, + minsn=minsn, + process_status_file=process_status_file, + expid_range=(minexpid, maxexpid), + skipexisting=args.skipexisting, + fitarm=fitarm, + cmdline=cmdline, + zbest_select=zbest_select, + zbest_include=zbest_include, + ccf_continuum_normalize=ccf_continuum_normalize, + use_resolution_matrix=args.resolution_matrix, + ccf_init=ccf_init, + npoly=npoly, + throw_exceptions=args.throw_exceptions, + ) + else: + files.distribute_files() if __name__ == '__main__': diff --git a/py/rvspecfit/utils.py b/py/rvspecfit/utils.py index 94fcf92..ba348e1 100644 --- a/py/rvspecfit/utils.py +++ b/py/rvspecfit/utils.py @@ -105,3 +105,71 @@ def freezeDict(d): return tuple(d) else: return d + + +class MPIFileQueue: + + def __init__(self, file_list=None): + from mpi4py import MPI + self.MPI = MPI + self.comm = MPI.COMM_WORLD + self.rank = self.comm.Get_rank() + self.size = self.comm.Get_size() + self.file_list = file_list if self.rank == 0 else None + self.SEND_STATE = 1 + self.STOP_STATE = 2 + self.timeout = 3600 + self.REQUEST_CMD = 'file' + + def distribute_files(self): + if self.rank != 0: + # noop + return + # Only rank 0 manages the file distribution + index = 0 + num_files = len(self.file_list) + mode = self.SEND_STATE + # We first iterate over file_list + # then over .size to send stop messages + + while True: + # Receive request for next file + status = self.MPI.Status() + + self.comm.probe(source=self.MPI.ANY_SOURCE, + tag=self.MPI.ANY_TAG, + status=status) + request = self.comm.recv(source=status.source, + tag=self.MPI.ANY_TAG) + if request == self.REQUEST_CMD and mode == self.SEND_STATE: + if index < num_files: + self.comm.send(self.file_list[index], dest=status.source) + index += 1 + if index == num_files: + # we sent out the last file + # now we plan to send the termination command to + # every rank > 0 + index = 1 + mode = self.STOP_STATE + elif request == self.REQUEST_CMD and mode == self.STOP_STATE: + self.comm.send(None, + dest=status.source) # Send a termination signal + index += 1 + if index == self.size: + break + else: + raise RuntimeError('Unsupported message') + + def __next__(self): + if self.rank == 0: + # rank 0 does not work. he is the boss + raise StopIteration + # Other ranks request and receive files + self.comm.send(self.REQUEST_CMD, dest=0) + file_name = self.comm.recv(source=0, tag=self.MPI.ANY_TAG) + if file_name is None: + raise StopIteration # No more files, terminate + return file_name + + def __iter__(self): + return self