Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions packages/api/src/cell_explorer_api/services/zarr_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ async def obs_column(self, name: str) -> ObsColumn:
raw = await self.store.obs_column(name)
return _decode_col_to_obs_column(name, raw)

async def var_column(self, name: str) -> ObsColumn:
raw = await self.store.var_column(name)
return _decode_col_to_obs_column(name, raw)

async def gene_index(self, gene: str) -> int:
await self._ensure_gene_map()
assert self._gene_to_row_cache is not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
describe_obs_column_tool,
)
from cell_explorer_agent.tools.data.schema import get_dataset_schema_tool
from cell_explorer_agent.tools.data.var import describe_var_column_tool
from cell_explorer_agent.tools.registry import Tool, ToolCatalog, ToolKind
from cell_explorer_agent.tools.ui_action.color import (
clear_color_by_tool,
Expand Down Expand Up @@ -69,6 +70,7 @@ def build_v1_catalog(
# data
cat.register(get_dataset_schema_tool(z, limit_bytes=lim))
cat.register(describe_obs_column_tool(z, limit_bytes=lim))
cat.register(describe_var_column_tool(z, limit_bytes=lim))
cat.register(cluster_stats_tool(z, limit_bytes=lim))
cat.register(search_genes_tool(z, limit_bytes=lim))
cat.register(gene_expression_summary_tool(z, limit_bytes=lim))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def get_dataset_schema_tool(z: ZarrAccess, *, limit_bytes: int) -> Tool:
async def run() -> dict[str, Any]:
n_obs, n_var = await z.shape()
cols = await z.obs_columns()
var_cols = await z.var_columns()
obsm = await z.obsm_keys()
payload = {
"n_obs": n_obs,
Expand All @@ -20,6 +21,7 @@ async def run() -> dict[str, Any]:
{"name": c.name, "dtype": c.dtype, "cardinality": c.cardinality}
for c in cols
],
"var_columns": list(var_cols),
"embeddings": list(obsm),
}
return cap_json_bytes(payload, limit_bytes=limit_bytes)
Expand All @@ -29,7 +31,8 @@ async def run() -> dict[str, Any]:
kind="data",
description=(
"Return dataset shape, obs column names/dtypes/cardinalities, "
"and available embedding keys. Call this first to learn the schema."
"var (per-gene) column names, and available embedding keys. Call "
"this first to learn the schema."
),
args_schema={"type": "object", "properties": {}, "additionalProperties": False},
func=run,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""describe_var_column tool — mirror of describe_obs_column but for var columns.

Var columns are the per-gene metadata columns in the AnnData var dataframe.
Typical examples: `gene_symbol`, `feature_id`, `highly_variable`,
`mean_counts`, `dispersions_norm`. They have the same categorical/numeric
shape as obs columns, so the tool output mirrors describe_obs_column exactly.
"""

from collections import Counter
from typing import Any

import numpy as np

from cell_explorer_agent.tools.caps import cap_json_bytes
from cell_explorer_agent.tools.registry import Tool
from cell_explorer_agent.tools.zarr_protocol import ZarrAccess

TOP_N = 50


def describe_var_column_tool(z: ZarrAccess, *, limit_bytes: int) -> Tool:
async def run(name: str) -> dict[str, Any]:
try:
col = await z.var_column(name)
except KeyError:
return {"error": f"var column {name!r} not found"}

total = int(len(col.values))
if col.dtype == "categorical":
assert col.categories is not None
counts = Counter(col.values.tolist())
items = counts.most_common(TOP_N)
other = total - sum(c for _, c in items)
return cap_json_bytes(
{
"dtype": "categorical",
"total": total,
"top_categories": [
{"value": col.categories[code], "count": int(count)}
for code, count in items
],
"other_count": int(other),
},
limit_bytes=limit_bytes,
)

vals = np.asarray(col.values, dtype="float64")
return cap_json_bytes(
{
"dtype": "numeric",
"total": total,
"stats": {
"min": float(np.min(vals)),
"max": float(np.max(vals)),
"mean": float(np.mean(vals)),
"median": float(np.median(vals)),
"q1": float(np.quantile(vals, 0.25)),
"q3": float(np.quantile(vals, 0.75)),
"stddev": float(np.std(vals)),
},
},
limit_bytes=limit_bytes,
)

return Tool(
name="describe_var_column",
kind="data",
description=(
"Describe one var (per-gene) column. Use to inspect gene-metadata "
"columns like 'gene_symbol' or 'feature_id'. For categorical "
"columns, returns the top 50 values with counts plus the remainder. "
"For numeric columns, returns min/max/mean/median/quartiles/stddev. "
"Call get_dataset_schema first to learn which var columns are "
"available."
),
args_schema={
"type": "object",
"properties": {"name": {"type": "string"}},
"required": ["name"],
"additionalProperties": False,
},
func=run,
)
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ async def var_columns(self) -> list[str]:
Distinct from var_names which is the gene-name index.
"""
...

async def var_column(self, name: str) -> ObsColumn:
"""Single var column — efficient for targeted queries.

Returns the same ObsColumn shape as obs_column (categorical/numeric/string
with values + categories). The type name reflects historical naming;
the structure is axis-agnostic.
"""
...
async def obsm_keys(self) -> list[str]: ...
async def gene_index(self, gene: str) -> int: ...
async def gene_column(self, gene: str) -> np.ndarray: ...
Expand Down
9 changes: 9 additions & 0 deletions packages/cell-explorer-agent/tests/fakes/fake_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ class FakeZarrAccess:
obs: dict[str, ObsColumn] = field(default_factory=dict)
var: list[str] = field(default_factory=list)
var_columns_data: list[str] = field(default_factory=lambda: ["feature_id", "gene_symbol"])
# Per-name var column payloads, keyed by var_columns_data name. Tests that
# exercise describe_var_column populate this; defaults left empty so older
# tests stay unaffected.
var_data: dict[str, ObsColumn] = field(default_factory=dict)
obsm: list[str] = field(default_factory=lambda: ["X_umap", "X_pca"])
expression: dict[str, np.ndarray] = field(default_factory=dict)
_attrs: dict = field(
Expand Down Expand Up @@ -95,6 +99,11 @@ async def var_names(self) -> list[str]:
async def var_columns(self) -> list[str]:
return list(self.var_columns_data)

async def var_column(self, name: str) -> ObsColumn:
if name not in self.var_data:
raise KeyError(name)
return self.var_data[name]

async def obsm_keys(self) -> list[str]:
return list(self.obsm)

Expand Down
1 change: 1 addition & 0 deletions packages/cell-explorer-agent/tests/test_tool_catalog_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ async def test_v1_catalog_includes_all_tools(fake_zarr):
# data
"get_dataset_schema",
"describe_obs_column",
"describe_var_column",
"cluster_stats",
"search_genes",
"gene_expression_summary",
Expand Down
60 changes: 60 additions & 0 deletions packages/cell-explorer-agent/tests/test_tools_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ async def test_get_dataset_schema(fake_zarr):
assert cell_type["cardinality"] == 3
assert "X_umap" in result["embeddings"]
assert result["var_count"] == 50
# var_columns lists names from the var dataframe (e.g. gene_symbol, feature_id)
assert "var_columns" in result
assert "feature_id" in result["var_columns"]
assert "gene_symbol" in result["var_columns"]


async def test_get_dataset_schema_is_data_kind(fake_zarr):
Expand Down Expand Up @@ -905,3 +909,59 @@ async def test_xscan_group_sums_zero_floored_backward_compat():

np.testing.assert_allclose(sum_x[0], 3.0, atol=1e-6)
assert nnz[0] == 2, f"expected 2, got {nnz[0]}" # matches count_nonzero for zero-floored data


# ---------------------------------------------------------------------------
# describe_var_column tests — mirror the obs side
# ---------------------------------------------------------------------------

from cell_explorer_agent.tools.data.var import describe_var_column_tool
from cell_explorer_agent.tools.zarr_protocol import ObsColumn


async def test_describe_var_column_categorical(fake_zarr):
import numpy as np

# Seed the fake with a categorical var column (e.g., feature_type with two values)
n_var = 50
cats = ["protein_coding", "lncRNA"]
codes = np.array([0] * 40 + [1] * 10, dtype=np.int32)
fake_zarr.var_data["feature_type"] = ObsColumn(
name="feature_type",
dtype="categorical",
values=codes,
categories=list(cats),
)
fake_zarr.var_columns_data = ["feature_id", "gene_symbol", "feature_type"]

tool = describe_var_column_tool(fake_zarr, limit_bytes=32_768)
result = await tool.func(name="feature_type")
assert result["dtype"] == "categorical"
assert result["total"] == n_var
names = {c["value"] for c in result["top_categories"]}
assert names == {"protein_coding", "lncRNA"}


async def test_describe_var_column_numeric(fake_zarr):
import numpy as np

n_var = 50
fake_zarr.var_data["mean_counts"] = ObsColumn(
name="mean_counts",
dtype="numeric",
values=np.linspace(0.1, 10.0, n_var, dtype=np.float32),
)

tool = describe_var_column_tool(fake_zarr, limit_bytes=32_768)
result = await tool.func(name="mean_counts")
assert result["dtype"] == "numeric"
assert result["total"] == n_var
assert set(result["stats"]) == {"min", "max", "mean", "median", "q1", "q3", "stddev"}
assert result["stats"]["min"] < result["stats"]["max"]


async def test_describe_var_column_unknown_returns_error(fake_zarr):
tool = describe_var_column_tool(fake_zarr, limit_bytes=32_768)
result = await tool.func(name="nonexistent_column")
assert "error" in result
assert "nonexistent_column" in result["error"]
6 changes: 6 additions & 0 deletions packages/zarr-access/src/zarr_access/anndata_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ async def obs_column(self, name: str):
node = await obs.getitem(name)
return await decode_column(node)

async def var_column(self, name: str):
"""Single var column — efficient for targeted queries."""
var = await self._zarr.get_group("var")
node = await var.getitem(name)
return await decode_column(node)

async def obsm(self, key: str) -> np.ndarray:
"""Embedding array (e.g., X_umap)."""
obsm = await self._zarr.get_group("obsm")
Expand Down
Loading