diff --git a/marimo/_code_mode/__init__.py b/marimo/_code_mode/__init__.py index b9fafd36ad5..5d954506564 100644 --- a/marimo/_code_mode/__init__.py +++ b/marimo/_code_mode/__init__.py @@ -36,6 +36,7 @@ AsyncCodeModeContext, CellStatusType, NotebookCell, + StaleCellError, get_context, ) @@ -43,5 +44,6 @@ "AsyncCodeModeContext", "CellStatusType", "NotebookCell", + "StaleCellError", "get_context", ] diff --git a/marimo/_code_mode/_context.py b/marimo/_code_mode/_context.py index 767f182e121..a6233ccbbd1 100644 --- a/marimo/_code_mode/_context.py +++ b/marimo/_code_mode/_context.py @@ -189,7 +189,11 @@ def exception(self) -> Exception | None: ... # ------------------------------------------------------------------ -def get_context(*, skip_validation: bool = False) -> AsyncCodeModeContext: +def get_context( + *, + skip_validation: bool = False, + skip_staleness_check: bool = False, +) -> AsyncCodeModeContext: """Return an ``AsyncCodeModeContext`` for the running kernel. Use as an async context manager:: @@ -206,6 +210,12 @@ def get_context(*, skip_validation: bool = False) -> AsyncCodeModeContext: and should almost never be disabled. Only set to True when you intentionally need to insert code that would fail validation (e.g. incomplete stubs the user plans to fix by hand). + skip_staleness_check : bool, default False + When False (the default), ``edit_cell`` raises + :class:`StaleCellError` if the agent tries to overwrite a cell + whose code has changed since the agent last read it (e.g. via + ``ctx.cells[cell_id]``). Set to True to overwrite blindly — + useful when the agent intentionally discards prior content. """ runtime_ctx = _get_runtime_context() if not isinstance(runtime_ctx, KernelRuntimeContext): @@ -215,9 +225,38 @@ def get_context(*, skip_validation: bool = False) -> AsyncCodeModeContext: runtime_ctx._kernel, cell_manager=cell_manager, skip_validation=skip_validation, + skip_staleness_check=skip_staleness_check, ) +class StaleCellError(RuntimeError): + """Raised when ``edit_cell`` targets a cell the agent has not read at + its current version. + + ``stale_cells`` is the full set of cells in the same boat, so the agent + can re-read all of them in one pass. + """ + + def __init__( + self, + cell_id: CellId_t, + stale_cells: frozenset[CellId_t], + ) -> None: + self.cell_id = cell_id + self.stale_cells = stale_cells + others = sorted(stale_cells - {cell_id}) + other_hint = ( + f"\nOther stale cells: {', '.join(others)}." if others else "" + ) + super().__init__( + f"Cell {cell_id!r} was modified since the agent last read it.\n" + f"Read it first (e.g. `ctx.cells[{cell_id!r}].code`) before " + f"editing.{other_hint}\n" + f"To override and overwrite without re-reading, pass " + f"skip_staleness_check=True to cm.get_context()." + ) + + @helpable class NotebookCell: """Read-only view of a single cell with runtime status. @@ -411,6 +450,7 @@ def _doc(self) -> NotebookDocument: def _cell_view(self, cell: _NotebookCell) -> NotebookCell: """Wrap a document cell with runtime state from the graph.""" + self._ctx._note_read(cell.id, cell.version) try: graph = self._ctx.graph impl = graph.cells.get(cell.id) @@ -589,6 +629,7 @@ def __init__( cell_manager: CellManager | None = None, *, skip_validation: bool = False, + skip_staleness_check: bool = False, ) -> None: from marimo._messaging.notebook.document import get_current_document @@ -603,6 +644,7 @@ def __init__( self._document = document self._cell_manager = cell_manager self._skip_validation = skip_validation + self._skip_staleness_check = skip_staleness_check self._ops: list[_Op] = [] # Track cell IDs added during this batch so subsequent ops # can reference them before they exist in the graph. @@ -629,6 +671,9 @@ def _require_entered(self) -> None: "Without 'async with', operations are silently lost." ) + def _note_read(self, cell_id: CellId_t, version: int) -> None: + self._kernel.agent.read_tracker.record_read(cell_id, version) + def __getattr__(self, name: str) -> Any: # Legacy alias: `ctx.install_packages(...)` was the pre-namespace # API. Kept as a hidden shim for in-flight skills / examples; @@ -1076,6 +1121,27 @@ def edit_cell( # Setup is identified by cell_id alone — don't store a name. name = None + # Check after the setup-migration block so an implicit code-fill + # from the graph still trips the read-before-write guard. + if ( + code is not None + and not self._skip_staleness_check + and cell_id not in self._pending_adds + ): + cell = self._document.get(cell_id) + tracker = self._kernel.agent.read_tracker + # Empty cells have no prior content to clobber. The agent's own + # writes record reads at __aexit__, so a follow-up edit in a + # later context passes the check normally. + if ( + cell is not None + and cell.code.strip() + and not tracker.has_read(cell_id, cell.version) + ): + raise StaleCellError( + cell_id, tracker.get_stale_cells(self._document) + ) + # Build config only if any config kwarg was explicitly set. config: CellConfig | None = None if hide_code is not None or disabled is not None or column is not None: @@ -1594,6 +1660,20 @@ async def _apply_ops( NotebookDocumentTransactionNotification(transaction=tx) ) + # The agent wrote these cells in this batch — its effective view is + # the post-write version. Without this, a cell created in call N + # can't be edited in call N+1 without re-reading. + for op in ops: + if isinstance(op, _AddOp): + written_id = op.cell_id + elif isinstance(op, _UpdateOp) and op.code is not None: + written_id = op.new_cell_id or op.cell_id + else: + continue + current = self._document.get_cell_version(written_id) + if current is not None: + self._note_read(written_id, current) + # Run queued cells (explicit run_cell + autorun descendants), # filtered to cells that still exist after structural ops. _run_set = explicit_run or set() diff --git a/marimo/_messaging/notebook/document.py b/marimo/_messaging/notebook/document.py index 947c4d766c3..119a20fad9b 100644 --- a/marimo/_messaging/notebook/document.py +++ b/marimo/_messaging/notebook/document.py @@ -48,12 +48,17 @@ class NotebookCell(msgspec.Struct): - """A single cell in the document. Mutable — owned by the document.""" + """A single cell in the document. Mutable — owned by the document. + + ``version`` increments on each ``SetCode`` that actually changes + ``code``. Other property changes don't bump it. + """ id: CellId_t code: str name: str config: CellConfig + version: int = 0 def __repr__(self) -> str: first_line = self.code.split("\n", 1)[0] @@ -123,6 +128,11 @@ def get(self, cell_id: CellId_t) -> NotebookCell | None: return cell return None + def get_cell_version(self, cell_id: CellId_t) -> int | None: + """Return the cell's version counter, or ``None`` if not found.""" + cell = self.get(cell_id) + return cell.version if cell is not None else None + def __contains__(self, cell_id: object) -> bool: return any(c.id == cell_id for c in self._cells) @@ -203,7 +213,12 @@ def _apply_change(self, change: DocumentChange) -> None: self._cells = reordered elif isinstance(change, SetCode): - self._find_cell(change.cell_id).code = change.code + cell = self._find_cell(change.cell_id) + # No-op SetCode keeps version stable so format-on-save round + # trips don't invalidate the agent's prior read. + if cell.code != change.code: + cell.code = change.code + cell.version += 1 elif isinstance(change, SetName): self._find_cell(change.cell_id).name = change.name @@ -232,7 +247,22 @@ def _replace_cells(self, cells: list[NotebookCell]) -> None: pre-rebuild state — useful for diff comparison. Bumps ``version`` so observers see the state change like they would after ``apply()``. + + Per-cell ``version`` is reconciled against the prior state: an id + whose code matches inherits the prior version (agent reads stay + valid); an id whose code changed gets ``prior + 1`` (invalidates + stale agent reads); brand-new ids keep the version they were + constructed with. """ + prior_by_id = {c.id: c for c in self._cells} + for i, c in enumerate(cells): + prior = prior_by_id.get(c.id) + if prior is None: + continue + if prior.code == c.code: + cells[i] = structs_replace(c, version=prior.version) + else: + cells[i] = structs_replace(c, version=prior.version + 1) self._cells = cells self._version += 1 diff --git a/marimo/_runtime/agent.py b/marimo/_runtime/agent.py new file mode 100644 index 00000000000..eadc9e00d7c --- /dev/null +++ b/marimo/_runtime/agent.py @@ -0,0 +1,46 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Per-kernel agent state.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from marimo._messaging.notebook.document import NotebookDocument + from marimo._types.ids import CellId_t + + +class AgentReadTracker: + """Highest cell version the agent has observed, per cell.""" + + def __init__(self) -> None: + self._read_versions: dict[CellId_t, int] = {} + + def record_read(self, cell_id: CellId_t, version: int) -> None: + prev = self._read_versions.get(cell_id, -1) + if version > prev: + self._read_versions[cell_id] = version + + def has_read(self, cell_id: CellId_t, current_version: int) -> bool: + last = self._read_versions.get(cell_id) + return last is not None and last >= current_version + + def get_stale_cells(self, doc: NotebookDocument) -> frozenset[CellId_t]: + # Empty (or whitespace-only) cells have nothing to clobber, so they + # never count as stale even when the agent hasn't read them. + stale: set[CellId_t] = set() + for cell in doc.cells: + if not cell.code.strip(): + continue + last = self._read_versions.get(cell.id) + if last is None or cell.version > last: + stale.add(cell.id) + return frozenset(stale) + + +@dataclass +class Agent: + """One per ``Kernel`` — long-lived across scratchpad executions.""" + + read_tracker: AgentReadTracker = field(default_factory=AgentReadTracker) diff --git a/marimo/_runtime/runtime.py b/marimo/_runtime/runtime.py index bf0075c6001..49fbf20ad8d 100644 --- a/marimo/_runtime/runtime.py +++ b/marimo/_runtime/runtime.py @@ -120,6 +120,7 @@ from marimo._plugins.ui._core.ui_element import MarimoConvertValueException from marimo._plugins.ui._impl.anywidget.init import WIDGET_COMM_MANAGER from marimo._runtime import dataflow, handlers, marimo_pdb, patches +from marimo._runtime.agent import Agent from marimo._runtime.app_meta import AppMeta from marimo._runtime.commands import ( AppMetadata, @@ -592,6 +593,7 @@ def __init__( sys.path.insert(0, "") self.graph = dataflow.DirectedGraph() + self.agent = Agent() # When autorun on startup is disabled, this holds cells that have # not yet been run; these cells are removed when they or their # descendants are run diff --git a/tests/_code_mode/test_cells_view.py b/tests/_code_mode/test_cells_view.py index 67ea16c55e7..9061fbc0795 100644 --- a/tests/_code_mode/test_cells_view.py +++ b/tests/_code_mode/test_cells_view.py @@ -216,7 +216,11 @@ def _cell(cell_id: str, code: str, name: str = "") -> NotebookCell: def _view(cells: list[NotebookCell]) -> _CellsView: doc = NotebookDocument(cells) - ctx = type("_MockCtx", (), {"_document": doc})() + ctx = type( + "_MockCtx", + (), + {"_document": doc, "_note_read": lambda *_args, **_kwargs: None}, + )() return _CellsView(ctx) diff --git a/tests/_code_mode/test_context.py b/tests/_code_mode/test_context.py index 5cc0a03ef2b..6a4e17908bc 100644 --- a/tests/_code_mode/test_context.py +++ b/tests/_code_mode/test_context.py @@ -47,7 +47,8 @@ def _ctx( cells.extend(extra_doc_cells) doc = NotebookDocument(cells) with notebook_document_context(doc): - ctx = AsyncCodeModeContext(k) + # Staleness check coverage lives in test_staleness.py. + ctx = AsyncCodeModeContext(k, skip_staleness_check=True) # Use a deterministic seed in tests for snapshot stability. ctx._id_generator = CellIdGenerator(seed=7) ctx._id_generator.seen_ids = set(doc.cell_ids) diff --git a/tests/_code_mode/test_context_autosave.py b/tests/_code_mode/test_context_autosave.py index 68a21c6009d..a0379b59eff 100644 --- a/tests/_code_mode/test_context_autosave.py +++ b/tests/_code_mode/test_context_autosave.py @@ -71,7 +71,7 @@ def _ctx(k: Kernel) -> Generator[AsyncCodeModeContext, None, None]: ] doc = NotebookDocument(cells) with notebook_document_context(doc): - ctx = AsyncCodeModeContext(k) + ctx = AsyncCodeModeContext(k, skip_staleness_check=True) ctx._id_generator = CellIdGenerator(seed=7) ctx._id_generator.seen_ids = set(doc.cell_ids) yield ctx diff --git a/tests/_code_mode/test_staleness.py b/tests/_code_mode/test_staleness.py new file mode 100644 index 00000000000..8258a7a3867 --- /dev/null +++ b/tests/_code_mode/test_staleness.py @@ -0,0 +1,345 @@ +# Copyright 2026 Marimo. All rights reserved. +from __future__ import annotations + +from contextlib import contextmanager +from typing import TYPE_CHECKING + +import pytest + +from marimo._ast.cell_id import CellIdGenerator +from marimo._code_mode._context import ( + AsyncCodeModeContext, + StaleCellError, +) +from marimo._messaging.notebook.changes import SetCode, Transaction +from marimo._messaging.notebook.document import ( + NotebookCell, + NotebookDocument, + notebook_document_context, +) +from marimo._runtime.commands import ExecuteCellCommand +from marimo._types.ids import CellId_t + +if TYPE_CHECKING: + from collections.abc import Generator + + from marimo._runtime.runtime import Kernel + + +@contextmanager +def _ctx( + k: Kernel, + *, + skip_staleness_check: bool = False, +) -> Generator[AsyncCodeModeContext, None, None]: + cells = [ + NotebookCell(id=cid, code=cell.code, name="", config=cell.config) + for cid, cell in k.graph.cells.items() + ] + doc = NotebookDocument(cells) + with notebook_document_context(doc): + ctx = AsyncCodeModeContext( + k, skip_staleness_check=skip_staleness_check + ) + ctx._id_generator = CellIdGenerator(seed=7) + ctx._id_generator.seen_ids = set(doc.cell_ids) + yield ctx + + +async def _seed(k: Kernel, cell_id: str, code: str) -> None: + await k.run([ExecuteCellCommand(cell_id=CellId_t(cell_id), code=code)]) + + +def _bump_cell_version(doc: NotebookDocument, cell_id: str, code: str) -> None: + doc.apply( + Transaction( + changes=(SetCode(cell_id=CellId_t(cell_id), code=code),), + source="frontend", + ) + ) + + +class TestStalenessBlocks: + async def test_edit_without_read_raises(self, k: Kernel) -> None: + await _seed(k, "0", "a = 1") + with _ctx(k) as ctx: + async with ctx as nb: + with pytest.raises(StaleCellError) as exc: + nb.edit_cell("0", "a = 2") + assert exc.value.cell_id == CellId_t("0") + assert CellId_t("0") in exc.value.stale_cells + + async def test_edit_after_read_succeeds(self, k: Kernel) -> None: + await _seed(k, "0", "a = 1") + with _ctx(k) as ctx: + async with ctx as nb: + _ = nb.cells["0"].code # materialize == read + nb.edit_cell("0", "a = 2") + + async def test_read_persists_across_contexts(self, k: Kernel) -> None: + # kernel.agent is long-lived, so a read satisfies a future call's + # check as long as the cell's version hasn't been bumped since. + await _seed(k, "0", "a = 1") + with _ctx(k) as ctx: + async with ctx as nb: + _ = nb.cells["0"].code + + with _ctx(k) as ctx: + async with ctx as nb: + nb.edit_cell("0", "a = 2") + + async def test_frontend_edit_invalidates_prior_read( + self, k: Kernel + ) -> None: + await _seed(k, "0", "a = 1") + with _ctx(k) as ctx: + async with ctx as nb: + _ = nb.cells["0"].code + + with _ctx(k) as ctx: + _bump_cell_version(ctx._document, "0", "a = 99") + async with ctx as nb: + with pytest.raises(StaleCellError): + nb.edit_cell("0", "a = 2") + + +class TestStalenessExemptions: + async def test_pending_add_then_edit(self, k: Kernel) -> None: + with _ctx(k) as ctx: + async with ctx as nb: + cid = nb.create_cell("x = 1") + nb.edit_cell(cid, "x = 2") + + async def test_config_only_edit_does_not_check(self, k: Kernel) -> None: + await _seed(k, "0", "a = 1") + with _ctx(k) as ctx: + async with ctx as nb: + nb.edit_cell("0", hide_code=False) + + async def test_delete_cell_not_protected(self, k: Kernel) -> None: + await _seed(k, "0", "a = 1") + with _ctx(k) as ctx: + async with ctx as nb: + nb.delete_cell("0") + + async def test_move_cell_not_protected(self, k: Kernel) -> None: + await _seed(k, "0", "a = 1") + await _seed(k, "1", "b = 2") + with _ctx(k) as ctx: + async with ctx as nb: + nb.move_cell("0", after="1") + + async def test_skip_staleness_check_opt_out(self, k: Kernel) -> None: + await _seed(k, "0", "a = 1") + with _ctx(k, skip_staleness_check=True) as ctx: + async with ctx as nb: + nb.edit_cell("0", "a = 2") + + +class TestStaleCellErrorMessage: + async def test_lists_other_stale_cells(self, k: Kernel) -> None: + await _seed(k, "0", "a = 1") + await _seed(k, "1", "b = 2") + await _seed(k, "2", "c = 3") + with _ctx(k) as ctx: + async with ctx as nb: + with pytest.raises(StaleCellError) as exc: + nb.edit_cell("0", "a = 999") + assert exc.value.stale_cells == frozenset( + {CellId_t("0"), CellId_t("1"), CellId_t("2")} + ) + msg = str(exc.value) + assert "'0'" in msg + assert "Other stale cells: 1, 2." in msg + assert "skip_staleness_check=True" in msg + + +class TestCrossContextWriteRead: + async def test_create_then_edit_cross_context(self, k: Kernel) -> None: + with _ctx(k) as ctx: + async with ctx as nb: + cid = nb.create_cell("x = 1") + nb.run_cell(cid) + + with _ctx(k) as ctx: + async with ctx as nb: + nb.edit_cell(cid, "x = 2") + + async def test_edit_then_edit_cross_context(self, k: Kernel) -> None: + await _seed(k, "0", "a = 1") + with _ctx(k) as ctx: + async with ctx as nb: + _ = nb.cells["0"].code + nb.edit_cell("0", "a = 2") + + with _ctx(k) as ctx: + async with ctx as nb: + nb.edit_cell("0", "a = 3") + + +class TestFileWatchReload: + async def test_replace_cells_same_code_keeps_reads_valid( + self, k: Kernel + ) -> None: + from marimo._ast.cell import CellConfig + + await _seed(k, "0", "a = 1") + with _ctx(k) as ctx: + async with ctx as nb: + _ = nb.cells["0"].code + + with _ctx(k) as ctx: + ctx._document._replace_cells( + [ + NotebookCell( + id=CellId_t("0"), + code="a = 1", + name="", + config=CellConfig(), + ) + ] + ) + async with ctx as nb: + nb.edit_cell("0", "a = 2") + + async def test_replace_cells_changed_code_invalidates_reads( + self, k: Kernel + ) -> None: + from marimo._ast.cell import CellConfig + + await _seed(k, "0", "a = 1") + with _ctx(k) as ctx: + async with ctx as nb: + _ = nb.cells["0"].code + + with _ctx(k) as ctx: + ctx._document._replace_cells( + [ + NotebookCell( + id=CellId_t("0"), + code="a = 100", + name="", + config=CellConfig(), + ) + ] + ) + async with ctx as nb: + with pytest.raises(StaleCellError): + nb.edit_cell("0", "a = 2") + + +class TestEmptyCellExemption: + async def test_edit_empty_cell_without_read(self, k: Kernel) -> None: + await _seed(k, "0", "") + with _ctx(k) as ctx: + async with ctx as nb: + nb.edit_cell("0", "x = 1") + + async def test_edit_whitespace_cell_without_read(self, k: Kernel) -> None: + await _seed(k, "0", " \n ") + with _ctx(k) as ctx: + async with ctx as nb: + nb.edit_cell("0", "x = 1") + + async def test_non_empty_cell_still_requires_read(self, k: Kernel) -> None: + await _seed(k, "0", "a = 1") + with _ctx(k) as ctx: + async with ctx as nb: + with pytest.raises(StaleCellError): + nb.edit_cell("0", "x = 1") + + async def test_two_edits_on_empty_cell_same_context( + self, k: Kernel + ) -> None: + await _seed(k, "0", "") + with _ctx(k) as ctx: + async with ctx as nb: + nb.edit_cell("0", "x = 1") + nb.edit_cell("0", "x = 2") + + async def test_empty_then_non_empty_edit_cross_context( + self, k: Kernel + ) -> None: + # First edit lands via empty-cell exemption; the agent's own write + # records a read, so the next-context edit passes without a fresh + # materialization even though the cell is now non-empty. + await _seed(k, "0", "") + with _ctx(k) as ctx: + async with ctx as nb: + nb.edit_cell("0", "x = 1") + with _ctx(k) as ctx: + async with ctx as nb: + nb.edit_cell("0", "x = 2") + + async def test_empty_cells_excluded_from_stale_error_list( + self, k: Kernel + ) -> None: + await _seed(k, "empty", "") + await _seed(k, "real", "a = 1") + with _ctx(k) as ctx: + async with ctx as nb: + with pytest.raises(StaleCellError) as exc: + nb.edit_cell("real", "a = 2") + assert exc.value.stale_cells == frozenset({CellId_t("real")}) + + +class TestSetupCellMigration: + async def test_rename_to_setup_without_read_raises( + self, k: Kernel + ) -> None: + # The migration path silently fills `code` from `self.graph.cells` + # when the agent passes `name="setup"` with `code=None`. Without + # the post-migration staleness check, that fill could overwrite + # the doc with stale graph code without ever requiring a read. + await _seed(k, "0", "x = 1") + with _ctx(k) as ctx: + async with ctx as nb: + with pytest.raises(StaleCellError): + nb.edit_cell("0", name="setup") + + async def test_rename_to_setup_after_read_succeeds( + self, k: Kernel + ) -> None: + await _seed(k, "0", "x = 1") + with _ctx(k) as ctx: + async with ctx as nb: + _ = nb.cells["0"].code + nb.edit_cell("0", name="setup") + + +class TestCellsViewReadRecording: + async def test_iteration_records_reads(self, k: Kernel) -> None: + await _seed(k, "0", "a = 1") + await _seed(k, "1", "b = 2") + with _ctx(k) as ctx: + async with ctx as nb: + for _ in nb.cells: + pass + nb.edit_cell("0", "a = 99") + nb.edit_cell("1", "b = 99") + + async def test_find_records_reads_for_matches_only( + self, k: Kernel + ) -> None: + await _seed(k, "0", "alpha = 1") + await _seed(k, "1", "beta = 2") + with _ctx(k) as ctx: + async with ctx as nb: + matches = nb.cells.find("alpha") + assert len(matches) == 1 + nb.edit_cell("0", "alpha = 99") + with pytest.raises(StaleCellError): + nb.edit_cell("1", "beta = 99") + + async def test_grep_records_reads_for_matches_only( + self, k: Kernel + ) -> None: + await _seed(k, "0", "alpha = 1") + await _seed(k, "1", "beta = 2") + with _ctx(k) as ctx: + async with ctx as nb: + matches = nb.cells.grep(r"alpha") + assert len(matches) == 1 + nb.edit_cell("0", "alpha = 99") + with pytest.raises(StaleCellError): + nb.edit_cell("1", "beta = 99") diff --git a/tests/_messaging/notebook/test_document.py b/tests/_messaging/notebook/test_document.py index f3fd14220c3..8a77581ee30 100644 --- a/tests/_messaging/notebook/test_document.py +++ b/tests/_messaging/notebook/test_document.py @@ -259,6 +259,118 @@ def test_not_found(self) -> None: doc.apply(_tx(SetCode(cell_id=CellId_t("missing"), code="x"))) +# ------------------------------------------------------------------ +# Per-cell version +# ------------------------------------------------------------------ + + +class TestCellVersion: + def test_new_cell_starts_at_zero(self) -> None: + doc = _doc() + doc.apply( + _tx( + CreateCell( + cell_id=CellId_t("a"), + code="x", + name="__", + config=CellConfig(), + ) + ) + ) + assert doc.get_cell_version(CellId_t("a")) == 0 + + def test_set_code_bumps_version(self) -> None: + doc = _doc("a") + assert doc.get_cell_version(CellId_t("a")) == 0 + doc.apply(_tx(SetCode(cell_id=CellId_t("a"), code="x"))) + assert doc.get_cell_version(CellId_t("a")) == 1 + doc.apply(_tx(SetCode(cell_id=CellId_t("a"), code="y"))) + assert doc.get_cell_version(CellId_t("a")) == 2 + + def test_no_op_set_code_does_not_bump(self) -> None: + doc = _doc("a") + doc.apply(_tx(SetCode(cell_id=CellId_t("a"), code="x"))) + assert doc.get_cell_version(CellId_t("a")) == 1 + # Re-applying the same code is a no-op (e.g. format-on-save). + doc.apply(_tx(SetCode(cell_id=CellId_t("a"), code="x"))) + assert doc.get_cell_version(CellId_t("a")) == 1 + + def test_other_changes_do_not_bump_version(self) -> None: + doc = _doc("a") + doc.apply(_tx(SetName(cell_id=CellId_t("a"), name="my_cell"))) + assert doc.get_cell_version(CellId_t("a")) == 0 + doc.apply( + _tx( + SetConfig( + cell_id=CellId_t("a"), + column=None, + disabled=True, + hide_code=False, + ) + ) + ) + assert doc.get_cell_version(CellId_t("a")) == 0 + + def test_get_cell_version_missing(self) -> None: + doc = _doc("a") + assert doc.get_cell_version(CellId_t("missing")) is None + + def test_rekey_preserves_version(self) -> None: + doc = _doc("a") + doc.apply(_tx(SetCode(cell_id=CellId_t("a"), code="x"))) + doc.apply(_tx(SetCode(cell_id=CellId_t("a"), code="y"))) + assert doc.get_cell_version(CellId_t("a")) == 2 + doc._rekey({CellId_t("a"): CellId_t("x")}) + assert doc.get_cell_version(CellId_t("x")) == 2 + + def test_replace_cells_same_code_preserves_version(self) -> None: + doc = _doc("a") + doc.apply(_tx(SetCode(cell_id=CellId_t("a"), code="x"))) + assert doc.get_cell_version(CellId_t("a")) == 1 + doc._replace_cells( + [ + NotebookCell( + id=CellId_t("a"), + code="x", + name="__", + config=CellConfig(), + ) + ] + ) + assert doc.get_cell_version(CellId_t("a")) == 1 + + def test_replace_cells_changed_code_bumps_version(self) -> None: + doc = _doc("a") + doc.apply(_tx(SetCode(cell_id=CellId_t("a"), code="x"))) + assert doc.get_cell_version(CellId_t("a")) == 1 + doc._replace_cells( + [ + NotebookCell( + id=CellId_t("a"), + code="y", + name="__", + config=CellConfig(), + ) + ] + ) + assert doc.get_cell_version(CellId_t("a")) == 2 + + def test_replace_cells_new_id_keeps_constructed_version(self) -> None: + doc = _doc("a") + doc._replace_cells( + [ + NotebookCell( + id=CellId_t("b"), + code="y", + name="__", + config=CellConfig(), + version=42, + ) + ] + ) + assert doc.get_cell_version(CellId_t("b")) == 42 + + # ------------------------------------------------------------------ # SetName # ------------------------------------------------------------------ diff --git a/tests/_runtime/test_agent.py b/tests/_runtime/test_agent.py new file mode 100644 index 00000000000..135f61f617b --- /dev/null +++ b/tests/_runtime/test_agent.py @@ -0,0 +1,92 @@ +# Copyright 2026 Marimo. All rights reserved. +from __future__ import annotations + +from marimo._ast.cell import CellConfig +from marimo._messaging.notebook.changes import SetCode, Transaction +from marimo._messaging.notebook.document import NotebookCell, NotebookDocument +from marimo._runtime.agent import Agent, AgentReadTracker +from marimo._types.ids import CellId_t + + +def _cell( + name: str, *, version: int = 0, code: str | None = None +) -> NotebookCell: + # Default to non-empty code so the cell counts as stale-eligible; tests + # that exercise the empty-cell exemption pass code="" explicitly. + return NotebookCell( + id=CellId_t(name), + code=f"# {name}" if code is None else code, + name="__", + config=CellConfig(), + version=version, + ) + + +def _doc(*cells: NotebookCell) -> NotebookDocument: + return NotebookDocument(list(cells)) + + +class TestAgentReadTracker: + def test_record_and_has_read(self) -> None: + t = AgentReadTracker() + assert not t.has_read(CellId_t("a"), 0) + t.record_read(CellId_t("a"), 0) + assert t.has_read(CellId_t("a"), 0) + assert not t.has_read(CellId_t("a"), 1) + + def test_record_read_max_merges(self) -> None: + t = AgentReadTracker() + t.record_read(CellId_t("a"), 5) + t.record_read(CellId_t("a"), 2) + assert t.has_read(CellId_t("a"), 5) + assert not t.has_read(CellId_t("a"), 6) + + def test_get_stale_cells_never_read(self) -> None: + t = AgentReadTracker() + doc = _doc(_cell("a"), _cell("b")) + assert t.get_stale_cells(doc) == frozenset( + {CellId_t("a"), CellId_t("b")} + ) + + def test_get_stale_cells_bumped_since_read(self) -> None: + t = AgentReadTracker() + doc = _doc(_cell("a", version=0), _cell("b", version=0)) + t.record_read(CellId_t("a"), 0) + t.record_read(CellId_t("b"), 0) + assert t.get_stale_cells(doc) == frozenset() + + doc.apply( + Transaction( + changes=(SetCode(cell_id=CellId_t("a"), code="x"),), + source="frontend", + ) + ) + assert doc.get_cell_version(CellId_t("a")) == 1 + assert t.get_stale_cells(doc) == frozenset({CellId_t("a")}) + + def test_get_stale_cells_ignores_deleted(self) -> None: + t = AgentReadTracker() + doc = _doc(_cell("a")) + t.record_read(CellId_t("ghost"), 7) + assert t.get_stale_cells(doc) == frozenset({CellId_t("a")}) + + def test_get_stale_cells_ignores_empty_cells(self) -> None: + t = AgentReadTracker() + doc = _doc( + _cell("empty", code=""), + _cell("whitespace", code=" \n "), + _cell("real", code="a = 1"), + ) + assert t.get_stale_cells(doc) == frozenset({CellId_t("real")}) + + +class TestAgent: + def test_default_factory_initializes_tracker(self) -> None: + a = Agent() + assert isinstance(a.read_tracker, AgentReadTracker) + + def test_independent_instances(self) -> None: + a1, a2 = Agent(), Agent() + a1.read_tracker.record_read(CellId_t("a"), 1) + assert a1.read_tracker.has_read(CellId_t("a"), 1) + assert not a2.read_tracker.has_read(CellId_t("a"), 1) diff --git a/tests/_server/test_scratchpad_integration.py b/tests/_server/test_scratchpad_integration.py index b948223856c..d9eef38ed9d 100644 --- a/tests/_server/test_scratchpad_integration.py +++ b/tests/_server/test_scratchpad_integration.py @@ -625,6 +625,7 @@ def test_ctx_run_cell_cascade_error(session: _Session) -> None: lines = session.execute( "import marimo._code_mode as cm\n" "async with cm.get_context() as ctx:\n" + ' ctx.cells["cell_a"]\n' ' ctx.edit_cell("cell_a", code="x = 0")\n' ' ctx.run_cell("cell_a")', )