Skip to content
Open
11 changes: 1 addition & 10 deletions suite2p/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,7 @@ def main():
logging.exception(f'fatal error in {"run_plane" if args.single_plane else "run_s2p"}:')
raise

else:
# Check if the OS is macOS and the machine is Apple Silicon (ARM-based)
if platform.system() == "Darwin" and 'arm' in platform.processor().lower():
# Set the number of threads for OpenMP and OpenBLAS
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
print("Environment set to use 1 thread for OpenMP and OpenBLAS (Apple Silicon macOS).")
else:
print("Not macOS on Apple Silicon, proceeding without limiting threads.")

else:
from suite2p import gui
gui.run()#statfile="C:/DATA/exs2p/suite2p/plane0/stat.npy")

Expand Down
32 changes: 17 additions & 15 deletions suite2p/extraction/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,22 @@ def extract_traces(f_in, cell_masks, neuropil_masks, batch_size=500,
if device.type == 'mps':
device = torch.device('cpu')

npix_neuropil = torch.Tensor([len(nm) for nm in neuropil_masks]).to(device)
# create coo tensor of neuropil and cell masks
ccol_indices = [m for nm in neuropil_masks for m in nm]
row_indices = [k for k in range(len(neuropil_masks)) for m in neuropil_masks[k]]
inds = torch.Tensor([ccol_indices, row_indices]).to(device)
# convert to csc (tried creating csc directly but it was slow)
nmasks = torch.sparse_coo_tensor(inds, torch.ones(len(row_indices), device=device),
size=(Ly*Lx, ncells))
nmasks = nmasks.to_sparse_csc()
if neuropil_masks is not None:
npix_neuropil = torch.Tensor([len(nm) for nm in neuropil_masks]).to(device)
# create coo tensor of neuropil masks
ccol_indices = [m for nm in neuropil_masks for m in nm]
row_indices = [k for k in range(len(neuropil_masks)) for m in neuropil_masks[k]]
inds = torch.Tensor([ccol_indices, row_indices]).to(device)
# convert to csc (tried creating csc directly but it was slow)
nmasks = torch.sparse_coo_tensor(inds, torch.ones(len(row_indices), device=device),
size=(Ly*Lx, ncells))
nmasks = nmasks.to_sparse_csc()

ccol_indices = [m for cm in cell_masks for m in cm[0]]
row_indices = [k for k in range(len(cell_masks)) for m in cell_masks[k][0]]
cell_lam = torch.Tensor([l for cm in cell_masks for l in cm[1]]).to(device)
inds = torch.Tensor([ccol_indices, row_indices]).to(device)
cmasks = torch.sparse_coo_tensor(inds, cell_lam,
cmasks = torch.sparse_coo_tensor(inds, cell_lam,
size=(Ly*Lx, ncells))
cmasks = cmasks.to_sparse_csc()

Expand All @@ -87,11 +88,12 @@ def extract_traces(f_in, cell_masks, neuropil_masks, batch_size=500,
tstart, tend = n * batch_size, min((n+1) * batch_size, n_frames)
data = torch.from_numpy(f_in[tstart : tend]).to(device)
data = data.reshape(-1, Ly*Lx).float()

Fneu_batch = (data @ nmasks) / npix_neuropil
Fneu[:, tstart : tend] = Fneu_batch.T.cpu().numpy()

F_batch = data @ cmasks

if neuropil_masks is not None:
Fneu_batch = (data @ nmasks) / npix_neuropil
Fneu[:, tstart : tend] = Fneu_batch.T.cpu().numpy()

F_batch = data @ cmasks
F[:, tstart : tend] = F_batch.T.cpu().numpy()

return F, Fneu
Expand Down
10 changes: 8 additions & 2 deletions suite2p/gui/menus.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
from qtpy import QtGui
from qtpy.QtWidgets import QAction, QMenu
from pkg_resources import iter_entry_points
from importlib.metadata import entry_points

from . import reggui, drawroi, merge, io, rungui, visualize, classgui
from suite2p.io.nwb import save_nwb
Expand Down Expand Up @@ -166,7 +166,13 @@ def plugins(parent):
main_menu = parent.menuBar()
parent.plugins = {}
plugin_menu = main_menu.addMenu("&Plugins")
for entry_pt in iter_entry_points(group="suite2p.plugin", name=None):
try:
# Works for python 3.12+
suite2p_plugins = entry_points(group="suite2p.plugin")
except TypeError:
# works for Python 3.9-3.11
suite2p_plugins = entry_points().get("suite2p.plugin", [])
for entry_pt in suite2p_plugins:
plugin_obj = entry_pt.load() # load the advertised class from entry_points
parent.plugins[entry_pt.name] = plugin_obj(
parent
Expand Down
72 changes: 57 additions & 15 deletions suite2p/gui/rungui_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,25 +115,67 @@ def setup_logger(name):
return logger


class Suite2pWorker(QtCore.QThread):
class Suite2pWorker(QtCore.QObject):
"""Worker that runs suite2p in a separate process to avoid QThread stack limitations on macOS."""
finished = QtCore.Signal(str)

def __init__(self, parent, db_file, settings_file):
super(Suite2pWorker, self).__init__()
self.db_file = db_file
self.settings_file = settings_file
self.parent = parent
# self.logHandler = ThreadLogger()

def run(self):
db = np.load(self.db_file, allow_pickle=True).item()
settings = np.load(self.settings_file, allow_pickle=True).item()

try:
logger_setup(get_save_folder(db))
run_s2p(db=db, settings=settings)
self.process = None

def start(self):
"""Start suite2p in a separate process using QProcess."""
self.process = QtCore.QProcess()
self.process.setProcessChannelMode(QtCore.QProcess.MergedChannels)
self.process.readyReadStandardOutput.connect(self._on_output)
self.process.finished.connect(self._on_finished)

# Create a Python script to run suite2p
script = f'''
import numpy as np
from suite2p.run_s2p import logger_setup, run_s2p, get_save_folder

db = np.load("{self.db_file}", allow_pickle=True).item()
settings = np.load("{self.settings_file}", allow_pickle=True).item()

logger_setup(get_save_folder(db))
run_s2p(db=db, settings=settings)
'''
self.process.start(sys.executable, ["-c", script])

def _on_output(self):
"""Handle output from the subprocess."""
if self.process:
data = self.process.readAllStandardOutput()
text = bytes(data).decode("utf-8", errors="replace")
print(text, end="")

def _on_finished(self, exit_code, exit_status):
"""Handle process completion."""
if exit_code == 0:
self.finished.emit("finished")
except Exception as e:
print("ERROR:", e)
traceback.print_exc()
self.finished.emit("error")
else:
self.finished.emit("error")

def terminate(self):
"""Terminate the subprocess if running."""
if self.process and self.process.state() != QtCore.QProcess.NotRunning:
self.process.terminate()

def quit(self):
"""Stop the subprocess (alias for terminate, for QThread compatibility)."""
self.terminate()

def wait(self):
"""Wait for the process to finish (for compatibility)."""
if self.process:
self.process.waitForFinished(-1)

def isRunning(self):
"""Check if the process is still running (for QThread compatibility)."""
if self.process:
return self.process.state() == QtCore.QProcess.Running
return False
2 changes: 2 additions & 0 deletions suite2p/io/nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def _load_npy_cross_platform(path):
NWB = True
except ModuleNotFoundError:
NWB = False
logger.warning("pynwb not installed, save_nwb, read_nwb, and nwb_to_binary "
"will not work. Install with: pip install pynwb")


def nwb_to_binary(settings):
Expand Down
18 changes: 13 additions & 5 deletions suite2p/registration/nonrigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,11 +443,19 @@ def transform_data(data, nblocks, xblock, yblock, ymax1, xmax1,
yxup = yxup.permute(0, 2, 3, 1)

if device.type == "mps":
# Manually pad the input tensor with the border values
# Manually pad the input tensor with the border values.
data_padded = F.pad(data.float().unsqueeze(1), (1, 1, 1, 1), mode="replicate")
height, width = data.shape[-2:] # Get the height and width of the original data tensor
# Adjust the grid to account for the padding
adjusted_yxup = yxup + torch.tensor([[[[1 / width, 1 / height]]]]).to(yxup.device) # Adjust grid
# Get the height and width of the original data tensor
height, width = data.shape[-2:]
# Scale the grid to account for the padding. Padded data is now of shape (width + 2) x (height + 2).
# Scale_x and scale_y adjust so we exclude the padding. Align_corner is set to true so original image width is width -1. Same for the height.
scale_x = (width - 1) / (width + 1)
scale_y = (height - 1) / (height + 1)
# Scale the padded image to be within the right coordinates for sampling
adjusted_yxup = yxup * torch.tensor([[[[scale_x, scale_y]]]]).to(yxup.device)
# Clamp the grid before subsampling as all coordinate values must lie between [-1,1].
# Sampling should always be along the image (not include padding coordinates, which will exceed [-1,1] range).
adjusted_yxup = torch.clamp(adjusted_yxup, -1, 1)
# Perform grid sampling on the padded tensor
fr_shift = F.grid_sample(
data_padded,
Expand All @@ -460,5 +468,5 @@ def transform_data(data, nblocks, xblock, yblock, ymax1, xmax1,
fr_shift = F.grid_sample(data.float().unsqueeze(1), yxup[:,:,:,[1,0]],
mode="bilinear", padding_mode="border", align_corners=True)


return fr_shift.squeeze().short()#.cpu().numpy()
23 changes: 20 additions & 3 deletions suite2p/registration/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,16 +275,24 @@ def compute_filters_and_norm(refImg, norm_frames=True, spatial_smooth=1.15, spat
maskMul, maskOffset, cfRefImg = rigid.compute_masks_ref_smooth_fft(refImg=rimg, maskSlope=spatial_taper,
smooth_sigma=spatial_smooth)
Ly, Lx = refImg.shape
# MPS backend does not support float64, convert to float32
if device.type == "mps":
maskMul, maskOffset = maskMul.to(torch.float32), maskOffset.to(torch.float32)
cfRefImg = cfRefImg.to(torch.complex64)
maskMul, maskOffset = maskMul.to(device), maskOffset.to(device)
cfRefImg = cfRefImg.to(device)
blocks = []
if block_size is not None:
blocks = nonrigid.make_blocks(Ly=Ly, Lx=Lx, block_size=block_size,
lpad=lpad, subpixel=subpixel)
maskMulNR, maskOffsetNR, cfRefImgNR = nonrigid.compute_masks_ref_smooth_fft(
refImg0=rimg, maskSlope=spatial_taper, smooth_sigma=spatial_smooth,
refImg0=rimg, maskSlope=spatial_taper, smooth_sigma=spatial_smooth,
yblock=blocks[0], xblock=blocks[1],
)
# MPS backend does not support float64, convert to float32
if device.type == "mps":
maskMulNR, maskOffsetNR = maskMulNR.to(torch.float32), maskOffsetNR.to(torch.float32)
cfRefImgNR = cfRefImgNR.to(torch.complex64)
maskMulNR, maskOffsetNR = maskMulNR.to(device), maskOffsetNR.to(device)
cfRefImgNR = cfRefImgNR.to(device)

Expand Down Expand Up @@ -440,6 +448,10 @@ def shift_frames(fr_torch, yoff, xoff, yoff1=None, xoff1=None, blocks=None,
if fr_torch.device.type == "cuda":
yoff1 = torch.from_numpy(yoff1).pin_memory().to(device)
xoff1 = torch.from_numpy(xoff1).pin_memory().to(device)
elif device.type == "mps":
# MPS backend does not support float64
yoff1 = torch.from_numpy(yoff1).to(torch.float32).to(device)
xoff1 = torch.from_numpy(xoff1).to(torch.float32).to(device)
else:
yoff1 = torch.from_numpy(yoff1).to(device)
xoff1 = torch.from_numpy(xoff1).to(device)
Expand Down Expand Up @@ -580,7 +592,9 @@ def register_frames(f_align_in, refImg, f_align_out=None, batch_size=100,
if upsample_meanImg:
if not isinstance(upsample_meanImg, (np.ndarray, list, tuple)):
upsample_meanImg = [upsample_meanImg, upsample_meanImg]
mean_img_ups = torch.zeros((int(Ly*upsample_meanImg[0]), int(Lx*upsample_meanImg[1])), dtype=torch.double, device=device)
# MPS backend does not support float64
ups_dtype = torch.float32 if device.type == "mps" else torch.double
mean_img_ups = torch.zeros((int(Ly*upsample_meanImg[0]), int(Lx*upsample_meanImg[1])), dtype=ups_dtype, device=device)
counts_ups = torch.zeros((int(Ly*upsample_meanImg[0]), int(Lx*upsample_meanImg[1])), dtype=torch.int, device=device)
else:
mean_img_ups, counts_ups, meanImg_ups = None, None, None
Expand Down Expand Up @@ -890,7 +904,10 @@ def registration_wrapper(f_reg, f_raw=None, f_reg_chan2=None, f_raw_chan2=None,

nchannels = 2 if f_alt_in is not None else 1
logger.info(f"registering {nchannels} channels")

if device.type == "mps":
logger.warning("MPS device does not support float64, using float32 for registration. "
"If you encounter registration issues, try using cuda or cpu instead.")

### ----- compute reference image and bidiphase shift -------------- ###
n_frames, Ly, Lx = f_align_in.shape
badframes0 = np.zeros(n_frames, "bool") if badframes is None else badframes.copy()
Expand Down
35 changes: 34 additions & 1 deletion tests/test_registration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import numpy as np
import pytest
import torch
from suite2p.registration import bidiphase
from suite2p.registration.nonrigid import transform_data


def test_positive_bidiphase_shift_shifts_every_other_line():
Expand Down Expand Up @@ -41,4 +44,34 @@ def test_negative_bidiphase_shift_shifts_every_other_line():

shifted = orig.copy()
bidiphase.shift(shifted, -2)
assert np.allclose(shifted, expected)
assert np.allclose(shifted, expected)


@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
def test_transform_data_mps_cpu_consistency():
"""Test that MPS and CPU code paths in transform_data produce similar results."""
from suite2p.registration.nonrigid import make_blocks

np.random.seed(42)
torch.manual_seed(42)
Ly, Lx, n_frames = 128, 128, 2
yblock, xblock, nblocks, *_ = make_blocks(Ly, Lx, (32, 32))
data_np = np.random.rand(n_frames, Ly, Lx).astype(np.float32) * 100
ymax1 = torch.randn(nblocks[0] * nblocks[1], n_frames) * 2
xmax1 = torch.randn(nblocks[0] * nblocks[1], n_frames) * 2

result_cpu = transform_data(
torch.from_numpy(data_np), nblocks, xblock, yblock, ymax1.clone(), xmax1.clone()
)
result_mps = transform_data(
torch.from_numpy(data_np).to("mps"), nblocks, xblock, yblock,
ymax1.clone().to("mps"), xmax1.clone().to("mps")
)

cpu_np = result_cpu.numpy().astype(np.float32)
mps_np = result_mps.cpu().numpy().astype(np.float32)
correlation = np.corrcoef(cpu_np.flatten(), mps_np.flatten())[0, 1]
max_diff = np.abs(cpu_np - mps_np).max()

assert correlation > 0.99, f"Correlation: {correlation}"
assert max_diff < 2, f"Max diff: {max_diff}"
Loading