Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions mesmerize_core/algorithms/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from contextlib import contextmanager
import os
import psutil
from typing import (Optional, Union, Generator, Protocol,
Callable, TypeVar, Sequence, Iterable, runtime_checkable)

import caiman as cm
from caiman.cluster import setup_cluster
from ipyparallel import DirectView
from multiprocessing.pool import Pool


RetVal = TypeVar("RetVal")
@runtime_checkable
class CustomCluster(Protocol):
"""
Protocol for a cluster that is not a multiprocessing pool
(including ipyparallel.DirectView)
"""

def map_sync(
self, fn: Callable[..., RetVal], args: Iterable
) -> Sequence[RetVal]: ...

def __len__(self) -> int:
"""return number of workers"""
...


Cluster = Union[Pool, CustomCluster, DirectView]


def get_n_processes(dview: Optional[Cluster]) -> int:
"""Infer number of processes in a multiprocessing or ipyparallel cluster"""
if isinstance(dview, Pool):
assert hasattr(dview, '_processes'), "Pool not keeping track of # of processes?"
return dview._processes # type: ignore
elif dview is not None:
return len(dview)
else:
return 1


@contextmanager
def ensure_server(dview: Optional[Cluster]) -> Generator[tuple[Cluster, int], None, None]:
"""
Context manager that passes through an existing 'dview' or
opens up a multiprocessing server if none is passed in.
If a server was opened, closes it upon exit.
Usage: `with ensure_server(dview) as (dview, n_processes):`
"""
if dview is not None:
yield dview, get_n_processes(dview)
else:
# no cluster passed in, so open one
procs_available = psutil.cpu_count()
if procs_available is None:
raise RuntimeError('Cannot determine number of processes')

if "MESMERIZE_N_PROCESSES" in os.environ.keys():
try:
n_processes = int(os.environ["MESMERIZE_N_PROCESSES"])
except:
n_processes = procs_available - 1
else:
n_processes = procs_available - 1

# Start cluster for parallel processing
_, dview, n_processes = setup_cluster(
backend="multiprocessing", n_processes=n_processes, single_thread=False
)
assert isinstance(dview, Pool) and isinstance(n_processes, int), 'setup_cluster with multiprocessing did not return a Pool'
try:
yield dview, n_processes
finally:
cm.stop_server(dview=dview)
166 changes: 74 additions & 92 deletions mesmerize_core/algorithms/cnmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,24 @@
import caiman as cm
from caiman.source_extraction.cnmf import cnmf as cnmf
from caiman.source_extraction.cnmf.params import CNMFParams
import psutil
import numpy as np
import traceback
from pathlib import Path, PurePosixPath
from shutil import move as move_file
import os
import time

# prevent circular import
if __name__ in ["__main__", "__mp_main__"]: # when running in subprocess
from mesmerize_core import set_parent_raw_data_path, load_batch
from mesmerize_core.utils import IS_WINDOWS
from mesmerize_core.algorithms._utils import ensure_server
else: # when running with local backend
from ..batch_utils import set_parent_raw_data_path, load_batch
from ..utils import IS_WINDOWS
from ._utils import ensure_server


def run_algo(batch_path, uuid, data_path: str = None):
def run_algo(batch_path, uuid, data_path: str = None, dview=None):
algo_start = time.time()
set_parent_raw_data_path(data_path)

Expand All @@ -42,103 +42,85 @@ def run_algo(batch_path, uuid, data_path: str = None):
f"Starting CNMF item:\n{item}\nWith params:{params}"
)

# adapted from current demo notebook
if "MESMERIZE_N_PROCESSES" in os.environ.keys():
with ensure_server(dview) as (dview, n_processes):

# merge cnmf and eval kwargs into one dict
cnmf_params = CNMFParams(params_dict=params["main"])
# Run CNMF, denote boolean 'success' if CNMF completes w/out error
try:
n_processes = int(os.environ["MESMERIZE_N_PROCESSES"])
except:
n_processes = psutil.cpu_count() - 1
else:
n_processes = psutil.cpu_count() - 1
# Start cluster for parallel processing
c, dview, n_processes = cm.cluster.setup_cluster(
backend="local", n_processes=n_processes, single_thread=False
)
fname_new = cm.save_memmap(
[input_movie_path], base_name=f"{uuid}_cnmf-memmap_", order="C", dview=dview
)

# merge cnmf and eval kwargs into one dict
cnmf_params = CNMFParams(params_dict=params["main"])
# Run CNMF, denote boolean 'success' if CNMF completes w/out error
try:
fname_new = cm.save_memmap(
[input_movie_path], base_name=f"{uuid}_cnmf-memmap_", order="C", dview=dview
)
print("making memmap")

print("making memmap")
Yr, dims, T = cm.load_memmap(fname_new)

Yr, dims, T = cm.load_memmap(fname_new)
images = np.reshape(Yr.T, [T] + list(dims), order="F")
images = np.reshape(Yr.T, [T] + list(dims), order="F")

proj_paths = dict()
for proj_type in ["mean", "std", "max"]:
p_img = getattr(np, f"nan{proj_type}")(images, axis=0)
proj_paths[proj_type] = output_dir.joinpath(
f"{uuid}_{proj_type}_projection.npy"
)
np.save(str(proj_paths[proj_type]), p_img)

# in fname new load in memmap order C
cm.stop_server(dview=dview)
c, dview, n_processes = cm.cluster.setup_cluster(
backend="local", n_processes=None, single_thread=False
)
Comment on lines -79 to -83
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure what this part is for; it doesn't really work with the changes here, so I just deleted it, but we can try to do something else if it's important.


print("performing CNMF")
cnm = cnmf.CNMF(n_processes, params=cnmf_params, dview=dview)

print("fitting images")
cnm.fit(images)
#
if "refit" in params.keys():
if params["refit"] is True:
print("refitting")
cnm = cnm.refit(images, dview=dview)

print("performing eval")
cnm.estimates.evaluate_components(images, cnm.params, dview=dview)

output_path = output_dir.joinpath(f"{uuid}.hdf5")

cnm.save(str(output_path))

Cn = cm.local_correlations(images, swap_dim=False)
Cn[np.isnan(Cn)] = 0

corr_img_path = output_dir.joinpath(f"{uuid}_cn.npy")
np.save(str(corr_img_path), Cn, allow_pickle=False)

# output dict for dataframe row (pd.Series)
d = dict()

cnmf_memmap_path = output_dir.joinpath(Path(fname_new).name)
if IS_WINDOWS:
Yr._mmap.close() # accessing private attr but windows is annoying otherwise
move_file(fname_new, cnmf_memmap_path)

# save paths as relative path strings with forward slashes
cnmf_hdf5_path = str(PurePosixPath(output_path.relative_to(output_dir.parent)))
cnmf_memmap_path = str(
PurePosixPath(cnmf_memmap_path.relative_to(output_dir.parent))
)
corr_img_path = str(PurePosixPath(corr_img_path.relative_to(output_dir.parent)))
for proj_type in proj_paths.keys():
d[f"{proj_type}-projection-path"] = str(
PurePosixPath(proj_paths[proj_type].relative_to(output_dir.parent))
)
proj_paths = dict()
for proj_type in ["mean", "std", "max"]:
p_img = getattr(np, f"nan{proj_type}")(images, axis=0)
proj_paths[proj_type] = output_dir.joinpath(
f"{uuid}_{proj_type}_projection.npy"
)
np.save(str(proj_paths[proj_type]), p_img)

print("performing CNMF")
cnm = cnmf.CNMF(n_processes, params=cnmf_params, dview=dview)

print("fitting images")
cnm.fit(images)
#
if "refit" in params.keys():
if params["refit"] is True:
print("refitting")
cnm = cnm.refit(images, dview=dview)

print("performing eval")
cnm.estimates.evaluate_components(images, cnm.params, dview=dview)

d.update(
{
"cnmf-hdf5-path": cnmf_hdf5_path,
"cnmf-memmap-path": cnmf_memmap_path,
"corr-img-path": corr_img_path,
"success": True,
"traceback": None,
}
)
output_path = output_dir.joinpath(f"{uuid}.hdf5")

except:
d = {"success": False, "traceback": traceback.format_exc()}
cnm.save(str(output_path))

cm.stop_server(dview=dview)
Cn = cm.local_correlations(images, swap_dim=False)
Cn[np.isnan(Cn)] = 0

corr_img_path = output_dir.joinpath(f"{uuid}_cn.npy")
np.save(str(corr_img_path), Cn, allow_pickle=False)

# output dict for dataframe row (pd.Series)
d = dict()

cnmf_memmap_path = output_dir.joinpath(Path(fname_new).name)
if IS_WINDOWS:
Yr._mmap.close() # accessing private attr but windows is annoying otherwise
move_file(fname_new, cnmf_memmap_path)

# save paths as relative path strings with forward slashes
cnmf_hdf5_path = str(PurePosixPath(output_path.relative_to(output_dir.parent)))
cnmf_memmap_path = str(
PurePosixPath(cnmf_memmap_path.relative_to(output_dir.parent))
)
corr_img_path = str(PurePosixPath(corr_img_path.relative_to(output_dir.parent)))
for proj_type in proj_paths.keys():
d[f"{proj_type}-projection-path"] = str(
PurePosixPath(proj_paths[proj_type].relative_to(output_dir.parent))
)

d.update(
{
"cnmf-hdf5-path": cnmf_hdf5_path,
"cnmf-memmap-path": cnmf_memmap_path,
"corr-img-path": corr_img_path,
"success": True,
"traceback": None,
}
)

except:
d = {"success": False, "traceback": traceback.format_exc()}

runtime = round(time.time() - algo_start, 2)
df.caiman.update_item_with_results(uuid, d, runtime)
Expand Down
Loading
Loading