diff --git a/suite2p/__main__.py b/suite2p/__main__.py index 09b4ce1b..006f6d05 100644 --- a/suite2p/__main__.py +++ b/suite2p/__main__.py @@ -51,16 +51,7 @@ def main(): logging.exception(f'fatal error in {"run_plane" if args.single_plane else "run_s2p"}:') raise - else: - # Check if the OS is macOS and the machine is Apple Silicon (ARM-based) - if platform.system() == "Darwin" and 'arm' in platform.processor().lower(): - # Set the number of threads for OpenMP and OpenBLAS - os.environ["OMP_NUM_THREADS"] = "1" - os.environ["OPENBLAS_NUM_THREADS"] = "1" - print("Environment set to use 1 thread for OpenMP and OpenBLAS (Apple Silicon macOS).") - else: - print("Not macOS on Apple Silicon, proceeding without limiting threads.") - + else: from suite2p import gui gui.run()#statfile="C:/DATA/exs2p/suite2p/plane0/stat.npy") diff --git a/suite2p/extraction/extract.py b/suite2p/extraction/extract.py index 71041fea..4b3832d2 100644 --- a/suite2p/extraction/extract.py +++ b/suite2p/extraction/extract.py @@ -57,21 +57,22 @@ def extract_traces(f_in, cell_masks, neuropil_masks, batch_size=500, if device.type == 'mps': device = torch.device('cpu') - npix_neuropil = torch.Tensor([len(nm) for nm in neuropil_masks]).to(device) - # create coo tensor of neuropil and cell masks - ccol_indices = [m for nm in neuropil_masks for m in nm] - row_indices = [k for k in range(len(neuropil_masks)) for m in neuropil_masks[k]] - inds = torch.Tensor([ccol_indices, row_indices]).to(device) - # convert to csc (tried creating csc directly but it was slow) - nmasks = torch.sparse_coo_tensor(inds, torch.ones(len(row_indices), device=device), - size=(Ly*Lx, ncells)) - nmasks = nmasks.to_sparse_csc() + if neuropil_masks is not None: + npix_neuropil = torch.Tensor([len(nm) for nm in neuropil_masks]).to(device) + # create coo tensor of neuropil masks + ccol_indices = [m for nm in neuropil_masks for m in nm] + row_indices = [k for k in range(len(neuropil_masks)) for m in neuropil_masks[k]] + inds = torch.Tensor([ccol_indices, row_indices]).to(device) + # convert to csc (tried creating csc directly but it was slow) + nmasks = torch.sparse_coo_tensor(inds, torch.ones(len(row_indices), device=device), + size=(Ly*Lx, ncells)) + nmasks = nmasks.to_sparse_csc() ccol_indices = [m for cm in cell_masks for m in cm[0]] row_indices = [k for k in range(len(cell_masks)) for m in cell_masks[k][0]] cell_lam = torch.Tensor([l for cm in cell_masks for l in cm[1]]).to(device) inds = torch.Tensor([ccol_indices, row_indices]).to(device) - cmasks = torch.sparse_coo_tensor(inds, cell_lam, + cmasks = torch.sparse_coo_tensor(inds, cell_lam, size=(Ly*Lx, ncells)) cmasks = cmasks.to_sparse_csc() @@ -87,11 +88,12 @@ def extract_traces(f_in, cell_masks, neuropil_masks, batch_size=500, tstart, tend = n * batch_size, min((n+1) * batch_size, n_frames) data = torch.from_numpy(f_in[tstart : tend]).to(device) data = data.reshape(-1, Ly*Lx).float() - - Fneu_batch = (data @ nmasks) / npix_neuropil - Fneu[:, tstart : tend] = Fneu_batch.T.cpu().numpy() - - F_batch = data @ cmasks + + if neuropil_masks is not None: + Fneu_batch = (data @ nmasks) / npix_neuropil + Fneu[:, tstart : tend] = Fneu_batch.T.cpu().numpy() + + F_batch = data @ cmasks F[:, tstart : tend] = F_batch.T.cpu().numpy() return F, Fneu diff --git a/suite2p/gui/menus.py b/suite2p/gui/menus.py index 29c8fe60..53f9aa62 100644 --- a/suite2p/gui/menus.py +++ b/suite2p/gui/menus.py @@ -3,7 +3,7 @@ """ from qtpy import QtGui from qtpy.QtWidgets import QAction, QMenu -from pkg_resources import iter_entry_points +from importlib.metadata import entry_points from . import reggui, drawroi, merge, io, rungui, visualize, classgui from suite2p.io.nwb import save_nwb @@ -166,7 +166,13 @@ def plugins(parent): main_menu = parent.menuBar() parent.plugins = {} plugin_menu = main_menu.addMenu("&Plugins") - for entry_pt in iter_entry_points(group="suite2p.plugin", name=None): + try: + # Works for python 3.12+ + suite2p_plugins = entry_points(group="suite2p.plugin") + except TypeError: + # works for Python 3.9-3.11 + suite2p_plugins = entry_points().get("suite2p.plugin", []) + for entry_pt in suite2p_plugins: plugin_obj = entry_pt.load() # load the advertised class from entry_points parent.plugins[entry_pt.name] = plugin_obj( parent diff --git a/suite2p/gui/rungui_utils.py b/suite2p/gui/rungui_utils.py index f67124dd..5ec1f448 100644 --- a/suite2p/gui/rungui_utils.py +++ b/suite2p/gui/rungui_utils.py @@ -115,25 +115,67 @@ def setup_logger(name): return logger -class Suite2pWorker(QtCore.QThread): +class Suite2pWorker(QtCore.QObject): + """Worker that runs suite2p in a separate process to avoid QThread stack limitations on macOS.""" finished = QtCore.Signal(str) - + def __init__(self, parent, db_file, settings_file): super(Suite2pWorker, self).__init__() self.db_file = db_file self.settings_file = settings_file self.parent = parent - # self.logHandler = ThreadLogger() - - def run(self): - db = np.load(self.db_file, allow_pickle=True).item() - settings = np.load(self.settings_file, allow_pickle=True).item() - - try: - logger_setup(get_save_folder(db)) - run_s2p(db=db, settings=settings) + self.process = None + + def start(self): + """Start suite2p in a separate process using QProcess.""" + self.process = QtCore.QProcess() + self.process.setProcessChannelMode(QtCore.QProcess.MergedChannels) + self.process.readyReadStandardOutput.connect(self._on_output) + self.process.finished.connect(self._on_finished) + + # Create a Python script to run suite2p + script = f''' +import numpy as np +from suite2p.run_s2p import logger_setup, run_s2p, get_save_folder + +db = np.load("{self.db_file}", allow_pickle=True).item() +settings = np.load("{self.settings_file}", allow_pickle=True).item() + +logger_setup(get_save_folder(db)) +run_s2p(db=db, settings=settings) +''' + self.process.start(sys.executable, ["-c", script]) + + def _on_output(self): + """Handle output from the subprocess.""" + if self.process: + data = self.process.readAllStandardOutput() + text = bytes(data).decode("utf-8", errors="replace") + print(text, end="") + + def _on_finished(self, exit_code, exit_status): + """Handle process completion.""" + if exit_code == 0: self.finished.emit("finished") - except Exception as e: - print("ERROR:", e) - traceback.print_exc() - self.finished.emit("error") \ No newline at end of file + else: + self.finished.emit("error") + + def terminate(self): + """Terminate the subprocess if running.""" + if self.process and self.process.state() != QtCore.QProcess.NotRunning: + self.process.terminate() + + def quit(self): + """Stop the subprocess (alias for terminate, for QThread compatibility).""" + self.terminate() + + def wait(self): + """Wait for the process to finish (for compatibility).""" + if self.process: + self.process.waitForFinished(-1) + + def isRunning(self): + """Check if the process is still running (for QThread compatibility).""" + if self.process: + return self.process.state() == QtCore.QProcess.Running + return False \ No newline at end of file diff --git a/suite2p/io/nwb.py b/suite2p/io/nwb.py index ab4857ab..514fd2e8 100644 --- a/suite2p/io/nwb.py +++ b/suite2p/io/nwb.py @@ -54,6 +54,8 @@ def _load_npy_cross_platform(path): NWB = True except ModuleNotFoundError: NWB = False + logger.warning("pynwb not installed, save_nwb, read_nwb, and nwb_to_binary " + "will not work. Install with: pip install pynwb") def nwb_to_binary(settings): diff --git a/suite2p/registration/nonrigid.py b/suite2p/registration/nonrigid.py index e1e935d7..90e6eb91 100644 --- a/suite2p/registration/nonrigid.py +++ b/suite2p/registration/nonrigid.py @@ -443,11 +443,19 @@ def transform_data(data, nblocks, xblock, yblock, ymax1, xmax1, yxup = yxup.permute(0, 2, 3, 1) if device.type == "mps": - # Manually pad the input tensor with the border values + # Manually pad the input tensor with the border values. data_padded = F.pad(data.float().unsqueeze(1), (1, 1, 1, 1), mode="replicate") - height, width = data.shape[-2:] # Get the height and width of the original data tensor - # Adjust the grid to account for the padding - adjusted_yxup = yxup + torch.tensor([[[[1 / width, 1 / height]]]]).to(yxup.device) # Adjust grid + # Get the height and width of the original data tensor + height, width = data.shape[-2:] + # Scale the grid to account for the padding. Padded data is now of shape (width + 2) x (height + 2). + # Scale_x and scale_y adjust so we exclude the padding. Align_corner is set to true so original image width is width -1. Same for the height. + scale_x = (width - 1) / (width + 1) + scale_y = (height - 1) / (height + 1) + # Scale the padded image to be within the right coordinates for sampling + adjusted_yxup = yxup * torch.tensor([[[[scale_x, scale_y]]]]).to(yxup.device) + # Clamp the grid before subsampling as all coordinate values must lie between [-1,1]. + # Sampling should always be along the image (not include padding coordinates, which will exceed [-1,1] range). + adjusted_yxup = torch.clamp(adjusted_yxup, -1, 1) # Perform grid sampling on the padded tensor fr_shift = F.grid_sample( data_padded, @@ -460,5 +468,5 @@ def transform_data(data, nblocks, xblock, yblock, ymax1, xmax1, fr_shift = F.grid_sample(data.float().unsqueeze(1), yxup[:,:,:,[1,0]], mode="bilinear", padding_mode="border", align_corners=True) - + return fr_shift.squeeze().short()#.cpu().numpy() diff --git a/suite2p/registration/register.py b/suite2p/registration/register.py index 2711ba8e..ce1ace57 100644 --- a/suite2p/registration/register.py +++ b/suite2p/registration/register.py @@ -275,6 +275,10 @@ def compute_filters_and_norm(refImg, norm_frames=True, spatial_smooth=1.15, spat maskMul, maskOffset, cfRefImg = rigid.compute_masks_ref_smooth_fft(refImg=rimg, maskSlope=spatial_taper, smooth_sigma=spatial_smooth) Ly, Lx = refImg.shape + # MPS backend does not support float64, convert to float32 + if device.type == "mps": + maskMul, maskOffset = maskMul.to(torch.float32), maskOffset.to(torch.float32) + cfRefImg = cfRefImg.to(torch.complex64) maskMul, maskOffset = maskMul.to(device), maskOffset.to(device) cfRefImg = cfRefImg.to(device) blocks = [] @@ -282,9 +286,13 @@ def compute_filters_and_norm(refImg, norm_frames=True, spatial_smooth=1.15, spat blocks = nonrigid.make_blocks(Ly=Ly, Lx=Lx, block_size=block_size, lpad=lpad, subpixel=subpixel) maskMulNR, maskOffsetNR, cfRefImgNR = nonrigid.compute_masks_ref_smooth_fft( - refImg0=rimg, maskSlope=spatial_taper, smooth_sigma=spatial_smooth, + refImg0=rimg, maskSlope=spatial_taper, smooth_sigma=spatial_smooth, yblock=blocks[0], xblock=blocks[1], ) + # MPS backend does not support float64, convert to float32 + if device.type == "mps": + maskMulNR, maskOffsetNR = maskMulNR.to(torch.float32), maskOffsetNR.to(torch.float32) + cfRefImgNR = cfRefImgNR.to(torch.complex64) maskMulNR, maskOffsetNR = maskMulNR.to(device), maskOffsetNR.to(device) cfRefImgNR = cfRefImgNR.to(device) @@ -440,6 +448,10 @@ def shift_frames(fr_torch, yoff, xoff, yoff1=None, xoff1=None, blocks=None, if fr_torch.device.type == "cuda": yoff1 = torch.from_numpy(yoff1).pin_memory().to(device) xoff1 = torch.from_numpy(xoff1).pin_memory().to(device) + elif device.type == "mps": + # MPS backend does not support float64 + yoff1 = torch.from_numpy(yoff1).to(torch.float32).to(device) + xoff1 = torch.from_numpy(xoff1).to(torch.float32).to(device) else: yoff1 = torch.from_numpy(yoff1).to(device) xoff1 = torch.from_numpy(xoff1).to(device) @@ -580,7 +592,9 @@ def register_frames(f_align_in, refImg, f_align_out=None, batch_size=100, if upsample_meanImg: if not isinstance(upsample_meanImg, (np.ndarray, list, tuple)): upsample_meanImg = [upsample_meanImg, upsample_meanImg] - mean_img_ups = torch.zeros((int(Ly*upsample_meanImg[0]), int(Lx*upsample_meanImg[1])), dtype=torch.double, device=device) + # MPS backend does not support float64 + ups_dtype = torch.float32 if device.type == "mps" else torch.double + mean_img_ups = torch.zeros((int(Ly*upsample_meanImg[0]), int(Lx*upsample_meanImg[1])), dtype=ups_dtype, device=device) counts_ups = torch.zeros((int(Ly*upsample_meanImg[0]), int(Lx*upsample_meanImg[1])), dtype=torch.int, device=device) else: mean_img_ups, counts_ups, meanImg_ups = None, None, None @@ -890,7 +904,10 @@ def registration_wrapper(f_reg, f_raw=None, f_reg_chan2=None, f_raw_chan2=None, nchannels = 2 if f_alt_in is not None else 1 logger.info(f"registering {nchannels} channels") - + if device.type == "mps": + logger.warning("MPS device does not support float64, using float32 for registration. " + "If you encounter registration issues, try using cuda or cpu instead.") + ### ----- compute reference image and bidiphase shift -------------- ### n_frames, Ly, Lx = f_align_in.shape badframes0 = np.zeros(n_frames, "bool") if badframes is None else badframes.copy() diff --git a/tests/test_registration.py b/tests/test_registration.py index 592eeb4c..0300e122 100644 --- a/tests/test_registration.py +++ b/tests/test_registration.py @@ -1,5 +1,8 @@ import numpy as np +import pytest +import torch from suite2p.registration import bidiphase +from suite2p.registration.nonrigid import transform_data def test_positive_bidiphase_shift_shifts_every_other_line(): @@ -41,4 +44,34 @@ def test_negative_bidiphase_shift_shifts_every_other_line(): shifted = orig.copy() bidiphase.shift(shifted, -2) - assert np.allclose(shifted, expected) \ No newline at end of file + assert np.allclose(shifted, expected) + + +@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available") +def test_transform_data_mps_cpu_consistency(): + """Test that MPS and CPU code paths in transform_data produce similar results.""" + from suite2p.registration.nonrigid import make_blocks + + np.random.seed(42) + torch.manual_seed(42) + Ly, Lx, n_frames = 128, 128, 2 + yblock, xblock, nblocks, *_ = make_blocks(Ly, Lx, (32, 32)) + data_np = np.random.rand(n_frames, Ly, Lx).astype(np.float32) * 100 + ymax1 = torch.randn(nblocks[0] * nblocks[1], n_frames) * 2 + xmax1 = torch.randn(nblocks[0] * nblocks[1], n_frames) * 2 + + result_cpu = transform_data( + torch.from_numpy(data_np), nblocks, xblock, yblock, ymax1.clone(), xmax1.clone() + ) + result_mps = transform_data( + torch.from_numpy(data_np).to("mps"), nblocks, xblock, yblock, + ymax1.clone().to("mps"), xmax1.clone().to("mps") + ) + + cpu_np = result_cpu.numpy().astype(np.float32) + mps_np = result_mps.cpu().numpy().astype(np.float32) + correlation = np.corrcoef(cpu_np.flatten(), mps_np.flatten())[0, 1] + max_diff = np.abs(cpu_np - mps_np).max() + + assert correlation > 0.99, f"Correlation: {correlation}" + assert max_diff < 2, f"Max diff: {max_diff}" \ No newline at end of file