Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions mesmerize_core/algorithms/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from pathlib import Path
import psutil
from typing import Optional, Union

import caiman as cm
from ipyparallel import DirectView
from multiprocessing.pool import Pool
import numpy as np

Cluster = Union[Pool, DirectView]

def get_n_processes(dview: Optional[Cluster]) -> int:
"""Infer number of processes in a multiprocessing or ipyparallel cluster"""
if isinstance(dview, Pool) and hasattr(dview, '_processes'):
return dview._processes
elif isinstance(dview, DirectView):
return len(dview)
else:
return 1


def estimate_n_pixels_per_process(n_processes: int, T: int, dims: tuple[int, ...]) -> int:
"""
Estimate a safe number of pixels to allocate to each parallel process at a time
Taken from CNMF.fit (TODO factor this out in caiman and just import it)
"""
avail_memory_per_process = psutil.virtual_memory()[
1] / 2.**30 / n_processes
mem_per_pix = 3.6977678498329843e-09
npx_per_proc = int(avail_memory_per_process / 8. / mem_per_pix / T)
npx_per_proc = int(np.minimum(npx_per_proc, np.prod(dims) // n_processes))
return npx_per_proc


def make_chunk_projection(Yr_chunk: np.ndarray, proj_type: str):
return getattr(np, proj_type)(Yr_chunk, axis=1)

def make_chunk_projection_helper(args: tuple[str, slice, str]):
Yr_name, chunk_slice, proj_type = args
Yr, _, _ = cm.load_memmap(Yr_name)
return make_chunk_projection(Yr[chunk_slice], proj_type)


def make_projection_parallel(movie_path: str, proj_type: str, dview: Optional[Cluster]) -> np.ndarray:
Yr, dims, T = cm.load_memmap(movie_path)
if dview is None:
p_img_flat = make_chunk_projection(Yr, proj_type)
else:
# use n_pixels_per_process from CNMF to avoid running out of memory
n_pix = Yr.shape[0]
chunk_size = estimate_n_pixels_per_process(get_n_processes(dview), T, dims)
chunk_starts = range(0, n_pix, chunk_size)
chunk_slices = [slice(start, min(start + chunk_size, n_pix)) for start in chunk_starts]
args = [(movie_path, chunk_slice, proj_type) for chunk_slice in chunk_slices]
map_fn = dview.map if isinstance(dview, Pool) else dview.map_sync
chunk_projs = map_fn(make_chunk_projection_helper, args)
p_img_flat = np.concatenate(chunk_projs, axis=0)
return np.reshape(p_img_flat, dims, order='F')


def save_projections_parallel(uuid, movie_path: Union[str, Path], output_dir: Path, dview: Optional[Cluster]
) -> dict[str, Path]:
proj_paths = dict()
for proj_type in ["mean", "std", "max"]:
p_img = make_projection_parallel(str(movie_path), "nan" + proj_type, dview=dview)
proj_paths[proj_type] = output_dir.joinpath(
f"{uuid}_{proj_type}_projection.npy"
)
np.save(str(proj_paths[proj_type]), p_img)
return proj_paths
12 changes: 5 additions & 7 deletions mesmerize_core/algorithms/cnmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
if __name__ in ["__main__", "__mp_main__"]: # when running in subprocess
from mesmerize_core import set_parent_raw_data_path, load_batch
from mesmerize_core.utils import IS_WINDOWS
from mesmerize_core.algorithms._utils import save_projections_parallel
else: # when running with local backend
from ..batch_utils import set_parent_raw_data_path, load_batch
from ..utils import IS_WINDOWS
from ._utils import save_projections_parallel


def run_algo(batch_path, uuid, data_path: str = None):
Expand Down Expand Up @@ -67,13 +69,9 @@ def run_algo(batch_path, uuid, data_path: str = None):
Yr, dims, T = cm.load_memmap(fname_new)
images = np.reshape(Yr.T, [T] + list(dims), order="F")

proj_paths = dict()
for proj_type in ["mean", "std", "max"]:
p_img = getattr(np, f"nan{proj_type}")(images, axis=0)
proj_paths[proj_type] = output_dir.joinpath(
f"{uuid}_{proj_type}_projection.npy"
)
np.save(str(proj_paths[proj_type]), p_img)
proj_paths = save_projections_parallel(
uuid=uuid, movie_path=fname_new, output_dir=output_dir, dview=dview
)

# in fname new load in memmap order C
cm.stop_server(dview=dview)
Expand Down
12 changes: 5 additions & 7 deletions mesmerize_core/algorithms/cnmfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
if __name__ in ["__main__", "__mp_main__"]: # when running in subprocess
from mesmerize_core import set_parent_raw_data_path, load_batch
from mesmerize_core.utils import IS_WINDOWS
from mesmerize_core.algorithms._utils import save_projections_parallel
else: # when running with local backend
from ..batch_utils import set_parent_raw_data_path, load_batch
from ..utils import IS_WINDOWS
from ._utils import save_projections_parallel


def run_algo(batch_path, uuid, data_path: str = None):
Expand Down Expand Up @@ -59,13 +61,9 @@ def run_algo(batch_path, uuid, data_path: str = None):

# TODO: if projections already exist from mcorr we don't
# need to waste compute time re-computing them here
proj_paths = dict()
for proj_type in ["mean", "std", "max"]:
p_img = getattr(np, f"nan{proj_type}")(images, axis=0)
proj_paths[proj_type] = output_dir.joinpath(
f"{uuid}_{proj_type}_projection.npy"
)
np.save(str(proj_paths[proj_type]), p_img)
proj_paths = save_projections_parallel(
uuid=uuid, movie_path=fname_new, output_dir=output_dir, dview=dview
)

d = dict() # for output

Expand Down
15 changes: 5 additions & 10 deletions mesmerize_core/algorithms/mcorr.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
# prevent circular import
if __name__ in ["__main__", "__mp_main__"]: # when running in subprocess
from mesmerize_core import set_parent_raw_data_path, load_batch
from mesmerize_core.algorithms._utils import save_projections_parallel
else: # when running with local backend
from ..batch_utils import set_parent_raw_data_path, load_batch
from ._utils import save_projections_parallel


def run_algo(batch_path, uuid, data_path: str = None):
Expand Down Expand Up @@ -77,16 +79,9 @@ def run_algo(batch_path, uuid, data_path: str = None):
print("mc finished successfully!")

print("computing projections")
Yr, dims, T = cm.load_memmap(str(mcorr_memmap_path))
images = np.reshape(Yr.T, [T] + list(dims), order="F")

proj_paths = dict()
for proj_type in ["mean", "std", "max"]:
p_img = getattr(np, f"nan{proj_type}")(images, axis=0)
proj_paths[proj_type] = output_dir.joinpath(
f"{uuid}_{proj_type}_projection.npy"
)
np.save(str(proj_paths[proj_type]), p_img)
proj_paths = save_projections_parallel(
uuid=uuid, movie_path=mcorr_memmap_path, output_dir=output_dir, dview=dview
)

print("Computing correlation image")
Cns = local_correlations_movie_offline(
Expand Down
Loading