diff --git a/suite2p/__main__.py b/suite2p/__main__.py index 09b4ce1b..006f6d05 100644 --- a/suite2p/__main__.py +++ b/suite2p/__main__.py @@ -51,16 +51,7 @@ def main(): logging.exception(f'fatal error in {"run_plane" if args.single_plane else "run_s2p"}:') raise - else: - # Check if the OS is macOS and the machine is Apple Silicon (ARM-based) - if platform.system() == "Darwin" and 'arm' in platform.processor().lower(): - # Set the number of threads for OpenMP and OpenBLAS - os.environ["OMP_NUM_THREADS"] = "1" - os.environ["OPENBLAS_NUM_THREADS"] = "1" - print("Environment set to use 1 thread for OpenMP and OpenBLAS (Apple Silicon macOS).") - else: - print("Not macOS on Apple Silicon, proceeding without limiting threads.") - + else: from suite2p import gui gui.run()#statfile="C:/DATA/exs2p/suite2p/plane0/stat.npy") diff --git a/suite2p/extraction/extract.py b/suite2p/extraction/extract.py index 71041fea..4b3832d2 100644 --- a/suite2p/extraction/extract.py +++ b/suite2p/extraction/extract.py @@ -57,21 +57,22 @@ def extract_traces(f_in, cell_masks, neuropil_masks, batch_size=500, if device.type == 'mps': device = torch.device('cpu') - npix_neuropil = torch.Tensor([len(nm) for nm in neuropil_masks]).to(device) - # create coo tensor of neuropil and cell masks - ccol_indices = [m for nm in neuropil_masks for m in nm] - row_indices = [k for k in range(len(neuropil_masks)) for m in neuropil_masks[k]] - inds = torch.Tensor([ccol_indices, row_indices]).to(device) - # convert to csc (tried creating csc directly but it was slow) - nmasks = torch.sparse_coo_tensor(inds, torch.ones(len(row_indices), device=device), - size=(Ly*Lx, ncells)) - nmasks = nmasks.to_sparse_csc() + if neuropil_masks is not None: + npix_neuropil = torch.Tensor([len(nm) for nm in neuropil_masks]).to(device) + # create coo tensor of neuropil masks + ccol_indices = [m for nm in neuropil_masks for m in nm] + row_indices = [k for k in range(len(neuropil_masks)) for m in neuropil_masks[k]] + inds = torch.Tensor([ccol_indices, row_indices]).to(device) + # convert to csc (tried creating csc directly but it was slow) + nmasks = torch.sparse_coo_tensor(inds, torch.ones(len(row_indices), device=device), + size=(Ly*Lx, ncells)) + nmasks = nmasks.to_sparse_csc() ccol_indices = [m for cm in cell_masks for m in cm[0]] row_indices = [k for k in range(len(cell_masks)) for m in cell_masks[k][0]] cell_lam = torch.Tensor([l for cm in cell_masks for l in cm[1]]).to(device) inds = torch.Tensor([ccol_indices, row_indices]).to(device) - cmasks = torch.sparse_coo_tensor(inds, cell_lam, + cmasks = torch.sparse_coo_tensor(inds, cell_lam, size=(Ly*Lx, ncells)) cmasks = cmasks.to_sparse_csc() @@ -87,11 +88,12 @@ def extract_traces(f_in, cell_masks, neuropil_masks, batch_size=500, tstart, tend = n * batch_size, min((n+1) * batch_size, n_frames) data = torch.from_numpy(f_in[tstart : tend]).to(device) data = data.reshape(-1, Ly*Lx).float() - - Fneu_batch = (data @ nmasks) / npix_neuropil - Fneu[:, tstart : tend] = Fneu_batch.T.cpu().numpy() - - F_batch = data @ cmasks + + if neuropil_masks is not None: + Fneu_batch = (data @ nmasks) / npix_neuropil + Fneu[:, tstart : tend] = Fneu_batch.T.cpu().numpy() + + F_batch = data @ cmasks F[:, tstart : tend] = F_batch.T.cpu().numpy() return F, Fneu diff --git a/suite2p/gui/drawroi.py b/suite2p/gui/drawroi.py index 99fa73d4..b32e94d8 100644 --- a/suite2p/gui/drawroi.py +++ b/suite2p/gui/drawroi.py @@ -18,6 +18,10 @@ from ..detection.stats import roi_stats from ..extraction import preprocess from ..extraction.dcnv import oasis +from ..extraction.extract import extract_traces +from ..io.binary import BinaryFile +from ..parameters import default_settings +from ..run_s2p import _assign_torch_device def masks_and_traces(settings, stat_manual, stat_orig): @@ -27,6 +31,8 @@ def masks_and_traces(settings, stat_manual, stat_orig): returns: F (ROIs x time), Fneu (ROIs x time), F_chan2, Fneu_chan2, settings, stat F_chan2 and Fneu_chan2 will be empty if no second channel """ + # Merge with defaults to ensure all required keys are present + settings = {**default_settings(), **settings} t0 = time.time() @@ -35,11 +41,12 @@ def masks_and_traces(settings, stat_manual, stat_orig): for n in range(len(stat_orig)): stat_all.append(stat_orig[n]) - stat_all = roi_stats(stat_all, settings["Ly"], settings["Lx"], aspect=settings.get("aspect", None), + stat_all = np.array(stat_all) + stat_all = roi_stats(stat_all, settings["Ly"], settings["Lx"], diameter=settings["diameter"]) cell_masks = [ masks.create_cell_mask(stat, Ly=settings["Ly"], Lx=settings["Lx"], - allow_overlap=settings["allow_overlap"]) for stat in stat_all + allow_overlap=settings["extraction"]["allow_overlap"]) for stat in stat_all ] cell_pix = masks.create_cell_pix(stat_all, Ly=settings["Ly"], Lx=settings["Lx"]) manual_roi_stats = stat_all[:len(stat_manual)] @@ -48,13 +55,26 @@ def masks_and_traces(settings, stat_manual, stat_orig): ypixs=[stat["ypix"] for stat in manual_roi_stats], xpixs=[stat["xpix"] for stat in manual_roi_stats], cell_pix=cell_pix, - inner_neuropil_radius=settings["inner_neuropil_radius"], - min_neuropil_pixels=settings["min_neuropil_pixels"], + inner_neuropil_radius=settings["extraction"]["inner_neuropil_radius"], + min_neuropil_pixels=settings["extraction"]["min_neuropil_pixels"], ) print("Masks made in %0.2f sec." % (time.time() - t0)) - F, Fneu, F_chan2, Fneu_chan2 = extract_traces_from_masks(settings, manual_cell_masks, - manual_neuropil_masks) + # Extract traces from binary file + Ly, Lx = settings["Ly"], settings["Lx"] + batch_size = settings["extraction"]["batch_size"] + device = _assign_torch_device(settings["torch_device"]) + f_reg = BinaryFile(Ly, Lx, settings["reg_file"]) + F, Fneu = extract_traces(f_reg, manual_cell_masks, manual_neuropil_masks, batch_size=batch_size, device=device) + f_reg.close() + + # Handle chan2 if present + if "reg_file_chan2" in settings and settings["reg_file_chan2"]: + f_reg_chan2 = BinaryFile(Ly, Lx, settings["reg_file_chan2"]) + F_chan2, Fneu_chan2 = extract_traces(f_reg_chan2, manual_cell_masks, manual_neuropil_masks, batch_size=batch_size, device=device) + f_reg_chan2.close() + else: + F_chan2, Fneu_chan2 = None, None # compute activity statistics for classifier npix = np.array([stat_orig[n]["npix"] for n in range(len(stat_orig)) @@ -69,7 +89,7 @@ def masks_and_traces(settings, stat_manual, stat_orig): manual_roi_stats[n]["iplane"] = stat_orig[0]["iplane"] # subtract neuropil and compute skew, std from F - dF = F - settings["neucoeff"] * Fneu + dF = F - settings["extraction"]["neuropil_coefficient"] * Fneu sk = stats.skew(dF, axis=1) sd = np.std(dF, axis=1) @@ -81,10 +101,10 @@ def masks_and_traces(settings, stat_manual, stat_orig): np.mean(manual_roi_stats[n]["xpix"]) ] - dF = preprocess(F=dF, baseline=settings["baseline"], win_baseline=settings["win_baseline"], - sig_baseline=settings["sig_baseline"], fs=settings["fs"], - prctile_baseline=settings["prctile_baseline"]) - spks = oasis(F=dF, batch_size=settings["batch_size"], tau=settings["tau"], fs=settings["fs"]) + dF = preprocess(F=dF, baseline=settings["dcnv_preprocess"]["baseline"], win_baseline=settings["dcnv_preprocess"]["win_baseline"], + sig_baseline=settings["dcnv_preprocess"]["sig_baseline"], fs=settings["fs"], + prctile_baseline=settings["dcnv_preprocess"]["prctile_baseline"], device=device) + spks = oasis(F=dF, batch_size=settings["extraction"]["batch_size"], tau=settings["tau"], fs=settings["fs"]) return F, Fneu, F_chan2, Fneu_chan2, spks, settings, manual_roi_stats @@ -187,7 +207,7 @@ def __init__(self, parent): self.saveGUI = False self.closeGUI = QPushButton("Save and Quit") self.closeGUI.setFont(QtGui.QFont("Arial", 8, QtGui.QFont.Bold)) - self.closeGUI.clicked.connect(self.close_GUI) + self.closeGUI.clicked.connect(lambda: self.close_GUI()) self.closeGUI.setEnabled(False) self.closeGUI.setFixedWidth(100) self.closeGUI.setStyleSheet(self.styleUnpressed) @@ -247,9 +267,7 @@ def close_GUI(self): # Append new stat file with old and save print("Saving new stat") - stat_all = self.new_stat.copy() - for n in range(len(self.parent.stat)): - stat_all.append(self.parent.stat[n]) + stat_all = np.concatenate((self.new_stat, self.parent.stat)) np.save(os.path.join(self.parent.basename, "stat.npy"), stat_all) iscell_prob = np.concatenate( (self.parent.iscell[:, np.newaxis], self.parent.probcell[:, np.newaxis]), @@ -297,30 +315,26 @@ def normalize_img_add_masks(self): if i == 0: mimg = np.zeros((self.Ly, self.Lx), np.float32) mimg[self.parent.ops["yrange"][0]:self.parent.ops["yrange"][1], - self.parent.ops["xrange"][0]:self.parent. - settings["xrange"][1]] = self.parent.ops["meanImg"][ + self.parent.ops["xrange"][0]:self.parent.ops["xrange"][1]] = self.parent.ops["meanImg"][ self.parent.ops["yrange"][0]:self.parent.ops["yrange"][1], self.parent.ops["xrange"][0]:self.parent.ops["xrange"][1]] elif i == 1: mimg = np.zeros((self.Ly, self.Lx), np.float32) mimg[self.parent.ops["yrange"][0]:self.parent.ops["yrange"][1], - self.parent.ops["xrange"][0]:self.parent. - settings["xrange"][1]] = self.parent.ops["meanImgE"][ + self.parent.ops["xrange"][0]:self.parent.ops["xrange"][1]] = self.parent.ops["meanImgE"][ self.parent.ops["yrange"][0]:self.parent.ops["yrange"][1], self.parent.ops["xrange"][0]:self.parent.ops["xrange"][1]] elif i == 2: mimg = np.zeros((self.Ly, self.Lx), np.float32) mimg[self.parent.ops["yrange"][0]:self.parent.ops["yrange"][1], - self.parent.ops["xrange"][0]:self.parent. - settings["xrange"][1]] = self.parent.ops["Vcorr"] + self.parent.ops["xrange"][0]:self.parent.ops["xrange"][1]] = self.parent.ops["Vcorr"] else: mimg = np.zeros((self.Ly, self.Lx), np.float32) if "max_proj" in self.parent.ops: mimg[self.parent.ops["yrange"][0]:self.parent.ops["yrange"][1], - self.parent.ops["xrange"][0]:self.parent. - settings["xrange"][1]] = self.parent.ops["max_proj"] + self.parent.ops["xrange"][0]:self.parent.ops["xrange"][1]] = self.parent.ops["max_proj"] mimg1 = np.percentile(mimg, 1) mimg99 = np.percentile(mimg, 99) diff --git a/suite2p/gui/io.py b/suite2p/gui/io.py index 85c98569..0b622d8a 100644 --- a/suite2p/gui/io.py +++ b/suite2p/gui/io.py @@ -501,6 +501,8 @@ def save_merge(parent): (parent.iscell[:, np.newaxis], parent.probcell[:, np.newaxis]), axis=1) np.save(os.path.join(parent.basename, "iscell.npy"), iscell) + parent.lcell0.setText("%d" % (parent.iscell.sum())) + parent.lcell1.setText("%d" % (parent.iscell.size - parent.iscell.sum())) parent.notmerged = np.ones(parent.iscell.size, "bool") diff --git a/suite2p/gui/menus.py b/suite2p/gui/menus.py index 29c8fe60..53f9aa62 100644 --- a/suite2p/gui/menus.py +++ b/suite2p/gui/menus.py @@ -3,7 +3,7 @@ """ from qtpy import QtGui from qtpy.QtWidgets import QAction, QMenu -from pkg_resources import iter_entry_points +from importlib.metadata import entry_points from . import reggui, drawroi, merge, io, rungui, visualize, classgui from suite2p.io.nwb import save_nwb @@ -166,7 +166,13 @@ def plugins(parent): main_menu = parent.menuBar() parent.plugins = {} plugin_menu = main_menu.addMenu("&Plugins") - for entry_pt in iter_entry_points(group="suite2p.plugin", name=None): + try: + # Works for python 3.12+ + suite2p_plugins = entry_points(group="suite2p.plugin") + except TypeError: + # works for Python 3.9-3.11 + suite2p_plugins = entry_points().get("suite2p.plugin", []) + for entry_pt in suite2p_plugins: plugin_obj = entry_pt.load() # load the advertised class from entry_points parent.plugins[entry_pt.name] = plugin_obj( parent diff --git a/suite2p/gui/merge.py b/suite2p/gui/merge.py index 7b0e8882..fd2a543f 100644 --- a/suite2p/gui/merge.py +++ b/suite2p/gui/merge.py @@ -38,6 +38,7 @@ def do_merge(parent): merge_activity_masks(parent) parent.merged.append(parent.imerge) parent.update_plot() + io.save_merge(parent) print(parent.merged) print("merged ROIs") @@ -125,10 +126,11 @@ def merge_activity_masks(parent): if parent.hasred: F_chan2 = F_chan2.mean(axis=0) Fneu_chan2 = Fneu_chan2.mean(axis=0) - dF = F - parent.ops["neucoeff"] * Fneu + dF = F - parent.ops["extraction"]["neuropil_coefficient"] * Fneu # activity stats stat0["skew"] = stats.skew(dF) stat0["std"] = dF.std() + stat0["snr"] = 1 - 0.5 * np.diff(dF).var() / dF.var() spks = oasis(F=dF[np.newaxis, :], batch_size=parent.ops["batch_size"], tau=parent.ops["tau"], fs=parent.ops["fs"]) @@ -151,9 +153,9 @@ def merge_activity_masks(parent): # add cell to structs parent.stat = np.concatenate((parent.stat, np.array([stat0])), axis=0) parent.stat = roi_stats(parent.stat, parent.Ly, parent.Lx, - aspect=parent.ops.get("aspect", None), - diameter=parent.ops.get("diameter", None), - do_crop=parent.ops.get("soma_crop", 1)) + diameter=parent.ops["diameter"], + do_soma_crop=parent.ops["detection"]["soma_crop"], + max_overlap=None) parent.stat[-1]["lam"] = parent.stat[-1]["lam"] * merged_cells.size parent.Fcell = np.concatenate((parent.Fcell, F[np.newaxis, :]), axis=0) parent.Fneu = np.concatenate((parent.Fneu, Fneu[np.newaxis, :]), axis=0) @@ -182,8 +184,10 @@ def merge_activity_masks(parent): # recompute binned F parent.mode_change(parent.activityMode) + # Remove the maskes for the previous cells that were merged. for n in merged_cells: parent.stat[n]["inmerge"] = len(parent.stat) - 1 + parent.iscell[n] = False masks.remove_roi(parent, n, i0) masks.add_roi(parent, len(parent.stat) - 1, i0) masks.redraw_masks(parent, ypix, xpix) @@ -262,6 +266,7 @@ def do_merge(self, parent): merge_activity_masks(parent) parent.merged.append(parent.imerge) parent.update_plot() + io.save_merge(parent) self.cc_row = np.matmul(parent.Fbin[parent.iscell], parent.Fbin[-1].T) / parent.Fbin.shape[-1] diff --git a/suite2p/gui/rungui_utils.py b/suite2p/gui/rungui_utils.py index f67124dd..5ec1f448 100644 --- a/suite2p/gui/rungui_utils.py +++ b/suite2p/gui/rungui_utils.py @@ -115,25 +115,67 @@ def setup_logger(name): return logger -class Suite2pWorker(QtCore.QThread): +class Suite2pWorker(QtCore.QObject): + """Worker that runs suite2p in a separate process to avoid QThread stack limitations on macOS.""" finished = QtCore.Signal(str) - + def __init__(self, parent, db_file, settings_file): super(Suite2pWorker, self).__init__() self.db_file = db_file self.settings_file = settings_file self.parent = parent - # self.logHandler = ThreadLogger() - - def run(self): - db = np.load(self.db_file, allow_pickle=True).item() - settings = np.load(self.settings_file, allow_pickle=True).item() - - try: - logger_setup(get_save_folder(db)) - run_s2p(db=db, settings=settings) + self.process = None + + def start(self): + """Start suite2p in a separate process using QProcess.""" + self.process = QtCore.QProcess() + self.process.setProcessChannelMode(QtCore.QProcess.MergedChannels) + self.process.readyReadStandardOutput.connect(self._on_output) + self.process.finished.connect(self._on_finished) + + # Create a Python script to run suite2p + script = f''' +import numpy as np +from suite2p.run_s2p import logger_setup, run_s2p, get_save_folder + +db = np.load("{self.db_file}", allow_pickle=True).item() +settings = np.load("{self.settings_file}", allow_pickle=True).item() + +logger_setup(get_save_folder(db)) +run_s2p(db=db, settings=settings) +''' + self.process.start(sys.executable, ["-c", script]) + + def _on_output(self): + """Handle output from the subprocess.""" + if self.process: + data = self.process.readAllStandardOutput() + text = bytes(data).decode("utf-8", errors="replace") + print(text, end="") + + def _on_finished(self, exit_code, exit_status): + """Handle process completion.""" + if exit_code == 0: self.finished.emit("finished") - except Exception as e: - print("ERROR:", e) - traceback.print_exc() - self.finished.emit("error") \ No newline at end of file + else: + self.finished.emit("error") + + def terminate(self): + """Terminate the subprocess if running.""" + if self.process and self.process.state() != QtCore.QProcess.NotRunning: + self.process.terminate() + + def quit(self): + """Stop the subprocess (alias for terminate, for QThread compatibility).""" + self.terminate() + + def wait(self): + """Wait for the process to finish (for compatibility).""" + if self.process: + self.process.waitForFinished(-1) + + def isRunning(self): + """Check if the process is still running (for QThread compatibility).""" + if self.process: + return self.process.state() == QtCore.QProcess.Running + return False \ No newline at end of file diff --git a/suite2p/io/nwb.py b/suite2p/io/nwb.py index ab4857ab..514fd2e8 100644 --- a/suite2p/io/nwb.py +++ b/suite2p/io/nwb.py @@ -54,6 +54,8 @@ def _load_npy_cross_platform(path): NWB = True except ModuleNotFoundError: NWB = False + logger.warning("pynwb not installed, save_nwb, read_nwb, and nwb_to_binary " + "will not work. Install with: pip install pynwb") def nwb_to_binary(settings): diff --git a/suite2p/registration/nonrigid.py b/suite2p/registration/nonrigid.py index e1e935d7..90e6eb91 100644 --- a/suite2p/registration/nonrigid.py +++ b/suite2p/registration/nonrigid.py @@ -443,11 +443,19 @@ def transform_data(data, nblocks, xblock, yblock, ymax1, xmax1, yxup = yxup.permute(0, 2, 3, 1) if device.type == "mps": - # Manually pad the input tensor with the border values + # Manually pad the input tensor with the border values. data_padded = F.pad(data.float().unsqueeze(1), (1, 1, 1, 1), mode="replicate") - height, width = data.shape[-2:] # Get the height and width of the original data tensor - # Adjust the grid to account for the padding - adjusted_yxup = yxup + torch.tensor([[[[1 / width, 1 / height]]]]).to(yxup.device) # Adjust grid + # Get the height and width of the original data tensor + height, width = data.shape[-2:] + # Scale the grid to account for the padding. Padded data is now of shape (width + 2) x (height + 2). + # Scale_x and scale_y adjust so we exclude the padding. Align_corner is set to true so original image width is width -1. Same for the height. + scale_x = (width - 1) / (width + 1) + scale_y = (height - 1) / (height + 1) + # Scale the padded image to be within the right coordinates for sampling + adjusted_yxup = yxup * torch.tensor([[[[scale_x, scale_y]]]]).to(yxup.device) + # Clamp the grid before subsampling as all coordinate values must lie between [-1,1]. + # Sampling should always be along the image (not include padding coordinates, which will exceed [-1,1] range). + adjusted_yxup = torch.clamp(adjusted_yxup, -1, 1) # Perform grid sampling on the padded tensor fr_shift = F.grid_sample( data_padded, @@ -460,5 +468,5 @@ def transform_data(data, nblocks, xblock, yblock, ymax1, xmax1, fr_shift = F.grid_sample(data.float().unsqueeze(1), yxup[:,:,:,[1,0]], mode="bilinear", padding_mode="border", align_corners=True) - + return fr_shift.squeeze().short()#.cpu().numpy() diff --git a/suite2p/registration/register.py b/suite2p/registration/register.py index 2711ba8e..ce1ace57 100644 --- a/suite2p/registration/register.py +++ b/suite2p/registration/register.py @@ -275,6 +275,10 @@ def compute_filters_and_norm(refImg, norm_frames=True, spatial_smooth=1.15, spat maskMul, maskOffset, cfRefImg = rigid.compute_masks_ref_smooth_fft(refImg=rimg, maskSlope=spatial_taper, smooth_sigma=spatial_smooth) Ly, Lx = refImg.shape + # MPS backend does not support float64, convert to float32 + if device.type == "mps": + maskMul, maskOffset = maskMul.to(torch.float32), maskOffset.to(torch.float32) + cfRefImg = cfRefImg.to(torch.complex64) maskMul, maskOffset = maskMul.to(device), maskOffset.to(device) cfRefImg = cfRefImg.to(device) blocks = [] @@ -282,9 +286,13 @@ def compute_filters_and_norm(refImg, norm_frames=True, spatial_smooth=1.15, spat blocks = nonrigid.make_blocks(Ly=Ly, Lx=Lx, block_size=block_size, lpad=lpad, subpixel=subpixel) maskMulNR, maskOffsetNR, cfRefImgNR = nonrigid.compute_masks_ref_smooth_fft( - refImg0=rimg, maskSlope=spatial_taper, smooth_sigma=spatial_smooth, + refImg0=rimg, maskSlope=spatial_taper, smooth_sigma=spatial_smooth, yblock=blocks[0], xblock=blocks[1], ) + # MPS backend does not support float64, convert to float32 + if device.type == "mps": + maskMulNR, maskOffsetNR = maskMulNR.to(torch.float32), maskOffsetNR.to(torch.float32) + cfRefImgNR = cfRefImgNR.to(torch.complex64) maskMulNR, maskOffsetNR = maskMulNR.to(device), maskOffsetNR.to(device) cfRefImgNR = cfRefImgNR.to(device) @@ -440,6 +448,10 @@ def shift_frames(fr_torch, yoff, xoff, yoff1=None, xoff1=None, blocks=None, if fr_torch.device.type == "cuda": yoff1 = torch.from_numpy(yoff1).pin_memory().to(device) xoff1 = torch.from_numpy(xoff1).pin_memory().to(device) + elif device.type == "mps": + # MPS backend does not support float64 + yoff1 = torch.from_numpy(yoff1).to(torch.float32).to(device) + xoff1 = torch.from_numpy(xoff1).to(torch.float32).to(device) else: yoff1 = torch.from_numpy(yoff1).to(device) xoff1 = torch.from_numpy(xoff1).to(device) @@ -580,7 +592,9 @@ def register_frames(f_align_in, refImg, f_align_out=None, batch_size=100, if upsample_meanImg: if not isinstance(upsample_meanImg, (np.ndarray, list, tuple)): upsample_meanImg = [upsample_meanImg, upsample_meanImg] - mean_img_ups = torch.zeros((int(Ly*upsample_meanImg[0]), int(Lx*upsample_meanImg[1])), dtype=torch.double, device=device) + # MPS backend does not support float64 + ups_dtype = torch.float32 if device.type == "mps" else torch.double + mean_img_ups = torch.zeros((int(Ly*upsample_meanImg[0]), int(Lx*upsample_meanImg[1])), dtype=ups_dtype, device=device) counts_ups = torch.zeros((int(Ly*upsample_meanImg[0]), int(Lx*upsample_meanImg[1])), dtype=torch.int, device=device) else: mean_img_ups, counts_ups, meanImg_ups = None, None, None @@ -890,7 +904,10 @@ def registration_wrapper(f_reg, f_raw=None, f_reg_chan2=None, f_raw_chan2=None, nchannels = 2 if f_alt_in is not None else 1 logger.info(f"registering {nchannels} channels") - + if device.type == "mps": + logger.warning("MPS device does not support float64, using float32 for registration. " + "If you encounter registration issues, try using cuda or cpu instead.") + ### ----- compute reference image and bidiphase shift -------------- ### n_frames, Ly, Lx = f_align_in.shape badframes0 = np.zeros(n_frames, "bool") if badframes is None else badframes.copy() diff --git a/tests/test_registration.py b/tests/test_registration.py index 592eeb4c..0300e122 100644 --- a/tests/test_registration.py +++ b/tests/test_registration.py @@ -1,5 +1,8 @@ import numpy as np +import pytest +import torch from suite2p.registration import bidiphase +from suite2p.registration.nonrigid import transform_data def test_positive_bidiphase_shift_shifts_every_other_line(): @@ -41,4 +44,34 @@ def test_negative_bidiphase_shift_shifts_every_other_line(): shifted = orig.copy() bidiphase.shift(shifted, -2) - assert np.allclose(shifted, expected) \ No newline at end of file + assert np.allclose(shifted, expected) + + +@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available") +def test_transform_data_mps_cpu_consistency(): + """Test that MPS and CPU code paths in transform_data produce similar results.""" + from suite2p.registration.nonrigid import make_blocks + + np.random.seed(42) + torch.manual_seed(42) + Ly, Lx, n_frames = 128, 128, 2 + yblock, xblock, nblocks, *_ = make_blocks(Ly, Lx, (32, 32)) + data_np = np.random.rand(n_frames, Ly, Lx).astype(np.float32) * 100 + ymax1 = torch.randn(nblocks[0] * nblocks[1], n_frames) * 2 + xmax1 = torch.randn(nblocks[0] * nblocks[1], n_frames) * 2 + + result_cpu = transform_data( + torch.from_numpy(data_np), nblocks, xblock, yblock, ymax1.clone(), xmax1.clone() + ) + result_mps = transform_data( + torch.from_numpy(data_np).to("mps"), nblocks, xblock, yblock, + ymax1.clone().to("mps"), xmax1.clone().to("mps") + ) + + cpu_np = result_cpu.numpy().astype(np.float32) + mps_np = result_mps.cpu().numpy().astype(np.float32) + correlation = np.corrcoef(cpu_np.flatten(), mps_np.flatten())[0, 1] + max_diff = np.abs(cpu_np - mps_np).max() + + assert correlation > 0.99, f"Correlation: {correlation}" + assert max_diff < 2, f"Max diff: {max_diff}" \ No newline at end of file