From aea1f65fd0874e5ea9ef27480f790b434f0eb2c3 Mon Sep 17 00:00:00 2001 From: Stefan Date: Sun, 12 Apr 2026 21:31:40 +0200 Subject: [PATCH 1/2] feat: add mixscale continuous perturbation scoring Implements the Mixscale scoring method (Jiang et al., Nat Cell Biol 2025) as a new method on the Mixscape class. Unlike the binary KO/NP classification in mixscape(), mixscale() computes a continuous perturbation efficiency score per cell likescalar projection onto the estimated perturbation direction vector. Reuses existing _get_perturbation_markers() pipeline for DE gene detection and follows the same code patterns as mixscape() for split handling, layer access & scaling. Closes #921 (partial - continuous scoring component) --- pertpy/tools/_mixscape.py | 199 +++++++++++++++++++++++++++++++++++ tests/tools/test_mixscale.py | 179 +++++++++++++++++++++++++++++++ 2 files changed, 378 insertions(+) create mode 100644 tests/tools/test_mixscale.py diff --git a/pertpy/tools/_mixscape.py b/pertpy/tools/_mixscape.py index 5eb6acb7..00fd6a05 100644 --- a/pertpy/tools/_mixscape.py +++ b/pertpy/tools/_mixscape.py @@ -390,6 +390,205 @@ def mixscape( if copy: return adata + def mixscale( + self, + adata: AnnData, + pert_key: str, + control: str, + *, + new_class_name: str = "mixscale_score", + layer: str | None = None, + min_de_genes: int = 5, + max_de_genes: int = 100, + logfc_threshold: float = 0.25, + de_layer: str | None = None, + test_method: str = "wilcoxon", + scale: bool = True, + split_by: str | None = None, + pval_cutoff: float = 5e-2, + perturbation_type: str = "KO", + copy: bool = False, + ): + """Calculate continuous perturbation scores using the Mixscale method. + + Unlike :meth:`mixscape` which performs binary KO/NP classification via + Gaussian Mixture Models, this method assigns a continuous perturbation + efficiency score to each cell. The score is the scalar projection of + each cell's perturbation signature onto the estimated perturbation + direction vector, standardized relative to non-targeting controls. + + This is particularly useful for CRISPRi/CRISPRa screens where cells + exhibit a gradient of perturbation responses rather than binary + knockouts. + + The implementation follows Jiang, Dalgarno et al., "Systematic + reconstruction of molecular pathway signatures using scalable + single-cell perturbation screens", Nature Cell Biology (2025). + + Args: + adata: The annotated data object. + pert_key: The column of `.obs` with target gene labels. + control: Control category from the `pert_key` column. + new_class_name: Name of the score column to be stored in `.obs`. + layer: Key from `adata.layers` whose value will be used for scoring. + Default is using `.layers["X_pert"]`. + min_de_genes: Required number of DE genes for scoring a perturbation. + Perturbations with fewer DE genes are skipped. + max_de_genes: Maximum number of DE genes to use for scoring. + logfc_threshold: Minimum log fold-change threshold for DE gene selection. + de_layer: Layer to use for identifying differentially expressed genes. + If `None`, `adata.X` is used. + test_method: Method to use for differential expression testing. + scale: Whether to scale the perturbation data before computing scores. + split_by: Provide `.obs` column with experimental condition/cell type + annotation, if perturbations are condition/cell type-specific. + pval_cutoff: P-value cut-off for selection of significantly DE genes. + perturbation_type: Type of CRISPR perturbation for labeling. + copy: Determines whether a copy of the `adata` is returned. + + Returns: + If `copy=True`, returns the copy of `adata` with the scores in `.obs`. + Otherwise, writes the scores directly to `.obs` of the provided `adata`. + + The following fields are added to `adata.obs`: + + - `adata.obs[new_class_name]`: Continuous perturbation score per cell. + Higher absolute values indicate stronger perturbation effect. + Non-targeting control cells receive a score of 0. + Scores are z-score standardized relative to the control distribution. + + Examples: + Compute continuous perturbation scores: + + >>> import pertpy as pt + >>> mdata = pt.dt.papalexi_2021() + >>> ms_pt = pt.tl.Mixscape() + >>> ms_pt.perturbation_signature(mdata["rna"], "perturbation", "NT", split_by="replicate") + >>> ms_pt.mixscale(mdata["rna"], "gene_target", "NT", layer="X_pert") + """ + if copy: + adata = adata.copy() + + if split_by is None: + split_masks = [np.full(adata.n_obs, True, dtype=bool)] + categories = ["all"] + else: + split_obs = adata.obs[split_by] + categories = split_obs.unique() + split_masks = [split_obs == category for category in categories] + + # Reuse the existing DE gene detection pipeline + perturbation_markers = self._get_perturbation_markers( + adata=adata, + split_masks=split_masks, + categories=categories, + pert_key=pert_key, + control=control, + layer=de_layer, + pval_cutoff=pval_cutoff, + min_de_genes=min_de_genes, + logfc_threshold=logfc_threshold, + test_method=test_method, + ) + + # Get perturbation signature matrix + if layer is not None: + X = adata.layers[layer] + else: + try: + X = adata.layers["X_pert"] + except KeyError: + raise KeyError( + "No 'X_pert' found in .layers! " + "Please run perturbation_signature first." + ) from None + + # Initialize scores to 0 (NT control default) + adata.obs[new_class_name] = 0.0 + + for split, split_mask in enumerate(split_masks): + category = categories[split] + gene_targets = list( + set(adata[split_mask].obs[pert_key]).difference([control]) + ) + nt_cells = (adata.obs[pert_key] == control) & split_mask + + for gene in gene_targets: + guide_cells = (adata.obs[pert_key] == gene) & split_mask + all_cells = guide_cells | nt_cells + + if len(perturbation_markers[(category, gene)]) == 0: + continue + + de_genes = perturbation_markers[(category, gene)] + # Limit to max_de_genes + if len(de_genes) > max_de_genes: + de_genes = de_genes[:max_de_genes] + + de_genes_indices = np.where( + np.isin(adata.var_names, list(de_genes)) + )[0] + + if len(de_genes_indices) == 0: + continue + + # Subset to DE genes for all relevant cells + dat = X[np.asarray(all_cells)][:, de_genes_indices] + if scale: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="zero-centering a sparse array/matrix densifies it.", + ) + dat = sc.pp.scale(dat) + + # Compute indices within the subsetted data + nt_cells_dat_idx = ( + all_cells[all_cells] + .index.get_indexer(nt_cells[nt_cells].index) + ) + guide_cells_dat_idx = ( + all_cells[all_cells] + .index.get_indexer(guide_cells[guide_cells].index) + ) + + # Compute perturbation direction vector + # (mean of perturbed cells minus mean of control cells) + guide_cells_mean = np.mean(dat[guide_cells_dat_idx], axis=0) + nt_cells_mean = np.mean(dat[nt_cells_dat_idx], axis=0) + vec = guide_cells_mean - nt_cells_mean + + # Scalar projection onto the perturbation direction + vec_norm_sq = np.dot(vec, vec) + if vec_norm_sq == 0: + continue + + if isinstance(dat, spmatrix): + pvec = dat.dot(vec) / vec_norm_sq + else: + pvec = np.dot(dat, vec) / vec_norm_sq + pvec = np.asarray(pvec).flatten() + + # Extract scores for guide and NT cells + guide_scores = pvec[guide_cells_dat_idx] + nt_scores = pvec[nt_cells_dat_idx] + + # Z-score standardization relative to NT controls + nt_mean = np.mean(nt_scores) + nt_std = np.std(nt_scores) + if nt_std == 0: + nt_std = 1.0 + + standardized_scores = (guide_scores - nt_mean) / nt_std + + # Store scores for perturbed cells + guide_cell_indices = guide_cells[guide_cells].index + adata.obs.loc[guide_cell_indices, new_class_name] = ( + standardized_scores + ) + + if copy: + return adata def lda( self, adata: AnnData, diff --git a/tests/tools/test_mixscale.py b/tests/tools/test_mixscale.py new file mode 100644 index 00000000..3a8070f2 --- /dev/null +++ b/tests/tools/test_mixscale.py @@ -0,0 +1,179 @@ +"""Tests for Mixscape.mixscale continuous perturbation scoring.""" + +import numpy as np +import pytest +import scanpy as sc +from anndata import AnnData +from scipy.sparse import csr_matrix + +import pertpy as pt + + +@pytest.fixture +def synthetic_perturbation_adata(): + """Create synthetic perturbation data with known strong/weak effects.""" + np.random.seed(42) + n_genes = 200 + + # 100 NT controls, 100 strong KO, 100 weak KO for GeneA + # 50 cells for GeneB (moderate effect) + n_cells = 350 + + X = np.random.randn(n_cells, n_genes).astype(np.float32) + + # GeneA strong KO: large effect on first 20 genes + X[100:200, :20] -= 3.0 + # GeneA weak KO: small effect on first 20 genes + X[200:300, :20] -= 1.0 + # GeneB moderate: moderate effect on genes 20-40 + X[300:350, 20:40] -= 2.0 + + adata = AnnData(X=X) + adata.var_names = [f"Gene_{i}" for i in range(n_genes)] + adata.obs_names = [f"Cell_{i}" for i in range(n_cells)] + + labels = ( + ["NT"] * 100 + + ["GeneA"] * 100 + + ["GeneA"] * 100 + + ["GeneB"] * 50 + ) + adata.obs["gene_target"] = labels + adata.obs["perturbation"] = [ + "NT" if x == "NT" else "targeting" for x in labels + ] + + sc.pp.pca(adata) + + return adata + + +class TestMixscale: + """Tests for the mixscale method.""" + + def test_basic_scoring(self, synthetic_perturbation_adata): + """Test that mixscale runs and produces scores.""" + adata = synthetic_perturbation_adata + ms = pt.tl.Mixscape() + ms.perturbation_signature(adata, "gene_target", "NT") + ms.mixscale(adata, "gene_target", "NT", layer="X_pert") + + assert "mixscale_score" in adata.obs.columns + assert adata.obs["mixscale_score"].dtype == float + + def test_control_cells_score_zero(self, synthetic_perturbation_adata): + """Control cells should have score 0.""" + adata = synthetic_perturbation_adata + ms = pt.tl.Mixscape() + ms.perturbation_signature(adata, "gene_target", "NT") + ms.mixscale(adata, "gene_target", "NT", layer="X_pert") + + nt_scores = adata.obs.loc[ + adata.obs["gene_target"] == "NT", "mixscale_score" + ] + assert (nt_scores == 0).all() + + def test_perturbed_cells_nonzero(self, synthetic_perturbation_adata): + """Perturbed cells should have non-zero scores.""" + adata = synthetic_perturbation_adata + ms = pt.tl.Mixscape() + ms.perturbation_signature(adata, "gene_target", "NT") + ms.mixscale(adata, "gene_target", "NT", layer="X_pert") + + ko_scores = adata.obs.loc[ + adata.obs["gene_target"] == "GeneA", "mixscale_score" + ] + assert ko_scores.abs().mean() > 0 + + def test_strong_vs_weak_perturbation(self, synthetic_perturbation_adata): + """Strongly perturbed cells should have higher absolute scores.""" + adata = synthetic_perturbation_adata + ms = pt.tl.Mixscape() + ms.perturbation_signature(adata, "gene_target", "NT") + ms.mixscale(adata, "gene_target", "NT", layer="X_pert") + + scores = adata.obs["mixscale_score"].values + # Cells 100-199 are strong KO, 200-299 are weak KO + strong_mean = np.abs(scores[100:200]).mean() + weak_mean = np.abs(scores[200:300]).mean() + + assert strong_mean > weak_mean, ( + f"Strong KO mean ({strong_mean:.2f}) should exceed " + f"weak KO mean ({weak_mean:.2f})" + ) + + def test_custom_column_name(self, synthetic_perturbation_adata): + """Test custom output column name.""" + adata = synthetic_perturbation_adata + ms = pt.tl.Mixscape() + ms.perturbation_signature(adata, "gene_target", "NT") + ms.mixscale( + adata, + "gene_target", + "NT", + layer="X_pert", + new_class_name="my_score", + ) + + assert "my_score" in adata.obs.columns + assert "mixscale_score" not in adata.obs.columns + + def test_copy_mode(self, synthetic_perturbation_adata): + """Test that copy=True returns a new object.""" + adata = synthetic_perturbation_adata + ms = pt.tl.Mixscape() + ms.perturbation_signature(adata, "gene_target", "NT") + result = ms.mixscale( + adata, + "gene_target", + "NT", + layer="X_pert", + copy=True, + ) + + assert result is not None + assert result is not adata + assert "mixscale_score" in result.obs.columns + + def test_no_perturbation_signature_raises( + self, synthetic_perturbation_adata + ): + """Should raise KeyError if perturbation_signature hasn't been run.""" + adata = synthetic_perturbation_adata + ms = pt.tl.Mixscape() + + with pytest.raises(KeyError, match="X_pert"): + ms.mixscale(adata, "gene_target", "NT") + + def test_multiple_perturbations(self, synthetic_perturbation_adata): + """Test scoring with multiple perturbation groups.""" + adata = synthetic_perturbation_adata + ms = pt.tl.Mixscape() + ms.perturbation_signature(adata, "gene_target", "NT") + ms.mixscale(adata, "gene_target", "NT", layer="X_pert") + + # Both GeneA and GeneB should have scores + gene_a_scores = adata.obs.loc[ + adata.obs["gene_target"] == "GeneA", "mixscale_score" + ] + gene_b_scores = adata.obs.loc[ + adata.obs["gene_target"] == "GeneB", "mixscale_score" + ] + + assert gene_a_scores.abs().mean() > 0 + assert gene_b_scores.abs().mean() > 0 + + def test_sparse_input(self, synthetic_perturbation_adata): + """Test that mixscale works with sparse matrices.""" + adata = synthetic_perturbation_adata + adata.X = csr_matrix(adata.X) + ms = pt.tl.Mixscape() + ms.perturbation_signature(adata, "gene_target", "NT") + ms.mixscale(adata, "gene_target", "NT", layer="X_pert") + + assert "mixscale_score" in adata.obs.columns + assert not np.isnan( + adata.obs.loc[ + adata.obs["gene_target"] != "NT", "mixscale_score" + ] + ).any() From c0d36c87139b94272507f166d64c25b7ba156a0b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 12 Apr 2026 19:37:34 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pertpy/tools/_mixscape.py | 33 ++++++++--------------------- tests/tools/test_mixscale.py | 40 ++++++++---------------------------- 2 files changed, 17 insertions(+), 56 deletions(-) diff --git a/pertpy/tools/_mixscape.py b/pertpy/tools/_mixscape.py index 00fd6a05..674ce280 100644 --- a/pertpy/tools/_mixscape.py +++ b/pertpy/tools/_mixscape.py @@ -498,19 +498,14 @@ def mixscale( try: X = adata.layers["X_pert"] except KeyError: - raise KeyError( - "No 'X_pert' found in .layers! " - "Please run perturbation_signature first." - ) from None + raise KeyError("No 'X_pert' found in .layers! Please run perturbation_signature first.") from None # Initialize scores to 0 (NT control default) adata.obs[new_class_name] = 0.0 for split, split_mask in enumerate(split_masks): category = categories[split] - gene_targets = list( - set(adata[split_mask].obs[pert_key]).difference([control]) - ) + gene_targets = list(set(adata[split_mask].obs[pert_key]).difference([control])) nt_cells = (adata.obs[pert_key] == control) & split_mask for gene in gene_targets: @@ -525,9 +520,7 @@ def mixscale( if len(de_genes) > max_de_genes: de_genes = de_genes[:max_de_genes] - de_genes_indices = np.where( - np.isin(adata.var_names, list(de_genes)) - )[0] + de_genes_indices = np.where(np.isin(adata.var_names, list(de_genes)))[0] if len(de_genes_indices) == 0: continue @@ -543,14 +536,8 @@ def mixscale( dat = sc.pp.scale(dat) # Compute indices within the subsetted data - nt_cells_dat_idx = ( - all_cells[all_cells] - .index.get_indexer(nt_cells[nt_cells].index) - ) - guide_cells_dat_idx = ( - all_cells[all_cells] - .index.get_indexer(guide_cells[guide_cells].index) - ) + nt_cells_dat_idx = all_cells[all_cells].index.get_indexer(nt_cells[nt_cells].index) + guide_cells_dat_idx = all_cells[all_cells].index.get_indexer(guide_cells[guide_cells].index) # Compute perturbation direction vector # (mean of perturbed cells minus mean of control cells) @@ -563,10 +550,7 @@ def mixscale( if vec_norm_sq == 0: continue - if isinstance(dat, spmatrix): - pvec = dat.dot(vec) / vec_norm_sq - else: - pvec = np.dot(dat, vec) / vec_norm_sq + pvec = dat.dot(vec) / vec_norm_sq if isinstance(dat, spmatrix) else np.dot(dat, vec) / vec_norm_sq pvec = np.asarray(pvec).flatten() # Extract scores for guide and NT cells @@ -583,12 +567,11 @@ def mixscale( # Store scores for perturbed cells guide_cell_indices = guide_cells[guide_cells].index - adata.obs.loc[guide_cell_indices, new_class_name] = ( - standardized_scores - ) + adata.obs.loc[guide_cell_indices, new_class_name] = standardized_scores if copy: return adata + def lda( self, adata: AnnData, diff --git a/tests/tools/test_mixscale.py b/tests/tools/test_mixscale.py index 3a8070f2..3f7a51cc 100644 --- a/tests/tools/test_mixscale.py +++ b/tests/tools/test_mixscale.py @@ -32,16 +32,9 @@ def synthetic_perturbation_adata(): adata.var_names = [f"Gene_{i}" for i in range(n_genes)] adata.obs_names = [f"Cell_{i}" for i in range(n_cells)] - labels = ( - ["NT"] * 100 - + ["GeneA"] * 100 - + ["GeneA"] * 100 - + ["GeneB"] * 50 - ) + labels = ["NT"] * 100 + ["GeneA"] * 100 + ["GeneA"] * 100 + ["GeneB"] * 50 adata.obs["gene_target"] = labels - adata.obs["perturbation"] = [ - "NT" if x == "NT" else "targeting" for x in labels - ] + adata.obs["perturbation"] = ["NT" if x == "NT" else "targeting" for x in labels] sc.pp.pca(adata) @@ -68,9 +61,7 @@ def test_control_cells_score_zero(self, synthetic_perturbation_adata): ms.perturbation_signature(adata, "gene_target", "NT") ms.mixscale(adata, "gene_target", "NT", layer="X_pert") - nt_scores = adata.obs.loc[ - adata.obs["gene_target"] == "NT", "mixscale_score" - ] + nt_scores = adata.obs.loc[adata.obs["gene_target"] == "NT", "mixscale_score"] assert (nt_scores == 0).all() def test_perturbed_cells_nonzero(self, synthetic_perturbation_adata): @@ -80,9 +71,7 @@ def test_perturbed_cells_nonzero(self, synthetic_perturbation_adata): ms.perturbation_signature(adata, "gene_target", "NT") ms.mixscale(adata, "gene_target", "NT", layer="X_pert") - ko_scores = adata.obs.loc[ - adata.obs["gene_target"] == "GeneA", "mixscale_score" - ] + ko_scores = adata.obs.loc[adata.obs["gene_target"] == "GeneA", "mixscale_score"] assert ko_scores.abs().mean() > 0 def test_strong_vs_weak_perturbation(self, synthetic_perturbation_adata): @@ -98,8 +87,7 @@ def test_strong_vs_weak_perturbation(self, synthetic_perturbation_adata): weak_mean = np.abs(scores[200:300]).mean() assert strong_mean > weak_mean, ( - f"Strong KO mean ({strong_mean:.2f}) should exceed " - f"weak KO mean ({weak_mean:.2f})" + f"Strong KO mean ({strong_mean:.2f}) should exceed weak KO mean ({weak_mean:.2f})" ) def test_custom_column_name(self, synthetic_perturbation_adata): @@ -135,9 +123,7 @@ def test_copy_mode(self, synthetic_perturbation_adata): assert result is not adata assert "mixscale_score" in result.obs.columns - def test_no_perturbation_signature_raises( - self, synthetic_perturbation_adata - ): + def test_no_perturbation_signature_raises(self, synthetic_perturbation_adata): """Should raise KeyError if perturbation_signature hasn't been run.""" adata = synthetic_perturbation_adata ms = pt.tl.Mixscape() @@ -153,12 +139,8 @@ def test_multiple_perturbations(self, synthetic_perturbation_adata): ms.mixscale(adata, "gene_target", "NT", layer="X_pert") # Both GeneA and GeneB should have scores - gene_a_scores = adata.obs.loc[ - adata.obs["gene_target"] == "GeneA", "mixscale_score" - ] - gene_b_scores = adata.obs.loc[ - adata.obs["gene_target"] == "GeneB", "mixscale_score" - ] + gene_a_scores = adata.obs.loc[adata.obs["gene_target"] == "GeneA", "mixscale_score"] + gene_b_scores = adata.obs.loc[adata.obs["gene_target"] == "GeneB", "mixscale_score"] assert gene_a_scores.abs().mean() > 0 assert gene_b_scores.abs().mean() > 0 @@ -172,8 +154,4 @@ def test_sparse_input(self, synthetic_perturbation_adata): ms.mixscale(adata, "gene_target", "NT", layer="X_pert") assert "mixscale_score" in adata.obs.columns - assert not np.isnan( - adata.obs.loc[ - adata.obs["gene_target"] != "NT", "mixscale_score" - ] - ).any() + assert not np.isnan(adata.obs.loc[adata.obs["gene_target"] != "NT", "mixscale_score"]).any()