diff --git a/docs/notebooks/04_differential_expression.ipynb b/docs/notebooks/04_differential_expression.ipynb index b1541c77..178e9e1f 100644 --- a/docs/notebooks/04_differential_expression.ipynb +++ b/docs/notebooks/04_differential_expression.ipynb @@ -63,7 +63,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "/var/folders/2l/hhd_z4hx3070zw8rlj4c3l940000gn/T/tmp5j3sii7k/report.parquet does not yet exist\n" + "/var/folders/2l/hhd_z4hx3070zw8rlj4c3l940000gn/T/tmp2016cn1_/report.parquet does not yet exist\n" ] }, { @@ -77,8 +77,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "/var/folders/2l/hhd_z4hx3070zw8rlj4c3l940000gn/T/tmp5j3sii7k/report.parquet successfully downloaded (70.8467435836792 MB)\n", - "/var/folders/2l/hhd_z4hx3070zw8rlj4c3l940000gn/T/tmpr1853dya/iteration_L_literature_reprocessing_PELSA_samplemap.csv does not yet exist\n" + "/var/folders/2l/hhd_z4hx3070zw8rlj4c3l940000gn/T/tmp2016cn1_/report.parquet successfully downloaded (70.8467435836792 MB)\n", + "/var/folders/2l/hhd_z4hx3070zw8rlj4c3l940000gn/T/tmp2nq17bfx/iteration_L_literature_reprocessing_PELSA_samplemap.csv does not yet exist\n" ] }, { @@ -92,7 +92,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "/var/folders/2l/hhd_z4hx3070zw8rlj4c3l940000gn/T/tmpr1853dya/iteration_L_literature_reprocessing_PELSA_samplemap.csv successfully downloaded (0.000179290771484375 MB)\n" + "/var/folders/2l/hhd_z4hx3070zw8rlj4c3l940000gn/T/tmp2nq17bfx/iteration_L_literature_reprocessing_PELSA_samplemap.csv successfully downloaded (0.000179290771484375 MB)\n" ] } ], diff --git a/src/alphapepttools/pp/impute.py b/src/alphapepttools/pp/impute.py index 39739524..62714111 100644 --- a/src/alphapepttools/pp/impute.py +++ b/src/alphapepttools/pp/impute.py @@ -8,16 +8,137 @@ import pandas as pd from sklearn.impute import KNNImputer -# logging configuration -logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def _check_for_complete_data( + data: np.ndarray, +) -> bool: + """Check if data contains any missing values + + Parameters + ---------- + data + Samples x Features array + + Returns + ------- + bool + True if data contains no missing values, False otherwise + """ + return not np.any(np.isnan(data)) + + +def _raise_on_all_nan_values(data: np.ndarray) -> None: + """Check if a feature contains all nan + + Parameters + ---------- + data + Samples x Features array + + Raises + ------ + ValueError + If any feature contains only NaNs + """ + all_nan_features = np.isnan(data).all(axis=0) + if any(all_nan_features): + raise ValueError( + f"Features with index {list(np.where(all_nan_features)[0])} contain all nan values. Drop these features beforehand." + ) + + +def _warn_too_many_missing( + data: np.ndarray, + threshold: float = 0.5, +) -> None: + """Warn users if they are about to impute data with too many missing values. + + While publications differ on the exact threshold of permissible % missing data for imputation, + we decide to inform users if the majority of values for a given feature or sample are missing, + which could lead to a situation where imputed values dominate the actual data. + + Parameters + ---------- + data : np.ndarray + Data array containing the data to be checked. + threshold : float + Proportion of missing values above which a warning is raised. + + """ + # Check features (columns) + missing_proportions_features = np.isnan(data).mean(axis=0) + too_many_missing_features = np.where(missing_proportions_features > threshold)[0] + + if len(too_many_missing_features) > 0: + logging.warning( + f" impute: Warning - {len(too_many_missing_features)} features have more than {threshold * 100}% missing values. " + "Proceed with caution when interpreting imputed results." + ) + + +def _impute_gaussian( + data: np.ndarray, + std_offset: float = 1.8, + std_factor: float = 0.3, + random_state: int = 42, +) -> np.ndarray: + """Impute missing values in each column by random sampling from a gaussian distribution. + + The distribution is centered at std_offset * feature standard deviation below the + feature mean and has a standard deviation of std_factor * feature standard deviation. + + The default values are set to mirror Perseus-style imputation: multiply the feature's + standard deviation by 0.3 and shift the mean down by 1.8 standard deviations, then sample + from the resulting distribution. + + Parameters + ---------- + data + Samples x Features array + std_offset + Number of standard deviations below the mean to center the + gaussian distribution. + std_factor + Factor to multiply the feature's standard deviation with to + get the standard deviation of the gaussian distribution. + random_state + Random seed for reproducibility + + Returns + ------- + np.ndarray + Imputed data array + """ + if _check_for_complete_data(data): + logger.info("Data contains no missing values. Skipping imputation.") + return data + + rng = np.random.default_rng(random_state) + + # generate corresponding downshifted features + stds = np.nanstd(data, axis=0) + means = np.nanmean(data, axis=0) + shifted_means = means - std_offset * stds + shifted_stds = stds * std_factor + + # iterate over nan-containing columns and impute from corresponding gaussian + na_col_idxs = np.where(np.isnan(data).sum(axis=0) > 0)[0] + for i in na_col_idxs: + na_row_idxs = np.where(np.isnan(data[:, i]))[0] + data[na_row_idxs, i] = rng.normal(shifted_means[i], shifted_stds[i], len(na_row_idxs)) + + return data def impute_gaussian( adata: ad.AnnData, - std_offset: float = 3, + group_column: str | None = None, + layer: str | None = None, + std_offset: float = 1.8, std_factor: float = 0.3, random_state: int = 42, - layer: str | None = None, *, copy: bool = False, ) -> ad.AnnData: @@ -25,20 +146,29 @@ def impute_gaussian( The distribution is centered at std_offset * feature standard deviation below the feature mean and has a standard deviation of std_factor * feature standard deviation. - The function returns a copy of the AnnData object with imputed values in place of NaNs. + Can perform global imputation using all samples or group-wise imputation + using subsets of samples defined by a categorical variable. Parameters ---------- - adata : anndata.AnnData + adata AnnData object containing the data to be imputed. - std_offset : float + group_column + Column name in `adata.obs` defining groups for group-wise imputation. + If `None` (default), computes statistics across all samples. + If specified, computes statistics separately for each group and imputes + missing values using the group-specific gaussian distribution. + If `group_column` contains NaNs, the respective observations are ignored. + layer + Name of the layer to impute. If None (default), the data matrix X is used. + std_offset Number of standard deviations below the mean to center the gaussian distribution. - std_factor : float + std_factor Factor to multiply the feature's standard deviation with to get the standard deviation of the gaussian distribution. - layer - Name of the layer to impute. If None (default), the data matrix X is used. + random_state + Random seed for reproducibility copy Whether to return a modified copy (True) of the anndata object. If False (default) modifies the object inplace @@ -50,72 +180,69 @@ def impute_gaussian( If `copy=False` modifies the anndata object at layer inplace and returns None. If `copy=True`, returns a modified copy. - """ - # always copy for now, implement inplace later if needed - adata = adata.copy() if copy else adata + Raises + ------ + ValueError + If `group_column` contains NaNs + ValueError + If a feature contains only NaNs - X = adata.X if layer is None else adata.layers[layer] - input_X_shape = X.shape + Notes + ----- + Features that are fully missing will not be imputed. Appropriate filtering of features with + :func:`at.pp.filter_data_completeness` is critical. - # All columns must be either int or float - if not np.issubdtype(X.dtype, np.number): - raise ValueError("adata.X must be numeric.") + Example + ------- + Impute the values in the `.X` matrix - nan_count = np.isnan(X).sum() + .. code-block:: python - # Get the indices of those columns that have missing values: we are going to need downshifted Gaussian's for those - rng = np.random.default_rng(random_state) - na_col_idxs = np.where(np.isnan(X).sum(axis=0) > 0)[0] + adata = at.pp.impute_gaussian(adata) + assert np.sum(np.isnan(adata.X)) == 0 - if len(na_col_idxs) == 0: - logging.info(" impute_gaussian: No NaN values found, no imputation performed.") - return adata + Impute data in a specific layer - # generate corresponding downshifted features - stds = np.nanstd(X, axis=0) - means = np.nanmean(X, axis=0) - shifted_means = means - std_offset * stds - shifted_stds = stds * std_factor + .. code-block:: python - # iterate over nan-containing columns and impute from corresponding gaussian - for i in na_col_idxs: - na_row_idxs = np.where(np.isnan(X[:, i]))[0] - X[na_row_idxs, i] = rng.normal(shifted_means[i], shifted_stds[i], len(na_row_idxs)) + adata = at.pp.impute_gaussian(adata, layer="layer2") + assert np.sum(np.isnan(adata.layers["layer2"])) == 0 + + Impute groupwise based on a categorical column: - if not X.shape == input_X_shape: - raise ValueError(" impute_gaussian: Imputed data shape does not match original data shape.") + .. code-block:: python - if np.isnan(X).any(): - raise ValueError(" impute_gaussian: Imputation failed, data retained NaN values.") + adata = at.pp.impute_gaussian(adata, group_column="cell_type") + # Imputes group-wise gaussian distributions + """ + adata = adata.copy() if copy else adata - logging.info(f" impute_gaussian: Imputation complete. Imputed {nan_count} NaN values with Gaussian distribution.") + data = adata.X if layer is None else adata.layers[layer] - if layer is None: - adata.X = X + if group_column is None: + _raise_on_all_nan_values(data) + data = _impute_gaussian(data, std_offset=std_offset, std_factor=std_factor, random_state=random_state) else: - adata.layers[layer] = X - - return adata if copy else None + if pd.isna(adata.obs[group_column]).any(): + raise ValueError( + f"`group_column` {group_column} contains nans. Cannot impute groups with missing values.", + ) + groups = adata.obs.groupby(group_column, dropna=True).indices -def _check_all_nan(data: np.ndarray) -> None: - """Check if a feature contains all nan + for group_indices in groups.values(): + group = data[group_indices] + _raise_on_all_nan_values(group) + data[group_indices, :] = _impute_gaussian( + group, std_offset=std_offset, std_factor=std_factor, random_state=random_state + ) - Parameters - ---------- - data - Samples x Features array + if layer is None: + adata.X = data + else: + adata.layers[layer] = data - Raises - ------ - ValueError - If any feature contains only NaNs - """ - all_nan_features = np.isnan(data).all(axis=0) - if any(all_nan_features): - raise ValueError( - f"Features with index {list(np.where(all_nan_features)[0])} contain all nan values. Drop these features beforehand." - ) + return adata if copy else None def _impute_nanmedian(data: np.ndarray) -> np.ndarray: @@ -126,16 +253,15 @@ def _impute_nanmedian(data: np.ndarray) -> np.ndarray: data Samples x Features array """ - return np.where(np.isnan(data), np.nanmedian(data, axis=0), data) - + if _check_for_complete_data(data): + logger.info("Data contains no missing values. Skipping imputation.") + return data -def _impute_knn(data: np.ndarray, **kwargs) -> np.ndarray: - imputer = KNNImputer(**kwargs) - return imputer.fit_transform(data) + return np.where(np.isnan(data), np.nanmedian(data, axis=0), data) def impute_median( - adata: ad.AnnData, group_column: str | None = None, *, layer: str | None = None, copy: bool = True + adata: ad.AnnData, group_column: str | None = None, layer: str | None = None, *, copy: bool = True ) -> ad.AnnData: """Impute missing values using median imputation @@ -205,7 +331,8 @@ def impute_median( data = adata.X if layer is None else adata.layers[layer] if group_column is None: - _check_all_nan(data) + _warn_too_many_missing(data) + _raise_on_all_nan_values(data) data = _impute_nanmedian(data) else: if pd.isna(adata.obs[group_column]).any(): @@ -217,7 +344,8 @@ def impute_median( for group_indices in groups.values(): group = data[group_indices] - _check_all_nan(group) + _warn_too_many_missing(group) + _raise_on_all_nan_values(group) data[group_indices, :] = _impute_nanmedian(group) if layer is None: @@ -228,6 +356,16 @@ def impute_median( return adata if copy else None +def _impute_knn(data: np.ndarray, **kwargs) -> np.ndarray: + """Impute missing values using kNN imputation""" + if _check_for_complete_data(data): + logger.info("Data contains no missing values. Skipping imputation.") + return data + + imputer = KNNImputer(**kwargs) + return imputer.fit_transform(data) + + def _validate_knn_grouping(groups: dict, n_neighbors: int) -> None: """Validate that knn grouping is valid""" if any(pd.isna(key) for key in groups): @@ -243,9 +381,9 @@ def impute_knn( adata: ad.AnnData, group_column: str | None = None, layer: str | None = None, - *, n_neighbors: int = 2, weights: Literal["distance", "uniform"] = "distance", + *, copy: bool = False, **kwargs, ) -> ad.AnnData: @@ -330,7 +468,8 @@ def impute_knn( data = adata.X if layer is None else adata.layers[layer] if group_column is None: - _check_all_nan(data) + _warn_too_many_missing(data) + _raise_on_all_nan_values(data) data = _impute_knn(data, n_neighbors=n_neighbors, weights=weights, **kwargs) else: groups = adata.obs.groupby(group_column, dropna=True).indices @@ -338,7 +477,8 @@ def impute_knn( for group_indices in groups.values(): group = data[group_indices] - _check_all_nan(group) + _warn_too_many_missing(group) + _raise_on_all_nan_values(group) data[group_indices, :] = _impute_knn(group, n_neighbors=n_neighbors, weights=weights, **kwargs) if layer is None: diff --git a/tests/pp/test_impute.py b/tests/pp/test_impute.py index d757624f..449aa813 100644 --- a/tests/pp/test_impute.py +++ b/tests/pp/test_impute.py @@ -4,7 +4,13 @@ import pytest from alphapepttools.pp import impute_gaussian, impute_knn, impute_median -from alphapepttools.pp.impute import _check_all_nan, _impute_knn, _impute_nanmedian +from alphapepttools.pp.impute import ( + _impute_gaussian, + _impute_knn, + _impute_nanmedian, + _raise_on_all_nan_values, + _warn_too_many_missing, +) @pytest.fixture @@ -67,9 +73,105 @@ def knn_imputation_dummy_data(imputation_dummy_data) -> tuple[np.ndarray, np.nda return imputation_dummy_data, X_ref, kwargs +@pytest.fixture +def gaussian_imputation_dummy_data(imputation_dummy_data) -> tuple[np.ndarray, np.ndarray]: + """Test data and reference for gaussian imputation""" + RANDOM_STATE = 42 + STD_FACTOR = 0.3 + STD_OFFSET = 1.8 + + X = imputation_dummy_data.copy() + rng = np.random.default_rng(RANDOM_STATE) + + # Iterate over each column and impute NaNs + for col_idx in range(X.shape[1]): + col = X[:, col_idx] + nan_mask = np.isnan(col) + + if nan_mask.any(): + # Get non-NaN values for this column + non_nan_vals = col[~nan_mask] + + # Calculate gaussian parameters + mean_val = np.nanmean(non_nan_vals) + std_val = np.nanstd(non_nan_vals) + shifted_mean = mean_val - STD_OFFSET * std_val + shifted_std = std_val * STD_FACTOR + + # Impute each NaN in this column + nan_indices = np.where(nan_mask)[0] + for idx in nan_indices: + X[idx, col_idx] = rng.normal(loc=shifted_mean, scale=shifted_std, size=1)[0] + + return imputation_dummy_data, X + + +class TestWarnTooManyMissing: + """Test the _warn_too_many_missing function""" + + @pytest.fixture + def missing_data_fixture(self) -> ad.AnnData: + """Create a 4x4 AnnData with features having 0, 1, 2, and 3 missing values""" + data = np.array( + [[1.0, 2.0, np.nan, np.nan], [3.0, np.nan, 4.0, np.nan], [5.0, 6.0, np.nan, np.nan], [7.0, 8.0, 9.0, 10.0]] + ) + + return ad.AnnData(data, layers={"layer": data}) + + @pytest.mark.parametrize( + ("threshold", "expected_warning_count"), + [ + (0.25, 2), # Features 3 and 4 should trigger warnings + (0.5, 1), # Only feature 4 should trigger warning + (0.75, 0), # No features should trigger warnings + ], + ) + def test_warn_too_many_missing_x(self, missing_data_fixture, threshold, expected_warning_count, caplog): + """Test warning triggers for different thresholds on X matrix""" + import logging + + adata = missing_data_fixture + + with caplog.at_level(logging.WARNING): + _warn_too_many_missing(adata.X, threshold=threshold) + + warning_records = [r for r in caplog.records if r.levelname == "WARNING"] + + if expected_warning_count > 0: + assert len(warning_records) == 1 + assert f"{expected_warning_count} features have more than {threshold * 100}%" in warning_records[0].message + else: + assert len(warning_records) == 0 + + @pytest.mark.parametrize( + ("threshold", "expected_warning_count"), + [ + (0.25, 2), # Features 3 and 4 should trigger warnings + (0.5, 1), # Only feature 4 should trigger warning + (0.75, 0), # No features should trigger warnings + ], + ) + def test_warn_too_many_missing_layer(self, missing_data_fixture, threshold, expected_warning_count, caplog): + """Test warning triggers for different thresholds on layer""" + import logging + + adata = missing_data_fixture + + with caplog.at_level(logging.WARNING): + _warn_too_many_missing(adata.layers["layer"], threshold=threshold) + + warning_records = [r for r in caplog.records if r.levelname == "WARNING"] + + if expected_warning_count > 0: + assert len(warning_records) == 1 + assert f"{expected_warning_count} features have more than {threshold * 100}%" in warning_records[0].message + else: + assert len(warning_records) == 0 + + def test___check_all_nan(dummy_data_all_nan) -> None: with pytest.raises(ValueError, match=r"Features with index \[4\]"): - _check_all_nan(dummy_data_all_nan) + _raise_on_all_nan_values(dummy_data_all_nan) def test__impute_nanmedian(median_imputation_dummy_data) -> None: @@ -90,61 +192,145 @@ def test__impute_knn(knn_imputation_dummy_data) -> None: assert np.all(np.isclose(X_imputed, X_ref, equal_nan=True)) -class TestImputeGaussian: - @pytest.fixture - def gaussian_imputation_dummy_data(self): - def create_data(): - data = pd.DataFrame( - { - "A": [1.0, 2.0, np.nan, 4.0, 5.0], - "B": [10.0, np.nan, 30.0, 40.0, 50.0], - }, - index=["s1", "s2", "s3", "s4", "s5"], - ) - return ad.AnnData(data, layers={"new_layer": data}) +def test__impute_gaussian(gaussian_imputation_dummy_data) -> None: + """Test gaussian imputation for data with nan values""" + X, X_ref = gaussian_imputation_dummy_data - return create_data() + X_imputed = _impute_gaussian(X.copy()) - @pytest.mark.parametrize("copy", [False, True]) - @pytest.mark.parametrize("layer", [None, "new_layer"]) - def test_impute_gaussian(self, gaussian_imputation_dummy_data: ad.AnnData, layer: str, *, copy: bool) -> None: - """Test that imputation with fixed random state produces reproducible results.""" + assert np.all(np.isclose(X_imputed, X_ref, equal_nan=True)) + +class TestImputeGaussianAnnData: + @pytest.fixture + def gaussian_imputation_dummy_anndata( + self, + gaussian_imputation_dummy_data, + ) -> tuple[ad.AnnData, np.ndarray, np.ndarray]: + """Test data for gaussian imputation""" + obs = pd.DataFrame( + { + "sample_id": ["A", "B", "C", "D"], + "sample_group": ["A", "A", "B", "B"], + "sample_group_with_nan": ["A", "A", np.nan, np.nan], + } + ) + + X, X_ref = gaussian_imputation_dummy_data + + # Generate grouped reference data RANDOM_STATE = 42 STD_FACTOR = 0.3 - STD_OFFSET = 3 - A_VALS = [1, 2, 4, 5] - B_VALS = [10, 30, 40, 50] - - result = impute_gaussian( - gaussian_imputation_dummy_data, - std_offset=STD_OFFSET, - std_factor=STD_FACTOR, - random_state=RANDOM_STATE, - layer=layer, - copy=copy, - ) + STD_OFFSET = 1.8 + X_ref_grouped = X.copy() rng = np.random.default_rng(RANDOM_STATE) - expected_A3 = rng.normal( - loc=np.nanmean(A_VALS) - STD_OFFSET * np.nanstd(A_VALS), scale=np.nanstd(A_VALS) * STD_FACTOR, size=1 - )[0] + # Group A: rows 0, 1 + # Group B: rows 2, 3 + groups = {"A": [0, 1], "B": [2, 3]} + + for group_indices in groups.values(): + group_data = X_ref_grouped[group_indices, :] + + for col_idx in range(group_data.shape[1]): + col = group_data[:, col_idx] + nan_mask = np.isnan(col) - expected_B2 = rng.normal( - loc=np.nanmean(B_VALS) - STD_OFFSET * np.nanstd(B_VALS), - scale=np.nanstd(B_VALS) * STD_FACTOR, - size=1, - )[0] + # Basically recap what _impute_gaussian does, but only for this group and explicitly written out + if nan_mask.any(): + non_nan_vals = col[~nan_mask] + mean_val = np.nanmean(non_nan_vals) + std_val = np.nanstd(non_nan_vals) + shifted_mean = mean_val - STD_OFFSET * std_val + shifted_std = std_val * STD_FACTOR - adata_imputed = result if copy else gaussian_imputation_dummy_data + nan_indices = np.where(nan_mask)[0] + for idx in nan_indices: + group_data[idx, col_idx] = rng.normal(loc=shifted_mean, scale=shifted_std, size=1)[0] - imputed = adata_imputed.to_df(layer=layer) + X_ref_grouped[group_indices, :] = group_data - assert np.allclose(imputed.loc["s3", "A"], expected_A3) - assert np.allclose(imputed.loc["s2", "B"], expected_B2) - assert not np.isnan(imputed.loc["s3", "A"]) - assert not np.isnan(imputed.loc["s2", "B"]) + return ad.AnnData(X, obs=obs, layers={"layer2": X}), X_ref, X_ref_grouped + + @pytest.fixture + def gaussian_imputation_dummy_anndata_all_nan(self, dummy_data_all_nan: np.ndarray) -> ad.AnnData: + """AnnData object with a feature that contains only NaNs""" + + obs = pd.DataFrame( + { + "sample_id": ["A", "B", "C", "D"], + "sample_group": ["A", "A", "B", "B"], + "sample_group_with_nan": ["A", "A", np.nan, np.nan], + } + ) + + return ad.AnnData(X=dummy_data_all_nan, obs=obs) + + @pytest.mark.parametrize("copy", [False, True]) + @pytest.mark.parametrize("layer", [None, "layer2"]) + @pytest.mark.parametrize("group_column", [None, "sample_group"]) + def test_impute_gaussian( + self, gaussian_imputation_dummy_anndata, layer: str, group_column: str, *, copy: bool + ) -> None: + """Test gaussian imputation for data with nan values""" + adata, X_ref, X_ref_grouped = gaussian_imputation_dummy_anndata + + result = impute_gaussian(adata, layer=layer, group_column=group_column, copy=copy) + + if copy: + assert isinstance(result, ad.AnnData) + adata_imputed = result + else: + assert result is None + adata_imputed = adata + + X_imputed = adata_imputed.X if layer is None else adata_imputed.layers[layer] + + if group_column is None: + assert np.all(np.isclose(X_imputed, X_ref, equal_nan=True)) + elif group_column == "sample_group": + assert np.all(np.isclose(X_imputed, X_ref_grouped, equal_nan=True)) + else: + pytest.fail("Unexpected group column passed") + + @pytest.mark.parametrize("group_column", [None, "sample_group"]) + def test_impute_gaussian__feature_all_nan( + self, gaussian_imputation_dummy_anndata_all_nan, group_column: str + ) -> None: + """Test gaussian imputation raises if a feature contains all nan""" + adata = gaussian_imputation_dummy_anndata_all_nan + + with pytest.raises(ValueError, match=r"Features with index \[4\]"): + _ = impute_gaussian(adata, group_column=group_column) + + def test_impute_gaussian__raises_if_group_column_contains_nan(self, gaussian_imputation_dummy_anndata) -> None: + """Test that gaussian imputation raises error if group_column contains nan""" + + adata, _, _ = gaussian_imputation_dummy_anndata + + with pytest.raises(ValueError, match="`group_column`"): + _ = impute_gaussian(adata, layer=None, group_column="sample_group_with_nan") + + def test_impute_gaussian__missing_group_column( + self, + gaussian_imputation_dummy_anndata, + ) -> None: + """Test that KeyError is raised if `group_column` does not exist in `adata.obs`""" + adata, _, _ = gaussian_imputation_dummy_anndata + + with pytest.raises(KeyError): + impute_gaussian(adata, group_column="non_existent_column") + + def test_impute_gaussian__missing_layer( + self, + gaussian_imputation_dummy_anndata, + ) -> None: + """Test that KeyError is raised if `layer` does not exist in `adata`""" + adata, _, _ = gaussian_imputation_dummy_anndata + + with pytest.raises(KeyError): + impute_gaussian(adata, layer="non_existent_layer") class TestImputeMedianAnnData: