diff --git a/marimo/_runtime/app/script_runner.py b/marimo/_runtime/app/script_runner.py index d6648ee62cf..b8d1a4a388d 100644 --- a/marimo/_runtime/app/script_runner.py +++ b/marimo/_runtime/app/script_runner.py @@ -19,11 +19,9 @@ from marimo._runtime.exceptions import ( MarimoMissingRefError, MarimoRuntimeException, + unwrap_user_exception, ) -from marimo._runtime.executor import ( - ExecutionConfig, - get_executor, -) +from marimo._runtime.executor import resolve_executor from marimo._runtime.patches import ( create_main_module, extract_docstring_from_header, @@ -67,7 +65,7 @@ def __init__( if app.cell_manager.cell_data_at(cid).cell is not None and not self.app.graph.is_disabled(cid) ) - self._executor = get_executor(ExecutionConfig()) + self._executor = resolve_executor() def _cancel(self, cell_id: CellId_t) -> None: cancelled = { @@ -108,15 +106,24 @@ def _run_synchronous( cell = self.app.graph.cells[cid] with get_context().with_cell_id(cid): try: - output = self._executor.execute_cell( - cell, glbls, self.app.graph - ) + output = self._executor.execute_cell(cell, glbls) outputs[cid] = output except MarimoRuntimeException as e: - unwrapped_exception: BaseException | None = e.__cause__ + unwrapped_exception = unwrap_user_exception( + e, self.app.graph + ) if isinstance(unwrapped_exception, MarimoStopError): self._cancel(cid) + elif isinstance( + unwrapped_exception, MarimoMissingRefError + ): + name_err = unwrapped_exception.name_error + raise ( + name_err + if name_err is not None + else unwrapped_exception + ) from None else: raise finally: @@ -155,14 +162,25 @@ async def _run_asynchronous( with get_context().with_cell_id(cid): try: output = await self._executor.execute_cell_async( - cell, glbls, self.app.graph + cell, glbls ) outputs[cid] = output except MarimoRuntimeException as e: - unwrapped_exception: BaseException | None = e.__cause__ + unwrapped_exception = unwrap_user_exception( + e, self.app.graph + ) if isinstance(unwrapped_exception, MarimoStopError): self._cancel(cid) + elif isinstance( + unwrapped_exception, MarimoMissingRefError + ): + name_err = unwrapped_exception.name_error + raise ( + name_err + if name_err is not None + else unwrapped_exception + ) from None else: raise finally: diff --git a/marimo/_runtime/dataflow/runner.py b/marimo/_runtime/dataflow/runner.py index 0b8bd0f737c..e01a54221e0 100644 --- a/marimo/_runtime/dataflow/runner.py +++ b/marimo/_runtime/dataflow/runner.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any -from marimo._runtime.executor import ExecutionConfig, get_executor +from marimo._runtime.executor import DefaultExecutor if TYPE_CHECKING: from marimo._ast.cell import CellImpl @@ -28,7 +28,7 @@ class Runner: def __init__(self, graph: DirectedGraph) -> None: self._graph = graph - self._executor = get_executor(ExecutionConfig()) + self._executor = DefaultExecutor() @staticmethod def _returns(cell_impl: CellImpl, glbls: dict[str, Any]) -> dict[str, Any]: @@ -99,13 +99,11 @@ async def run_cell_async( glbls: dict[str, Any] = {} for cid in topological_sort(graph, ancestor_ids): - await self._executor.execute_cell_async( - graph.cells[cid], glbls, graph - ) + await self._executor.execute_cell_async(graph.cells[cid], glbls) Runner._substitute_refs(cell_impl, glbls, kwargs) output = await self._executor.execute_cell_async( - graph.cells[cell_impl.cell_id], glbls, graph + graph.cells[cell_impl.cell_id], glbls ) defs = Runner._returns(cell_impl, glbls) return output, defs @@ -142,11 +140,11 @@ def run_cell_sync( glbls: dict[str, Any] = {} for cid in topological_sort(graph, ancestor_ids): - self._executor.execute_cell(graph.cells[cid], glbls, graph) + self._executor.execute_cell(graph.cells[cid], glbls) self._substitute_refs(cell_impl, glbls, kwargs) output = self._executor.execute_cell( - graph.cells[cell_impl.cell_id], glbls, graph + graph.cells[cell_impl.cell_id], glbls ) defs = Runner._returns(cell_impl, glbls) return output, defs diff --git a/marimo/_runtime/exceptions.py b/marimo/_runtime/exceptions.py index 8adf3ba5df9..fcb51f4cea0 100644 --- a/marimo/_runtime/exceptions.py +++ b/marimo/_runtime/exceptions.py @@ -1,6 +1,11 @@ # Copyright 2026 Marimo. All rights reserved. from __future__ import annotations +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from marimo._runtime.dataflow import DirectedGraph + class MarimoRuntimeException(BaseException): """Wrapper for all marimo runtime exceptions.""" @@ -19,3 +24,16 @@ def __init__(self, ref: str, name_error: NameError | None = None) -> None: super().__init__(ref) self.ref = ref self.name_error = name_error + + +def unwrap_user_exception( + exc: MarimoRuntimeException, + graph: DirectedGraph | None = None, +) -> BaseException | None: + """Extract the user exception from a `MarimoRuntimeException`.""" + cause = exc.__cause__ + if graph is not None and isinstance(cause, NameError): + name = getattr(cause, "name", None) + if name and name in graph.definitions: + return MarimoMissingRefError(name, cause) + return cause diff --git a/marimo/_runtime/executor.py b/marimo/_runtime/executor.py deleted file mode 100644 index 776e85201ef..00000000000 --- a/marimo/_runtime/executor.py +++ /dev/null @@ -1,292 +0,0 @@ -# Copyright 2026 Marimo. All rights reserved. -from __future__ import annotations - -import inspect -import re -from abc import ABC, abstractmethod -from copy import deepcopy -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any - -from marimo._ast.cell import CellImpl, _is_coroutine -from marimo._ast.variables import is_mangled_local -from marimo._entrypoints.registry import EntryPointRegistry -from marimo._runtime.copy import ( - CloneError, - ShallowCopy, - ZeroCopy, - shallow_copy, -) -from marimo._runtime.exceptions import ( - MarimoMissingRefError, - MarimoNameError, - MarimoRuntimeException, -) -from marimo._runtime.primitives import ( - CLONE_PRIMITIVES, - build_ref_predicate_for_primitives, - from_unclonable_module, - is_unclonable_type, -) - -if TYPE_CHECKING: - from marimo._runtime.dataflow import DirectedGraph - -_EXECUTOR_REGISTRY = EntryPointRegistry[type["Executor"]]( - "marimo.cell.executor", -) - - -def get_executor( - config: ExecutionConfig, - registry: EntryPointRegistry[type[Executor]] = _EXECUTOR_REGISTRY, -) -> Executor: - """Get a code executor based on the execution configuration.""" - executors = registry.get_all() - - base: Executor = DefaultExecutor() - if config.is_strict: - base = StrictExecutor(base) - - for executor in executors: - base = executor(base) - return base - - -@dataclass -class ExecutionConfig: - """Configuration for cell execution.""" - - is_strict: bool = False - - -def _raise_name_error( - graph: DirectedGraph | None, name_error: NameError -) -> None: - if graph is None: - raise MarimoRuntimeException from name_error - (missing_name,) = re.findall(r"'([^']*)'", str(name_error)) - # Will miss "locals" by default since not in the graph defs. - if missing_name in graph.definitions: - raise MarimoRuntimeException from MarimoMissingRefError( - missing_name, name_error - ) - raise MarimoRuntimeException from name_error - - -class Executor(ABC): - def __init__(self, base: Executor | None = None) -> None: - self.base = base - - @abstractmethod - def execute_cell( - self, - cell: CellImpl, - glbls: dict[str, Any], - graph: DirectedGraph, - ) -> Any: - pass - - @abstractmethod - async def execute_cell_async( - self, - cell: CellImpl, - glbls: dict[str, Any], - graph: DirectedGraph, - ) -> Any: - pass - - -class DefaultExecutor(Executor): - async def execute_cell_async( - self, - cell: CellImpl, - glbls: dict[str, Any], - graph: DirectedGraph | None = None, - ) -> Any: - if cell.body is None: - return None - assert cell.last_expr is not None - try: - if _is_coroutine(cell.body): - await eval(cell.body, glbls) - else: - exec(cell.body, glbls) - - if _is_coroutine(cell.last_expr): - return await eval(cell.last_expr, glbls) - else: - return eval(cell.last_expr, glbls) - except NameError as e: - _raise_name_error(graph, e) - except (BaseException, Exception) as e: - # Raising from a BaseException will fold in the stacktrace prior - # to execution - raise MarimoRuntimeException from e - - def execute_cell( - self, - cell: CellImpl, - glbls: dict[str, Any], - graph: DirectedGraph | None = None, - ) -> Any: - try: - if cell.body is None: - return None - assert cell.last_expr is not None - - exec(cell.body, glbls) - return eval(cell.last_expr, glbls) - except NameError as e: - _raise_name_error(graph, e) - except (BaseException, Exception) as e: - raise MarimoRuntimeException from e - - -class StrictExecutor(Executor): - async def execute_cell_async( - self, - cell: CellImpl, - glbls: dict[str, Any], - graph: DirectedGraph, - ) -> Any: - assert self.base is not None, "Invalid executor composition." - - # Manage globals and references, but refers to the default beyond that. - refs = graph.get_transitive_references( - cell.refs, - predicate=build_ref_predicate_for_primitives( - glbls, CLONE_PRIMITIVES - ), - ) - backup = self._sanitize_inputs(cell, refs, glbls) - try: - response = await self.base.execute_cell_async(cell, glbls, graph) - finally: - # Restore globals from backup and backfill outputs - self._update_outputs(cell, glbls, backup) - return response - - def execute_cell( - self, - cell: CellImpl, - glbls: dict[str, Any], - graph: DirectedGraph, - ) -> Any: - assert self.base is not None, "Invalid executor composition." - - refs = graph.get_transitive_references( - cell.refs, - predicate=build_ref_predicate_for_primitives( - glbls, CLONE_PRIMITIVES - ), - ) - backup = self._sanitize_inputs(cell, refs, glbls) - try: - response = self.base.execute_cell(cell, glbls, graph) - finally: - self._update_outputs(cell, glbls, backup) - return response - - def _sanitize_inputs( - self, - cell: CellImpl, - refs: set[str], - glbls: dict[str, Any], - ) -> dict[str, Any]: - # Some attributes should remain global - lcls = { - key: glbls[key] - for key in [ - "_MicropipFinder", - "_MicropipLoader", - "__builtin__", - "__doc__", - "__file__", - "__marimo__", - "__name__", - "__package__", - "__loader__", - "__spec__", - "input", - ] - if key in glbls - } - - for ref in refs: - if ref in glbls: - if ( - isinstance( - glbls[ref], - (ZeroCopy), - ) - or inspect.ismodule(glbls[ref]) - or inspect.isfunction(glbls[ref]) - or from_unclonable_module(glbls[ref]) - or is_unclonable_type(glbls[ref]) - ): - lcls[ref] = glbls[ref] - elif isinstance(glbls[ref], ShallowCopy): - lcls[ref] = shallow_copy(glbls[ref]) - else: - try: - lcls[ref] = deepcopy(glbls[ref]) - except TypeError as e: - raise CloneError( - f"Could not clone reference `{ref}` of type " - f"{getattr(glbls[ref], '__module__', '')}. " - f"{glbls[ref].__class__.__name__} " - "try wrapping the object in a `zero_copy` " - "call. If this is a common object type, consider " - "making an issue on the marimo GitHub " - "repository to never deepcopy." - ) from e - elif ref not in glbls["__builtins__"]: - if ref in cell.defs: - raise MarimoNameError( - f"name `{ref}` is referenced before definition.", ref - ) - raise MarimoMissingRefError(ref) - - # NOTE: Execution expects the globals dictionary by memory reference, - # so we need to clear it and update it with the sanitized locals, - # returning a backup of the original globals for later restoration. - # This must be performed at the end of the function to ensure valid - # state in case of failure. - backup = {**glbls} - glbls.clear() - glbls.update(lcls) - return backup - - def _update_outputs( - self, - cell: CellImpl, - glbls: dict[str, Any], - backup: dict[str, Any], - ) -> None: - # NOTE: After execution, restore global state and update outputs. - lcls = {**glbls} - glbls.clear() - glbls.update(backup) - - defs = cell.defs - for df in defs: - if df in lcls: - # Overwrite will delete the reference. - # Weak copy holds on with references. - glbls[df] = lcls[df] - # Captures the case where a variable was previously defined by the - # cell but this most recent run did not define it. The value is now - # stale and needs to be flushed. - elif df in glbls: - del glbls[df] - - # Flush all private variables from memory - for df in backup: - if is_mangled_local(df, cell.cell_id): - del glbls[df] - - # Now repopulate all private variables. - for df in lcls: - if is_mangled_local(df, cell.cell_id): - glbls[df] = lcls[df] diff --git a/marimo/_runtime/executor/__init__.py b/marimo/_runtime/executor/__init__.py new file mode 100644 index 00000000000..1428f9a4847 --- /dev/null +++ b/marimo/_runtime/executor/__init__.py @@ -0,0 +1,36 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Cell execution runtime. + +ExecutionLifecycle: Manages global information prior to execution +Executor: Runs the execution +Evaluator: Composes lifecycles and the executor. + +""" + +from __future__ import annotations + +from marimo._runtime.executor.evaluator import ( + _EXECUTOR_REGISTRY, + Evaluator, + resolve_executor, +) +from marimo._runtime.executor.executor import ( + DefaultExecutor, + Executor, +) +from marimo._runtime.executor.lifecycles import ( + ExecutionLifecycle, + Skip, +) +from marimo._runtime.executor.lifecycles.strict import StrictLifecycle + +__all__ = [ + "_EXECUTOR_REGISTRY", + "DefaultExecutor", + "Evaluator", + "ExecutionLifecycle", + "Executor", + "Skip", + "StrictLifecycle", + "resolve_executor", +] diff --git a/marimo/_runtime/executor/evaluator.py b/marimo/_runtime/executor/evaluator.py new file mode 100644 index 00000000000..7a749b3d6b1 --- /dev/null +++ b/marimo/_runtime/executor/evaluator.py @@ -0,0 +1,123 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Evaluator — composes ExecutionLifecycles around an Executor.""" + +from __future__ import annotations + +from dataclasses import replace +from typing import TYPE_CHECKING, Any + +from marimo import _loggers +from marimo._entrypoints.registry import EntryPointRegistry +from marimo._runtime.executor.executor import DefaultExecutor, Executor +from marimo._runtime.executor.lifecycles import ExecutionLifecycle, Skip +from marimo._runtime.runner.result import RunResult + +if TYPE_CHECKING: + from collections.abc import Callable + + from marimo._ast.cell import CellImpl + + +LOGGER = _loggers.marimo_logger() + + +class Evaluator: + """Compose ExecutionLifecycles around an Executor. Owns `evaluate`.""" + + def __init__( + self, + executor: Executor, + lifecycles: list[ExecutionLifecycle] | None = None, + ) -> None: + self.executor = executor + self.lifecycles: list[ExecutionLifecycle] = lifecycles or [] + + async def evaluate( + self, cell: CellImpl, glbls: dict[str, Any] + ) -> RunResult: + """Setup lifecycles, execute, and teardown lifecycles.""" + completed: list[ExecutionLifecycle] = [] + skip: Skip | None = None + result: RunResult | None = None + + try: + for life in self.lifecycles: + decision = life.setup(cell, glbls) + completed.append(life) + if isinstance(decision, Skip): + skip = decision + break + except BaseException as e: + result = RunResult(output=None, exception=e) + + if result is None: + if skip is not None and skip.result is not None: + # Lifecycle supplied a complete RunResult — preserve all + # fields (output, accumulated_output, exception). + result = skip.result + elif skip is not None: + result = RunResult(output=None, exception=None) + else: + try: + value = await self.executor.execute_cell_async(cell, glbls) + result = RunResult(output=value, exception=None) + except BaseException as e: + result = RunResult(output=None, exception=e) + + teardown_exc: BaseException | None = None + for life in reversed(completed): + try: + life.teardown(cell, glbls, result) + except BaseException as e: + if teardown_exc is not None: + LOGGER.error( + "teardown exception overridden by later teardown: %s", + teardown_exc, + ) + teardown_exc = e + + if teardown_exc is not None: + if result.exception is not None: + LOGGER.warning( + "body exception suppressed by teardown raise: %s", + result.exception, + ) + return replace(result, exception=teardown_exc) + return result + + +# Public entry-point registry for plugin-loaded Executors. Registered +# values are **factories** (`Callable[[], Executor]`); the kernel +# calls the factory once to get an instance and hands it to an +# `Evaluator`. +_EXECUTOR_REGISTRY: EntryPointRegistry[Callable[[], Executor]] = ( + EntryPointRegistry("marimo.cell.executor") +) + + +def resolve_executor() -> Executor: + """Return the registered executor factory's product, or `DefaultExecutor`. + + NB. Only one factory is loaded, with others logged for visibility. + """ + names = _EXECUTOR_REGISTRY.names() + if not names: + return DefaultExecutor() + name, *additional = names + if additional: + LOGGER.warning( + "multiple `marimo.cell.executor` factories registered; " + "using %r and ignoring %d other(s)", + name, + len(additional), + ) + try: + return _EXECUTOR_REGISTRY.get(name)() + except Exception as e: + LOGGER.warning( + "marimo.cell.executor factory %r failed to construct: %s; " + "falling back to `DefaultExecutor`.", + name, + e, + ) + return DefaultExecutor() diff --git a/marimo/_runtime/executor/executor.py b/marimo/_runtime/executor/executor.py new file mode 100644 index 00000000000..fa496ed96c9 --- /dev/null +++ b/marimo/_runtime/executor/executor.py @@ -0,0 +1,62 @@ +# Copyright 2026 Marimo. All rights reserved. +"""An Executor executes a single cell's body.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Protocol + +from marimo._ast.cell import _is_coroutine +from marimo._runtime.exceptions import MarimoRuntimeException + +if TYPE_CHECKING: + from marimo._ast.cell import CellImpl + + +class Executor(Protocol): + """Body strategy: how to run a cell's body.""" + + name: str + + def execute_cell(self, cell: CellImpl, glbls: dict[str, Any]) -> Any: ... + + async def execute_cell_async( + self, cell: CellImpl, glbls: dict[str, Any] + ) -> Any: ... + + +class DefaultExecutor: + name = "default" + + def execute_cell(self, cell: CellImpl, glbls: dict[str, Any]) -> Any: + if cell.body is None: + return None + assert cell.last_expr is not None + if _is_coroutine(cell.body) or _is_coroutine(cell.last_expr): + raise RuntimeError( + "A coroutine cell cannot be run synchronously. Use " + "execute_cell_async() instead." + ) + try: + exec(cell.body, glbls) + return eval(cell.last_expr, glbls) + except BaseException as e: + # Raising from BaseException folds in the stack trace prior + # to execution. + raise MarimoRuntimeException from e + + async def execute_cell_async( + self, cell: CellImpl, glbls: dict[str, Any] + ) -> Any: + if cell.body is None: + return None + assert cell.last_expr is not None + try: + if _is_coroutine(cell.body): + await eval(cell.body, glbls) + else: + exec(cell.body, glbls) + if _is_coroutine(cell.last_expr): + return await eval(cell.last_expr, glbls) + return eval(cell.last_expr, glbls) + except BaseException as e: + raise MarimoRuntimeException from e diff --git a/marimo/_runtime/executor/lifecycles/__init__.py b/marimo/_runtime/executor/lifecycles/__init__.py new file mode 100644 index 00000000000..68674e9d1bc --- /dev/null +++ b/marimo/_runtime/executor/lifecycles/__init__.py @@ -0,0 +1,40 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Per-cell setup/teardown lifecycles owned by the Evaluator.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Protocol + +from marimo._runtime.runner.result import RunResult + +if TYPE_CHECKING: + from marimo._ast.cell import CellImpl + + +@dataclass +class Skip: + """Returned from `ExecutionLifecycle.setup` to short-circuit the body. + + `result` is the cell's `RunResult`; lifecycles can use this to + inject a cache hit or a pre-failed result without running the body. + `None` means the lifecycle wants to skip but has no associated run + (output stays `None`, no exception). + """ + + result: RunResult | None = None + + +class ExecutionLifecycle(Protocol): + """Per-cell setup/teardown wrap.""" + + name: str + + def setup(self, cell: CellImpl, glbls: dict[str, Any]) -> Skip | None: ... + + def teardown( + self, + cell: CellImpl, + glbls: dict[str, Any], + run_result: RunResult, + ) -> None: ... diff --git a/marimo/_runtime/executor/lifecycles/strict.py b/marimo/_runtime/executor/lifecycles/strict.py new file mode 100644 index 00000000000..a5621c81d91 --- /dev/null +++ b/marimo/_runtime/executor/lifecycles/strict.py @@ -0,0 +1,173 @@ +# Copyright 2026 Marimo. All rights reserved. +"""StrictLifecycle provides globals sanitization around the body.""" + +from __future__ import annotations + +import inspect +from copy import deepcopy +from typing import TYPE_CHECKING, Any + +from marimo._ast.variables import is_mangled_local, unmangle_local +from marimo._runtime.copy import ( + CloneError, + ShallowCopy, + ZeroCopy, + shallow_copy, +) +from marimo._runtime.executor.lifecycles import Skip +from marimo._runtime.primitives import ( + CLONE_PRIMITIVES, + build_ref_predicate_for_primitives, + from_unclonable_module, + is_unclonable_type, +) +from marimo._runtime.runner.result import RunResult + +if TYPE_CHECKING: + from marimo._ast.cell import CellImpl + from marimo._messaging.errors import MarimoStrictExecutionError + from marimo._runtime.dataflow import DirectedGraph + from marimo._types.ids import CellId_t + + +# Attributes that should remain visible inside a strict-mode cell body +# even when the rest of the globals dict is replaced by the sanitized +# transitive references. +_PRESERVED_GLOBALS: frozenset[str] = frozenset( + { + "_MicropipFinder", + "_MicropipLoader", + "__builtin__", + "__doc__", + "__file__", + "__marimo__", + "__name__", + "__package__", + "__loader__", + "__spec__", + "input", + } +) + + +class StrictLifecycle: + """Sanitize globals before exec; restore them in teardown.""" + + name = "strict" + + def __init__(self, graph: DirectedGraph) -> None: + self._graph = graph + # Per-cell setup→teardown backup. Keyed by cell_id. + self._backups: dict[CellId_t, dict[str, Any]] = {} + + def setup(self, cell: CellImpl, glbls: dict[str, Any]) -> Skip | None: + refs = self._graph.get_transitive_references( + cell.refs, + predicate=build_ref_predicate_for_primitives( + glbls, CLONE_PRIMITIVES + ), + ) + + lcls = {key: glbls[key] for key in _PRESERVED_GLOBALS if key in glbls} + + for ref in refs: + if ref in glbls: + lcls[ref] = self._sanitize_ref(ref, glbls[ref]) + elif ref not in glbls["__builtins__"]: + err = self._build_strict_error(cell, ref) + return Skip(result=RunResult(output=err, exception=err)) + + # Execution expects the globals dictionary by memory reference, + # so clear it and update with the sanitized locals, stashing a + # backup for teardown. + backup = {**glbls} + glbls.clear() + glbls.update(lcls) + self._backups[cell.cell_id] = backup + return None + + def _build_strict_error( + self, cell: CellImpl, ref: str + ) -> MarimoStrictExecutionError: + """Produce the user-facing error for an unresolved ref in setup.""" + from marimo._messaging.errors import MarimoStrictExecutionError + + if ref in cell.defs: + return MarimoStrictExecutionError( + f"name `{ref}` is referenced before definition.", ref, None + ) + blamed_cell: CellId_t | None = None + try: + (blamed_cell, *_) = self._graph.get_defining_cells(ref) + except (KeyError, ValueError): + ref, var_cell_id = unmangle_local(ref) + if var_cell_id: + blamed_cell = var_cell_id + return MarimoStrictExecutionError( + f"marimo was unable to resolve a reference to `{ref}` in cell : ", + ref, + blamed_cell, + ) + + def _sanitize_ref(self, name: str, value: Any) -> Any: + if ( + isinstance(value, ZeroCopy) + or inspect.ismodule(value) + or inspect.isfunction(value) + or from_unclonable_module(value) + or is_unclonable_type(value) + ): + return value + if isinstance(value, ShallowCopy): + return shallow_copy(value) + try: + return deepcopy(value) + except TypeError as e: + raise CloneError( + f"Could not clone reference `{name}` of type " + f"{getattr(value, '__module__', '')}. " + f"{value.__class__.__name__} " + "try wrapping the object in a `zero_copy` " + "call. If this is a common object type, consider " + "making an issue on the marimo GitHub " + "repository to never deepcopy." + ) from e + + def teardown( + self, + cell: CellImpl, + glbls: dict[str, Any], + run_result: RunResult, # noqa: ARG002 + ) -> None: + backup = self._backups.pop(cell.cell_id, None) + if backup is None: + # Setup didn't complete for this cell (raised before stashing + # the backup, or a Skip earlier in the chain meant setup + # never ran). Nothing to restore. + return + + # Restore the pre-execution globals, then re-apply the cell's + # new defs over top. + lcls = {**glbls} + glbls.clear() + glbls.update(backup) + + defs = cell.defs + for df in defs: + if df in lcls: + glbls[df] = lcls[df] + elif df in glbls: + # Previously defined by this cell, not redefined this + # run — stale, flush it. + del glbls[df] + + # Flush all private variables for this cell from the restored + # backup. + for df in backup: + if is_mangled_local(df, cell.cell_id): + del glbls[df] + + # Repopulate this cell's private variables. + for df in lcls: + if is_mangled_local(df, cell.cell_id): + glbls[df] = lcls[df] diff --git a/marimo/_runtime/kernel_request_handlers.py b/marimo/_runtime/kernel_request_handlers.py index d271b75c217..3e5ec6473e6 100644 --- a/marimo/_runtime/kernel_request_handlers.py +++ b/marimo/_runtime/kernel_request_handlers.py @@ -107,8 +107,8 @@ async def _handle_execute_scratchpad( ): await self._kernel.run_scratchpad(request.code) finally: - # Always emit completion so a waiting ``ScratchCellListener`` - # doesn't block forever if ``run_scratchpad`` raises. + # Always emit completion so a waiting `ScratchCellListener` + # doesn't block forever if `run_scratchpad` raises. broadcast_notification( CompletedRunNotification(run_id=request.run_id) ) diff --git a/marimo/_runtime/runner/cell_runner.py b/marimo/_runtime/runner/cell_runner.py index c878a890462..d2c48ce9c3a 100644 --- a/marimo/_runtime/runner/cell_runner.py +++ b/marimo/_runtime/runner/cell_runner.py @@ -8,8 +8,6 @@ import signal import threading import traceback -from collections import deque -from dataclasses import dataclass from pathlib import Path from types import TracebackType from typing import TYPE_CHECKING, Any @@ -22,7 +20,6 @@ from marimo._messaging.errors import ( MarimoExceptionRaisedError, MarimoSQLError, - MarimoStrictExecutionError, UnknownError, ) from marimo._messaging.tracebacks import write_traceback @@ -31,12 +28,14 @@ from marimo._runtime.control_flow import MarimoInterrupt, MarimoStopError from marimo._runtime.exceptions import ( MarimoMissingRefError, - MarimoNameError, MarimoRuntimeException, + unwrap_user_exception, ) from marimo._runtime.executor import ( - ExecutionConfig, - get_executor, + Evaluator, + ExecutionLifecycle, + StrictLifecycle, + resolve_executor, ) from marimo._runtime.marimo_pdb import MarimoPdb from marimo._runtime.runner.hook_context import ( @@ -44,6 +43,8 @@ ExceptionOrError, ExecutionContextManager, ) +from marimo._runtime.runner.result import RunResult +from marimo._runtime.runner.scheduler import SequentialScheduler from marimo._sql.error_utils import ( create_sql_error_from_exception, is_sql_parse_error, @@ -53,6 +54,7 @@ LOGGER = marimo_logger() if TYPE_CHECKING: + from collections import deque from collections.abc import Iterator from marimo._runtime.runner.hooks import NotebookCellHooks @@ -83,22 +85,7 @@ def cell_filename(cell_id: CellId_t) -> str: return f"" -@dataclass -class RunResult: - # Raw output of cell: last expression - output: Any - # Exception raised by cell, if any - # - # TODO(akshayka): Exceptions and "Errors" (most of which are at parse time - # and can't be encountered by the runner) shouldn't be packed into a single - # field. - exception: ExceptionOrError | None - # Accumulated output: via imperative mo.output.append() - accumulated_output: Any = None - - def success(self) -> bool: - """Whether the cell expected successfully""" - return self.exception is None +__all__ = ["RunResult", "Runner", "cell_filename", "should_show_traceback"] def should_show_traceback( @@ -137,9 +124,6 @@ def __init__( self.graph = graph self.debugger = debugger self.excluded_cells = excluded_cells or set() - self._executor = get_executor( - ExecutionConfig(is_strict=execution_type == "strict") - ) self.execution_context = execution_context self._hooks = hooks self.user_config = user_config @@ -155,27 +139,30 @@ def __init__( # so that they can be transitioned out of error if a future # run request repairs the graph self.roots = roots - self.cells_to_run: deque[CellId_t] = deque( - Runner.compute_cells_to_run( - self.graph, - self.roots, - self.excluded_cells, - self.execution_mode, - ) + cells_to_run_list = Runner.compute_cells_to_run( + self.graph, + self.roots, + self.excluded_cells, + self.execution_mode, ) - # tracks cancelled cells: raising cell -> descendants, with O(1) lookup - self.cancelled_cells = CancelledCells() - # whether the runner has been interrupted - self.interrupted = False + self._scheduler = SequentialScheduler(cells_to_run_list, self.graph) + # mapping from cell_id to exception it raised self.exceptions: dict[CellId_t, ExceptionOrError] = {} - # each cell's position in the run queue + # each cell's position in the original run queue self._run_position = { - cell_id: index for index, cell_id in enumerate(self.cells_to_run) + cell_id: index for index, cell_id in enumerate(cells_to_run_list) } + lifecycles: list[ExecutionLifecycle] = [] + if execution_type == "strict": + lifecycles.append(StrictLifecycle(self.graph)) + self._evaluator = Evaluator( + executor=resolve_executor(), lifecycles=lifecycles + ) + @staticmethod def compute_cells_to_run( graph: dataflow.DirectedGraph, @@ -266,24 +253,33 @@ def handle_sigint(*_: Any) -> None: # restore the previous sigint handler signal.signal(signal.SIGINT, save_sigint) + @property + def cells_to_run(self) -> deque[CellId_t]: + return self._scheduler.cells_to_run + + @property + def cancelled_cells(self) -> CancelledCells: + return self._scheduler.cancelled_cells + + @property + def interrupted(self) -> bool: + return self._scheduler.interrupted + + @interrupted.setter + def interrupted(self, value: bool) -> None: + self._scheduler.interrupted = value + def cancel(self, cell_id: CellId_t) -> None: """Mark a cell (and its descendants) as cancelled.""" - descendants = { - cid - for cid in dataflow.transitive_closure(self.graph, {cell_id}) - if cid in self.cells_to_run - } - self.cancelled_cells.add(cell_id, descendants) - for cid in descendants: - self.graph.cells[cid].set_run_result_status("cancelled") + self._scheduler.cancel(cell_id) def cancelled(self, cell_id: CellId_t) -> bool: """Return whether a cell has been cancelled.""" - return cell_id in self.cancelled_cells + return self._scheduler.cancelled(cell_id) def pending(self) -> bool: """Whether there are more cells to run.""" - return not self.interrupted and len(self.cells_to_run) > 0 + return self._scheduler.pending() def _get_run_position(self, cell_id: CellId_t) -> int | None: """Position in the original run queue""" @@ -357,7 +353,7 @@ def resolve_state_updates( def pop_cell(self) -> CellId_t: """Get the next cell to run.""" - return self.cells_to_run.popleft() + return self._scheduler.pop_cell() def _run_result_from_exception( self, @@ -461,65 +457,87 @@ async def run(self, cell_id: CellId_t) -> RunResult: self.debugger._last_traceback = None cell = self.graph.cells[cell_id] - run_result = None + # The Evaluator captures all body/lifecycle exceptions into the + # returned RunResult; cell_id-specific classification + side + # effects are applied below in `_finalize_run_result`. try: if cell.is_coroutine(): return_value_future = asyncio.ensure_future( - self._executor.execute_cell_async( - cell, - self.glbls, - self.graph, - ) + self._evaluator.evaluate(cell, self.glbls) ) if threading.current_thread() == threading.main_thread(): # edit mode: need to handle user interrupts with Runner._cancel_on_sigint(return_value_future): - return_value = await return_value_future + raw_result = await return_value_future else: # run mode: can't use signal.signal, not interruptible # by user anyway. - return_value = await return_value_future + raw_result = await return_value_future else: - return_value = self._executor.execute_cell( - cell, - self.glbls, - self.graph, - ) - run_result = RunResult(output=return_value, exception=None) - except asyncio.exceptions.CancelledError: - # User interrupt - # interrupt the entire runner - # Async cells can only be cancelled via a user interrupt - run_result = RunResult(output=None, exception=MarimoInterrupt()) - # Still provide a general traceback. - tmpio = io.StringIO() - traceback.print_exc(file=tmpio) - tmpio.seek(0) - write_traceback(tmpio.read()) - # Strict mode errors may also raise errors outside of execution. - except MarimoNameError as e: - self.cancel(cell_id) - strict_exception = MarimoStrictExecutionError(str(e), e.ref, None) - run_result = RunResult( - output=strict_exception, exception=strict_exception + raw_result = await self._evaluator.evaluate(cell, self.glbls) + run_result = self._finalize_run_result(raw_result, cell_id) + except BaseException: + # Defensive: an unexpected escape from the Evaluator or a bug + # in `_finalize_run_result` would otherwise tear down the + # runner loop. Degrade gracefully with an empty RunResult. + LOGGER.error( + """marimo encountered an internal error. + + marimo finished executing a cell, but did not produce + a run result. + + Please copy this message and paste it in a GitHub issue: + + https://github.com/marimo-team/marimo/issues + + Any additional context of what caused this error, such + as sample code to reproduce, will help us debug. + """ ) - except MarimoMissingRefError as e: - # In strict mode, marimo refuses to evaluate a cell if there are - # missing definitions. Since the cell hasn't run, this is a pre - # check error, but still mark descendants as cancelled. + run_result = RunResult(output=None, exception=None) + + # Mark as interrupted if the cell raised a MarimoInterrupt + # Set here since failed async can also trigger an Interrupt. + if isinstance(run_result.exception, MarimoInterrupt): + self.interrupted = True + + self._update_debugger_state(run_result, cell_id) + + if run_result.exception is not None: + self.exceptions[cell_id] = run_result.exception + + return run_result + + def _finalize_run_result( + self, raw_result: RunResult, cell_id: CellId_t + ) -> RunResult: + """Classify the Evaluator's RunResult and apply Runner side effects.""" + exc = raw_result.exception + if exc is None: + return raw_result + if not isinstance(exc, BaseException): + # No exception to handle. Cancel descendants and surface the payload + # as-is. self.cancel(cell_id) - ref, blamed_cell = self._get_blamed_cell(e) - name_output = MarimoStrictExecutionError( - "marimo was unable to resolve " - f"a reference to `{ref}` in cell : ", - ref, - blamed_cell, + return raw_result + + if isinstance(exc, asyncio.exceptions.CancelledError): + # User interrupt — async cells can only be cancelled via SIGINT. + # Surface as MarimoInterrupt so `run` flips `self.interrupted`. + tmpio = io.StringIO() + traceback.print_exception( + type(exc), exc, exc.__traceback__, file=tmpio ) - run_result = RunResult(output=name_output, exception=name_output) + tmpio.seek(0) + write_traceback(tmpio.read()) + return RunResult(output=None, exception=MarimoInterrupt()) + # Should cover all cell runtime exceptions. - except MarimoRuntimeException as e: - output: Any = None - unwrapped_exception: BaseException | None = e.__cause__ + if isinstance(exc, MarimoRuntimeException): + # Unwrap the user exception and upgrade a raw NameError to + # MarimoMissingRefError when the missing name is defined + # elsewhere in the graph. + unwrapped_exception = unwrap_user_exception(exc, self.graph) # Interrupts are sometimes sent multiple times; in particular, # it appears that polars forwards interrupts, so interrupting @@ -531,7 +549,7 @@ async def run(self, cell_id: CellId_t) -> RunResult: try: run_result, unwrapped_exception = ( self._run_result_from_exception( - output, unwrapped_exception, cell_id + None, unwrapped_exception, cell_id ) ) except KeyboardInterrupt: @@ -541,8 +559,8 @@ async def run(self, cell_id: CellId_t) -> RunResult: # Exceptions trigger cancellation of descendants. # - # TODO(akshayka): Another interrupt will end up interrupting - # this call as well, so this should be lifted out of `run`. + # TODO(akshayka): A SIGINT during cancel() can interrupt this + # call, so this should be lifted to a non-interruptible path. self.cancel(cell_id) if should_show_traceback(run_result.exception): @@ -567,81 +585,54 @@ async def run(self, cell_id: CellId_t) -> RunResult: ) tmpio.seek(0) write_traceback(tmpio.read()) - except BaseException as e: - # Check that MarimoRuntimeException has't already handled the - # error, since exceptions fall through except blocks. - # If not, then this is an unexpected error. - if not isinstance(e, MarimoRuntimeException): - LOGGER.error(f"Unexpected error type: {e}") - self.cancel(cell_id) - unknown_error = UnknownError(f"{e}") - run_result = RunResult(output=None, exception=unknown_error) - tmpio = io.StringIO() - traceback.print_exc(file=tmpio) - tmpio.seek(0) - write_traceback(tmpio.read()) - finally: - # TODO(akshayka): some of this logic should be lifted out - # of `run`, (in particular to where execution context is not set) - # so that it is not interruptible - if run_result is None: - LOGGER.error( - """marimo encountered an internal error. - - marimo finished executing a cell, but did not produce - a run result. - - Please copy this message and paste it in a GitHub issue: - - https://github.com/marimo-team/marimo/issues - - Any additional context of what caused this error, such - as sample code to reproduce, will help us debug. - """ - ) - run_result = RunResult(output=None, exception=None) - - # Mark as interrupted if the cell raised a MarimoInterrupt - # Set here since failed async can also trigger an Interrupt. - if isinstance(run_result.exception, MarimoInterrupt): - self.interrupted = True - - # if a debugger is active, force it to skip past marimo code. - try: - # Bdb defines the botframe attribute and sets it to non-None - # when it starts up - if self.debugger is not None: - if ( - hasattr(self.debugger, "botframe") - and self.debugger.botframe is not None - ): - self.debugger.set_continue() - # Hold on to this information for debugging postmortem etc. - if run_result.exception is not None and hasattr( - run_result.exception, "__traceback__" - ): - tb = run_result.exception.__traceback__ - if isinstance(tb, TracebackType): - self.debugger._last_traceback = tb - self.debugger._last_tracebacks[cell_id] = tb - except Exception as debugger_error: - # This has never been hit, but just in case -- don't want - # to crash the kernel. - LOGGER.error( - """Internal marimo error. Please copy this message and - paste it in a GitHub issue: - - https://github.com/marimo-team/marimo/issues - - An exception raised attempting to continue debugger (%s). - """, - str(debugger_error), - ) - - if run_result.exception is not None: - self.exceptions[cell_id] = run_result.exception - - return run_result + return run_result + + # Anything else escaping the Evaluator is unexpected. + LOGGER.error(f"Unexpected error type: {exc}") + self.cancel(cell_id) + tmpio = io.StringIO() + traceback.print_exception( + type(exc), exc, exc.__traceback__, file=tmpio + ) + tmpio.seek(0) + write_traceback(tmpio.read()) + return RunResult(output=None, exception=UnknownError(f"{exc}")) + + def _update_debugger_state( + self, run_result: RunResult, cell_id: CellId_t + ) -> None: + """Skip marimo frames in the debugger and stash the cell's traceback.""" + # if a debugger is active, force it to skip past marimo code. + try: + # Bdb defines the botframe attribute and sets it to non-None + # when it starts up + if self.debugger is not None: + if ( + hasattr(self.debugger, "botframe") + and self.debugger.botframe is not None + ): + self.debugger.set_continue() + # Hold on to this information for debugging postmortem etc. + if run_result.exception is not None and hasattr( + run_result.exception, "__traceback__" + ): + tb = run_result.exception.__traceback__ + if isinstance(tb, TracebackType): + self.debugger._last_traceback = tb + self.debugger._last_tracebacks[cell_id] = tb + except Exception as debugger_error: + # This has never been hit, but just in case -- don't want + # to crash the kernel. + LOGGER.error( + """Internal marimo error. Please copy this message and + paste it in a GitHub issue: + + https://github.com/marimo-team/marimo/issues + + An exception raised attempting to continue debugger (%s). + """, + str(debugger_error), + ) def _get_blamed_cell( self, e: MarimoMissingRefError diff --git a/marimo/_runtime/runner/result.py b/marimo/_runtime/runner/result.py new file mode 100644 index 00000000000..0be0f17a3b3 --- /dev/null +++ b/marimo/_runtime/runner/result.py @@ -0,0 +1,28 @@ +# Copyright 2026 Marimo. All rights reserved. +"""The value type a cell produces when it runs.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from marimo._runtime.runner.hook_context import ExceptionOrError + + +@dataclass +class RunResult: + # Raw output of cell: last expression + output: Any + # Exception raised by cell, if any + # + # TODO(akshayka): Exceptions and "Errors" (most of which are at parse time + # and can't be encountered by the runner) shouldn't be packed into a single + # field. + exception: ExceptionOrError | None + # Accumulated output: via imperative mo.output.append() + accumulated_output: Any = None + + def success(self) -> bool: + """Whether the cell executed successfully""" + return self.exception is None diff --git a/marimo/_runtime/runner/scheduler.py b/marimo/_runtime/runner/scheduler.py new file mode 100644 index 00000000000..6205d2670dc --- /dev/null +++ b/marimo/_runtime/runner/scheduler.py @@ -0,0 +1,89 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Scheduler owns the cell queue and cancellation state""" + +from __future__ import annotations + +from collections import deque +from typing import TYPE_CHECKING, Protocol + +from marimo._runtime import dataflow +from marimo._runtime.runner.hook_context import CancelledCells + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator, Sequence + + from marimo._runtime.dataflow import DirectedGraph + from marimo._types.ids import CellId_t + + +class Scheduler(Protocol): + """Cell queue + cancellation. Surface for future scheduler types.""" + + def pending(self) -> bool: ... + def pop_cell(self) -> CellId_t: ... + def cancel(self, cell_id: CellId_t) -> None: ... + def cancelled(self, cell_id: CellId_t) -> bool: ... + def batch( + self, cell_ids: Iterable[CellId_t] + ) -> Iterator[list[CellId_t]]: ... + + +class SequentialScheduler: + """Single-threaded FIFO queue + cancellation.""" + + def __init__( + self, + cells_to_run: Sequence[CellId_t], + graph: DirectedGraph, + ) -> None: + self._cells_to_run: deque[CellId_t] = deque(cells_to_run) + self._cancelled = CancelledCells() + self._graph = graph + self._interrupted = False + + def pending(self) -> bool: + return not self._interrupted and len(self._cells_to_run) > 0 + + def pop_cell(self) -> CellId_t: + return self._cells_to_run.popleft() + + def batch(self, cell_ids: Iterable[CellId_t]) -> Iterator[list[CellId_t]]: + """Yield batches of cells to execute. + + Sequential default: one cell per batch. + """ + self._cells_to_run.clear() + self._cells_to_run.extend(cell_ids) + while self._cells_to_run and not self._interrupted: + yield [self._cells_to_run.popleft()] + + def cancel(self, cell_id: CellId_t) -> None: + """Mark a cell and its descendants as cancelled.""" + descendants = { + cid + for cid in dataflow.transitive_closure(self._graph, {cell_id}) + if cid in self._cells_to_run + } + self._cancelled.add(cell_id, descendants) + for cid in descendants: + self._graph.cells[cid].set_run_result_status("cancelled") + + def cancelled(self, cell_id: CellId_t) -> bool: + return cell_id in self._cancelled + + @property + def interrupted(self) -> bool: + return self._interrupted + + @interrupted.setter + def interrupted(self, value: bool) -> None: + self._interrupted = value + + @property + def cancelled_cells(self) -> CancelledCells: + return self._cancelled + + @property + def cells_to_run(self) -> deque[CellId_t]: + """The live queue. Mutates as cells are popped.""" + return self._cells_to_run diff --git a/marimo/_runtime/runtime.py b/marimo/_runtime/runtime.py index 3fc225da51e..9baef25d88e 100644 --- a/marimo/_runtime/runtime.py +++ b/marimo/_runtime/runtime.py @@ -479,8 +479,6 @@ def __init__( # timestamp, to save the user from having to spam the interrupt button self.last_interrupt_timestamp: float | None = None - # Named attributes exist because internal kernel paths (run hooks, - # script metadata) and tests reach into specific callbacks directly. self.secrets_callbacks = SecretsCallbacks(self) self.datasets_callbacks = DatasetCallbacks(self) self.packages_callbacks = PackagesCallbacks(self) diff --git a/tests/_runtime/test_executor_evaluator.py b/tests/_runtime/test_executor_evaluator.py new file mode 100644 index 00000000000..fb24fe5148d --- /dev/null +++ b/tests/_runtime/test_executor_evaluator.py @@ -0,0 +1,322 @@ +# Copyright 2026 Marimo. All rights reserved. +# Stub classes here conform to the ExecutionLifecycle / Executor +# Protocols, so they take `cell` / `glbls` even when the test body +# doesn't use them. +# ruff: noqa: ARG001, ARG002 +"""Tests for the Evaluator + ExecutionLifecycle composition. + +Covers setup chain order, Skip termination, teardown reverse order, +teardown visibility of body exceptions, teardown-wins semantics on +double raise, and KeyboardInterrupt propagation through teardown. +""" + +from __future__ import annotations + +import asyncio +from typing import Any + +from marimo._runtime.exceptions import MarimoRuntimeException +from marimo._runtime.executor import ( + DefaultExecutor, + Evaluator, + ExecutionLifecycle, + Skip, +) +from marimo._runtime.runner.result import RunResult + + +class _Recorder: + """Lifecycle that records setup/teardown calls into a shared log.""" + + def __init__( + self, + log: list[str], + tag: str, + skip: Skip | None = None, + setup_raises: BaseException | None = None, + teardown_raises: BaseException | None = None, + ) -> None: + self.name = f"recorder-{tag}" + self._log = log + self._tag = tag + self._skip = skip + self._setup_raises = setup_raises + self._teardown_raises = teardown_raises + self.last_run_result: Any = None + + def setup(self, cell: Any, glbls: dict[str, Any]) -> Skip | None: + self._log.append(f"setup:{self._tag}") + if self._setup_raises is not None: + raise self._setup_raises + return self._skip + + def teardown( + self, cell: Any, glbls: dict[str, Any], run_result: Any + ) -> None: + self._log.append(f"teardown:{self._tag}") + self.last_run_result = run_result + if self._teardown_raises is not None: + raise self._teardown_raises + + +class _StubExecutor: + """Executor that runs a caller-provided body, no exec/eval.""" + + name = "stub" + + def __init__(self, body: Any) -> None: + self._body = body + + def execute_cell(self, cell: Any, glbls: dict[str, Any]) -> Any: + return self._body(cell, glbls) + + async def execute_cell_async( + self, cell: Any, glbls: dict[str, Any] + ) -> Any: + result = self._body(cell, glbls) + if asyncio.iscoroutine(result): + return await result + return result + + +async def test_skip_terminates_setup_chain_but_runs_completed_teardowns() -> ( + None +): + log: list[str] = [] + a = _Recorder( + log, "A", skip=Skip(result=RunResult(output=42, exception=None)) + ) + b = _Recorder(log, "B") + + body_ran = [False] + + def body(cell: Any, glbls: dict[str, Any]) -> Any: + body_ran[0] = True + return "should-not-see-this" + + ev = Evaluator(executor=_StubExecutor(body), lifecycles=[a, b]) + result = await ev.evaluate(cell=None, glbls={}) + + assert result.output == 42 + assert result.exception is None + assert body_ran[0] is False + # A setup ran, A teardown ran. B setup did NOT run, B teardown did + # NOT run. + assert log == ["setup:A", "teardown:A"] + + +async def test_skip_result_preserves_accumulated_output() -> None: + """`Skip(result=RunResult(...))` threads the entire RunResult + through teardown — `output`, `exception`, and + `accumulated_output` all survive, including any future fields + added to `RunResult`.""" + log: list[str] = [] + skip_result = RunResult( + output="cached", exception=None, accumulated_output="streamed" + ) + a = _Recorder(log, "A", skip=Skip(result=skip_result)) + + ev = Evaluator(executor=_StubExecutor(lambda *_: "unused"), lifecycles=[a]) + result = await ev.evaluate(cell=None, glbls={}) + + assert result.output == "cached" + assert result.accumulated_output == "streamed" + assert result.exception is None + # Teardown saw the same RunResult that came back out. + assert a.last_run_result is result + + +async def test_teardowns_fire_in_reverse_order_on_success() -> None: + log: list[str] = [] + a = _Recorder(log, "A") + b = _Recorder(log, "B") + c = _Recorder(log, "C") + + ev = Evaluator( + executor=_StubExecutor(lambda *_: "ok"), + lifecycles=[a, b, c], + ) + result = await ev.evaluate(cell=None, glbls={}) + + assert result.output == "ok" + assert result.exception is None + assert log == [ + "setup:A", + "setup:B", + "setup:C", + "teardown:C", + "teardown:B", + "teardown:A", + ] + + +async def test_teardown_sees_body_exception_via_run_result() -> None: + log: list[str] = [] + a = _Recorder(log, "A") + + def boom(cell: Any, glbls: dict[str, Any]) -> Any: + raise ValueError("body bomb") + + ev = Evaluator(executor=_StubExecutor(boom), lifecycles=[a]) + # The _StubExecutor doesn't wrap user exceptions; the body's + # ValueError lands directly in result.exception, and the teardown + # sees that same exception via run_result. + result = await ev.evaluate(cell=None, glbls={}) + + assert isinstance(result.exception, ValueError) + assert str(result.exception) == "body bomb" + assert a.last_run_result is not None + assert isinstance(a.last_run_result.exception, ValueError) + + +async def test_default_executor_wraps_user_exception_in_marimo_runtime() -> ( + None +): + """DefaultExecutor turns user exceptions into MarimoRuntimeException + with the user exception as __cause__. The teardown sees the wrapped + form, and the returned RunResult carries it as its exception.""" + from marimo._ast.cell import CellImpl + + log: list[str] = [] + a = _Recorder(log, "A") + + body_src = "raise ValueError('user bomb')" + + class _FakeCell: + cell_id = "0" + body = compile(body_src, "", "exec") + last_expr = compile("None", "", "eval") + + def is_coroutine(self) -> bool: + return False + + del CellImpl # silence unused-import + ev = Evaluator(executor=DefaultExecutor(), lifecycles=[a]) + result = await ev.evaluate(_FakeCell(), {}) # type: ignore[arg-type] + + assert isinstance(result.exception, MarimoRuntimeException) + assert isinstance(result.exception.__cause__, ValueError) + assert a.last_run_result is not None + # Teardown saw the wrapped exception, not the raw ValueError. + assert isinstance(a.last_run_result.exception, MarimoRuntimeException) + + +async def test_teardown_runs_for_completed_setups_when_later_setup_raises() -> ( + None +): + log: list[str] = [] + a = _Recorder(log, "A") + b = _Recorder(log, "B", setup_raises=RuntimeError("setup-B raised")) + c = _Recorder(log, "C") # never reached + + ev = Evaluator( + executor=_StubExecutor(lambda *_: "ok"), + lifecycles=[a, b, c], + ) + result = await ev.evaluate(cell=None, glbls={}) + + assert isinstance(result.exception, RuntimeError) + assert str(result.exception) == "setup-B raised" + # A.setup ran (completed), B.setup ran and raised, C.setup did not + # run. Teardowns run only for lifecycles whose setup *completed* + # without raising — so only A. B is not teardowned because its + # state was never established. + assert log == [ + "setup:A", + "setup:B", + "teardown:A", + ] + + +async def test_teardown_wins_on_double_raise() -> None: + log: list[str] = [] + a = _Recorder(log, "A", teardown_raises=RuntimeError("teardown wins")) + + def body(cell: Any, glbls: dict[str, Any]) -> Any: + raise ValueError("body loses") + + ev = Evaluator(executor=_StubExecutor(body), lifecycles=[a]) + result = await ev.evaluate(cell=None, glbls={}) + + # Teardown exception replaces body exception in the final RunResult. + assert isinstance(result.exception, RuntimeError) + assert str(result.exception) == "teardown wins" + + +async def test_keyboard_interrupt_captured_into_run_result() -> None: + log: list[str] = [] + a = _Recorder(log, "A") + + def body(cell: Any, glbls: dict[str, Any]) -> Any: + raise KeyboardInterrupt + + ev = Evaluator(executor=_StubExecutor(body), lifecycles=[a]) + result = await ev.evaluate(cell=None, glbls={}) + + # Teardown ran (state still cleaned up) even though body raised + # BaseException, and the interrupt is captured in the RunResult + # rather than propagating out of evaluate(). + assert log == ["setup:A", "teardown:A"] + assert isinstance(result.exception, KeyboardInterrupt) + assert isinstance(a.last_run_result.exception, KeyboardInterrupt) + + +def test_strict_lifecycle_round_trip() -> None: + """Globals restored to pre-state after StrictLifecycle setup + + teardown.""" + from marimo._runtime.executor.lifecycles.strict import StrictLifecycle + + class _FakeCell: + cell_id = "c0" + refs: set[str] = set() + defs: set[str] = set() + + class _FakeGraph: + def get_transitive_references( + self, refs: set[str], predicate: Any + ) -> set[str]: + return set() + + lifecycle = StrictLifecycle(graph=_FakeGraph()) # type: ignore[arg-type] + glbls: dict[str, Any] = { + "x": 1, + "y": [1, 2, 3], + "__builtins__": __builtins__, + } + pre = {k: v for k, v in glbls.items()} + + skip = lifecycle.setup(_FakeCell(), glbls) # type: ignore[arg-type] + assert skip is None + + # During setup, glbls should be the sanitized scope (subset). + assert "x" not in glbls # No refs declared → x is not in scope. + + lifecycle.teardown(_FakeCell(), glbls, run_result=None) # type: ignore[arg-type] + + # Globals restored — same values for unchanged keys. + assert glbls["x"] == pre["x"] + assert glbls["y"] == pre["y"] + + +def test_execution_lifecycle_protocol_conformance() -> None: + """A Protocol-conforming class without inheriting works as a + lifecycle.""" + log: list[str] = [] + + class _MyLifecycle: + name = "mine" + + def setup(self, cell: Any, glbls: dict[str, Any]) -> Skip | None: + log.append("setup") + return None + + def teardown( + self, cell: Any, glbls: dict[str, Any], run_result: Any + ) -> None: + log.append("teardown") + + # Static type check via assignment to a ExecutionLifecycle-typed + # variable. If the Protocol is misshaped, mypy/pyright complains + # here, not at runtime. + lifecycle: ExecutionLifecycle = _MyLifecycle() + assert lifecycle.name == "mine" diff --git a/tests/_runtime/test_scheduler.py b/tests/_runtime/test_scheduler.py new file mode 100644 index 00000000000..58a0a517ea6 --- /dev/null +++ b/tests/_runtime/test_scheduler.py @@ -0,0 +1,82 @@ +# Copyright 2026 Marimo. All rights reserved. +"""Queue + cancellation invariants for SequentialScheduler.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest.mock import MagicMock + +from marimo._runtime.runner.scheduler import SequentialScheduler +from marimo._types.ids import CellId_t + +if TYPE_CHECKING: + import pytest + + +def _empty_graph() -> MagicMock: + """A graph whose transitive_closure returns just the input cell.""" + g = MagicMock() + g.cells = {} + return g + + +def test_pending_and_pop_cell_fifo() -> None: + cells = [CellId_t("a"), CellId_t("b"), CellId_t("c")] + sched = SequentialScheduler(cells, graph=_empty_graph()) + + assert sched.pending() is True + assert sched.pop_cell() == "a" + assert sched.pop_cell() == "b" + assert sched.pop_cell() == "c" + assert sched.pending() is False + + +def test_interrupted_blocks_pending() -> None: + sched = SequentialScheduler([CellId_t("a")], graph=_empty_graph()) + + assert sched.pending() is True + sched.interrupted = True + assert sched.pending() is False + + +def test_cancel_marks_cancelled( + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Mock graph: transitive_closure returns just the cell itself, no + # descendants. Cell registered in graph.cells so set_run_result_status + # has a target. + g = MagicMock() + cid = CellId_t("a") + cell_mock = MagicMock() + g.cells = {cid: cell_mock} + + def fake_closure(graph: object, roots: set[CellId_t]) -> set[CellId_t]: + del graph + return set(roots) + + monkeypatch.setattr( + "marimo._runtime.dataflow.transitive_closure", fake_closure + ) + sched = SequentialScheduler([cid], graph=g) + assert sched.cancelled(cid) is False + sched.cancel(cid) + assert sched.cancelled(cid) is True + cell_mock.set_run_result_status.assert_called_with("cancelled") + + +def test_batch_yields_singletons() -> None: + sched = SequentialScheduler([], graph=_empty_graph()) + cells = [CellId_t("a"), CellId_t("b"), CellId_t("c")] + batches = list(sched.batch(cells)) + assert batches == [["a"], ["b"], ["c"]] + + +def test_batch_respects_interrupt() -> None: + sched = SequentialScheduler([], graph=_empty_graph()) + cells = [CellId_t("a"), CellId_t("b"), CellId_t("c")] + iterator = sched.batch(cells) + assert next(iterator) == ["a"] + sched.interrupted = True + # Generator stops once interrupted is set. + remaining = list(iterator) + assert remaining == [] diff --git a/tests/test_entrypoints.py b/tests/test_entrypoints.py index c39a29748a2..135301cc57d 100644 --- a/tests/test_entrypoints.py +++ b/tests/test_entrypoints.py @@ -1,12 +1,19 @@ +from __future__ import annotations + import os -from typing import cast +from typing import TYPE_CHECKING, Any, cast from unittest.mock import MagicMock, patch import pytest from marimo._entrypoints.ids import KnownEntryPoint from marimo._entrypoints.registry import EntryPointRegistry, get_entry_points -from marimo._runtime.executor import ExecutionConfig, Executor, get_executor +from marimo._runtime.executor import Executor + +if TYPE_CHECKING: + from collections.abc import Callable + + from marimo._ast.cell import CellImpl class TestEntryPointRegistry: @@ -166,35 +173,77 @@ def test_get_all_with_entry_points( assert set(result) == {"value1", "ep_value1", "ep_value2"} -class CustomExecutor(Executor): +class CustomExecutor: + """Protocol-conforming Executor (no ABC inheritance).""" + + name = "custom" + def execute_cell( self, - cell: str, - glbls: dict[str, str], - graph: str, - ) -> str: - return f"Executed {cell} with {glbls} in {graph}" + cell: CellImpl, + glbls: dict[str, Any], + ) -> Any: + return f"Executed {cell} with {glbls}" async def execute_cell_async( self, - cell: str, - glbls: dict[str, str], - graph: str, - ) -> str: - return f"Executed {cell} with {glbls} in {graph}" + cell: CellImpl, + glbls: dict[str, Any], + ) -> Any: + return f"Executed {cell} with {glbls}" -class TestExecutorEntryPoint: - @pytest.fixture - def registry(self) -> EntryPointRegistry[Executor]: - reg = EntryPointRegistry[Executor]("marimo.cell.executor") - reg.register("custom", CustomExecutor) - return reg +def _custom_executor_factory() -> Executor: + return CustomExecutor() - def test_get_entry_points_modern( - self, registry: EntryPointRegistry[Executor] - ) -> None: - executor = get_executor( - ExecutionConfig(is_strict=False), registry=registry + +class TestExecutorEntryPoint: + def test_factory_registers_and_resolves(self) -> None: + # Registry holds factories (Callable[[], Executor]); the kernel + # calls the factory once to get an instance. + reg: EntryPointRegistry[Callable[[], Executor]] = EntryPointRegistry( + "marimo.cell.executor" ) + reg.register("custom", _custom_executor_factory) + + factory = reg.get("custom") + executor = factory() assert isinstance(executor, CustomExecutor) + assert executor.execute_cell("c", {"x": "1"}) == ( # type: ignore[arg-type] + "Executed c with {'x': '1'}" + ) + + def test_resolve_executor_only_loads_first_factory(self) -> None: + """`resolve_executor` must not import factories beyond the first. + + A broken or slow third-party plugin can't take down the kernel + if it never gets loaded. + """ + from marimo._runtime.executor.evaluator import ( + _EXECUTOR_REGISTRY, + resolve_executor, + ) + + loaded: list[str] = [] + + def working_factory() -> Executor: + loaded.append("working") + return CustomExecutor() + + def broken_factory() -> Executor: + loaded.append("broken") + raise RuntimeError("third-party plugin is broken") + + # Restore the registry's plugins on exit so we don't leak + # registrations into other tests. + before = dict(_EXECUTOR_REGISTRY._plugins) + _EXECUTOR_REGISTRY._plugins.clear() + try: + _EXECUTOR_REGISTRY.register("aaa-working", working_factory) + _EXECUTOR_REGISTRY.register("zzz-broken", broken_factory) + executor = resolve_executor() + assert isinstance(executor, CustomExecutor) + assert loaded == ["working"] + finally: + _EXECUTOR_REGISTRY._plugins.clear() + _EXECUTOR_REGISTRY._plugins.update(before)