diff --git a/marimo/_ast/app.py b/marimo/_ast/app.py index 40db052bc03..6ea253bc8f3 100644 --- a/marimo/_ast/app.py +++ b/marimo/_ast/app.py @@ -259,7 +259,6 @@ def __init__(self, **kwargs: Any) -> None: self._cell_manager = CellManager(prefix=cell_prefix) self._graph = dataflow.DirectedGraph() self._execution_context: ExecutionContext | None = None - self._runner = dataflow.Runner(self._graph) self._header: str | None = None self._unparsable_code: list[str] = [] @@ -780,17 +779,23 @@ def process_data(pd, batch_size, learning_rate): async def _run_cell_async( self, cell: Cell, kwargs: dict[str, Any] ) -> tuple[Any, _Namespace]: + from marimo._runtime.runner import by_kwargs + self._maybe_initialize() - output, defs = await self._runner.run_cell_async( - cell._cell.cell_id, kwargs + output, defs = await by_kwargs.run_cell_async( + self._graph, cell._cell.cell_id, kwargs ) return output, _Namespace(defs, owner=self) def _run_cell_sync( self, cell: Cell, kwargs: dict[str, Any] ) -> tuple[Any, _Namespace]: + from marimo._runtime.runner import by_kwargs + self._maybe_initialize() - output, defs = self._runner.run_cell_sync(cell._cell.cell_id, kwargs) + output, defs = by_kwargs.run_cell_sync( + self._graph, cell._cell.cell_id, kwargs + ) return output, _Namespace(defs, owner=self) async def _set_ui_element_value( @@ -1000,11 +1005,6 @@ def set_execution_context( ) -> None: self._app._execution_context = execution_context - @property - def runner(self) -> dataflow.Runner: - self._app._maybe_initialize() - return self._app._runner - def update_config(self, updates: dict[str, Any]) -> _AppConfig: return self.config.update(updates) diff --git a/marimo/_ast/cell.py b/marimo/_ast/cell.py index 7257aed1993..999f6c411ab 100644 --- a/marimo/_ast/cell.py +++ b/marimo/_ast/cell.py @@ -481,8 +481,12 @@ def _is_coroutine(self) -> bool: if hasattr(self, "_is_coro_cached"): return self._is_coro_cached assert self._app is not None - self._is_coro_cached: bool = self._app.runner.is_coroutine( - self._cell.cell_id + from marimo._runtime.runner import by_kwargs + + # Currently expensive since `graph` triggers _maybe_initialize on the + # underlying App. + self._is_coro_cached: bool = by_kwargs.is_coroutine( + self._app.graph, self._cell.cell_id ) return self._is_coro_cached @@ -655,8 +659,15 @@ def add(mo, x, y): } refs = {**from_setup, **refs} + from marimo._runtime.runner import by_kwargs + try: - if self._is_coroutine: + # Refresh the async decision with the caller's substitutions — + # an unsubstituted ancestor may have been async but isn't on + # this call's ancestor closure. + if by_kwargs.is_coroutine( + self._app.graph, self._cell.cell_id, refs + ): return self._app.run_cell_async(cell=self, kwargs=refs) else: return self._app.run_cell_sync(cell=self, kwargs=refs) diff --git a/marimo/_messaging/tracebacks.py b/marimo/_messaging/tracebacks.py index cb153407758..29ae9669cdf 100644 --- a/marimo/_messaging/tracebacks.py +++ b/marimo/_messaging/tracebacks.py @@ -49,10 +49,8 @@ def write_traceback(traceback: str) -> None: # In run mode, only forward to the frontend if show_tracebacks is on. if in_run_mode and not _show_tracebacks_enabled(): return - # Strip marimo's internal executor.py frame and highlight for the UI - trimmed = _trim_traceback(traceback) sys.stderr._write_with_mimetype( - _highlight_traceback(trimmed), + _highlight_traceback(traceback), mimetype="application/vnd.marimo+traceback", ) else: @@ -64,16 +62,15 @@ def write_traceback(traceback: str) -> None: if in_run_mode and not _show_tracebacks_enabled(): sys.stderr.write(traceback) return - trimmed = _trim_traceback(traceback) broadcast_notification( CellNotification( cell_id=ctx.cell_id, console=CellOutput( channel=CellChannel.STDERR, mimetype="application/vnd.marimo+traceback", - data=trimmed + data=traceback if code_mode - else _highlight_traceback(trimmed), + else _highlight_traceback(traceback), ), ), ctx.stream, @@ -83,27 +80,5 @@ def write_traceback(traceback: str) -> None: sys.stderr.write(traceback) -def _trim_traceback(traceback: str) -> str: - """ - Skip first DefaultExecutor.execute_cell traceback item which all traces start with. - """ - - lines = traceback.split("\n") - if ( - len(lines) > 2 - and lines[0] == "Traceback (most recent call last):" - and ( - '/marimo/_runtime/executor.py", line ' in lines[1] - or '\\marimo\\_runtime\\executor.py", line ' in lines[1] - ) - and lines[1].endswith(", in execute_cell") - ): - for i in range(2, len(lines)): - if lines[i].startswith(" File "): - return "\n".join(lines[:1] + lines[i:]) - - return traceback - - def is_code_highlighting(value: str) -> bool: return 'class="codehilite"' in value diff --git a/marimo/_runtime/app/script_runner.py b/marimo/_runtime/app/script_runner.py index b8d1a4a388d..73c621f7328 100644 --- a/marimo/_runtime/app/script_runner.py +++ b/marimo/_runtime/app/script_runner.py @@ -2,7 +2,6 @@ from __future__ import annotations import asyncio -from collections import deque from typing import TYPE_CHECKING, Any from marimo._ast.names import SETUP_CELL_NAME @@ -21,12 +20,17 @@ MarimoRuntimeException, unwrap_user_exception, ) -from marimo._runtime.executor import resolve_executor +from marimo._runtime.executor import ( + Evaluator, + resolve_executor, +) from marimo._runtime.patches import ( create_main_module, extract_docstring_from_header, patch_main_module_context, ) +from marimo._runtime.runner.result import RunResult +from marimo._runtime.runner.scheduler import SequentialScheduler from marimo._types.ids import CellId_t if TYPE_CHECKING: @@ -47,7 +51,6 @@ def __init__( self.app = app self.filename = filename self._docstring = extract_docstring_from_header(app._app._header) - self.cells_cancelled: set[CellId_t] = set() self._glbls = glbls if glbls else {} # Setup cell cannot be overridden, and it's possible that some @@ -59,24 +62,21 @@ def __init__( excluded=CellId_t(SETUP_CELL_NAME), ) - self.cells_to_run: deque[CellId_t] = deque( + cells_to_run = [ cid for cid in pruned_execution_order if app.cell_manager.cell_data_at(cid).cell is not None and not self.app.graph.is_disabled(cid) - ) - self._executor = resolve_executor() + ] - def _cancel(self, cell_id: CellId_t) -> None: - cancelled = { - cid - for cid in dataflow.transitive_closure(self.app.graph, {cell_id}) - if cid in self.cells_to_run - } - for cid in cancelled: - self.app.graph.cells[cid].set_run_result_status("cancelled") - self.cells_cancelled |= cancelled + self._scheduler = SequentialScheduler(cells_to_run, self.app.graph) + self._evaluator = Evaluator(executor=resolve_executor(), lifecycles=[]) + # _run_synchronous and _run_asynchronous are deliberate near-twins: + # the only difference is the await on the cell step. Keeping them + # as separate methods (rather than wrapping with asyncio.run + # unconditionally) preserves the no-event-loop guarantee for purely + # synchronous apps. def _run_synchronous( self, post_execute_hooks: list[Callable[[], Any]], @@ -93,39 +93,20 @@ def _run_synchronous( glbls.update(self._glbls) outputs: dict[CellId_t, Any] = {} - while self.cells_to_run: - cid = self.cells_to_run.popleft() - if cid in self.cells_cancelled: + while self._scheduler.pending(): + cid = self._scheduler.pop_cell() + if self._scheduler.cancelled(cid): continue - # Set up has already run in this case. + # Setup has already run by this point. if cid == CellId_t(SETUP_CELL_NAME): for hook in post_execute_hooks: hook() continue - cell = self.app.graph.cells[cid] with get_context().with_cell_id(cid): try: - output = self._executor.execute_cell(cell, glbls) - outputs[cid] = output - except MarimoRuntimeException as e: - unwrapped_exception = unwrap_user_exception( - e, self.app.graph - ) - - if isinstance(unwrapped_exception, MarimoStopError): - self._cancel(cid) - elif isinstance( - unwrapped_exception, MarimoMissingRefError - ): - name_err = unwrapped_exception.name_error - raise ( - name_err - if name_err is not None - else unwrapped_exception - ) from None - else: - raise + result = self._evaluator.evaluate_sync(cell, glbls) + self._handle_run_result(cid, result, outputs) finally: for hook in post_execute_hooks: hook() @@ -147,47 +128,55 @@ async def _run_asynchronous( glbls.update(self._glbls) outputs: dict[CellId_t, Any] = {} - - while self.cells_to_run: - cid = self.cells_to_run.popleft() - if cid in self.cells_cancelled: + while self._scheduler.pending(): + cid = self._scheduler.pop_cell() + if self._scheduler.cancelled(cid): continue - + # Setup has already run by this point. if cid == CellId_t(SETUP_CELL_NAME): for hook in post_execute_hooks: hook() continue - cell = self.app.graph.cells[cid] with get_context().with_cell_id(cid): try: - output = await self._executor.execute_cell_async( - cell, glbls - ) - outputs[cid] = output - except MarimoRuntimeException as e: - unwrapped_exception = unwrap_user_exception( - e, self.app.graph - ) - - if isinstance(unwrapped_exception, MarimoStopError): - self._cancel(cid) - elif isinstance( - unwrapped_exception, MarimoMissingRefError - ): - name_err = unwrapped_exception.name_error - raise ( - name_err - if name_err is not None - else unwrapped_exception - ) from None - else: - raise + result = await self._evaluator.evaluate(cell, glbls) + self._handle_run_result(cid, result, outputs) finally: for hook in post_execute_hooks: hook() return outputs, glbls + def _handle_run_result( + self, + cid: CellId_t, + result: RunResult, + outputs: dict[CellId_t, Any], + ) -> None: + """Classify the Evaluator's RunResult; record output/cancel/raise.""" + exc = result.exception + if exc is None: + outputs[cid] = result.output + return + if not isinstance(exc, BaseException): + # Defensive check descendants, since all exceptions are expected to + # be wrapped. + outputs[cid] = result.output + self._scheduler.cancel(cid) + return + if isinstance(exc, MarimoRuntimeException): + unwrapped = unwrap_user_exception(exc, self.app.graph) + if isinstance(unwrapped, MarimoStopError): + outputs[cid] = unwrapped.output + self._scheduler.cancel(cid) + return + if isinstance(unwrapped, MarimoMissingRefError): + name_err = unwrapped.name_error + raise ( + name_err if name_err is not None else unwrapped + ) from None + raise exc + def run(self) -> RunOutput: from marimo._runtime.context.script_context import ( initialize_script_context, @@ -231,7 +220,7 @@ def run(self) -> RunOutput: theme=get_context().marimo_config["display"]["theme"] ) - post_execute_hooks = [] + post_execute_hooks: list[Callable[[], Any]] = [] if DependencyManager.matplotlib.has(): from marimo._output.mpl import close_figures @@ -249,25 +238,9 @@ def run(self) -> RunOutput: ) return outputs, defs - # Cell runner manages the exception handling for kernel - # runner, but script runner should raise the wrapped - # exception if invoked directly. + # Raise the wrapped user exception from "None" so the stack + # trace points at the failing cell, not the runner. except MarimoRuntimeException as e: - # MarimoMissingRefError, wraps the under lying NameError - # for context, so we raise the NameError directly. - if isinstance(e.__cause__, MarimoMissingRefError): - # For type checking + sanity check - if not isinstance(e.__cause__.name_error, NameError): - raise MarimoRuntimeException( - "Unexpected error occurred while running the app. " - "Improperly wrapped MarimoMissingRefError exception. " - "Please report this issue to " - "https://github.com/marimo-team/marimo/issues" - ) from e.__cause__ - raise e.__cause__.name_error from e.__cause__ - # For all other exceptions, we raise the wrapped exception - # from "None" to indicate this is an Error propagation, and to not - # muddy the stacktrace from the failing cells themselves. raise e.__cause__ from None # type: ignore finally: if installed_script_context: diff --git a/marimo/_runtime/dataflow/__init__.py b/marimo/_runtime/dataflow/__init__.py index ebbe5cdadd2..7b9e304f7eb 100644 --- a/marimo/_runtime/dataflow/__init__.py +++ b/marimo/_runtime/dataflow/__init__.py @@ -7,7 +7,6 @@ from marimo import _loggers from marimo._ast.cell import CellImpl from marimo._runtime.dataflow.graph import DirectedGraph -from marimo._runtime.dataflow.runner import Runner from marimo._runtime.dataflow.topology import GraphTopology from marimo._runtime.dataflow.types import Edge, EdgeWithVar from marimo._types.ids import CellId_t @@ -253,7 +252,6 @@ def import_block_relatives(cid: CellId_t, children: bool) -> set[CellId_t]: "DirectedGraph", "Edge", "EdgeWithVar", - "Runner", "get_cycles", "get_import_block_relatives", "induced_subgraph", diff --git a/marimo/_runtime/dataflow/runner.py b/marimo/_runtime/dataflow/runner.py deleted file mode 100644 index e01a54221e0..00000000000 --- a/marimo/_runtime/dataflow/runner.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2026 Marimo. All rights reserved. -"""Runner utility for executing individual cells in a graph.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Any - -from marimo._runtime.executor import DefaultExecutor - -if TYPE_CHECKING: - from marimo._ast.cell import CellImpl - from marimo._runtime.dataflow.graph import DirectedGraph - from marimo._types.ids import CellId_t - - -class Runner: - """Utility for running individual cells in a graph - - This class provides methods to a run a cell in the graph and obtain its - output (last expression) and the values of its defs. - - If needed, the runner will recursively compute the values of the cell's - refs by executing its ancestors. Refs can also be substituted by the - caller. - - TODO(akshayka): Add an API for caching defs across cell runs. - """ - - def __init__(self, graph: DirectedGraph) -> None: - self._graph = graph - self._executor = DefaultExecutor() - - @staticmethod - def _returns(cell_impl: CellImpl, glbls: dict[str, Any]) -> dict[str, Any]: - return {name: glbls[name] for name in cell_impl.defs if name in glbls} - - @staticmethod - def _substitute_refs( - cell_impl: CellImpl, - glbls: dict[str, Any], - kwargs: dict[str, Any], - ) -> None: - for argname, argvalue in kwargs.items(): - if argname in cell_impl.refs: - glbls[argname] = argvalue - else: - raise ValueError( - f"Cell got unexpected argument {argname}" - f"The allowed arguments are {cell_impl.refs}." - ) - - def _get_ancestors( - self, cell_impl: CellImpl, kwargs: dict[str, Any] - ) -> set[CellId_t]: - from marimo._runtime.dataflow import transitive_closure - - # Get the transitive closure of parents defining unsubstituted refs - graph = self._graph - substitutions = set(kwargs.keys()) - unsubstituted_refs = cell_impl.refs - substitutions - parent_ids = { - parent_id - for parent_id in graph.parents[cell_impl.cell_id] - if graph.cells[parent_id].defs.intersection(unsubstituted_refs) - } - return transitive_closure(graph, parent_ids, children=False) - - @staticmethod - def _validate_kwargs(cell_impl: CellImpl, kwargs: dict[str, Any]) -> None: - for argname in kwargs: - if argname not in cell_impl.refs: - raise ValueError( - f"Cell got unexpected argument {argname}; " - f"The allowed arguments are {cell_impl.refs}." - ) - - def is_coroutine(self, cell_id: CellId_t) -> bool: - return self._graph.cells[cell_id].is_coroutine() or any( - self._graph.cells[cid].is_coroutine() - for cid in self._get_ancestors( - self._graph.cells[cell_id], kwargs={} - ) - ) - - async def run_cell_async( - self, cell_id: CellId_t, kwargs: dict[str, Any] - ) -> tuple[Any, dict[str, Any]]: - """Run a possibly async cell and its ancestors - - Substitutes kwargs as refs for the cell, omitting ancestors that - whose refs are substituted. - """ - from marimo._runtime.dataflow import topological_sort - - graph = self._graph - cell_impl = graph.cells[cell_id] - Runner._validate_kwargs(cell_impl, kwargs) - ancestor_ids = self._get_ancestors(cell_impl, kwargs) - - glbls: dict[str, Any] = {} - for cid in topological_sort(graph, ancestor_ids): - await self._executor.execute_cell_async(graph.cells[cid], glbls) - - Runner._substitute_refs(cell_impl, glbls, kwargs) - output = await self._executor.execute_cell_async( - graph.cells[cell_impl.cell_id], glbls - ) - defs = Runner._returns(cell_impl, glbls) - return output, defs - - def run_cell_sync( - self, cell_id: CellId_t, kwargs: dict[str, Any] - ) -> tuple[Any, dict[str, Any]]: - """Run a synchronous cell and its ancestors - - Substitutes kwargs as refs for the cell, omitting ancestors that - whose refs are substituted. - - Raises a `RuntimeError` if the cell or any of its unsubstituted - ancestors are coroutine functions. - """ - from marimo._runtime.dataflow import topological_sort - - graph = self._graph - cell_impl = graph.cells[cell_id] - if cell_impl.is_coroutine(): - raise RuntimeError( - "A coroutine function can't be run synchronously. " - "Use `run_async()` instead" - ) - - Runner._validate_kwargs(cell_impl, kwargs) - ancestor_ids = self._get_ancestors(cell_impl, kwargs) - - if any(graph.cells[cid].is_coroutine() for cid in ancestor_ids): - raise RuntimeError( - "Cell has an ancestor that is a " - "coroutine (async) cell. Use `run_async()` instead" - ) - - glbls: dict[str, Any] = {} - for cid in topological_sort(graph, ancestor_ids): - self._executor.execute_cell(graph.cells[cid], glbls) - - self._substitute_refs(cell_impl, glbls, kwargs) - output = self._executor.execute_cell( - graph.cells[cell_impl.cell_id], glbls - ) - defs = Runner._returns(cell_impl, glbls) - return output, defs diff --git a/marimo/_runtime/executor/evaluator.py b/marimo/_runtime/executor/evaluator.py index 7a749b3d6b1..d3fd0781d3b 100644 --- a/marimo/_runtime/executor/evaluator.py +++ b/marimo/_runtime/executor/evaluator.py @@ -3,17 +3,24 @@ from __future__ import annotations +import asyncio +import contextlib +import functools +import signal +import threading from dataclasses import replace from typing import TYPE_CHECKING, Any from marimo import _loggers from marimo._entrypoints.registry import EntryPointRegistry +from marimo._runtime.control_flow import MarimoInterrupt from marimo._runtime.executor.executor import DefaultExecutor, Executor from marimo._runtime.executor.lifecycles import ExecutionLifecycle, Skip from marimo._runtime.runner.result import RunResult +from marimo._types.globals import MutableGlobals if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Callable, Iterator from marimo._ast.cell import CellImpl @@ -33,13 +40,70 @@ def __init__( self.lifecycles: list[ExecutionLifecycle] = lifecycles or [] async def evaluate( - self, cell: CellImpl, glbls: dict[str, Any] + self, cell: CellImpl, glbls: MutableGlobals ) -> RunResult: """Setup lifecycles, execute, and teardown lifecycles.""" + completed, skip, body_exc = self._setup_chain(cell, glbls) + + if body_exc is not None: + result: RunResult = RunResult(output=None, exception=body_exc) + elif skip is not None: + # Lifecycle short-circuited — pass its full RunResult through + # so `accumulated_output` and any other field survive. + result = ( + skip.result + if skip.result is not None + else RunResult(output=None, exception=None) + ) + else: + try: + value = await self.executor.execute_cell_async(cell, glbls) + result = RunResult(output=value, exception=None) + except BaseException as e: + result = RunResult(output=None, exception=e) + + return self._teardown_chain(cell, glbls, completed, result) + + def evaluate_sync( + self, cell: CellImpl, glbls: MutableGlobals + ) -> RunResult: + """Sync mirror of `evaluate` — for callers without an event loop.""" + completed, skip, body_exc = self._setup_chain(cell, glbls) + + if body_exc is not None: + result: RunResult = RunResult(output=None, exception=body_exc) + elif skip is not None: + result = ( + skip.result + if skip.result is not None + else RunResult(output=None, exception=None) + ) + else: + try: + value = self.executor.execute_cell(cell, glbls) + result = RunResult(output=value, exception=None) + except BaseException as e: + result = RunResult(output=None, exception=e) + + return self._teardown_chain(cell, glbls, completed, result) + + async def evaluate_interruptible( + self, cell: CellImpl, glbls: MutableGlobals + ) -> RunResult: + """Await `evaluate` with SIGINT capture for coroutine cells.""" + if not cell.is_coroutine(): + return await self.evaluate(cell, glbls) + future = asyncio.ensure_future(self.evaluate(cell, glbls)) + if threading.current_thread() is threading.main_thread(): + with _cancel_on_sigint(future): + return await future + return await future + + def _setup_chain( + self, cell: CellImpl, glbls: MutableGlobals + ) -> tuple[list[ExecutionLifecycle], Skip | None, BaseException | None]: completed: list[ExecutionLifecycle] = [] skip: Skip | None = None - result: RunResult | None = None - try: for life in self.lifecycles: decision = life.setup(cell, glbls) @@ -48,22 +112,16 @@ async def evaluate( skip = decision break except BaseException as e: - result = RunResult(output=None, exception=e) - - if result is None: - if skip is not None and skip.result is not None: - # Lifecycle supplied a complete RunResult — preserve all - # fields (output, accumulated_output, exception). - result = skip.result - elif skip is not None: - result = RunResult(output=None, exception=None) - else: - try: - value = await self.executor.execute_cell_async(cell, glbls) - result = RunResult(output=value, exception=None) - except BaseException as e: - result = RunResult(output=None, exception=e) + return completed, None, e + return completed, skip, None + def _teardown_chain( + self, + cell: CellImpl, + glbls: MutableGlobals, + completed: list[ExecutionLifecycle], + result: RunResult, + ) -> RunResult: teardown_exc: BaseException | None = None for life in reversed(completed): try: @@ -121,3 +179,48 @@ def resolve_executor() -> Executor: e, ) return DefaultExecutor() + + +# Adapted from +# https://github.com/ipython/ipykernel/blob/eddd3e666a82ebec287168b0da7cfa03639a3772/ipykernel/ipkernel.py#L312 +@contextlib.contextmanager +def _cancel_on_sigint(future: asyncio.Future[Any]) -> Iterator[None]: + """Cancel `future` if a SIGINT arrives during evaluation.""" + sigint_future: asyncio.Future[int] = asyncio.Future() + + def cancel_unless_done(f: asyncio.Future[Any], _: Any) -> None: + if f.cancelled() or f.done(): + return + f.cancel() + + sigint_future.add_done_callback( + functools.partial(cancel_unless_done, future) + ) + future.add_done_callback( + functools.partial(cancel_unless_done, sigint_future) + ) + + # Capture the previously-installed SIGINT handler *before* we install + # ours so `handle_sigint` can invoke it for its side effects + # (kernel broadcast, duckdb interrupt). For async cells the actual + # halt comes from cancelling the future, not from a raised + # `MarimoInterrupt` — so we swallow that here. + prior_sigint = signal.getsignal(signal.SIGINT) + + def handle_sigint(signum: int, frame: Any) -> None: + if sigint_future.cancelled() or sigint_future.done(): + return + sigint_future.set_result(1) + if callable(prior_sigint): + try: + prior_sigint(signum, frame) + except MarimoInterrupt: + # The kernel's handler raises MarimoInterrupt for sync + # halt; we cancel the future instead. + pass + + save_sigint = signal.signal(signal.SIGINT, handle_sigint) + try: + yield + finally: + signal.signal(signal.SIGINT, save_sigint) diff --git a/marimo/_runtime/executor/executor.py b/marimo/_runtime/executor/executor.py index fa496ed96c9..a75285feeec 100644 --- a/marimo/_runtime/executor/executor.py +++ b/marimo/_runtime/executor/executor.py @@ -3,31 +3,47 @@ from __future__ import annotations +import asyncio from typing import TYPE_CHECKING, Any, Protocol from marimo._ast.cell import _is_coroutine from marimo._runtime.exceptions import MarimoRuntimeException +from marimo._types.globals import MutableGlobals if TYPE_CHECKING: from marimo._ast.cell import CellImpl +def _strip_frame(e: BaseException, count: int = 1) -> None: + """Drop the top `count` frames from `e.__traceback__`. + + Stops early if the traceback runs out — never strips the last + frame, so we don't lose the only frame we have. + """ + tb = e.__traceback__ + for _ in range(count): + if tb is None or tb.tb_next is None: + break + tb = tb.tb_next + e.__traceback__ = tb + + class Executor(Protocol): """Body strategy: how to run a cell's body.""" name: str - def execute_cell(self, cell: CellImpl, glbls: dict[str, Any]) -> Any: ... + def execute_cell(self, cell: CellImpl, glbls: MutableGlobals) -> Any: ... async def execute_cell_async( - self, cell: CellImpl, glbls: dict[str, Any] + self, cell: CellImpl, glbls: MutableGlobals ) -> Any: ... class DefaultExecutor: name = "default" - def execute_cell(self, cell: CellImpl, glbls: dict[str, Any]) -> Any: + def execute_cell(self, cell: CellImpl, glbls: MutableGlobals) -> Any: if cell.body is None: return None assert cell.last_expr is not None @@ -39,13 +55,16 @@ def execute_cell(self, cell: CellImpl, glbls: dict[str, Any]) -> Any: try: exec(cell.body, glbls) return eval(cell.last_expr, glbls) + except asyncio.CancelledError: + # Cancellation is control flow, not user error — surface bare. + raise except BaseException as e: - # Raising from BaseException folds in the stack trace prior - # to execution. + # Strip our own frame so user-facing tracebacks start at user code. + _strip_frame(e) raise MarimoRuntimeException from e async def execute_cell_async( - self, cell: CellImpl, glbls: dict[str, Any] + self, cell: CellImpl, glbls: MutableGlobals ) -> Any: if cell.body is None: return None @@ -58,5 +77,8 @@ async def execute_cell_async( if _is_coroutine(cell.last_expr): return await eval(cell.last_expr, glbls) return eval(cell.last_expr, glbls) + except asyncio.CancelledError: + raise except BaseException as e: + _strip_frame(e) raise MarimoRuntimeException from e diff --git a/marimo/_runtime/executor/lifecycles/__init__.py b/marimo/_runtime/executor/lifecycles/__init__.py index 68674e9d1bc..24237c7969b 100644 --- a/marimo/_runtime/executor/lifecycles/__init__.py +++ b/marimo/_runtime/executor/lifecycles/__init__.py @@ -4,9 +4,10 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Protocol +from typing import TYPE_CHECKING, Protocol from marimo._runtime.runner.result import RunResult +from marimo._types.globals import MutableGlobals if TYPE_CHECKING: from marimo._ast.cell import CellImpl @@ -30,11 +31,11 @@ class ExecutionLifecycle(Protocol): name: str - def setup(self, cell: CellImpl, glbls: dict[str, Any]) -> Skip | None: ... + def setup(self, cell: CellImpl, glbls: MutableGlobals) -> Skip | None: ... def teardown( self, cell: CellImpl, - glbls: dict[str, Any], + glbls: MutableGlobals, run_result: RunResult, ) -> None: ... diff --git a/marimo/_runtime/executor/lifecycles/strict.py b/marimo/_runtime/executor/lifecycles/strict.py index a5621c81d91..383fe9ce3f3 100644 --- a/marimo/_runtime/executor/lifecycles/strict.py +++ b/marimo/_runtime/executor/lifecycles/strict.py @@ -22,6 +22,7 @@ is_unclonable_type, ) from marimo._runtime.runner.result import RunResult +from marimo._types.globals import MutableGlobals if TYPE_CHECKING: from marimo._ast.cell import CellImpl @@ -60,7 +61,7 @@ def __init__(self, graph: DirectedGraph) -> None: # Per-cell setup→teardown backup. Keyed by cell_id. self._backups: dict[CellId_t, dict[str, Any]] = {} - def setup(self, cell: CellImpl, glbls: dict[str, Any]) -> Skip | None: + def setup(self, cell: CellImpl, glbls: MutableGlobals) -> Skip | None: refs = self._graph.get_transitive_references( cell.refs, predicate=build_ref_predicate_for_primitives( @@ -136,7 +137,7 @@ def _sanitize_ref(self, name: str, value: Any) -> Any: def teardown( self, cell: CellImpl, - glbls: dict[str, Any], + glbls: MutableGlobals, run_result: RunResult, # noqa: ARG002 ) -> None: backup = self._backups.pop(cell.cell_id, None) diff --git a/marimo/_runtime/runner/by_kwargs.py b/marimo/_runtime/runner/by_kwargs.py new file mode 100644 index 00000000000..90e85b05098 --- /dev/null +++ b/marimo/_runtime/runner/by_kwargs.py @@ -0,0 +1,190 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Lightweight runner functions for use in direct cell evaluation or testing. + +Walks the cell's ancestor closure (minus any ancestor whose defs the +caller substituted via kwargs), runs them with a fresh globals dict, +then runs the target cell. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from marimo._runtime.control_flow import MarimoStopError +from marimo._runtime.exceptions import MarimoRuntimeException +from marimo._runtime.executor import ( + DefaultExecutor, + Evaluator, +) +from marimo._types.globals import Globals, MutableGlobals + +if TYPE_CHECKING: + from marimo._ast.cell import CellImpl + from marimo._runtime.dataflow.topology import GraphTopology + from marimo._runtime.runner.result import RunResult + from marimo._types.ids import CellId_t + + +def _new_evaluator() -> Evaluator: + """A fresh relaxed-mode Evaluator (no lifecycles).""" + return Evaluator(executor=DefaultExecutor(), lifecycles=[]) + + +def _returns(cell_impl: CellImpl, glbls: Globals) -> dict[str, Any]: + return {name: glbls[name] for name in cell_impl.defs if name in glbls} + + +def _substitute_refs( + cell_impl: CellImpl, + glbls: MutableGlobals, + kwargs: dict[str, Any], +) -> None: + for argname, argvalue in kwargs.items(): + if argname in cell_impl.refs: + glbls[argname] = argvalue + else: + raise ValueError( + f"Cell got unexpected argument {argname}" + f"The allowed arguments are {cell_impl.refs}." + ) + + +def _validate_kwargs(cell_impl: CellImpl, kwargs: dict[str, Any]) -> None: + for argname in kwargs: + if argname not in cell_impl.refs: + raise ValueError( + f"Cell got unexpected argument {argname}; " + f"The allowed arguments are {cell_impl.refs}." + ) + + +def _get_ancestors( + graph: GraphTopology, + cell_impl: CellImpl, + kwargs: dict[str, Any], +) -> set[CellId_t]: + from marimo._runtime.dataflow import transitive_closure + + substitutions = set(kwargs.keys()) + unsubstituted_refs = cell_impl.refs - substitutions + parent_ids = { + parent_id + for parent_id in graph.parents[cell_impl.cell_id] + if graph.cells[parent_id].defs.intersection(unsubstituted_refs) + } + return transitive_closure(graph, parent_ids, children=False) + + +def _classify(result: RunResult) -> MarimoStopError | None: + """Inspect a RunResult; raise on real errors, return the stop on mo.stop.""" + exc = result.exception + if exc is None: + return None + if isinstance(exc, MarimoStopError): + # Defensive: any caller bypassing `MarimoRuntimeException` + # wrapping (e.g. a custom Executor that raises directly) still + # gets stop-control-flow handling. + return exc + if isinstance(exc, MarimoRuntimeException) and isinstance( + exc.__cause__, MarimoStopError + ): + return exc.__cause__ + if isinstance(exc, BaseException): + raise exc + return None + + +def is_coroutine( + graph: GraphTopology, + cell_id: CellId_t, + kwargs: dict[str, Any] | None = None, +) -> bool: + """True if the cell or any of its unsubstituted ancestors is async. + + NB. Currently expensive due to calls on graph. + """ + return graph.cells[cell_id].is_coroutine() or any( + graph.cells[cid].is_coroutine() + for cid in _get_ancestors(graph, graph.cells[cell_id], kwargs or {}) + ) + + +async def run_cell_async( + graph: GraphTopology, + cell_id: CellId_t, + kwargs: dict[str, Any], +) -> tuple[Any, MutableGlobals]: + """Run a possibly async cell and its ancestors. + + Substitutes kwargs as refs for the cell, omitting ancestors whose + refs are substituted. + """ + from marimo._runtime.dataflow import topological_sort + + cell_impl = graph.cells[cell_id] + _validate_kwargs(cell_impl, kwargs) + ancestor_ids = _get_ancestors(graph, cell_impl, kwargs) + + evaluator = _new_evaluator() + glbls: MutableGlobals = {} + for cid in topological_sort(graph, ancestor_ids): + stop = _classify(await evaluator.evaluate(graph.cells[cid], glbls)) + if stop is not None: + return stop.output, _returns(cell_impl, glbls) + + _substitute_refs(cell_impl, glbls, kwargs) + target_result = await evaluator.evaluate( + graph.cells[cell_impl.cell_id], glbls + ) + stop = _classify(target_result) + if stop is not None: + return stop.output, _returns(cell_impl, glbls) + return target_result.output, _returns(cell_impl, glbls) + + +def run_cell_sync( + graph: GraphTopology, + cell_id: CellId_t, + kwargs: dict[str, Any], +) -> tuple[Any, MutableGlobals]: + """Run a synchronous cell and its ancestors. + + Substitutes kwargs as refs for the cell, omitting ancestors whose + refs are substituted. + + Raises `RuntimeError` if the cell or any of its unsubstituted + ancestors are coroutine functions. + """ + from marimo._runtime.dataflow import topological_sort + + cell_impl = graph.cells[cell_id] + if cell_impl.is_coroutine(): + raise RuntimeError( + "A coroutine function can't be run synchronously. " + "Use `run_async()` instead" + ) + + _validate_kwargs(cell_impl, kwargs) + ancestor_ids = _get_ancestors(graph, cell_impl, kwargs) + + if any(graph.cells[cid].is_coroutine() for cid in ancestor_ids): + raise RuntimeError( + "Cell has an ancestor that is a " + "coroutine (async) cell. Use `run_async()` instead" + ) + + evaluator = _new_evaluator() + glbls: MutableGlobals = {} + for cid in topological_sort(graph, ancestor_ids): + stop = _classify(evaluator.evaluate_sync(graph.cells[cid], glbls)) + if stop is not None: + return stop.output, _returns(cell_impl, glbls) + + _substitute_refs(cell_impl, glbls, kwargs) + target_result = evaluator.evaluate_sync( + graph.cells[cell_impl.cell_id], glbls + ) + stop = _classify(target_result) + if stop is not None: + return stop.output, _returns(cell_impl, glbls) + return target_result.output, _returns(cell_impl, glbls) diff --git a/marimo/_runtime/runner/cell_runner.py b/marimo/_runtime/runner/cell_runner.py index d2c48ce9c3a..026be93bd03 100644 --- a/marimo/_runtime/runner/cell_runner.py +++ b/marimo/_runtime/runner/cell_runner.py @@ -2,11 +2,7 @@ from __future__ import annotations import asyncio -import contextlib -import functools import io -import signal -import threading import traceback from pathlib import Path from types import TracebackType @@ -55,7 +51,6 @@ if TYPE_CHECKING: from collections import deque - from collections.abc import Iterator from marimo._runtime.runner.hooks import NotebookCellHooks from marimo._runtime.state import State @@ -206,53 +201,6 @@ def compute_cells_to_run( return sorted_cells - # Adapted from - # https://github.com/ipython/ipykernel/blob/eddd3e666a82ebec287168b0da7cfa03639a3772/ipykernel/ipkernel.py#L312 - @staticmethod - @contextlib.contextmanager - def _cancel_on_sigint(future: asyncio.Future[Any]) -> Iterator[None]: - """ContextManager for capturing SIGINT and cancelling a future - - SIGINT raises in the event loop when running async code, - but we want it to halt a coroutine. - - Ideally, it would raise KeyboardInterrupt, but this turns it into a - CancelledError. - """ - sigint_future: asyncio.Future[int] = asyncio.Future() - - # whichever future finishes first, - # cancel the other one - def cancel_unless_done(f: asyncio.Future[Any], _: Any) -> None: - if f.cancelled() or f.done(): - return - f.cancel() - - # when sigint finishes, - # abort the coroutine with CancelledError - sigint_future.add_done_callback( - functools.partial(cancel_unless_done, future) - ) - # when the main future finishes, - # stop watching for SIGINT events - future.add_done_callback( - functools.partial(cancel_unless_done, sigint_future) - ) - - def handle_sigint(*_: Any) -> None: - if sigint_future.cancelled() or sigint_future.done(): - return - # mark as done, to trigger cancellation - sigint_future.set_result(1) - - # set the custom sigint handler during this context - save_sigint = signal.signal(signal.SIGINT, handle_sigint) - try: - yield - finally: - # restore the previous sigint handler - signal.signal(signal.SIGINT, save_sigint) - @property def cells_to_run(self) -> deque[CellId_t]: return self._scheduler.cells_to_run @@ -461,20 +409,9 @@ async def run(self, cell_id: CellId_t) -> RunResult: # returned RunResult; cell_id-specific classification + side # effects are applied below in `_finalize_run_result`. try: - if cell.is_coroutine(): - return_value_future = asyncio.ensure_future( - self._evaluator.evaluate(cell, self.glbls) - ) - if threading.current_thread() == threading.main_thread(): - # edit mode: need to handle user interrupts - with Runner._cancel_on_sigint(return_value_future): - raw_result = await return_value_future - else: - # run mode: can't use signal.signal, not interruptible - # by user anyway. - raw_result = await return_value_future - else: - raw_result = await self._evaluator.evaluate(cell, self.glbls) + raw_result = await self._evaluator.evaluate_interruptible( + cell, self.glbls + ) run_result = self._finalize_run_result(raw_result, cell_id) except BaseException: # Defensive: an unexpected escape from the Evaluator or a bug diff --git a/marimo/_runtime/runner/hooks_post_execution.py b/marimo/_runtime/runner/hooks_post_execution.py index a2911ede32d..25b6ef3f593 100644 --- a/marimo/_runtime/runner/hooks_post_execution.py +++ b/marimo/_runtime/runner/hooks_post_execution.py @@ -35,7 +35,6 @@ ) from marimo._messaging.tracebacks import ( _highlight_traceback, - _trim_traceback, write_traceback, ) from marimo._messaging.variables import create_variable_value @@ -426,9 +425,7 @@ def _broadcast_outputs( and run_result.exception.__traceback__ ): tb_lines = tb.format_exception(run_result.exception) - formatted_traceback = _highlight_traceback( - _trim_traceback("".join(tb_lines)) - ) + formatted_traceback = _highlight_traceback("".join(tb_lines)) CellNotificationUtils.broadcast_error( data=[ diff --git a/marimo/_types/globals.py b/marimo/_types/globals.py new file mode 100644 index 00000000000..ab51826c4f7 --- /dev/null +++ b/marimo/_types/globals.py @@ -0,0 +1,15 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Type aliases for cell globals dicts. + +`MutableGlobals` is the concrete `dict` passed through `exec` / +`eval`; `Globals` is the read-only view for consumers that only +inspect the dict (e.g. collecting a cell's defs after execution). +""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, TypeAlias + +Globals: TypeAlias = Mapping[str, Any] +MutableGlobals: TypeAlias = dict[str, Any] diff --git a/tests/_ast/test_app.py b/tests/_ast/test_app.py index c91e4bb8548..064c8da4df8 100644 --- a/tests/_ast/test_app.py +++ b/tests/_ast/test_app.py @@ -1,15 +1,13 @@ # Copyright 2026 Marimo. All rights reserved. from __future__ import annotations -import os + import pathlib import subprocess import sys import textwrap from typing import TYPE_CHECKING, Any -from unittest.mock import patch -import click import pytest from marimo._ast.app import ( @@ -28,21 +26,13 @@ SetupRootError, UnparsableError, ) -from marimo._ast.load import load_app from marimo._ast.names import SETUP_CELL_NAME -from marimo._convert.converters import MarimoConvert from marimo._dependencies.dependencies import DependencyManager from marimo._plugins.stateless.flex import vstack +from marimo._runtime.commands import UpdateUIElementCommand from marimo._runtime.context.types import get_context -from marimo._runtime.commands import UpdateUIElementCommand, ExecuteCellCommand -from marimo._schemas.serialization import ( - AppInstantiation, - CellDef, - NotebookSerializationV1, -) from marimo._types.ids import CellId_t -from tests.conftest import ExecReqProvider, MockedKernel -from tests._messaging.mocks import MockStream +from tests.conftest import ExecReqProvider if TYPE_CHECKING: from marimo._runtime.runtime import Kernel @@ -102,7 +92,7 @@ def test_run_with_docstring() -> None: @app.cell def _() -> tuple[object]: - doc = __doc__ # noqa: F821 + doc = __doc__ return (doc,) _, defs = app.run() @@ -115,7 +105,7 @@ def test_run_with_no_docstring() -> None: @app.cell def _() -> tuple[object]: - doc = __doc__ # noqa: F821 + doc = __doc__ return (doc,) _, defs = app.run() @@ -133,7 +123,9 @@ def config() -> tuple[int, float]: return batch_size, learning_rate @app.cell - def process_data(batch_size: int, learning_rate: float) -> tuple[float]: + def process_data( + batch_size: int, learning_rate: float + ) -> tuple[float]: result = batch_size * learning_rate return (result,) @@ -150,7 +142,9 @@ def other_cell() -> tuple[str]: assert defs["message"] == "independent" # Test 2: Run with overridden values - outputs, defs = app.run(defs={"batch_size": 64, "learning_rate": 0.001}) + outputs, defs = app.run( + defs={"batch_size": 64, "learning_rate": 0.001} + ) assert defs["batch_size"] == 64 assert defs["learning_rate"] == 0.001 assert defs["result"] == 64 * 0.001 @@ -203,7 +197,6 @@ def test_run_with_refs_setup_cell_protection() -> None: app = App() with app.setup: - import os setup_var = "from_setup" @app.cell @@ -232,7 +225,6 @@ def test_run_with_undefined_refs_in_setup_cell() -> None: app = App() with app.setup: - import os a = 1 if a > 2: setup_var = "from_setup" @@ -248,7 +240,6 @@ def use_setup(setup_var: str) -> tuple[str]: app.run() assert "setup_var" in str(exc_info.value) - @staticmethod def test_setup() -> None: app = App() @@ -291,7 +282,6 @@ def __() -> tuple[int, int]: assert (defs["y"], defs["z"]) == (1, 2) assert defs["a"] == 2 - @staticmethod def test_cycle() -> None: app = App() @@ -315,11 +305,11 @@ def test_cycle_missing_args_rets() -> None: @app.cell def one() -> None: - x = y # noqa: F841, F821 + x = y # noqa: F821 @app.cell def two() -> None: - y = x # noqa: F841, F821 + y = x # noqa: F821 with pytest.raises(CycleError): app.run() @@ -347,11 +337,11 @@ def test_multiple_definitions_missing_args_rets() -> None: @app.cell def one() -> None: - x = 0 # noqa: F841 + x = 0 @app.cell def two() -> None: - x = 0 # noqa: F841 + x = 0 with pytest.raises(MultipleDefinitionError): app.run() @@ -362,11 +352,11 @@ def test_delete_nonlocal_ok() -> None: @app.cell def one() -> None: - x = 0 # noqa: F841 + x = 0 @app.cell def two() -> None: - del x # noqa: F841, F821 + del x # noqa: F821 # smoke test, no error raised app.run() @@ -426,7 +416,7 @@ def test_resolve_var_not_local_from_nested_scope() -> None: @app.cell def _() -> tuple[str]: - _x = 10 # noqa: F841 + _x = 10 def _f() -> str: _x = "nested" @@ -454,7 +444,7 @@ def _f() -> str: @app.cell def _() -> None: - _x # type: ignore # noqa: F821 + _x # type: ignore return with pytest.raises(NameError) as e: @@ -468,7 +458,7 @@ def test_locals_dont_leak() -> None: @app.cell def _() -> None: - _x = 0 # noqa: F841 + _x = 0 return @app.cell @@ -503,7 +493,7 @@ def test_dunder_rewritten_as_local() -> None: @app.cell def _() -> None: - __ = 1 # noqa: F841 + __ = 1 return @app.cell @@ -648,6 +638,7 @@ def test_run_mo_stop() -> None: @app.cell def _() -> Any: import marimo as mo + return (mo,) @app.cell @@ -672,6 +663,7 @@ def test_run_mo_stop_descendant() -> None: @app.cell def _() -> Any: import marimo as mo + return (mo,) @app.cell @@ -697,6 +689,7 @@ def test_run_mo_stop_descendant_multiple() -> None: @app.cell def _() -> Any: import marimo as mo + return (mo,) @app.cell @@ -711,7 +704,6 @@ def _(mo) -> tuple[int]: y = 0 return (y,) - @app.cell def _(x) -> tuple[int]: x @@ -724,13 +716,42 @@ def _(y) -> tuple[int]: b = 0 return - _, defs = app.run() assert "x" not in defs assert "y" not in defs assert "a" not in defs assert "b" not in defs + @staticmethod + def test_run_mo_stop_records_output() -> None: + # mo.stop's `output=` arg is shown to the user in edit/kernel mode. + # Script mode used to silently drop it; cell_runner has always + # recorded it. The runner-consolidation refactor aligns these. + app = App() + + @app.cell + def first() -> Any: + import marimo as mo + + return (mo,) + + @app.cell + def stop_cell(mo) -> tuple[int]: + mo.stop(True, output="stopped-output-value") + x = 0 + return (x,) + + @app.cell + def descendant(x) -> tuple[int]: + y = x + 1 + return (y,) + + outputs, defs = app.run() + # stop_cell at index 1 records the stop's output; its descendant + # is cancelled and absent from the flattened output tuple. + assert outputs == (None, "stopped-output-value") + assert "x" not in defs + assert "y" not in defs @staticmethod def test_run_mo_stop_async() -> None: @@ -739,6 +760,7 @@ def test_run_mo_stop_async() -> None: @app.cell def _() -> Any: import marimo as mo + return (mo,) @app.cell @@ -763,6 +785,7 @@ def test_run_mo_stop_descendant_async() -> None: @app.cell def _() -> Any: import marimo as mo + return (mo,) @app.cell @@ -781,7 +804,6 @@ async def _(x) -> tuple[int]: assert "x" not in defs assert "y" not in defs - @pytest.mark.skipif( condition=not DependencyManager.matplotlib.has(), reason="requires matplotlib", @@ -955,11 +977,16 @@ def __(): # Public mutable fields should be deep-copied, not shared assert original_impl.config is not cloned_impl.config - assert original_impl.import_workspace is not cloned_impl.import_workspace + assert ( + original_impl.import_workspace is not cloned_impl.import_workspace + ) # Private mutable runtime state fields should also be independent assert original_impl._status is not cloned_impl._status - assert original_impl._run_result_status is not cloned_impl._run_result_status + assert ( + original_impl._run_result_status + is not cloned_impl._run_result_status + ) assert original_impl._stale is not cloned_impl._stale assert original_impl._output is not cloned_impl._output @@ -1030,8 +1057,7 @@ class TestInvalidSetup: @staticmethod def test_initial_setup() -> None: app = App() - app._unparsable_cell(";", - name="setup") + app._unparsable_cell(";", name="setup") assert app._cell_manager.has_cell("setup") assert app._cell_manager.cell_name("setup") == "setup" @@ -1039,22 +1065,21 @@ def test_initial_setup() -> None: @staticmethod def test_not_initial_setup() -> None: app = App() - app._unparsable_cell(";", - name="other") - app._unparsable_cell(";", - name="setup") + app._unparsable_cell(";", name="other") + app._unparsable_cell(";", name="setup") assert not app._cell_manager.has_cell("setup") @staticmethod def test_not_initial_setup_cell() -> None: app = App() + @app.cell def _(): def B() -> float: return 1.0 - app._unparsable_cell(";", - name="setup") + + app._unparsable_cell(";", name="setup") assert not app._cell_manager.has_cell("setup") @@ -1205,7 +1230,9 @@ def __() -> tuple[int]: with pytest.raises(ValueError) as excinfo: await app.embed(defs={"x": mo.ui.slider(1, 10)}) - assert "Substituting UI Elements for variables is not allowed" in str(excinfo.value) + assert "Substituting UI Elements for variables is not allowed" in str( + excinfo.value + ) async def test_app_embed_with_defs_multiple_vars(self) -> None: """Test embed() with defs overriding a cell that defines multiple variables.""" @@ -1532,8 +1559,6 @@ def test_app_not_changed() -> None: with app.setup: app = 1 - - @staticmethod def test_setup_not_exposed() -> None: app = App() @@ -1545,7 +1570,6 @@ def test_setup_not_exposed() -> None: except NameError: x = False - @staticmethod def test_setup_in_memory() -> None: app = App() @@ -1598,11 +1622,8 @@ def test_setup_hide_code() -> None: assert setup_cell is not None assert setup_cell.config.hide_code is False - @staticmethod - async def test_app_embed_preserves_file_path( - app: App - ) -> None: + async def test_app_embed_preserves_file_path(app: App) -> None: with app.setup: from tests._ast.app_data import notebook_filename @@ -1624,7 +1645,6 @@ def _(cloned: AppEmbedResult, filename: str, directory: str) -> None: assert cloned.defs.get("this_is_foo_file").endswith(filename) assert cloned.defs.get("this_is_foo_path").stem == directory - @staticmethod async def test_app_embed_in_kernel( k: Kernel, exec_req: ExecReqProvider @@ -1648,10 +1668,13 @@ async def test_app_embed_in_kernel( filename = "notebook_filename.py" directory = "app_data" assert k.globals["app"].defs.get("this_is_foo_file").endswith(filename) - assert k.globals["cloned"].defs.get("this_is_foo_file").endswith(filename) + assert ( + k.globals["cloned"].defs.get("this_is_foo_file").endswith(filename) + ) assert k.globals["app"].defs.get("this_is_foo_path").stem == directory - assert k.globals["cloned"].defs.get("this_is_foo_path").stem == directory - + assert ( + k.globals["cloned"].defs.get("this_is_foo_path").stem == directory + ) @staticmethod async def test_app_embed_same_cell_in_kernel( @@ -1693,8 +1716,9 @@ async def test_imported_app_has_prefixed_setup_cell( This tests the fix where setup cells get the prefix like other cells. """ - await k.run([ - exec_req.get(""" + await k.run( + [ + exec_req.get(""" # Import in kernel context; the prefix the app gets # depends on whether it was first imported in a kernel context, # so we reload it in case notebook_filename was loaded elsewhere @@ -1705,11 +1729,14 @@ async def test_imported_app_has_prefixed_setup_cell( importlib.reload(mod) app = mod.app """) - ]) + ] + ) assert not k.errors nb_app = k.globals["app"] cell_ids = list(InternalApp(nb_app).cell_manager.cell_ids()) - setup_cell_ids = [cid for cid in cell_ids if cid.endswith(SETUP_CELL_NAME)] + setup_cell_ids = [ + cid for cid in cell_ids if cid.endswith(SETUP_CELL_NAME) + ] assert len(setup_cell_ids) == 1 assert is_external_cell_id(setup_cell_ids[0]) @@ -1780,9 +1807,7 @@ def _(): internal_app = InternalApp(app) cell_id = next(iter(internal_app.cell_manager.cell_ids())) - original_compiled = internal_app.cell_manager._compiled_cells[ - cell_id - ] + original_compiled = internal_app.cell_manager._compiled_cells[cell_id] assert original_compiled is not None internal_app.with_data( @@ -1901,7 +1926,6 @@ def __(x: int) -> tuple[int]: assert not k.errors assert k.globals["overrides"] == {"x": 100} - @pytest.mark.xfail( True, reason="Flaky in CI, can't repro locally", strict=False ) diff --git a/tests/_messaging/test_tracebacks.py b/tests/_messaging/test_tracebacks.py index ddd4fe4350e..00fc5d906d8 100644 --- a/tests/_messaging/test_tracebacks.py +++ b/tests/_messaging/test_tracebacks.py @@ -6,7 +6,6 @@ from marimo._messaging.context import HTTP_REQUEST_CTX, is_code_mode_request from marimo._messaging.tracebacks import ( _highlight_traceback, - _trim_traceback, is_code_highlighting, write_traceback, ) @@ -221,11 +220,3 @@ def test_empty_url(self) -> None: assert is_code_mode_request() is False finally: HTTP_REQUEST_CTX.reset(token) - - def test_trim(self) -> None: - prefix = "Traceback (most recent call last):\n" - head = ' File ".../marimo/_runtime/executor.py", line 139, in execute_cell\n return eval(cell.last_expr, glbls)\n ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n' - rest = ( - ' File ".../__marimo__cell_Hbol_.py", line 2, in \n...\n' - ) - assert _trim_traceback(f"{prefix}{head}{rest}") == f"{prefix}{rest}" diff --git a/tests/_runtime/runner/test_cell_runner.py b/tests/_runtime/runner/test_cell_runner.py index 9da65671c89..4ab2f601517 100644 --- a/tests/_runtime/runner/test_cell_runner.py +++ b/tests/_runtime/runner/test_cell_runner.py @@ -276,3 +276,132 @@ async def test_converging_runs_when_all_branches_trigger( assert "b" in k.globals assert "result" in k.globals assert k.graph.cells["res"].run_result_status == "success" + + +# --- Surface 3: registered plugin Executor runs via Runner ------------------ + + +async def test_runner_dispatches_to_registered_plugin_executor( + execution_kernel: Kernel, + exec_req: ExecReqProvider, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A factory registered against `marimo.cell.executor` is the one + the kernel `Runner` dispatches through.""" + from typing import Any + + from marimo._runtime.executor.evaluator import _EXECUTOR_REGISTRY + + recorded: list[str] = [] + sentinel_output = object() + + class _SentinelExecutor: + name = "sentinel" + + def execute_cell(self, cell: Any, glbls: dict[str, Any]) -> object: + del glbls + recorded.append(cell.cell_id) + return sentinel_output + + async def execute_cell_async( + self, cell: Any, glbls: dict[str, Any] + ) -> object: + del glbls + recorded.append(cell.cell_id) + return sentinel_output + + def factory() -> _SentinelExecutor: + return _SentinelExecutor() + + # Populate the kernel first (uses the real DefaultExecutor — the + # registry isn't patched yet). + k = execution_kernel + await k.run([er := exec_req.get("'hello'; 123")]) + + # Fully isolate the registry: replace both `_plugins` and + # `names` so installed third-party entry points can't shadow the + # sentinel. monkeypatch restores both on teardown. + monkeypatch.setattr(_EXECUTOR_REGISTRY, "_plugins", {"sentinel": factory}) + monkeypatch.setattr(_EXECUTOR_REGISTRY, "names", lambda: ["sentinel"]) + + runner = Runner( + roots=set(k.graph.cells.keys()), + graph=k.graph, + glbls=k.globals, + debugger=k.debugger, + hooks=NotebookCellHooks(), + ) + run_result = await runner.run(er.cell_id) + + assert recorded == [er.cell_id] + assert run_result.output is sentinel_output + + +# --- Surface 4: Runner.interrupted flips on cancellation -------------------- + + +async def test_runner_interrupted_flag_flips_on_sync_marimo_interrupt( + execution_kernel: Kernel, exec_req: ExecReqProvider +) -> None: + """Sync cell body raising `MarimoInterrupt` (== `KeyboardInterrupt`) + surfaces as a bare `MarimoInterrupt` in the run result and flips + `runner.interrupted`.""" + k = execution_kernel + await k.run([er := exec_req.get("raise KeyboardInterrupt")]) + + runner = Runner( + roots=set(k.graph.cells.keys()), + graph=k.graph, + glbls=k.globals, + debugger=k.debugger, + hooks=NotebookCellHooks(), + ) + with capture_stderr(): + await runner.run(er.cell_id) + + assert runner.interrupted is True + + +async def test_runner_interrupted_flag_flips_on_async_cell_cancellation( + execution_kernel: Kernel, + exec_req: ExecReqProvider, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """An async cell cancelled mid-await flips `runner.interrupted`. + + A bare `asyncio.CancelledError` arriving in `RunResult.exception` is + converted to `MarimoInterrupt` by the bare-`CancelledError` branch + of `_finalize_run_result`, which `run()` recognises to flip the + flag. + + Simulates the evaluator output directly (a bare `CancelledError` in + the `RunResult`) so this test is independent of the executor's + coroutine compilation. + """ + import asyncio + + from marimo._runtime.runner.result import RunResult + + k = execution_kernel + await k.run([er := exec_req.get("123")]) + + runner = Runner( + roots=set(k.graph.cells.keys()), + graph=k.graph, + glbls=k.globals, + debugger=k.debugger, + hooks=NotebookCellHooks(), + ) + + async def fake_evaluate(cell, glbls): # type: ignore[no-untyped-def] + del cell, glbls + return RunResult(output=None, exception=asyncio.CancelledError()) + + monkeypatch.setattr( + runner._evaluator, "evaluate_interruptible", fake_evaluate + ) + + with capture_stderr(): + await runner.run(er.cell_id) + + assert runner.interrupted is True diff --git a/tests/_runtime/test_dataflow.py b/tests/_runtime/test_dataflow.py index e220e576821..a6087a183f2 100644 --- a/tests/_runtime/test_dataflow.py +++ b/tests/_runtime/test_dataflow.py @@ -10,6 +10,8 @@ from marimo._ast.visitor import Name, VariableData from marimo._dependencies.dependencies import DependencyManager from marimo._runtime import dataflow +from marimo._runtime.runner import by_kwargs +from marimo._runtime.runner.by_kwargs import _get_ancestors parse_cell = partial(compiler.compile_cell, cell_id="0") @@ -944,7 +946,7 @@ def test_is_disabled() -> None: def test_runner_sync() -> None: - """Test the Runner class for synchronous execution.""" + """Synchronous Cell.run(**kwargs) path. Must work without an event loop.""" graph = dataflow.DirectedGraph() # Create a chain of cells: 0 -> 1 -> 2 @@ -960,18 +962,15 @@ def test_runner_sync() -> None: third_cell = compiler.compile_cell(code, cell_id="2") graph.register_cell("2", third_cell) - # Create a runner - runner = dataflow.Runner(graph) - # Run the last cell - output, defs = runner.run_cell_sync("2", {}) + output, defs = by_kwargs.run_cell_sync(graph, "2", {}) # Check output and definitions assert output == 25 # 10 * 2 + 5 assert defs == {"z": 25} # Run the last cell with substituted values - output, defs = runner.run_cell_sync("2", {"y": 50}) + output, defs = by_kwargs.run_cell_sync(graph, "2", {"y": 50}) # Check output and definitions with substituted value assert output == 55 # 50 + 5 @@ -979,14 +978,14 @@ def test_runner_sync() -> None: # Try to run with an invalid argument try: - runner.run_cell_sync("2", {"invalid": 100}) + by_kwargs.run_cell_sync(graph, "2", {"invalid": 100}) raise AssertionError("Should have raised an exception") except ValueError: pass # Expected def test_runner_ancestors() -> None: - """Test that the Runner correctly identifies ancestors based on refs.""" + """Ancestor pruning based on substituted refs.""" graph = dataflow.DirectedGraph() # Create cells with different refs/defs patterns @@ -1002,19 +1001,16 @@ def test_runner_ancestors() -> None: third_cell = compiler.compile_cell(code, cell_id="2") graph.register_cell("2", third_cell) - # Create a runner - runner = dataflow.Runner(graph) - # Get ancestors of the third cell - ancestors = runner._get_ancestors(graph.cells["2"], {}) + ancestors = _get_ancestors(graph, graph.cells["2"], {}) assert ancestors == {"0", "1"} # When substituting y, only cell 0 should be an ancestor - ancestors = runner._get_ancestors(graph.cells["2"], {"y": 30}) + ancestors = _get_ancestors(graph, graph.cells["2"], {"y": 30}) assert ancestors == {"0"} # When substituting both x and y, there should be no ancestors - ancestors = runner._get_ancestors(graph.cells["2"], {"x": 40, "y": 30}) + ancestors = _get_ancestors(graph, graph.cells["2"], {"x": 40, "y": 30}) assert ancestors == set() diff --git a/tests/_runtime/test_exceptions.py b/tests/_runtime/test_exceptions.py new file mode 100644 index 00000000000..ad68a2fc76b --- /dev/null +++ b/tests/_runtime/test_exceptions.py @@ -0,0 +1,85 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Unit tests for `marimo._runtime.exceptions.unwrap_user_exception`.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +from marimo._runtime.exceptions import ( + MarimoMissingRefError, + MarimoRuntimeException, + unwrap_user_exception, +) + + +def _wrap(cause: BaseException) -> MarimoRuntimeException: + """Build a `MarimoRuntimeException` with `__cause__` set.""" + try: + raise MarimoRuntimeException from cause + except MarimoRuntimeException as exc: + return exc + + +def _graph(definitions: set[str]) -> Any: + """Minimal graph stub with only the attribute `unwrap` reads.""" + return SimpleNamespace(definitions=definitions) + + +def test_unwrap_no_graph_returns_raw_cause() -> None: + cause = ValueError("boom") + wrapped = _wrap(cause) + + assert unwrap_user_exception(wrapped) is cause + + +def test_unwrap_nameerror_without_graph_unchanged() -> None: + """No graph → upgrade never fires, even for NameError.""" + cause = NameError("name 'x' is not defined") + cause.name = "x" # set explicitly; constructor doesn't. + wrapped = _wrap(cause) + + assert unwrap_user_exception(wrapped) is cause + + +def test_unwrap_nameerror_with_graph_upgrades_when_in_definitions() -> None: + cause = NameError("name 'x' is not defined") + cause.name = "x" + wrapped = _wrap(cause) + + unwrapped = unwrap_user_exception(wrapped, graph=_graph({"x"})) + + assert isinstance(unwrapped, MarimoMissingRefError) + assert unwrapped.ref == "x" + assert unwrapped.name_error is cause + + +def test_unwrap_nameerror_with_graph_passthrough_when_not_in_definitions() -> ( + None +): + """`.name` is set but the graph doesn't define it — no upgrade.""" + cause = NameError("name 'x' is not defined") + cause.name = "x" + wrapped = _wrap(cause) + + assert unwrap_user_exception(wrapped, graph=_graph(set())) is cause + + +def test_unwrap_nameerror_with_none_name_returns_raw() -> None: + """`NameError.name is None` (the constructor default) → upgrade + short-circuits via the `if name and …` guard.""" + cause = NameError("name 'x' is not defined") + # Don't set `.name` — leave it as the constructor's default + # (None on most CPython versions). The guard must not upgrade. + assert getattr(cause, "name", None) is None + wrapped = _wrap(cause) + + # Even with `x` in graph.definitions, the guard prevents upgrade. + assert unwrap_user_exception(wrapped, graph=_graph({"x"})) is cause + + +def test_unwrap_no_cause_returns_none() -> None: + """`MarimoRuntimeException` raised without `from …` has no cause.""" + wrapped = MarimoRuntimeException() + + assert unwrap_user_exception(wrapped) is None diff --git a/tests/_runtime/test_executor_evaluator.py b/tests/_runtime/test_executor_evaluator.py index fb24fe5148d..6e05053b917 100644 --- a/tests/_runtime/test_executor_evaluator.py +++ b/tests/_runtime/test_executor_evaluator.py @@ -15,6 +15,8 @@ import asyncio from typing import Any +import pytest + from marimo._runtime.exceptions import MarimoRuntimeException from marimo._runtime.executor import ( DefaultExecutor, @@ -201,6 +203,53 @@ def is_coroutine(self) -> bool: assert isinstance(a.last_run_result.exception, MarimoRuntimeException) +def _cause_traceback_filenames(exc: BaseException) -> list[str]: + cause = exc.__cause__ + assert cause is not None + tb = cause.__traceback__ + files: list[str] = [] + while tb is not None: + files.append(tb.tb_frame.f_code.co_filename) + tb = tb.tb_next + return files + + +def test_default_executor_strips_own_frame_from_cause_sync() -> None: + """`DefaultExecutor.execute_cell` must not leave its own frame on + the cause's `__traceback__` — user-facing tracebacks should begin + at user code (the compiled `` source).""" + + class _FakeCell: + cell_id = "0" + body = compile("raise ValueError('user bomb')", "", "exec") + last_expr = compile("None", "", "eval") + + with pytest.raises(MarimoRuntimeException) as exc_info: + DefaultExecutor().execute_cell(_FakeCell(), {}) # type: ignore[arg-type] + + files = _cause_traceback_filenames(exc_info.value) + assert files, "cause traceback unexpectedly empty" + assert not any("executor/executor.py" in f for f in files), files + assert files[0] == "" + + +async def test_default_executor_strips_own_frame_from_cause_async() -> None: + """Same as the sync variant, for `execute_cell_async`.""" + + class _FakeCell: + cell_id = "0" + body = compile("raise ValueError('user bomb')", "", "exec") + last_expr = compile("None", "", "eval") + + with pytest.raises(MarimoRuntimeException) as exc_info: + await DefaultExecutor().execute_cell_async(_FakeCell(), {}) # type: ignore[arg-type] + + files = _cause_traceback_filenames(exc_info.value) + assert files, "cause traceback unexpectedly empty" + assert not any("executor/executor.py" in f for f in files), files + assert files[0] == "" + + async def test_teardown_runs_for_completed_setups_when_later_setup_raises() -> ( None ): @@ -298,6 +347,158 @@ def get_transitive_references( assert glbls["y"] == pre["y"] +class _StrictGraph: + """`_FakeGraph` for `StrictLifecycle` setup-path tests. + + `transitive_refs` controls what `get_transitive_references` returns + so the test can drive `setup` past sanitization into the + error-construction branch. `defining_cells` maps refs to defining + cell IDs; refs absent from the map raise `KeyError` to exercise + the `unmangle_local` fallback. + """ + + def __init__( + self, + transitive_refs: set[str], + defining_cells: dict[str, list[str]] | None = None, + ) -> None: + self._transitive_refs = transitive_refs + self._defining_cells = defining_cells or {} + + def get_transitive_references( + self, refs: set[str], predicate: Any + ) -> set[str]: + return set(self._transitive_refs) + + def get_defining_cells(self, ref: str) -> list[str]: + return self._defining_cells[ref] + + +class _StrictCell: + def __init__(self, refs: set[str], defs: set[str] | None = None) -> None: + self.cell_id = "c0" + self.refs = refs + self.defs = defs or set() + + +def test_strict_setup_skip_on_undefined_ref() -> None: + """Unresolved ref → `Skip(result=RunResult(output=err, exception=err))` + where `err` is a `MarimoStrictExecutionError` with no blamed cell + (graph has no defining cell and the ref is not a private var).""" + from marimo._messaging.errors import MarimoStrictExecutionError + from marimo._runtime.executor.lifecycles.strict import StrictLifecycle + + lifecycle = StrictLifecycle( + graph=_StrictGraph(transitive_refs={"x"}) # type: ignore[arg-type] + ) + glbls: dict[str, Any] = {"__builtins__": {}} + + skip = lifecycle.setup(_StrictCell(refs={"x"}), glbls) # type: ignore[arg-type] + + assert skip is not None + assert skip.result is not None + err = skip.result.exception + assert isinstance(err, MarimoStrictExecutionError) + assert err.ref == "x" + assert err.blamed_cell is None + assert skip.result.output is err + + +def test_strict_setup_skip_on_ref_before_def() -> None: + """Ref appears in the cell's own `defs` → ref-before-def branch.""" + from marimo._messaging.errors import MarimoStrictExecutionError + from marimo._runtime.executor.lifecycles.strict import StrictLifecycle + + lifecycle = StrictLifecycle( + graph=_StrictGraph(transitive_refs={"x"}) # type: ignore[arg-type] + ) + glbls: dict[str, Any] = {"__builtins__": {}} + + skip = lifecycle.setup( + _StrictCell(refs={"x"}, defs={"x"}), # type: ignore[arg-type] + glbls, + ) + + assert skip is not None + assert skip.result is not None + err = skip.result.exception + assert isinstance(err, MarimoStrictExecutionError) + assert err.ref == "x" + assert err.blamed_cell is None + + +def test_strict_setup_skip_resolves_blamed_cell_via_graph() -> None: + """`get_defining_cells` returns the owning cell → blamed_cell.""" + from marimo._messaging.errors import MarimoStrictExecutionError + from marimo._runtime.executor.lifecycles.strict import StrictLifecycle + + lifecycle = StrictLifecycle( + graph=_StrictGraph( # type: ignore[arg-type] + transitive_refs={"x"}, + defining_cells={"x": ["other"]}, + ) + ) + glbls: dict[str, Any] = {"__builtins__": {}} + + skip = lifecycle.setup(_StrictCell(refs={"x"}), glbls) # type: ignore[arg-type] + + assert skip is not None + assert skip.result is not None + err = skip.result.exception + assert isinstance(err, MarimoStrictExecutionError) + assert err.blamed_cell == "other" + + +def test_strict_setup_skip_falls_back_to_private_var_owner() -> None: + """`KeyError` from the graph → `unmangle_local` resolves the + owning cell for mangled private vars.""" + from marimo._messaging.errors import MarimoStrictExecutionError + from marimo._runtime.executor.lifecycles.strict import StrictLifecycle + + # `_cell_ZZZ_priv` unmangles to (name="_priv", cell="ZZZ"). + private_ref = "_cell_ZZZ_priv" + lifecycle = StrictLifecycle( + graph=_StrictGraph( # type: ignore[arg-type] + transitive_refs={private_ref}, + ) + ) + glbls: dict[str, Any] = {"__builtins__": {}} + + skip = lifecycle.setup( + _StrictCell(refs={private_ref}), # type: ignore[arg-type] + glbls, + ) + + assert skip is not None + assert skip.result is not None + err = skip.result.exception + assert isinstance(err, MarimoStrictExecutionError) + assert err.blamed_cell == "ZZZ" + + +def test_strict_setup_skip_does_not_mutate_globals_or_stash_backup() -> None: + """The Skip early-return must happen before globals are cleared and + before the backup is stashed. `teardown` must then be a no-op.""" + from marimo._runtime.executor.lifecycles.strict import StrictLifecycle + + lifecycle = StrictLifecycle( + graph=_StrictGraph(transitive_refs={"x"}) # type: ignore[arg-type] + ) + glbls: dict[str, Any] = { + "preserve_me": 42, + "__builtins__": {}, + } + pre = dict(glbls) + + skip = lifecycle.setup(_StrictCell(refs={"x"}), glbls) # type: ignore[arg-type] + assert skip is not None + assert glbls == pre, "Skip path must not mutate globals" + assert lifecycle._backups == {}, "Skip path must not stash a backup" + + lifecycle.teardown(_StrictCell(refs={"x"}), glbls, skip.result) # type: ignore[arg-type] + assert glbls == pre, "teardown after Skip must be a no-op" + + def test_execution_lifecycle_protocol_conformance() -> None: """A Protocol-conforming class without inheriting works as a lifecycle.""" @@ -320,3 +521,167 @@ def teardown( # here, not at runtime. lifecycle: ExecutionLifecycle = _MyLifecycle() assert lifecycle.name == "mine" + + +# --- Surface 4: _cancel_on_sigint + evaluate_interruptible ------------------ + + +def _async_body(src: str) -> Any: + """Compile `src` with top-level-await support; returns a code object + whose `co_flags` carry `CO_COROUTINE` so `_is_coroutine` is True.""" + import ast + + return compile(src, "", "exec", flags=ast.PyCF_ALLOW_TOP_LEVEL_AWAIT) + + +async def test_cancel_on_sigint_installs_and_restores_handler( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """`_cancel_on_sigint` swaps in its own handler on enter and + restores the previously-installed one on exit.""" + import signal + + from marimo._runtime.executor.evaluator import _cancel_on_sigint + + def prior(signum: int, frame: Any) -> None: + del signum, frame + + signal_calls: list[tuple[int, Any]] = [] + + def fake_signal(signum: int, handler: Any) -> Any: + signal_calls.append((signum, handler)) + return prior + + monkeypatch.setattr(signal, "signal", fake_signal) + monkeypatch.setattr(signal, "getsignal", lambda _signum: prior) + + fut: asyncio.Future[Any] = asyncio.Future() + with _cancel_on_sigint(fut): + # On enter: a new handler installed (not the prior). + assert signal_calls, "no signal.signal call recorded on enter" + assert signal_calls[0][0] == signal.SIGINT + assert signal_calls[0][1] is not prior + + # On exit: prior handler restored as the last call. + assert signal_calls[-1] == (signal.SIGINT, prior) + + +async def test_cancel_on_sigint_handler_cancels_future_and_chains_prior( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """The installed handler must cancel the wrapped future and invoke + the previously-installed handler for its side effects.""" + import signal + + from marimo._runtime.executor.evaluator import _cancel_on_sigint + + prior_calls: list[tuple[int, Any]] = [] + + def prior(signum: int, frame: Any) -> None: + prior_calls.append((signum, frame)) + + captured: list[Any] = [] + + def fake_signal(signum: int, handler: Any) -> Any: + captured.append(handler) + return prior + + monkeypatch.setattr(signal, "signal", fake_signal) + monkeypatch.setattr(signal, "getsignal", lambda _signum: prior) + + fut: asyncio.Future[Any] = asyncio.Future() + with _cancel_on_sigint(fut): + marimo_handler = captured[0] + marimo_handler(signal.SIGINT, None) + # Cancellation propagates through done-callbacks asynchronously; + # yield to the loop so they fire. + await asyncio.sleep(0) + + assert fut.cancelled() + assert prior_calls == [(signal.SIGINT, None)] + + +async def test_cancel_on_sigint_swallows_marimo_interrupt_from_prior_handler( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Prior handler raising `MarimoInterrupt` must not escape — the + kernel's sync-mode raise is irrelevant for async cells, where the + halt comes from cancelling the future.""" + import signal + + from marimo._runtime.control_flow import MarimoInterrupt + from marimo._runtime.executor.evaluator import _cancel_on_sigint + + def prior(signum: int, frame: Any) -> None: + raise MarimoInterrupt + + captured: list[Any] = [] + + def fake_signal(signum: int, handler: Any) -> Any: + captured.append(handler) + return prior + + monkeypatch.setattr(signal, "signal", fake_signal) + monkeypatch.setattr(signal, "getsignal", lambda _signum: prior) + + fut: asyncio.Future[Any] = asyncio.Future() + with _cancel_on_sigint(fut): + marimo_handler = captured[0] + # No exception escapes — the wrapper catches MarimoInterrupt + # from the prior handler. + marimo_handler(signal.SIGINT, None) + await asyncio.sleep(0) + assert fut.cancelled() + + +async def test_executor_async_cancellation_propagates_unwrapped() -> None: + """`asyncio.CancelledError` must propagate unwrapped through + `DefaultExecutor.execute_cell_async` — wrapping it as + `MarimoRuntimeException` would mask the cancellation.""" + + class _AsyncCell: + cell_id = "0" + body = _async_body("import asyncio\nawait asyncio.sleep(100)") + last_expr = compile("None", "", "eval") + + def is_coroutine(self) -> bool: + return True + + task = asyncio.create_task( + DefaultExecutor().execute_cell_async(_AsyncCell(), {}) # type: ignore[arg-type] + ) + # Yield so the task enters the awaited sleep before we cancel. + await asyncio.sleep(0) + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + +async def test_evaluate_interruptible_no_op_for_sync_cell() -> None: + """Sync cells: `evaluate_interruptible` returns the same shape as a + direct `evaluate()` call. The SIGINT-handler wrap is for async only.""" + + class _SyncCell: + cell_id = "0" + body = compile("x = 1", "", "exec") + last_expr = compile("x", "", "eval") + + def is_coroutine(self) -> bool: + return False + + ev = Evaluator(executor=DefaultExecutor(), lifecycles=[]) + + sync_glbls: dict[str, Any] = {} + interruptible_glbls: dict[str, Any] = {} + + direct = await ev.evaluate(_SyncCell(), sync_glbls) # type: ignore[arg-type] + interruptible = await ev.evaluate_interruptible( + _SyncCell(), # type: ignore[arg-type] + interruptible_glbls, + ) + + assert direct.output == interruptible.output == 1 + assert direct.exception is None + assert interruptible.exception is None + assert direct.accumulated_output == interruptible.accumulated_output diff --git a/tests/_server/test_scratchpad_integration.py b/tests/_server/test_scratchpad_integration.py index b948223856c..86281316f9b 100644 --- a/tests/_server/test_scratchpad_integration.py +++ b/tests/_server/test_scratchpad_integration.py @@ -671,7 +671,7 @@ def test_ctx_create_cell_multiply_defined(session: _Session) -> None: assert lines == snapshot( [ "event: stderr", - 'data: {"data": "Traceback (most recent call last):\\n File \\"/marimo/_runtime/executor.py\\", line N, in execute_cell_async\\n await eval(cell.body, glbls)\\n File \\"\\", line 2, in \\n async with cm.get_context() as ctx:\\n File \\"/marimo/_code_mode/_context.py\\", line N, in __aexit__\\n self._dry_run_compile(ops)\\n File \\"/marimo/_code_mode/_context.py\\", line N, in _dry_run_compile\\n raise RuntimeError(\\nRuntimeError: Multiply-defined names:\\n - \'x\' is already defined in cell \'cell_a\' (cell_a)\\n\\nTo skip validation, use: async with cm.get_context(skip_validation=True) as ctx\\n"}', + 'data: {"data": "Traceback (most recent call last):\\n File \\"\\", line 2, in \\n async with cm.get_context() as ctx:\\n File \\"/marimo/_code_mode/_context.py\\", line N, in __aexit__\\n self._dry_run_compile(ops)\\n File \\"/marimo/_code_mode/_context.py\\", line N, in _dry_run_compile\\n raise RuntimeError(\\nRuntimeError: Multiply-defined names:\\n - \'x\' is already defined in cell \'cell_a\' (cell_a)\\n\\nTo skip validation, use: async with cm.get_context(skip_validation=True) as ctx\\n"}', "", "event: done", 'data: {"success": false, "output": {"mimetype": "text/plain", "data": ""}}',