-
Notifications
You must be signed in to change notification settings - Fork 1.1k
feat(code-mode): read-before-write protection for cell edits #9585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -189,7 +189,11 @@ def exception(self) -> Exception | None: ... | |||||||||||
| # ------------------------------------------------------------------ | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def get_context(*, skip_validation: bool = False) -> AsyncCodeModeContext: | ||||||||||||
| def get_context( | ||||||||||||
| *, | ||||||||||||
| skip_validation: bool = False, | ||||||||||||
| skip_staleness_check: bool = False, | ||||||||||||
| ) -> AsyncCodeModeContext: | ||||||||||||
| """Return an ``AsyncCodeModeContext`` for the running kernel. | ||||||||||||
|
|
||||||||||||
| Use as an async context manager:: | ||||||||||||
|
|
@@ -206,6 +210,12 @@ def get_context(*, skip_validation: bool = False) -> AsyncCodeModeContext: | |||||||||||
| and should almost never be disabled. Only set to True when you | ||||||||||||
| intentionally need to insert code that would fail validation | ||||||||||||
| (e.g. incomplete stubs the user plans to fix by hand). | ||||||||||||
| skip_staleness_check : bool, default False | ||||||||||||
| When False (the default), ``edit_cell`` raises | ||||||||||||
| :class:`StaleCellError` if the agent tries to overwrite a cell | ||||||||||||
| whose code has changed since the agent last read it (e.g. via | ||||||||||||
| ``ctx.cells[cell_id]``). Set to True to overwrite blindly — | ||||||||||||
| useful when the agent intentionally discards prior content. | ||||||||||||
| """ | ||||||||||||
| runtime_ctx = _get_runtime_context() | ||||||||||||
| if not isinstance(runtime_ctx, KernelRuntimeContext): | ||||||||||||
|
|
@@ -215,9 +225,38 @@ def get_context(*, skip_validation: bool = False) -> AsyncCodeModeContext: | |||||||||||
| runtime_ctx._kernel, | ||||||||||||
| cell_manager=cell_manager, | ||||||||||||
| skip_validation=skip_validation, | ||||||||||||
| skip_staleness_check=skip_staleness_check, | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| class StaleCellError(RuntimeError): | ||||||||||||
| """Raised when ``edit_cell`` targets a cell the agent has not read at | ||||||||||||
| its current version. | ||||||||||||
|
|
||||||||||||
| ``stale_cells`` is the full set of cells in the same boat, so the agent | ||||||||||||
| can re-read all of them in one pass. | ||||||||||||
| """ | ||||||||||||
|
|
||||||||||||
| def __init__( | ||||||||||||
| self, | ||||||||||||
| cell_id: CellId_t, | ||||||||||||
| stale_cells: frozenset[CellId_t], | ||||||||||||
| ) -> None: | ||||||||||||
| self.cell_id = cell_id | ||||||||||||
| self.stale_cells = stale_cells | ||||||||||||
| others = sorted(stale_cells - {cell_id}) | ||||||||||||
| other_hint = ( | ||||||||||||
| f"\nOther stale cells: {', '.join(others)}." if others else "" | ||||||||||||
| ) | ||||||||||||
| super().__init__( | ||||||||||||
| f"Cell {cell_id!r} was modified since the agent last read it.\n" | ||||||||||||
| f"Read it first (e.g. `ctx.cells[{cell_id!r}].code`) before " | ||||||||||||
| f"editing.{other_hint}\n" | ||||||||||||
| f"To override and overwrite without re-reading, pass " | ||||||||||||
| f"skip_staleness_check=True to cm.get_context()." | ||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice, this was going to be my main comment. |
||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @helpable | ||||||||||||
| class NotebookCell: | ||||||||||||
| """Read-only view of a single cell with runtime status. | ||||||||||||
|
|
@@ -411,6 +450,7 @@ def _doc(self) -> NotebookDocument: | |||||||||||
|
|
||||||||||||
| def _cell_view(self, cell: _NotebookCell) -> NotebookCell: | ||||||||||||
| """Wrap a document cell with runtime state from the graph.""" | ||||||||||||
| self._ctx._note_read(cell.id, cell.version) | ||||||||||||
| try: | ||||||||||||
| graph = self._ctx.graph | ||||||||||||
| impl = graph.cells.get(cell.id) | ||||||||||||
|
|
@@ -589,6 +629,7 @@ def __init__( | |||||||||||
| cell_manager: CellManager | None = None, | ||||||||||||
| *, | ||||||||||||
| skip_validation: bool = False, | ||||||||||||
| skip_staleness_check: bool = False, | ||||||||||||
| ) -> None: | ||||||||||||
| from marimo._messaging.notebook.document import get_current_document | ||||||||||||
|
|
||||||||||||
|
|
@@ -603,6 +644,7 @@ def __init__( | |||||||||||
| self._document = document | ||||||||||||
| self._cell_manager = cell_manager | ||||||||||||
| self._skip_validation = skip_validation | ||||||||||||
| self._skip_staleness_check = skip_staleness_check | ||||||||||||
| self._ops: list[_Op] = [] | ||||||||||||
| # Track cell IDs added during this batch so subsequent ops | ||||||||||||
| # can reference them before they exist in the graph. | ||||||||||||
|
|
@@ -629,6 +671,9 @@ def _require_entered(self) -> None: | |||||||||||
| "Without 'async with', operations are silently lost." | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| def _note_read(self, cell_id: CellId_t, version: int) -> None: | ||||||||||||
| self._kernel.agent.read_tracker.record_read(cell_id, version) | ||||||||||||
|
|
||||||||||||
| def __getattr__(self, name: str) -> Any: | ||||||||||||
| # Legacy alias: `ctx.install_packages(...)` was the pre-namespace | ||||||||||||
| # API. Kept as a hidden shim for in-flight skills / examples; | ||||||||||||
|
|
@@ -1076,6 +1121,27 @@ def edit_cell( | |||||||||||
| # Setup is identified by cell_id alone — don't store a name. | ||||||||||||
| name = None | ||||||||||||
|
|
||||||||||||
| # Check after the setup-migration block so an implicit code-fill | ||||||||||||
| # from the graph still trips the read-before-write guard. | ||||||||||||
| if ( | ||||||||||||
| code is not None | ||||||||||||
| and not self._skip_staleness_check | ||||||||||||
| and cell_id not in self._pending_adds | ||||||||||||
| ): | ||||||||||||
| cell = self._document.get(cell_id) | ||||||||||||
| tracker = self._kernel.agent.read_tracker | ||||||||||||
| # Empty cells have no prior content to clobber. The agent's own | ||||||||||||
| # writes record reads at __aexit__, so a follow-up edit in a | ||||||||||||
| # later context passes the check normally. | ||||||||||||
| if ( | ||||||||||||
| cell is not None | ||||||||||||
| and cell.code.strip() | ||||||||||||
| and not tracker.has_read(cell_id, cell.version) | ||||||||||||
| ): | ||||||||||||
| raise StaleCellError( | ||||||||||||
| cell_id, tracker.get_stale_cells(self._document) | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| # Build config only if any config kwarg was explicitly set. | ||||||||||||
| config: CellConfig | None = None | ||||||||||||
| if hide_code is not None or disabled is not None or column is not None: | ||||||||||||
|
|
@@ -1594,6 +1660,20 @@ async def _apply_ops( | |||||||||||
| NotebookDocumentTransactionNotification(transaction=tx) | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| # The agent wrote these cells in this batch — its effective view is | ||||||||||||
| # the post-write version. Without this, a cell created in call N | ||||||||||||
| # can't be edited in call N+1 without re-reading. | ||||||||||||
| for op in ops: | ||||||||||||
| if isinstance(op, _AddOp): | ||||||||||||
| written_id = op.cell_id | ||||||||||||
| elif isinstance(op, _UpdateOp) and op.code is not None: | ||||||||||||
| written_id = op.new_cell_id or op.cell_id | ||||||||||||
| else: | ||||||||||||
| continue | ||||||||||||
| current = self._document.get_cell_version(written_id) | ||||||||||||
| if current is not None: | ||||||||||||
| self._note_read(written_id, current) | ||||||||||||
|
Comment on lines
+1673
to
+1675
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit, but we probably don't need a special method for this:
Suggested change
|
||||||||||||
|
|
||||||||||||
| # Run queued cells (explicit run_cell + autorun descendants), | ||||||||||||
| # filtered to cells that still exist after structural ops. | ||||||||||||
| _run_set = explicit_run or set() | ||||||||||||
|
|
||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| # Copyright 2026 Marimo. All rights reserved. | ||
| """Per-kernel agent state.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from dataclasses import dataclass, field | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| if TYPE_CHECKING: | ||
| from marimo._messaging.notebook.document import NotebookDocument | ||
| from marimo._types.ids import CellId_t | ||
|
|
||
|
|
||
| class AgentReadTracker: | ||
| """Highest cell version the agent has observed, per cell.""" | ||
|
|
||
| def __init__(self) -> None: | ||
| self._read_versions: dict[CellId_t, int] = {} | ||
|
|
||
| def record_read(self, cell_id: CellId_t, version: int) -> None: | ||
| prev = self._read_versions.get(cell_id, -1) | ||
| if version > prev: | ||
| self._read_versions[cell_id] = version | ||
|
|
||
| def has_read(self, cell_id: CellId_t, current_version: int) -> bool: | ||
| last = self._read_versions.get(cell_id) | ||
| return last is not None and last >= current_version | ||
|
|
||
| def get_stale_cells(self, doc: NotebookDocument) -> frozenset[CellId_t]: | ||
| # Empty (or whitespace-only) cells have nothing to clobber, so they | ||
| # never count as stale even when the agent hasn't read them. | ||
| stale: set[CellId_t] = set() | ||
| for cell in doc.cells: | ||
| if not cell.code.strip(): | ||
| continue | ||
| last = self._read_versions.get(cell.id) | ||
| if last is None or cell.version > last: | ||
| stale.add(cell.id) | ||
| return frozenset(stale) | ||
|
|
||
|
|
||
| @dataclass | ||
| class Agent: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice. I really like that we have some carved out space for this kind of state now. |
||
| """One per ``Kernel`` — long-lived across scratchpad executions.""" | ||
|
|
||
| read_tracker: AgentReadTracker = field(default_factory=AgentReadTracker) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe for a followup, but could/should we emit (over stderror) the cells that changed externally as a kind of indicator to the model (hey these have changed)? Maybe context bloat.