Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions suite2p/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,7 @@ def main():
logging.exception(f'fatal error in {"run_plane" if args.single_plane else "run_s2p"}:')
raise

else:
# Check if the OS is macOS and the machine is Apple Silicon (ARM-based)
if platform.system() == "Darwin" and 'arm' in platform.processor().lower():
# Set the number of threads for OpenMP and OpenBLAS
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
print("Environment set to use 1 thread for OpenMP and OpenBLAS (Apple Silicon macOS).")
else:
print("Not macOS on Apple Silicon, proceeding without limiting threads.")

else:
from suite2p import gui
gui.run()#statfile="C:/DATA/exs2p/suite2p/plane0/stat.npy")

Expand Down
32 changes: 17 additions & 15 deletions suite2p/extraction/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,22 @@ def extract_traces(f_in, cell_masks, neuropil_masks, batch_size=500,
if device.type == 'mps':
device = torch.device('cpu')

npix_neuropil = torch.Tensor([len(nm) for nm in neuropil_masks]).to(device)
# create coo tensor of neuropil and cell masks
ccol_indices = [m for nm in neuropil_masks for m in nm]
row_indices = [k for k in range(len(neuropil_masks)) for m in neuropil_masks[k]]
inds = torch.Tensor([ccol_indices, row_indices]).to(device)
# convert to csc (tried creating csc directly but it was slow)
nmasks = torch.sparse_coo_tensor(inds, torch.ones(len(row_indices), device=device),
size=(Ly*Lx, ncells))
nmasks = nmasks.to_sparse_csc()
if neuropil_masks is not None:
npix_neuropil = torch.Tensor([len(nm) for nm in neuropil_masks]).to(device)
# create coo tensor of neuropil masks
ccol_indices = [m for nm in neuropil_masks for m in nm]
row_indices = [k for k in range(len(neuropil_masks)) for m in neuropil_masks[k]]
inds = torch.Tensor([ccol_indices, row_indices]).to(device)
# convert to csc (tried creating csc directly but it was slow)
nmasks = torch.sparse_coo_tensor(inds, torch.ones(len(row_indices), device=device),
size=(Ly*Lx, ncells))
nmasks = nmasks.to_sparse_csc()

ccol_indices = [m for cm in cell_masks for m in cm[0]]
row_indices = [k for k in range(len(cell_masks)) for m in cell_masks[k][0]]
cell_lam = torch.Tensor([l for cm in cell_masks for l in cm[1]]).to(device)
inds = torch.Tensor([ccol_indices, row_indices]).to(device)
cmasks = torch.sparse_coo_tensor(inds, cell_lam,
cmasks = torch.sparse_coo_tensor(inds, cell_lam,
size=(Ly*Lx, ncells))
cmasks = cmasks.to_sparse_csc()

Expand All @@ -87,11 +88,12 @@ def extract_traces(f_in, cell_masks, neuropil_masks, batch_size=500,
tstart, tend = n * batch_size, min((n+1) * batch_size, n_frames)
data = torch.from_numpy(f_in[tstart : tend]).to(device)
data = data.reshape(-1, Ly*Lx).float()

Fneu_batch = (data @ nmasks) / npix_neuropil
Fneu[:, tstart : tend] = Fneu_batch.T.cpu().numpy()

F_batch = data @ cmasks

if neuropil_masks is not None:
Fneu_batch = (data @ nmasks) / npix_neuropil
Fneu[:, tstart : tend] = Fneu_batch.T.cpu().numpy()

F_batch = data @ cmasks
F[:, tstart : tend] = F_batch.T.cpu().numpy()

return F, Fneu
Expand Down
60 changes: 37 additions & 23 deletions suite2p/gui/drawroi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
from ..detection.stats import roi_stats
from ..extraction import preprocess
from ..extraction.dcnv import oasis
from ..extraction.extract import extract_traces
from ..io.binary import BinaryFile
from ..parameters import default_settings
from ..run_s2p import _assign_torch_device


def masks_and_traces(settings, stat_manual, stat_orig):
Expand All @@ -27,6 +31,8 @@ def masks_and_traces(settings, stat_manual, stat_orig):
returns: F (ROIs x time), Fneu (ROIs x time), F_chan2, Fneu_chan2, settings, stat
F_chan2 and Fneu_chan2 will be empty if no second channel
"""
# Merge with defaults to ensure all required keys are present
settings = {**default_settings(), **settings}

t0 = time.time()

Expand All @@ -35,11 +41,12 @@ def masks_and_traces(settings, stat_manual, stat_orig):
for n in range(len(stat_orig)):
stat_all.append(stat_orig[n])

stat_all = roi_stats(stat_all, settings["Ly"], settings["Lx"], aspect=settings.get("aspect", None),
stat_all = np.array(stat_all)
stat_all = roi_stats(stat_all, settings["Ly"], settings["Lx"],
diameter=settings["diameter"])
cell_masks = [
masks.create_cell_mask(stat, Ly=settings["Ly"], Lx=settings["Lx"],
allow_overlap=settings["allow_overlap"]) for stat in stat_all
allow_overlap=settings["extraction"]["allow_overlap"]) for stat in stat_all
]
cell_pix = masks.create_cell_pix(stat_all, Ly=settings["Ly"], Lx=settings["Lx"])
manual_roi_stats = stat_all[:len(stat_manual)]
Expand All @@ -48,13 +55,26 @@ def masks_and_traces(settings, stat_manual, stat_orig):
ypixs=[stat["ypix"] for stat in manual_roi_stats],
xpixs=[stat["xpix"] for stat in manual_roi_stats],
cell_pix=cell_pix,
inner_neuropil_radius=settings["inner_neuropil_radius"],
min_neuropil_pixels=settings["min_neuropil_pixels"],
inner_neuropil_radius=settings["extraction"]["inner_neuropil_radius"],
min_neuropil_pixels=settings["extraction"]["min_neuropil_pixels"],
)
print("Masks made in %0.2f sec." % (time.time() - t0))

F, Fneu, F_chan2, Fneu_chan2 = extract_traces_from_masks(settings, manual_cell_masks,
manual_neuropil_masks)
# Extract traces from binary file
Ly, Lx = settings["Ly"], settings["Lx"]
batch_size = settings["extraction"]["batch_size"]
device = _assign_torch_device(settings["torch_device"])
f_reg = BinaryFile(Ly, Lx, settings["reg_file"])
F, Fneu = extract_traces(f_reg, manual_cell_masks, manual_neuropil_masks, batch_size=batch_size, device=device)
f_reg.close()

# Handle chan2 if present
if "reg_file_chan2" in settings and settings["reg_file_chan2"]:
f_reg_chan2 = BinaryFile(Ly, Lx, settings["reg_file_chan2"])
F_chan2, Fneu_chan2 = extract_traces(f_reg_chan2, manual_cell_masks, manual_neuropil_masks, batch_size=batch_size, device=device)
f_reg_chan2.close()
else:
F_chan2, Fneu_chan2 = None, None

# compute activity statistics for classifier
npix = np.array([stat_orig[n]["npix"] for n in range(len(stat_orig))
Expand All @@ -69,7 +89,7 @@ def masks_and_traces(settings, stat_manual, stat_orig):
manual_roi_stats[n]["iplane"] = stat_orig[0]["iplane"]

# subtract neuropil and compute skew, std from F
dF = F - settings["neucoeff"] * Fneu
dF = F - settings["extraction"]["neuropil_coefficient"] * Fneu
sk = stats.skew(dF, axis=1)
sd = np.std(dF, axis=1)

Expand All @@ -81,10 +101,10 @@ def masks_and_traces(settings, stat_manual, stat_orig):
np.mean(manual_roi_stats[n]["xpix"])
]

dF = preprocess(F=dF, baseline=settings["baseline"], win_baseline=settings["win_baseline"],
sig_baseline=settings["sig_baseline"], fs=settings["fs"],
prctile_baseline=settings["prctile_baseline"])
spks = oasis(F=dF, batch_size=settings["batch_size"], tau=settings["tau"], fs=settings["fs"])
dF = preprocess(F=dF, baseline=settings["dcnv_preprocess"]["baseline"], win_baseline=settings["dcnv_preprocess"]["win_baseline"],
sig_baseline=settings["dcnv_preprocess"]["sig_baseline"], fs=settings["fs"],
prctile_baseline=settings["dcnv_preprocess"]["prctile_baseline"], device=device)
spks = oasis(F=dF, batch_size=settings["extraction"]["batch_size"], tau=settings["tau"], fs=settings["fs"])

return F, Fneu, F_chan2, Fneu_chan2, spks, settings, manual_roi_stats

Expand Down Expand Up @@ -187,7 +207,7 @@ def __init__(self, parent):
self.saveGUI = False
self.closeGUI = QPushButton("Save and Quit")
self.closeGUI.setFont(QtGui.QFont("Arial", 8, QtGui.QFont.Bold))
self.closeGUI.clicked.connect(self.close_GUI)
self.closeGUI.clicked.connect(lambda: self.close_GUI())
self.closeGUI.setEnabled(False)
self.closeGUI.setFixedWidth(100)
self.closeGUI.setStyleSheet(self.styleUnpressed)
Expand Down Expand Up @@ -247,9 +267,7 @@ def close_GUI(self):

# Append new stat file with old and save
print("Saving new stat")
stat_all = self.new_stat.copy()
for n in range(len(self.parent.stat)):
stat_all.append(self.parent.stat[n])
stat_all = np.concatenate((self.new_stat, self.parent.stat))
np.save(os.path.join(self.parent.basename, "stat.npy"), stat_all)
iscell_prob = np.concatenate(
(self.parent.iscell[:, np.newaxis], self.parent.probcell[:, np.newaxis]),
Expand Down Expand Up @@ -297,30 +315,26 @@ def normalize_img_add_masks(self):
if i == 0:
mimg = np.zeros((self.Ly, self.Lx), np.float32)
mimg[self.parent.ops["yrange"][0]:self.parent.ops["yrange"][1],
self.parent.ops["xrange"][0]:self.parent.
settings["xrange"][1]] = self.parent.ops["meanImg"][
self.parent.ops["xrange"][0]:self.parent.ops["xrange"][1]] = self.parent.ops["meanImg"][
self.parent.ops["yrange"][0]:self.parent.ops["yrange"][1],
self.parent.ops["xrange"][0]:self.parent.ops["xrange"][1]]

elif i == 1:
mimg = np.zeros((self.Ly, self.Lx), np.float32)
mimg[self.parent.ops["yrange"][0]:self.parent.ops["yrange"][1],
self.parent.ops["xrange"][0]:self.parent.
settings["xrange"][1]] = self.parent.ops["meanImgE"][
self.parent.ops["xrange"][0]:self.parent.ops["xrange"][1]] = self.parent.ops["meanImgE"][
self.parent.ops["yrange"][0]:self.parent.ops["yrange"][1],
self.parent.ops["xrange"][0]:self.parent.ops["xrange"][1]]
elif i == 2:
mimg = np.zeros((self.Ly, self.Lx), np.float32)
mimg[self.parent.ops["yrange"][0]:self.parent.ops["yrange"][1],
self.parent.ops["xrange"][0]:self.parent.
settings["xrange"][1]] = self.parent.ops["Vcorr"]
self.parent.ops["xrange"][0]:self.parent.ops["xrange"][1]] = self.parent.ops["Vcorr"]

else:
mimg = np.zeros((self.Ly, self.Lx), np.float32)
if "max_proj" in self.parent.ops:
mimg[self.parent.ops["yrange"][0]:self.parent.ops["yrange"][1],
self.parent.ops["xrange"][0]:self.parent.
settings["xrange"][1]] = self.parent.ops["max_proj"]
self.parent.ops["xrange"][0]:self.parent.ops["xrange"][1]] = self.parent.ops["max_proj"]

mimg1 = np.percentile(mimg, 1)
mimg99 = np.percentile(mimg, 99)
Expand Down
2 changes: 2 additions & 0 deletions suite2p/gui/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,8 @@ def save_merge(parent):
(parent.iscell[:, np.newaxis], parent.probcell[:, np.newaxis]), axis=1)
np.save(os.path.join(parent.basename, "iscell.npy"), iscell)

parent.lcell0.setText("%d" % (parent.iscell.sum()))
parent.lcell1.setText("%d" % (parent.iscell.size - parent.iscell.sum()))
parent.notmerged = np.ones(parent.iscell.size, "bool")


Expand Down
10 changes: 8 additions & 2 deletions suite2p/gui/menus.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
from qtpy import QtGui
from qtpy.QtWidgets import QAction, QMenu
from pkg_resources import iter_entry_points
from importlib.metadata import entry_points

from . import reggui, drawroi, merge, io, rungui, visualize, classgui
from suite2p.io.nwb import save_nwb
Expand Down Expand Up @@ -166,7 +166,13 @@ def plugins(parent):
main_menu = parent.menuBar()
parent.plugins = {}
plugin_menu = main_menu.addMenu("&Plugins")
for entry_pt in iter_entry_points(group="suite2p.plugin", name=None):
try:
# Works for python 3.12+
suite2p_plugins = entry_points(group="suite2p.plugin")
except TypeError:
# works for Python 3.9-3.11
suite2p_plugins = entry_points().get("suite2p.plugin", [])
for entry_pt in suite2p_plugins:
plugin_obj = entry_pt.load() # load the advertised class from entry_points
parent.plugins[entry_pt.name] = plugin_obj(
parent
Expand Down
13 changes: 9 additions & 4 deletions suite2p/gui/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def do_merge(parent):
merge_activity_masks(parent)
parent.merged.append(parent.imerge)
parent.update_plot()
io.save_merge(parent)
print(parent.merged)
print("merged ROIs")

Expand Down Expand Up @@ -125,10 +126,11 @@ def merge_activity_masks(parent):
if parent.hasred:
F_chan2 = F_chan2.mean(axis=0)
Fneu_chan2 = Fneu_chan2.mean(axis=0)
dF = F - parent.ops["neucoeff"] * Fneu
dF = F - parent.ops["extraction"]["neuropil_coefficient"] * Fneu
# activity stats
stat0["skew"] = stats.skew(dF)
stat0["std"] = dF.std()
stat0["snr"] = 1 - 0.5 * np.diff(dF).var() / dF.var()

spks = oasis(F=dF[np.newaxis, :], batch_size=parent.ops["batch_size"],
tau=parent.ops["tau"], fs=parent.ops["fs"])
Expand All @@ -151,9 +153,9 @@ def merge_activity_masks(parent):
# add cell to structs
parent.stat = np.concatenate((parent.stat, np.array([stat0])), axis=0)
parent.stat = roi_stats(parent.stat, parent.Ly, parent.Lx,
aspect=parent.ops.get("aspect", None),
diameter=parent.ops.get("diameter", None),
do_crop=parent.ops.get("soma_crop", 1))
diameter=parent.ops["diameter"],
do_soma_crop=parent.ops["detection"]["soma_crop"],
max_overlap=None)
parent.stat[-1]["lam"] = parent.stat[-1]["lam"] * merged_cells.size
parent.Fcell = np.concatenate((parent.Fcell, F[np.newaxis, :]), axis=0)
parent.Fneu = np.concatenate((parent.Fneu, Fneu[np.newaxis, :]), axis=0)
Expand Down Expand Up @@ -182,8 +184,10 @@ def merge_activity_masks(parent):
# recompute binned F
parent.mode_change(parent.activityMode)

# Remove the maskes for the previous cells that were merged.
for n in merged_cells:
parent.stat[n]["inmerge"] = len(parent.stat) - 1
parent.iscell[n] = False
masks.remove_roi(parent, n, i0)
masks.add_roi(parent, len(parent.stat) - 1, i0)
masks.redraw_masks(parent, ypix, xpix)
Expand Down Expand Up @@ -262,6 +266,7 @@ def do_merge(self, parent):
merge_activity_masks(parent)
parent.merged.append(parent.imerge)
parent.update_plot()
io.save_merge(parent)

self.cc_row = np.matmul(parent.Fbin[parent.iscell],
parent.Fbin[-1].T) / parent.Fbin.shape[-1]
Expand Down
72 changes: 57 additions & 15 deletions suite2p/gui/rungui_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,25 +115,67 @@ def setup_logger(name):
return logger


class Suite2pWorker(QtCore.QThread):
class Suite2pWorker(QtCore.QObject):
"""Worker that runs suite2p in a separate process to avoid QThread stack limitations on macOS."""
finished = QtCore.Signal(str)

def __init__(self, parent, db_file, settings_file):
super(Suite2pWorker, self).__init__()
self.db_file = db_file
self.settings_file = settings_file
self.parent = parent
# self.logHandler = ThreadLogger()

def run(self):
db = np.load(self.db_file, allow_pickle=True).item()
settings = np.load(self.settings_file, allow_pickle=True).item()

try:
logger_setup(get_save_folder(db))
run_s2p(db=db, settings=settings)
self.process = None

def start(self):
"""Start suite2p in a separate process using QProcess."""
self.process = QtCore.QProcess()
self.process.setProcessChannelMode(QtCore.QProcess.MergedChannels)
self.process.readyReadStandardOutput.connect(self._on_output)
self.process.finished.connect(self._on_finished)

# Create a Python script to run suite2p
script = f'''
import numpy as np
from suite2p.run_s2p import logger_setup, run_s2p, get_save_folder

db = np.load("{self.db_file}", allow_pickle=True).item()
settings = np.load("{self.settings_file}", allow_pickle=True).item()

logger_setup(get_save_folder(db))
run_s2p(db=db, settings=settings)
'''
self.process.start(sys.executable, ["-c", script])

def _on_output(self):
"""Handle output from the subprocess."""
if self.process:
data = self.process.readAllStandardOutput()
text = bytes(data).decode("utf-8", errors="replace")
print(text, end="")

def _on_finished(self, exit_code, exit_status):
"""Handle process completion."""
if exit_code == 0:
self.finished.emit("finished")
except Exception as e:
print("ERROR:", e)
traceback.print_exc()
self.finished.emit("error")
else:
self.finished.emit("error")

def terminate(self):
"""Terminate the subprocess if running."""
if self.process and self.process.state() != QtCore.QProcess.NotRunning:
self.process.terminate()

def quit(self):
"""Stop the subprocess (alias for terminate, for QThread compatibility)."""
self.terminate()

def wait(self):
"""Wait for the process to finish (for compatibility)."""
if self.process:
self.process.waitForFinished(-1)

def isRunning(self):
"""Check if the process is still running (for QThread compatibility)."""
if self.process:
return self.process.state() == QtCore.QProcess.Running
return False
2 changes: 2 additions & 0 deletions suite2p/io/nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def _load_npy_cross_platform(path):
NWB = True
except ModuleNotFoundError:
NWB = False
logger.warning("pynwb not installed, save_nwb, read_nwb, and nwb_to_binary "
"will not work. Install with: pip install pynwb")


def nwb_to_binary(settings):
Expand Down
Loading
Loading