diff --git a/packages/api/alembic/versions/27f37d3cf8d2_add_dataset_prompt_addendum.py b/packages/api/alembic/versions/27f37d3cf8d2_add_dataset_prompt_addendum.py new file mode 100644 index 0000000..79d65ad --- /dev/null +++ b/packages/api/alembic/versions/27f37d3cf8d2_add_dataset_prompt_addendum.py @@ -0,0 +1,28 @@ +"""add dataset prompt_addendum + +Revision ID: 27f37d3cf8d2 +Revises: 092e2aa153ce +Create Date: 2026-05-21 10:09:30.488887 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '27f37d3cf8d2' +down_revision: Union[str, Sequence[str], None] = '092e2aa153ce' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + op.add_column('datasets', sa.Column('prompt_addendum', sa.Text(), nullable=True)) + + +def downgrade() -> None: + """Downgrade schema.""" + op.drop_column('datasets', 'prompt_addendum') diff --git a/packages/api/src/cell_explorer_api/db/models.py b/packages/api/src/cell_explorer_api/db/models.py index 2e31a51..8cd2d3a 100644 --- a/packages/api/src/cell_explorer_api/db/models.py +++ b/packages/api/src/cell_explorer_api/db/models.py @@ -55,6 +55,7 @@ class Dataset(SQLModel, table=True): description: str | None = None is_public: bool = Field(default=False) required_roles: list[str] = Field(default_factory=list, sa_column=Column(JSON)) + prompt_addendum: str | None = Field(default=None) chat_enabled: bool = Field(default=False) created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) diff --git a/packages/api/src/cell_explorer_api/routes/admin.py b/packages/api/src/cell_explorer_api/routes/admin.py index 308503d..bb3cc6e 100644 --- a/packages/api/src/cell_explorer_api/routes/admin.py +++ b/packages/api/src/cell_explorer_api/routes/admin.py @@ -58,6 +58,7 @@ class DatasetCreate(BaseModel): description: str | None = None is_public: bool = False required_roles: list[str] = [] + prompt_addendum: str | None = None chat_enabled: bool = False @@ -67,6 +68,7 @@ class DatasetUpdate(BaseModel): description: str | None = None is_public: bool | None = None required_roles: list[str] | None = None + prompt_addendum: str | None = None chat_enabled: bool | None = None @@ -79,6 +81,7 @@ class DatasetAdminResponse(BaseModel): description: str | None is_public: bool required_roles: list[str] + prompt_addendum: str | None chat_enabled: bool @@ -175,6 +178,7 @@ async def list_datasets_admin( description=dataset.description, is_public=dataset.is_public, required_roles=dataset.required_roles, + prompt_addendum=dataset.prompt_addendum, chat_enabled=dataset.chat_enabled, ) for dataset in result.all() @@ -209,6 +213,7 @@ async def create_dataset( description=dataset.description, is_public=dataset.is_public, required_roles=dataset.required_roles, + prompt_addendum=dataset.prompt_addendum, chat_enabled=dataset.chat_enabled, ) @@ -239,6 +244,7 @@ async def update_dataset( description=dataset.description, is_public=dataset.is_public, required_roles=dataset.required_roles, + prompt_addendum=dataset.prompt_addendum, chat_enabled=dataset.chat_enabled, ) diff --git a/packages/api/src/cell_explorer_api/services/chat_session.py b/packages/api/src/cell_explorer_api/services/chat_session.py index e737157..4be5f58 100644 --- a/packages/api/src/cell_explorer_api/services/chat_session.py +++ b/packages/api/src/cell_explorer_api/services/chat_session.py @@ -32,11 +32,13 @@ # In-process cache of DatasetContext keyed by (slug, updated_at). The Dataset -# row's updated_at bumps on every admin PUT (see routes/admin.py), so this -# self-invalidates without an explicit hook: admin edits change the key, -# subsequent requests miss the cache and rebuild. Process restart clears the -# cache (uvicorn --reload covers dev). Stale entries accumulate on edits but -# the leak is bounded by edit frequency and dataset count. +# row's updated_at bumps unconditionally on every admin PUT (see +# routes/admin.py update_dataset, line ~234), so this self-invalidates for any +# DB-sourced DatasetContext field — including prompt_addendum — without an +# explicit hook: admin edits change the key, subsequent requests miss the +# cache and rebuild. Process restart clears the cache (uvicorn --reload covers +# dev). Stale entries accumulate on edits but the leak is bounded by edit +# frequency and dataset count. # # See issue #101. _dataset_ctx_cache: dict[tuple[str, datetime], DatasetContext] = {} @@ -62,6 +64,7 @@ async def _build_dataset_context_cached( slug=dataset.slug, name=dataset.name, description=dataset.description or "", + prompt_addendum=dataset.prompt_addendum, ) return _dataset_ctx_cache[key] diff --git a/packages/api/tests/services/test_chat_session.py b/packages/api/tests/services/test_chat_session.py index 41772a7..f0f8801 100644 --- a/packages/api/tests/services/test_chat_session.py +++ b/packages/api/tests/services/test_chat_session.py @@ -358,6 +358,43 @@ async def _spy_build(*args, **kwargs): assert call_count == 2, f"build_dataset_context called {call_count}x; expected 2 after updated_at bump" +@pytest.mark.asyncio +async def test_make_chat_agent_forwards_prompt_addendum_to_dataset_context(): + """When the Dataset row has prompt_addendum set, make_chat_agent forwards + it to build_dataset_context (via _build_dataset_context_cached).""" + dataset = _public_dataset() + dataset.prompt_addendum = "Test curator note: cells were sorted on CD45." + datasource = MagicMock(base_url="https://example.com", type="HTTP_TOKEN", credential_ref=None) + db = await _mk_db_session(_make_db_row(dataset, datasource)) + + fake_anndata = MagicMock(n_obs=10, n_vars=20, obsm_keys=[], obs_columns=[]) + + captured: dict = {} + from cell_explorer_agent import build_dataset_context as _real + + async def _spy_build(*args, **kwargs): + captured.update(kwargs) + return await _real(*args, **kwargs) + + with patch("cell_explorer_api.services.chat_session.ZarrStore") as MockZS, \ + patch("cell_explorer_api.services.chat_session.AnnDataStore") as MockADS, \ + patch("cell_explorer_api.services.chat_session.StrataStore") as MockSS, \ + patch("cell_explorer_api.services.chat_session.build_dataset_context", _spy_build): + MockZS.open = AsyncMock(return_value=MagicMock()) + MockADS.open = AsyncMock(return_value=fake_anndata) + MockSS.open = AsyncMock(return_value=MagicMock()) + + user = _FakeUser(roles=[]) + llm = FakeLLMClient(scripts=[]) + settings = MagicMock() + agent = await make_chat_agent( + user=user, dataset_slug="pbmc3k", db=db, settings=settings, llm=llm, + ) + + assert captured.get("prompt_addendum") == "Test curator note: cells were sorted on CD45." + assert agent.dataset_ctx.prompt_addendum == "Test curator note: cells were sorted on CD45." + + @pytest.mark.asyncio async def test_dataset_ctx_cache_is_independent_per_slug(): """Two datasets with different slugs cache independently.""" diff --git a/packages/api/tests/test_admin_routes.py b/packages/api/tests/test_admin_routes.py index 3a6ed28..d34675c 100644 --- a/packages/api/tests/test_admin_routes.py +++ b/packages/api/tests/test_admin_routes.py @@ -302,3 +302,102 @@ def test_admin_create_dataset_chat_enabled_defaults_false(seeded_app): ) assert response.status_code == 201 assert response.json()["chat_enabled"] is False + + +# --- prompt_addendum field --- + + +def test_create_dataset_with_prompt_addendum(seeded_app): + """POST /admin/datasets stores prompt_addendum and surfaces it on response.""" + ds_id = seeded_app.state.test_datasource_id + client = TestClient(seeded_app) + response = client.post( + "/api/admin/datasets", + json={ + "datasource_id": ds_id, + "name": "Test", + "slug": "test-with-addendum", + "path": "test.zarr", + "prompt_addendum": "Important: cells were sorted on CD45 first.", + }, + headers=AUTH_HEADER, + ) + assert response.status_code == 201, response.text + body = response.json() + assert body["prompt_addendum"] == "Important: cells were sorted on CD45 first." + + +def test_create_dataset_prompt_addendum_defaults_null(seeded_app): + """prompt_addendum is null when not supplied on create.""" + ds_id = seeded_app.state.test_datasource_id + client = TestClient(seeded_app) + response = client.post( + "/api/admin/datasets", + json={ + "datasource_id": ds_id, + "name": "No-Addendum-Default", + "slug": "no-addendum-default", + "path": "no-addendum.zarr", + }, + headers=AUTH_HEADER, + ) + assert response.status_code == 201, response.text + assert response.json()["prompt_addendum"] is None + + +def test_update_dataset_prompt_addendum(seeded_app): + """PUT updates only prompt_addendum; other fields unchanged.""" + ds_id = seeded_app.state.test_datasource_id + client = TestClient(seeded_app) + # Create a dataset with no addendum first + create_resp = client.post( + "/api/admin/datasets", + json={ + "datasource_id": ds_id, + "name": "No-Addendum", + "slug": "no-addendum", + "path": "no-addendum.zarr", + }, + headers=AUTH_HEADER, + ) + assert create_resp.status_code == 201 + created = create_resp.json() + assert created["prompt_addendum"] is None + # Update only prompt_addendum + response = client.put( + "/api/admin/datasets/no-addendum", + json={"prompt_addendum": "Curator notes here."}, + headers=AUTH_HEADER, + ) + assert response.status_code == 200, response.text + body = response.json() + assert body["prompt_addendum"] == "Curator notes here." + # Other fields kept their previous values + assert body["name"] == created["name"] + assert body["path"] == created["path"] + + +def test_get_datasets_includes_prompt_addendum(seeded_app): + """GET /admin/datasets surfaces prompt_addendum on each row.""" + ds_id = seeded_app.state.test_datasource_id + client = TestClient(seeded_app) + # Create a dataset with an addendum + create_resp = client.post( + "/api/admin/datasets", + json={ + "datasource_id": ds_id, + "name": "Has-Addendum", + "slug": "has-addendum", + "path": "has-addendum.zarr", + "prompt_addendum": "Sample curator note.", + }, + headers=AUTH_HEADER, + ) + assert create_resp.status_code == 201 + response = client.get("/api/admin/datasets", headers=AUTH_HEADER) + assert response.status_code == 200 + body = response.json() + rows = {d["slug"]: d for d in body["datasets"]} + assert "has-addendum" in rows + assert "prompt_addendum" in rows["has-addendum"] + assert rows["has-addendum"]["prompt_addendum"] == "Sample curator note." diff --git a/packages/cell-explorer-agent/src/cell_explorer_agent/prompt/dataset_context.py b/packages/cell-explorer-agent/src/cell_explorer_agent/prompt/dataset_context.py index ca945e5..0ad12be 100644 --- a/packages/cell-explorer-agent/src/cell_explorer_agent/prompt/dataset_context.py +++ b/packages/cell-explorer-agent/src/cell_explorer_agent/prompt/dataset_context.py @@ -23,10 +23,16 @@ class DatasetContext: n_var: int obs_columns: list[ObsColumnInfo] embedding_keys: list[str] + prompt_addendum: str | None = None async def build_dataset_context( - z: ZarrAccess, *, slug: str, name: str, description: str + z: ZarrAccess, + *, + slug: str, + name: str, + description: str, + prompt_addendum: str | None = None, ) -> DatasetContext: n_obs, n_var = await z.shape() obs = await z.obs_columns() @@ -47,4 +53,5 @@ async def build_dataset_context( for c in obs ], embedding_keys=list(emb), + prompt_addendum=prompt_addendum, ) diff --git a/packages/cell-explorer-agent/src/cell_explorer_agent/prompt/system.py b/packages/cell-explorer-agent/src/cell_explorer_agent/prompt/system.py index bda9f15..d812ea5 100644 --- a/packages/cell-explorer-agent/src/cell_explorer_agent/prompt/system.py +++ b/packages/cell-explorer-agent/src/cell_explorer_agent/prompt/system.py @@ -24,6 +24,12 @@ def build_system_prompt(ctx: DatasetContext) -> str: lines.append(f"Dataset: {ctx.slug} — {ctx.name}") if ctx.description: lines.append(f"Description: {ctx.description}") + if ctx.prompt_addendum and ctx.prompt_addendum.strip(): + lines.append("") + lines.append("=== Curator notes (authoritative dataset context) ===") + lines.append(ctx.prompt_addendum) + lines.append("=== end curator notes ===") + lines.append("") lines.append(f"Shape: {ctx.n_obs} cells × {ctx.n_var} genes.") lines.append("") lines.append("Obs columns:") diff --git a/packages/cell-explorer-agent/tests/test_prompt_system.py b/packages/cell-explorer-agent/tests/test_prompt_system.py index 20f3b57..c4df9cd 100644 --- a/packages/cell-explorer-agent/tests/test_prompt_system.py +++ b/packages/cell-explorer-agent/tests/test_prompt_system.py @@ -1,3 +1,6 @@ +import pytest + + def test_system_prompt_mentions_chart_awareness(): """The agent is told to write concise summaries when a chart is present in the tool result, instead of re-enumerating every row.""" @@ -13,3 +16,60 @@ def test_system_prompt_mentions_chart_awareness(): assert "chart" in prompt.lower() # gene_panel_by_obs is registered in the tool-use policy. assert "gene_panel_by_obs" in prompt + + +def _minimal_ctx(**kwargs): + from cell_explorer_agent.prompt.dataset_context import DatasetContext, ObsColumnInfo + return DatasetContext( + slug="test-ds", + name="Test Dataset", + description="A test description.", + n_obs=1000, + n_var=200, + obs_columns=[ObsColumnInfo(name="cell_type", dtype="categorical", cardinality=5)], + embedding_keys=["X_umap"], + **kwargs, + ) + + +def test_build_system_prompt_includes_curator_notes_when_present(): + from cell_explorer_agent.prompt.system import build_system_prompt + addendum = "There is a subtle but important consideration: cells were sorted on CD45." + ctx = _minimal_ctx(prompt_addendum=addendum) + prompt = build_system_prompt(ctx) + + assert "=== Curator notes (authoritative dataset context) ===" in prompt + assert addendum in prompt + assert "=== end curator notes ===" in prompt + + # Opening fence appears before the shape line. + idx_fence = prompt.index("=== Curator notes (authoritative dataset context) ===") + idx_shape = prompt.index("Shape:") + assert idx_fence < idx_shape + + # Opening fence appears after description. + idx_desc = prompt.index("A test description.") + assert idx_desc < idx_fence + + # The block is paragraph-isolated: a blank line follows the closing + # fence so it doesn't visually bleed into the shape metadata. + assert "=== end curator notes ===\n\nShape:" in prompt + + +def test_build_system_prompt_omits_curator_notes_when_none(): + from cell_explorer_agent.prompt.system import build_system_prompt + ctx = _minimal_ctx(prompt_addendum=None) + prompt = build_system_prompt(ctx) + assert "Curator notes" not in prompt + + # Byte-identical to prompt built from a context without the field set (default None). + ctx2 = _minimal_ctx() + assert prompt == build_system_prompt(ctx2) + + +@pytest.mark.parametrize("addendum", ["", " \n \t "]) +def test_build_system_prompt_omits_curator_notes_when_empty(addendum): + from cell_explorer_agent.prompt.system import build_system_prompt + ctx = _minimal_ctx(prompt_addendum=addendum) + prompt = build_system_prompt(ctx) + assert "Curator notes" not in prompt