diff --git a/packages/api/src/cell_explorer_api/services/zarr_adapter.py b/packages/api/src/cell_explorer_api/services/zarr_adapter.py index 3c96545..63141d2 100644 --- a/packages/api/src/cell_explorer_api/services/zarr_adapter.py +++ b/packages/api/src/cell_explorer_api/services/zarr_adapter.py @@ -101,6 +101,10 @@ async def obs_column(self, name: str) -> ObsColumn: raw = await self.store.obs_column(name) return _decode_col_to_obs_column(name, raw) + async def var_column(self, name: str) -> ObsColumn: + raw = await self.store.var_column(name) + return _decode_col_to_obs_column(name, raw) + async def gene_index(self, gene: str) -> int: await self._ensure_gene_map() assert self._gene_to_row_cache is not None diff --git a/packages/cell-explorer-agent/src/cell_explorer_agent/tools/__init__.py b/packages/cell-explorer-agent/src/cell_explorer_agent/tools/__init__.py index e47fb56..5dc7b39 100644 --- a/packages/cell-explorer-agent/src/cell_explorer_agent/tools/__init__.py +++ b/packages/cell-explorer-agent/src/cell_explorer_agent/tools/__init__.py @@ -15,6 +15,7 @@ describe_obs_column_tool, ) from cell_explorer_agent.tools.data.schema import get_dataset_schema_tool +from cell_explorer_agent.tools.data.var import describe_var_column_tool from cell_explorer_agent.tools.registry import Tool, ToolCatalog, ToolKind from cell_explorer_agent.tools.ui_action.color import ( clear_color_by_tool, @@ -69,6 +70,7 @@ def build_v1_catalog( # data cat.register(get_dataset_schema_tool(z, limit_bytes=lim)) cat.register(describe_obs_column_tool(z, limit_bytes=lim)) + cat.register(describe_var_column_tool(z, limit_bytes=lim)) cat.register(cluster_stats_tool(z, limit_bytes=lim)) cat.register(search_genes_tool(z, limit_bytes=lim)) cat.register(gene_expression_summary_tool(z, limit_bytes=lim)) diff --git a/packages/cell-explorer-agent/src/cell_explorer_agent/tools/data/schema.py b/packages/cell-explorer-agent/src/cell_explorer_agent/tools/data/schema.py index 26ceffd..a409fa5 100644 --- a/packages/cell-explorer-agent/src/cell_explorer_agent/tools/data/schema.py +++ b/packages/cell-explorer-agent/src/cell_explorer_agent/tools/data/schema.py @@ -11,6 +11,7 @@ def get_dataset_schema_tool(z: ZarrAccess, *, limit_bytes: int) -> Tool: async def run() -> dict[str, Any]: n_obs, n_var = await z.shape() cols = await z.obs_columns() + var_cols = await z.var_columns() obsm = await z.obsm_keys() payload = { "n_obs": n_obs, @@ -20,6 +21,7 @@ async def run() -> dict[str, Any]: {"name": c.name, "dtype": c.dtype, "cardinality": c.cardinality} for c in cols ], + "var_columns": list(var_cols), "embeddings": list(obsm), } return cap_json_bytes(payload, limit_bytes=limit_bytes) @@ -29,7 +31,8 @@ async def run() -> dict[str, Any]: kind="data", description=( "Return dataset shape, obs column names/dtypes/cardinalities, " - "and available embedding keys. Call this first to learn the schema." + "var (per-gene) column names, and available embedding keys. Call " + "this first to learn the schema." ), args_schema={"type": "object", "properties": {}, "additionalProperties": False}, func=run, diff --git a/packages/cell-explorer-agent/src/cell_explorer_agent/tools/data/var.py b/packages/cell-explorer-agent/src/cell_explorer_agent/tools/data/var.py new file mode 100644 index 0000000..750b78a --- /dev/null +++ b/packages/cell-explorer-agent/src/cell_explorer_agent/tools/data/var.py @@ -0,0 +1,83 @@ +"""describe_var_column tool — mirror of describe_obs_column but for var columns. + +Var columns are the per-gene metadata columns in the AnnData var dataframe. +Typical examples: `gene_symbol`, `feature_id`, `highly_variable`, +`mean_counts`, `dispersions_norm`. They have the same categorical/numeric +shape as obs columns, so the tool output mirrors describe_obs_column exactly. +""" + +from collections import Counter +from typing import Any + +import numpy as np + +from cell_explorer_agent.tools.caps import cap_json_bytes +from cell_explorer_agent.tools.registry import Tool +from cell_explorer_agent.tools.zarr_protocol import ZarrAccess + +TOP_N = 50 + + +def describe_var_column_tool(z: ZarrAccess, *, limit_bytes: int) -> Tool: + async def run(name: str) -> dict[str, Any]: + try: + col = await z.var_column(name) + except KeyError: + return {"error": f"var column {name!r} not found"} + + total = int(len(col.values)) + if col.dtype == "categorical": + assert col.categories is not None + counts = Counter(col.values.tolist()) + items = counts.most_common(TOP_N) + other = total - sum(c for _, c in items) + return cap_json_bytes( + { + "dtype": "categorical", + "total": total, + "top_categories": [ + {"value": col.categories[code], "count": int(count)} + for code, count in items + ], + "other_count": int(other), + }, + limit_bytes=limit_bytes, + ) + + vals = np.asarray(col.values, dtype="float64") + return cap_json_bytes( + { + "dtype": "numeric", + "total": total, + "stats": { + "min": float(np.min(vals)), + "max": float(np.max(vals)), + "mean": float(np.mean(vals)), + "median": float(np.median(vals)), + "q1": float(np.quantile(vals, 0.25)), + "q3": float(np.quantile(vals, 0.75)), + "stddev": float(np.std(vals)), + }, + }, + limit_bytes=limit_bytes, + ) + + return Tool( + name="describe_var_column", + kind="data", + description=( + "Describe one var (per-gene) column. Use to inspect gene-metadata " + "columns like 'gene_symbol' or 'feature_id'. For categorical " + "columns, returns the top 50 values with counts plus the remainder. " + "For numeric columns, returns min/max/mean/median/quartiles/stddev. " + "Call get_dataset_schema first to learn which var columns are " + "available." + ), + args_schema={ + "type": "object", + "properties": {"name": {"type": "string"}}, + "required": ["name"], + "additionalProperties": False, + }, + func=run, + ) diff --git a/packages/cell-explorer-agent/src/cell_explorer_agent/tools/zarr_protocol.py b/packages/cell-explorer-agent/src/cell_explorer_agent/tools/zarr_protocol.py index d052e53..dc41c81 100644 --- a/packages/cell-explorer-agent/src/cell_explorer_agent/tools/zarr_protocol.py +++ b/packages/cell-explorer-agent/src/cell_explorer_agent/tools/zarr_protocol.py @@ -36,6 +36,15 @@ async def var_columns(self) -> list[str]: Distinct from var_names which is the gene-name index. """ ... + + async def var_column(self, name: str) -> ObsColumn: + """Single var column — efficient for targeted queries. + + Returns the same ObsColumn shape as obs_column (categorical/numeric/string + with values + categories). The type name reflects historical naming; + the structure is axis-agnostic. + """ + ... async def obsm_keys(self) -> list[str]: ... async def gene_index(self, gene: str) -> int: ... async def gene_column(self, gene: str) -> np.ndarray: ... diff --git a/packages/cell-explorer-agent/tests/fakes/fake_zarr.py b/packages/cell-explorer-agent/tests/fakes/fake_zarr.py index a81f396..4bce556 100644 --- a/packages/cell-explorer-agent/tests/fakes/fake_zarr.py +++ b/packages/cell-explorer-agent/tests/fakes/fake_zarr.py @@ -20,6 +20,10 @@ class FakeZarrAccess: obs: dict[str, ObsColumn] = field(default_factory=dict) var: list[str] = field(default_factory=list) var_columns_data: list[str] = field(default_factory=lambda: ["feature_id", "gene_symbol"]) + # Per-name var column payloads, keyed by var_columns_data name. Tests that + # exercise describe_var_column populate this; defaults left empty so older + # tests stay unaffected. + var_data: dict[str, ObsColumn] = field(default_factory=dict) obsm: list[str] = field(default_factory=lambda: ["X_umap", "X_pca"]) expression: dict[str, np.ndarray] = field(default_factory=dict) _attrs: dict = field( @@ -95,6 +99,11 @@ async def var_names(self) -> list[str]: async def var_columns(self) -> list[str]: return list(self.var_columns_data) + async def var_column(self, name: str) -> ObsColumn: + if name not in self.var_data: + raise KeyError(name) + return self.var_data[name] + async def obsm_keys(self) -> list[str]: return list(self.obsm) diff --git a/packages/cell-explorer-agent/tests/test_tool_catalog_v1.py b/packages/cell-explorer-agent/tests/test_tool_catalog_v1.py index 5f7e2c1..afecddd 100644 --- a/packages/cell-explorer-agent/tests/test_tool_catalog_v1.py +++ b/packages/cell-explorer-agent/tests/test_tool_catalog_v1.py @@ -9,6 +9,7 @@ async def test_v1_catalog_includes_all_tools(fake_zarr): # data "get_dataset_schema", "describe_obs_column", + "describe_var_column", "cluster_stats", "search_genes", "gene_expression_summary", diff --git a/packages/cell-explorer-agent/tests/test_tools_data.py b/packages/cell-explorer-agent/tests/test_tools_data.py index 2acafb8..ee1b2c9 100644 --- a/packages/cell-explorer-agent/tests/test_tools_data.py +++ b/packages/cell-explorer-agent/tests/test_tools_data.py @@ -13,6 +13,10 @@ async def test_get_dataset_schema(fake_zarr): assert cell_type["cardinality"] == 3 assert "X_umap" in result["embeddings"] assert result["var_count"] == 50 + # var_columns lists names from the var dataframe (e.g. gene_symbol, feature_id) + assert "var_columns" in result + assert "feature_id" in result["var_columns"] + assert "gene_symbol" in result["var_columns"] async def test_get_dataset_schema_is_data_kind(fake_zarr): @@ -905,3 +909,59 @@ async def test_xscan_group_sums_zero_floored_backward_compat(): np.testing.assert_allclose(sum_x[0], 3.0, atol=1e-6) assert nnz[0] == 2, f"expected 2, got {nnz[0]}" # matches count_nonzero for zero-floored data + + +# --------------------------------------------------------------------------- +# describe_var_column tests — mirror the obs side +# --------------------------------------------------------------------------- + +from cell_explorer_agent.tools.data.var import describe_var_column_tool +from cell_explorer_agent.tools.zarr_protocol import ObsColumn + + +async def test_describe_var_column_categorical(fake_zarr): + import numpy as np + + # Seed the fake with a categorical var column (e.g., feature_type with two values) + n_var = 50 + cats = ["protein_coding", "lncRNA"] + codes = np.array([0] * 40 + [1] * 10, dtype=np.int32) + fake_zarr.var_data["feature_type"] = ObsColumn( + name="feature_type", + dtype="categorical", + values=codes, + categories=list(cats), + ) + fake_zarr.var_columns_data = ["feature_id", "gene_symbol", "feature_type"] + + tool = describe_var_column_tool(fake_zarr, limit_bytes=32_768) + result = await tool.func(name="feature_type") + assert result["dtype"] == "categorical" + assert result["total"] == n_var + names = {c["value"] for c in result["top_categories"]} + assert names == {"protein_coding", "lncRNA"} + + +async def test_describe_var_column_numeric(fake_zarr): + import numpy as np + + n_var = 50 + fake_zarr.var_data["mean_counts"] = ObsColumn( + name="mean_counts", + dtype="numeric", + values=np.linspace(0.1, 10.0, n_var, dtype=np.float32), + ) + + tool = describe_var_column_tool(fake_zarr, limit_bytes=32_768) + result = await tool.func(name="mean_counts") + assert result["dtype"] == "numeric" + assert result["total"] == n_var + assert set(result["stats"]) == {"min", "max", "mean", "median", "q1", "q3", "stddev"} + assert result["stats"]["min"] < result["stats"]["max"] + + +async def test_describe_var_column_unknown_returns_error(fake_zarr): + tool = describe_var_column_tool(fake_zarr, limit_bytes=32_768) + result = await tool.func(name="nonexistent_column") + assert "error" in result + assert "nonexistent_column" in result["error"] diff --git a/packages/zarr-access/src/zarr_access/anndata_store.py b/packages/zarr-access/src/zarr_access/anndata_store.py index 0efabee..adf9997 100644 --- a/packages/zarr-access/src/zarr_access/anndata_store.py +++ b/packages/zarr-access/src/zarr_access/anndata_store.py @@ -67,6 +67,12 @@ async def obs_column(self, name: str): node = await obs.getitem(name) return await decode_column(node) + async def var_column(self, name: str): + """Single var column — efficient for targeted queries.""" + var = await self._zarr.get_group("var") + node = await var.getitem(name) + return await decode_column(node) + async def obsm(self, key: str) -> np.ndarray: """Embedding array (e.g., X_umap).""" obsm = await self._zarr.get_group("obsm")