diff --git a/pyproject.toml b/pyproject.toml index 1d8f8ede69..912fba73e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -168,6 +168,26 @@ widgets = [ "distinctipy", ] +lupin = [ + "scipy", + "numba", + "scikit-learn", + "torch" +] + +spykingcircus2 = [ + "scipy", + "hdbscan", + "numba", +] + +tridesclous2= [ + "scipy", + "numba", + "scikit-learn", + "torch", +] + # `full` installs every module's optional feature deps. Defined as the union of # per-module extras so adding a dep to a module propagates here automatically. full = [ @@ -272,8 +292,11 @@ test-comparison = [ test-sorters-internal = [ {include-group = "test-common"}, - "torch", # spyking_circus2 template matching - "hdbscan>=0.8.33", # simplesorter / tridesclous2 + "scipy", + "torch", + "hdbscan>=0.8.33", + "numba", + "scikit-learn", ] test-sorters = [{include-group = "test-sorters-internal"}] test-curation = [{include-group = "test-common"}] diff --git a/src/spikeinterface/sorters/internal/lupin.py b/src/spikeinterface/sorters/internal/lupin.py index 54c8ad5fcf..d21b4d3d50 100644 --- a/src/spikeinterface/sorters/internal/lupin.py +++ b/src/spikeinterface/sorters/internal/lupin.py @@ -102,8 +102,22 @@ class LupinSorter(ComponentsBasedSorter): "debug": "Save debug files", } + installation_mesg = "\tpip install 'spikeinterface[lupin]'\nOr, if you have cloned SpikeInterface locally, using:\n\tpip install '.[lupin]'" + handle_multi_segment = True + @classmethod + def is_installed(cls): + import importlib.util + + lupin_deps = ["scipy", "numba", "sklearn", "torch"] + + for package_name in lupin_deps: + if not importlib.util.find_spec(package_name): + return False + + return True + @classmethod def get_sorter_version(cls): return "2026.01" diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index b1cfcb0843..85e0da7220 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -86,6 +86,20 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): In addition, it also uses a full Orthogonal Matching Pursuit engine to reconstruct the traces, leading to more spikes being discovered. The code is much faster and memory efficient, inheriting from all the preprocessing possibilities of spikeinterface""" + installation_mesg = "\tpip install 'spikeinterface[spykingcircus2]'\nOr, if you have cloned SpikeInterface locally, using:\n\tpip install '.[spykingcircus2]'" + + @classmethod + def is_installed(cls): + import importlib.util + + spykingcircus2_deps = ["scipy", "numba", "hdbscan"] + + for package_name in spykingcircus2_deps: + if not importlib.util.find_spec(package_name): + return False + + return True + @classmethod def get_sorter_version(cls): return "2025.12" diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index c042d8ee56..23fa98b2cb 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -97,6 +97,20 @@ class Tridesclous2Sorter(ComponentsBasedSorter): handle_multi_segment = True + installation_mesg = "\tpip install 'spikeinterface[tridesclous2]'\nOr, if you have cloned SpikeInterface locally, using:\n\tpip install '.[tridesclous2]'" + + @classmethod + def is_installed(cls): + import importlib.util + + tridesclous2_deps = ["scipy", "numba", "hdbscan"] + + for package_name in tridesclous2_deps: + if not importlib.util.find_spec(package_name): + return False + + return True + @classmethod def get_sorter_version(cls): return "2026.01"