diff --git a/mesmerize_core/algorithms/_utils.py b/mesmerize_core/algorithms/_utils.py new file mode 100644 index 0000000..e057487 --- /dev/null +++ b/mesmerize_core/algorithms/_utils.py @@ -0,0 +1,70 @@ +from pathlib import Path +import psutil +from typing import Optional, Union + +import caiman as cm +from ipyparallel import DirectView +from multiprocessing.pool import Pool +import numpy as np + +Cluster = Union[Pool, DirectView] + +def get_n_processes(dview: Optional[Cluster]) -> int: + """Infer number of processes in a multiprocessing or ipyparallel cluster""" + if isinstance(dview, Pool) and hasattr(dview, '_processes'): + return dview._processes + elif isinstance(dview, DirectView): + return len(dview) + else: + return 1 + + +def estimate_n_pixels_per_process(n_processes: int, T: int, dims: tuple[int, ...]) -> int: + """ + Estimate a safe number of pixels to allocate to each parallel process at a time + Taken from CNMF.fit (TODO factor this out in caiman and just import it) + """ + avail_memory_per_process = psutil.virtual_memory()[ + 1] / 2.**30 / n_processes + mem_per_pix = 3.6977678498329843e-09 + npx_per_proc = int(avail_memory_per_process / 8. / mem_per_pix / T) + npx_per_proc = int(np.minimum(npx_per_proc, np.prod(dims) // n_processes)) + return npx_per_proc + + +def make_chunk_projection(Yr_chunk: np.ndarray, proj_type: str): + return getattr(np, proj_type)(Yr_chunk, axis=1) + +def make_chunk_projection_helper(args: tuple[str, slice, str]): + Yr_name, chunk_slice, proj_type = args + Yr, _, _ = cm.load_memmap(Yr_name) + return make_chunk_projection(Yr[chunk_slice], proj_type) + + +def make_projection_parallel(movie_path: str, proj_type: str, dview: Optional[Cluster]) -> np.ndarray: + Yr, dims, T = cm.load_memmap(movie_path) + if dview is None: + p_img_flat = make_chunk_projection(Yr, proj_type) + else: + # use n_pixels_per_process from CNMF to avoid running out of memory + n_pix = Yr.shape[0] + chunk_size = estimate_n_pixels_per_process(get_n_processes(dview), T, dims) + chunk_starts = range(0, n_pix, chunk_size) + chunk_slices = [slice(start, min(start + chunk_size, n_pix)) for start in chunk_starts] + args = [(movie_path, chunk_slice, proj_type) for chunk_slice in chunk_slices] + map_fn = dview.map if isinstance(dview, Pool) else dview.map_sync + chunk_projs = map_fn(make_chunk_projection_helper, args) + p_img_flat = np.concatenate(chunk_projs, axis=0) + return np.reshape(p_img_flat, dims, order='F') + + +def save_projections_parallel(uuid, movie_path: Union[str, Path], output_dir: Path, dview: Optional[Cluster] + ) -> dict[str, Path]: + proj_paths = dict() + for proj_type in ["mean", "std", "max"]: + p_img = make_projection_parallel(str(movie_path), "nan" + proj_type, dview=dview) + proj_paths[proj_type] = output_dir.joinpath( + f"{uuid}_{proj_type}_projection.npy" + ) + np.save(str(proj_paths[proj_type]), p_img) + return proj_paths diff --git a/mesmerize_core/algorithms/cnmf.py b/mesmerize_core/algorithms/cnmf.py index ca8f599..bb550b7 100644 --- a/mesmerize_core/algorithms/cnmf.py +++ b/mesmerize_core/algorithms/cnmf.py @@ -15,9 +15,11 @@ if __name__ in ["__main__", "__mp_main__"]: # when running in subprocess from mesmerize_core import set_parent_raw_data_path, load_batch from mesmerize_core.utils import IS_WINDOWS + from mesmerize_core.algorithms._utils import save_projections_parallel else: # when running with local backend from ..batch_utils import set_parent_raw_data_path, load_batch from ..utils import IS_WINDOWS + from ._utils import save_projections_parallel def run_algo(batch_path, uuid, data_path: str = None): @@ -67,13 +69,9 @@ def run_algo(batch_path, uuid, data_path: str = None): Yr, dims, T = cm.load_memmap(fname_new) images = np.reshape(Yr.T, [T] + list(dims), order="F") - proj_paths = dict() - for proj_type in ["mean", "std", "max"]: - p_img = getattr(np, f"nan{proj_type}")(images, axis=0) - proj_paths[proj_type] = output_dir.joinpath( - f"{uuid}_{proj_type}_projection.npy" - ) - np.save(str(proj_paths[proj_type]), p_img) + proj_paths = save_projections_parallel( + uuid=uuid, movie_path=fname_new, output_dir=output_dir, dview=dview + ) # in fname new load in memmap order C cm.stop_server(dview=dview) diff --git a/mesmerize_core/algorithms/cnmfe.py b/mesmerize_core/algorithms/cnmfe.py index e053869..e168c5d 100644 --- a/mesmerize_core/algorithms/cnmfe.py +++ b/mesmerize_core/algorithms/cnmfe.py @@ -13,9 +13,11 @@ if __name__ in ["__main__", "__mp_main__"]: # when running in subprocess from mesmerize_core import set_parent_raw_data_path, load_batch from mesmerize_core.utils import IS_WINDOWS + from mesmerize_core.algorithms._utils import save_projections_parallel else: # when running with local backend from ..batch_utils import set_parent_raw_data_path, load_batch from ..utils import IS_WINDOWS + from ._utils import save_projections_parallel def run_algo(batch_path, uuid, data_path: str = None): @@ -59,13 +61,9 @@ def run_algo(batch_path, uuid, data_path: str = None): # TODO: if projections already exist from mcorr we don't # need to waste compute time re-computing them here - proj_paths = dict() - for proj_type in ["mean", "std", "max"]: - p_img = getattr(np, f"nan{proj_type}")(images, axis=0) - proj_paths[proj_type] = output_dir.joinpath( - f"{uuid}_{proj_type}_projection.npy" - ) - np.save(str(proj_paths[proj_type]), p_img) + proj_paths = save_projections_parallel( + uuid=uuid, movie_path=fname_new, output_dir=output_dir, dview=dview + ) d = dict() # for output diff --git a/mesmerize_core/algorithms/mcorr.py b/mesmerize_core/algorithms/mcorr.py index 3bac29e..297ea34 100644 --- a/mesmerize_core/algorithms/mcorr.py +++ b/mesmerize_core/algorithms/mcorr.py @@ -14,8 +14,10 @@ # prevent circular import if __name__ in ["__main__", "__mp_main__"]: # when running in subprocess from mesmerize_core import set_parent_raw_data_path, load_batch + from mesmerize_core.algorithms._utils import save_projections_parallel else: # when running with local backend from ..batch_utils import set_parent_raw_data_path, load_batch + from ._utils import save_projections_parallel def run_algo(batch_path, uuid, data_path: str = None): @@ -77,16 +79,9 @@ def run_algo(batch_path, uuid, data_path: str = None): print("mc finished successfully!") print("computing projections") - Yr, dims, T = cm.load_memmap(str(mcorr_memmap_path)) - images = np.reshape(Yr.T, [T] + list(dims), order="F") - - proj_paths = dict() - for proj_type in ["mean", "std", "max"]: - p_img = getattr(np, f"nan{proj_type}")(images, axis=0) - proj_paths[proj_type] = output_dir.joinpath( - f"{uuid}_{proj_type}_projection.npy" - ) - np.save(str(proj_paths[proj_type]), p_img) + proj_paths = save_projections_parallel( + uuid=uuid, movie_path=mcorr_memmap_path, output_dir=output_dir, dview=dview + ) print("Computing correlation image") Cns = local_correlations_movie_offline(