Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""add dataset prompt_addendum

Revision ID: 27f37d3cf8d2
Revises: 092e2aa153ce
Create Date: 2026-05-21 10:09:30.488887

"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = '27f37d3cf8d2'
down_revision: Union[str, Sequence[str], None] = '092e2aa153ce'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
"""Upgrade schema."""
op.add_column('datasets', sa.Column('prompt_addendum', sa.Text(), nullable=True))


def downgrade() -> None:
"""Downgrade schema."""
op.drop_column('datasets', 'prompt_addendum')
1 change: 1 addition & 0 deletions packages/api/src/cell_explorer_api/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class Dataset(SQLModel, table=True):
description: str | None = None
is_public: bool = Field(default=False)
required_roles: list[str] = Field(default_factory=list, sa_column=Column(JSON))
prompt_addendum: str | None = Field(default=None)
chat_enabled: bool = Field(default=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
Expand Down
6 changes: 6 additions & 0 deletions packages/api/src/cell_explorer_api/routes/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class DatasetCreate(BaseModel):
description: str | None = None
is_public: bool = False
required_roles: list[str] = []
prompt_addendum: str | None = None
chat_enabled: bool = False


Expand All @@ -67,6 +68,7 @@ class DatasetUpdate(BaseModel):
description: str | None = None
is_public: bool | None = None
required_roles: list[str] | None = None
prompt_addendum: str | None = None
chat_enabled: bool | None = None


Expand All @@ -79,6 +81,7 @@ class DatasetAdminResponse(BaseModel):
description: str | None
is_public: bool
required_roles: list[str]
prompt_addendum: str | None
chat_enabled: bool


Expand Down Expand Up @@ -175,6 +178,7 @@ async def list_datasets_admin(
description=dataset.description,
is_public=dataset.is_public,
required_roles=dataset.required_roles,
prompt_addendum=dataset.prompt_addendum,
chat_enabled=dataset.chat_enabled,
)
for dataset in result.all()
Expand Down Expand Up @@ -209,6 +213,7 @@ async def create_dataset(
description=dataset.description,
is_public=dataset.is_public,
required_roles=dataset.required_roles,
prompt_addendum=dataset.prompt_addendum,
chat_enabled=dataset.chat_enabled,
)

Expand Down Expand Up @@ -239,6 +244,7 @@ async def update_dataset(
description=dataset.description,
is_public=dataset.is_public,
required_roles=dataset.required_roles,
prompt_addendum=dataset.prompt_addendum,
chat_enabled=dataset.chat_enabled,
)

Expand Down
13 changes: 8 additions & 5 deletions packages/api/src/cell_explorer_api/services/chat_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@


# In-process cache of DatasetContext keyed by (slug, updated_at). The Dataset
# row's updated_at bumps on every admin PUT (see routes/admin.py), so this
# self-invalidates without an explicit hook: admin edits change the key,
# subsequent requests miss the cache and rebuild. Process restart clears the
# cache (uvicorn --reload covers dev). Stale entries accumulate on edits but
# the leak is bounded by edit frequency and dataset count.
# row's updated_at bumps unconditionally on every admin PUT (see
# routes/admin.py update_dataset, line ~234), so this self-invalidates for any
# DB-sourced DatasetContext field — including prompt_addendum — without an
# explicit hook: admin edits change the key, subsequent requests miss the
# cache and rebuild. Process restart clears the cache (uvicorn --reload covers
# dev). Stale entries accumulate on edits but the leak is bounded by edit
# frequency and dataset count.
#
# See issue #101.
_dataset_ctx_cache: dict[tuple[str, datetime], DatasetContext] = {}
Expand All @@ -62,6 +64,7 @@ async def _build_dataset_context_cached(
slug=dataset.slug,
name=dataset.name,
description=dataset.description or "",
prompt_addendum=dataset.prompt_addendum,
)
return _dataset_ctx_cache[key]

Expand Down
37 changes: 37 additions & 0 deletions packages/api/tests/services/test_chat_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,43 @@ async def _spy_build(*args, **kwargs):
assert call_count == 2, f"build_dataset_context called {call_count}x; expected 2 after updated_at bump"


@pytest.mark.asyncio
async def test_make_chat_agent_forwards_prompt_addendum_to_dataset_context():
"""When the Dataset row has prompt_addendum set, make_chat_agent forwards
it to build_dataset_context (via _build_dataset_context_cached)."""
dataset = _public_dataset()
dataset.prompt_addendum = "Test curator note: cells were sorted on CD45."
datasource = MagicMock(base_url="https://example.com", type="HTTP_TOKEN", credential_ref=None)
db = await _mk_db_session(_make_db_row(dataset, datasource))

fake_anndata = MagicMock(n_obs=10, n_vars=20, obsm_keys=[], obs_columns=[])

captured: dict = {}
from cell_explorer_agent import build_dataset_context as _real

async def _spy_build(*args, **kwargs):
captured.update(kwargs)
return await _real(*args, **kwargs)

with patch("cell_explorer_api.services.chat_session.ZarrStore") as MockZS, \
patch("cell_explorer_api.services.chat_session.AnnDataStore") as MockADS, \
patch("cell_explorer_api.services.chat_session.StrataStore") as MockSS, \
patch("cell_explorer_api.services.chat_session.build_dataset_context", _spy_build):
MockZS.open = AsyncMock(return_value=MagicMock())
MockADS.open = AsyncMock(return_value=fake_anndata)
MockSS.open = AsyncMock(return_value=MagicMock())

user = _FakeUser(roles=[])
llm = FakeLLMClient(scripts=[])
settings = MagicMock()
agent = await make_chat_agent(
user=user, dataset_slug="pbmc3k", db=db, settings=settings, llm=llm,
)

assert captured.get("prompt_addendum") == "Test curator note: cells were sorted on CD45."
assert agent.dataset_ctx.prompt_addendum == "Test curator note: cells were sorted on CD45."


@pytest.mark.asyncio
async def test_dataset_ctx_cache_is_independent_per_slug():
"""Two datasets with different slugs cache independently."""
Expand Down
99 changes: 99 additions & 0 deletions packages/api/tests/test_admin_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,102 @@ def test_admin_create_dataset_chat_enabled_defaults_false(seeded_app):
)
assert response.status_code == 201
assert response.json()["chat_enabled"] is False


# --- prompt_addendum field ---


def test_create_dataset_with_prompt_addendum(seeded_app):
"""POST /admin/datasets stores prompt_addendum and surfaces it on response."""
ds_id = seeded_app.state.test_datasource_id
client = TestClient(seeded_app)
response = client.post(
"/api/admin/datasets",
json={
"datasource_id": ds_id,
"name": "Test",
"slug": "test-with-addendum",
"path": "test.zarr",
"prompt_addendum": "Important: cells were sorted on CD45 first.",
},
headers=AUTH_HEADER,
)
assert response.status_code == 201, response.text
body = response.json()
assert body["prompt_addendum"] == "Important: cells were sorted on CD45 first."


def test_create_dataset_prompt_addendum_defaults_null(seeded_app):
"""prompt_addendum is null when not supplied on create."""
ds_id = seeded_app.state.test_datasource_id
client = TestClient(seeded_app)
response = client.post(
"/api/admin/datasets",
json={
"datasource_id": ds_id,
"name": "No-Addendum-Default",
"slug": "no-addendum-default",
"path": "no-addendum.zarr",
},
headers=AUTH_HEADER,
)
assert response.status_code == 201, response.text
assert response.json()["prompt_addendum"] is None


def test_update_dataset_prompt_addendum(seeded_app):
"""PUT updates only prompt_addendum; other fields unchanged."""
ds_id = seeded_app.state.test_datasource_id
client = TestClient(seeded_app)
# Create a dataset with no addendum first
create_resp = client.post(
"/api/admin/datasets",
json={
"datasource_id": ds_id,
"name": "No-Addendum",
"slug": "no-addendum",
"path": "no-addendum.zarr",
},
headers=AUTH_HEADER,
)
assert create_resp.status_code == 201
created = create_resp.json()
assert created["prompt_addendum"] is None
# Update only prompt_addendum
response = client.put(
"/api/admin/datasets/no-addendum",
json={"prompt_addendum": "Curator notes here."},
headers=AUTH_HEADER,
)
assert response.status_code == 200, response.text
body = response.json()
assert body["prompt_addendum"] == "Curator notes here."
# Other fields kept their previous values
assert body["name"] == created["name"]
assert body["path"] == created["path"]


def test_get_datasets_includes_prompt_addendum(seeded_app):
"""GET /admin/datasets surfaces prompt_addendum on each row."""
ds_id = seeded_app.state.test_datasource_id
client = TestClient(seeded_app)
# Create a dataset with an addendum
create_resp = client.post(
"/api/admin/datasets",
json={
"datasource_id": ds_id,
"name": "Has-Addendum",
"slug": "has-addendum",
"path": "has-addendum.zarr",
"prompt_addendum": "Sample curator note.",
},
headers=AUTH_HEADER,
)
assert create_resp.status_code == 201
response = client.get("/api/admin/datasets", headers=AUTH_HEADER)
assert response.status_code == 200
body = response.json()
rows = {d["slug"]: d for d in body["datasets"]}
assert "has-addendum" in rows
assert "prompt_addendum" in rows["has-addendum"]
assert rows["has-addendum"]["prompt_addendum"] == "Sample curator note."
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,16 @@ class DatasetContext:
n_var: int
obs_columns: list[ObsColumnInfo]
embedding_keys: list[str]
prompt_addendum: str | None = None


async def build_dataset_context(
z: ZarrAccess, *, slug: str, name: str, description: str
z: ZarrAccess,
*,
slug: str,
name: str,
description: str,
prompt_addendum: str | None = None,
) -> DatasetContext:
n_obs, n_var = await z.shape()
obs = await z.obs_columns()
Expand All @@ -47,4 +53,5 @@ async def build_dataset_context(
for c in obs
],
embedding_keys=list(emb),
prompt_addendum=prompt_addendum,
)
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ def build_system_prompt(ctx: DatasetContext) -> str:
lines.append(f"Dataset: {ctx.slug} — {ctx.name}")
if ctx.description:
lines.append(f"Description: {ctx.description}")
if ctx.prompt_addendum and ctx.prompt_addendum.strip():
lines.append("")
lines.append("=== Curator notes (authoritative dataset context) ===")
lines.append(ctx.prompt_addendum)
lines.append("=== end curator notes ===")
lines.append("")
lines.append(f"Shape: {ctx.n_obs} cells × {ctx.n_var} genes.")
lines.append("")
lines.append("Obs columns:")
Expand Down
60 changes: 60 additions & 0 deletions packages/cell-explorer-agent/tests/test_prompt_system.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import pytest


def test_system_prompt_mentions_chart_awareness():
"""The agent is told to write concise summaries when a chart is present
in the tool result, instead of re-enumerating every row."""
Expand All @@ -13,3 +16,60 @@ def test_system_prompt_mentions_chart_awareness():
assert "chart" in prompt.lower()
# gene_panel_by_obs is registered in the tool-use policy.
assert "gene_panel_by_obs" in prompt


def _minimal_ctx(**kwargs):
from cell_explorer_agent.prompt.dataset_context import DatasetContext, ObsColumnInfo
return DatasetContext(
slug="test-ds",
name="Test Dataset",
description="A test description.",
n_obs=1000,
n_var=200,
obs_columns=[ObsColumnInfo(name="cell_type", dtype="categorical", cardinality=5)],
embedding_keys=["X_umap"],
**kwargs,
)


def test_build_system_prompt_includes_curator_notes_when_present():
from cell_explorer_agent.prompt.system import build_system_prompt
addendum = "There is a subtle but important consideration: cells were sorted on CD45."
ctx = _minimal_ctx(prompt_addendum=addendum)
prompt = build_system_prompt(ctx)

assert "=== Curator notes (authoritative dataset context) ===" in prompt
assert addendum in prompt
assert "=== end curator notes ===" in prompt

# Opening fence appears before the shape line.
idx_fence = prompt.index("=== Curator notes (authoritative dataset context) ===")
idx_shape = prompt.index("Shape:")
assert idx_fence < idx_shape

# Opening fence appears after description.
idx_desc = prompt.index("A test description.")
assert idx_desc < idx_fence

# The block is paragraph-isolated: a blank line follows the closing
# fence so it doesn't visually bleed into the shape metadata.
assert "=== end curator notes ===\n\nShape:" in prompt


def test_build_system_prompt_omits_curator_notes_when_none():
from cell_explorer_agent.prompt.system import build_system_prompt
ctx = _minimal_ctx(prompt_addendum=None)
prompt = build_system_prompt(ctx)
assert "Curator notes" not in prompt

# Byte-identical to prompt built from a context without the field set (default None).
ctx2 = _minimal_ctx()
assert prompt == build_system_prompt(ctx2)


@pytest.mark.parametrize("addendum", ["", " \n \t "])
def test_build_system_prompt_omits_curator_notes_when_empty(addendum):
from cell_explorer_agent.prompt.system import build_system_prompt
ctx = _minimal_ctx(prompt_addendum=addendum)
prompt = build_system_prompt(ctx)
assert "Curator notes" not in prompt
Loading