Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions marimo/_code_mode/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,14 @@
AsyncCodeModeContext,
CellStatusType,
NotebookCell,
StaleCellError,
get_context,
)

__all__ = [
"AsyncCodeModeContext",
"CellStatusType",
"NotebookCell",
"StaleCellError",
"get_context",
]
82 changes: 81 additions & 1 deletion marimo/_code_mode/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,11 @@ def exception(self) -> Exception | None: ...
# ------------------------------------------------------------------


def get_context(*, skip_validation: bool = False) -> AsyncCodeModeContext:
def get_context(
*,
skip_validation: bool = False,
skip_staleness_check: bool = False,
) -> AsyncCodeModeContext:
"""Return an ``AsyncCodeModeContext`` for the running kernel.

Use as an async context manager::
Expand All @@ -206,6 +210,12 @@ def get_context(*, skip_validation: bool = False) -> AsyncCodeModeContext:
and should almost never be disabled. Only set to True when you
intentionally need to insert code that would fail validation
(e.g. incomplete stubs the user plans to fix by hand).
skip_staleness_check : bool, default False
When False (the default), ``edit_cell`` raises
:class:`StaleCellError` if the agent tries to overwrite a cell
whose code has changed since the agent last read it (e.g. via
``ctx.cells[cell_id]``). Set to True to overwrite blindly —
useful when the agent intentionally discards prior content.
"""
runtime_ctx = _get_runtime_context()
if not isinstance(runtime_ctx, KernelRuntimeContext):
Expand All @@ -215,9 +225,38 @@ def get_context(*, skip_validation: bool = False) -> AsyncCodeModeContext:
runtime_ctx._kernel,
cell_manager=cell_manager,
skip_validation=skip_validation,
skip_staleness_check=skip_staleness_check,
)


class StaleCellError(RuntimeError):
"""Raised when ``edit_cell`` targets a cell the agent has not read at
its current version.

``stale_cells`` is the full set of cells in the same boat, so the agent
can re-read all of them in one pass.
"""

def __init__(
self,
cell_id: CellId_t,
stale_cells: frozenset[CellId_t],
) -> None:
self.cell_id = cell_id
self.stale_cells = stale_cells
others = sorted(stale_cells - {cell_id})
other_hint = (
f"\nOther stale cells: {', '.join(others)}." if others else ""
)
super().__init__(
f"Cell {cell_id!r} was modified since the agent last read it.\n"
f"Read it first (e.g. `ctx.cells[{cell_id!r}].code`) before "
f"editing.{other_hint}\n"
f"To override and overwrite without re-reading, pass "
f"skip_staleness_check=True to cm.get_context()."
)


@helpable
class NotebookCell:
"""Read-only view of a single cell with runtime status.
Expand Down Expand Up @@ -411,6 +450,7 @@ def _doc(self) -> NotebookDocument:

def _cell_view(self, cell: _NotebookCell) -> NotebookCell:
"""Wrap a document cell with runtime state from the graph."""
self._ctx._note_read(cell.id, cell.version)
try:
graph = self._ctx.graph
impl = graph.cells.get(cell.id)
Expand Down Expand Up @@ -589,6 +629,7 @@ def __init__(
cell_manager: CellManager | None = None,
*,
skip_validation: bool = False,
skip_staleness_check: bool = False,
) -> None:
from marimo._messaging.notebook.document import get_current_document

Expand All @@ -603,6 +644,7 @@ def __init__(
self._document = document
self._cell_manager = cell_manager
self._skip_validation = skip_validation
self._skip_staleness_check = skip_staleness_check
self._ops: list[_Op] = []
# Track cell IDs added during this batch so subsequent ops
# can reference them before they exist in the graph.
Expand All @@ -629,6 +671,9 @@ def _require_entered(self) -> None:
"Without 'async with', operations are silently lost."
)

def _note_read(self, cell_id: CellId_t, version: int) -> None:
self._kernel.agent.read_tracker.record_read(cell_id, version)

def __getattr__(self, name: str) -> Any:
# Legacy alias: `ctx.install_packages(...)` was the pre-namespace
# API. Kept as a hidden shim for in-flight skills / examples;
Expand Down Expand Up @@ -1076,6 +1121,27 @@ def edit_cell(
# Setup is identified by cell_id alone — don't store a name.
name = None

# Check after the setup-migration block so an implicit code-fill
# from the graph still trips the read-before-write guard.
if (
code is not None
and not self._skip_staleness_check
and cell_id not in self._pending_adds
):
cell = self._document.get(cell_id)
tracker = self._kernel.agent.read_tracker
# Empty cells have no prior content to clobber. The agent's own
# writes record reads at __aexit__, so a follow-up edit in a
# later context passes the check normally.
if (
cell is not None
and cell.code.strip()
and not tracker.has_read(cell_id, cell.version)
):
raise StaleCellError(
cell_id, tracker.get_stale_cells(self._document)
)

# Build config only if any config kwarg was explicitly set.
config: CellConfig | None = None
if hide_code is not None or disabled is not None or column is not None:
Expand Down Expand Up @@ -1594,6 +1660,20 @@ async def _apply_ops(
NotebookDocumentTransactionNotification(transaction=tx)
)

# The agent wrote these cells in this batch — its effective view is
# the post-write version. Without this, a cell created in call N
# can't be edited in call N+1 without re-reading.
for op in ops:
if isinstance(op, _AddOp):
written_id = op.cell_id
elif isinstance(op, _UpdateOp) and op.code is not None:
written_id = op.new_cell_id or op.cell_id
else:
continue
current = self._document.get_cell_version(written_id)
if current is not None:
self._note_read(written_id, current)

# Run queued cells (explicit run_cell + autorun descendants),
# filtered to cells that still exist after structural ops.
_run_set = explicit_run or set()
Expand Down
34 changes: 32 additions & 2 deletions marimo/_messaging/notebook/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,17 @@


class NotebookCell(msgspec.Struct):
"""A single cell in the document. Mutable — owned by the document."""
"""A single cell in the document. Mutable — owned by the document.

``version`` increments on each ``SetCode`` that actually changes
``code``. Other property changes don't bump it.
"""

id: CellId_t
code: str
name: str
config: CellConfig
version: int = 0

def __repr__(self) -> str:
first_line = self.code.split("\n", 1)[0]
Expand Down Expand Up @@ -123,6 +128,11 @@ def get(self, cell_id: CellId_t) -> NotebookCell | None:
return cell
return None

def get_cell_version(self, cell_id: CellId_t) -> int | None:
"""Return the cell's version counter, or ``None`` if not found."""
cell = self.get(cell_id)
return cell.version if cell is not None else None

def __contains__(self, cell_id: object) -> bool:
return any(c.id == cell_id for c in self._cells)

Expand Down Expand Up @@ -203,7 +213,12 @@ def _apply_change(self, change: DocumentChange) -> None:
self._cells = reordered

elif isinstance(change, SetCode):
self._find_cell(change.cell_id).code = change.code
cell = self._find_cell(change.cell_id)
# No-op SetCode keeps version stable so format-on-save round
# trips don't invalidate the agent's prior read.
if cell.code != change.code:
cell.code = change.code
cell.version += 1

elif isinstance(change, SetName):
self._find_cell(change.cell_id).name = change.name
Expand Down Expand Up @@ -232,7 +247,22 @@ def _replace_cells(self, cells: list[NotebookCell]) -> None:
pre-rebuild state — useful for diff comparison. Bumps
``version`` so observers see the state change like they would
after ``apply()``.

Per-cell ``version`` is reconciled against the prior state: an id
whose code matches inherits the prior version (agent reads stay
valid); an id whose code changed gets ``prior + 1`` (invalidates
stale agent reads); brand-new ids keep the version they were
constructed with.
"""
prior_by_id = {c.id: c for c in self._cells}
for i, c in enumerate(cells):
prior = prior_by_id.get(c.id)
if prior is None:
continue
if prior.code == c.code:
cells[i] = structs_replace(c, version=prior.version)
else:
cells[i] = structs_replace(c, version=prior.version + 1)
self._cells = cells
self._version += 1

Expand Down
46 changes: 46 additions & 0 deletions marimo/_runtime/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright 2026 Marimo. All rights reserved.
"""Per-kernel agent state."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from marimo._messaging.notebook.document import NotebookDocument
from marimo._types.ids import CellId_t


class AgentReadTracker:
"""Highest cell version the agent has observed, per cell."""

def __init__(self) -> None:
self._read_versions: dict[CellId_t, int] = {}

def record_read(self, cell_id: CellId_t, version: int) -> None:
prev = self._read_versions.get(cell_id, -1)
if version > prev:
self._read_versions[cell_id] = version

def has_read(self, cell_id: CellId_t, current_version: int) -> bool:
last = self._read_versions.get(cell_id)
return last is not None and last >= current_version

def get_stale_cells(self, doc: NotebookDocument) -> frozenset[CellId_t]:
# Empty (or whitespace-only) cells have nothing to clobber, so they
# never count as stale even when the agent hasn't read them.
stale: set[CellId_t] = set()
for cell in doc.cells:
if not cell.code.strip():
continue
last = self._read_versions.get(cell.id)
if last is None or cell.version > last:
stale.add(cell.id)
return frozenset(stale)


@dataclass
class Agent:
"""One per ``Kernel`` — long-lived across scratchpad executions."""

read_tracker: AgentReadTracker = field(default_factory=AgentReadTracker)
2 changes: 2 additions & 0 deletions marimo/_runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
from marimo._plugins.ui._core.ui_element import MarimoConvertValueException
from marimo._plugins.ui._impl.anywidget.init import WIDGET_COMM_MANAGER
from marimo._runtime import dataflow, handlers, marimo_pdb, patches
from marimo._runtime.agent import Agent
from marimo._runtime.app_meta import AppMeta
from marimo._runtime.commands import (
AppMetadata,
Expand Down Expand Up @@ -592,6 +593,7 @@ def __init__(
sys.path.insert(0, "")

self.graph = dataflow.DirectedGraph()
self.agent = Agent()
# When autorun on startup is disabled, this holds cells that have
# not yet been run; these cells are removed when they or their
# descendants are run
Expand Down
6 changes: 5 additions & 1 deletion tests/_code_mode/test_cells_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,11 @@ def _cell(cell_id: str, code: str, name: str = "") -> NotebookCell:

def _view(cells: list[NotebookCell]) -> _CellsView:
doc = NotebookDocument(cells)
ctx = type("_MockCtx", (), {"_document": doc})()
ctx = type(
"_MockCtx",
(),
{"_document": doc, "_note_read": lambda *_args, **_kwargs: None},
)()
return _CellsView(ctx)


Expand Down
3 changes: 2 additions & 1 deletion tests/_code_mode/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def _ctx(
cells.extend(extra_doc_cells)
doc = NotebookDocument(cells)
with notebook_document_context(doc):
ctx = AsyncCodeModeContext(k)
# Staleness check coverage lives in test_staleness.py.
ctx = AsyncCodeModeContext(k, skip_staleness_check=True)
# Use a deterministic seed in tests for snapshot stability.
ctx._id_generator = CellIdGenerator(seed=7)
ctx._id_generator.seen_ids = set(doc.cell_ids)
Expand Down
2 changes: 1 addition & 1 deletion tests/_code_mode/test_context_autosave.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _ctx(k: Kernel) -> Generator[AsyncCodeModeContext, None, None]:
]
doc = NotebookDocument(cells)
with notebook_document_context(doc):
ctx = AsyncCodeModeContext(k)
ctx = AsyncCodeModeContext(k, skip_staleness_check=True)
ctx._id_generator = CellIdGenerator(seed=7)
ctx._id_generator.seen_ids = set(doc.cell_ids)
yield ctx
Expand Down
Loading
Loading