Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 62 additions & 1 deletion marimo/_runtime/reload/autoreload.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import modulefinder
import os
import sys
import sysconfig
import threading
import traceback
import types
Expand Down Expand Up @@ -73,6 +74,31 @@ def safe_hasattr(obj: M, attr: str) -> bool:
return False


def _non_user_module_roots() -> tuple[str, ...]:
Comment thread
mscolnick marked this conversation as resolved.
"""Filesystem prefixes that hold stdlib + site-packages modules.

Each entry is normalized and terminated with a separator so that a raw
prefix check on a normalized path matches whole directory boundaries
(e.g. `/usr/lib/python3.13/` does not match `/usr/lib/python3.13-mine/`).
"""
roots: set[str] = set()
for key in ("stdlib", "platstdlib", "purelib", "platlib"):
p = sysconfig.get_path(key)
Comment thread
mscolnick marked this conversation as resolved.
if p:
roots.add(p)
# Fallback for builds where sysconfig's stdlib path is missing or
# differs from the runtime location of the stdlib.
roots.add(os.path.dirname(os.__file__))

normalized: set[str] = set()
for r in roots:
n = os.path.normcase(os.path.realpath(r))
if not n.endswith(os.sep):
n += os.sep
normalized.add(n)
return tuple(normalized)


def modules_imported_by_cell(
cell: CellImpl, sys_modules: dict[str, types.ModuleType]
) -> set[str]:
Expand Down Expand Up @@ -160,10 +186,29 @@ def __init__(self) -> None:
# for thread-safety
self.lock = threading.Lock()
self._module_dependency_finder = ModuleDependencyFinder()
# Names known to live in stdlib/site-packages. Populated lazily by
# callers that pass `skip_non_user_modules=True` (the hot per-cell
# path), and then consulted by every `check()` call. Entries are
# never evicted: a module whose `__file__` moves between roots at
# runtime would not be re-evaluated.
self._skip: set[str] = set()
self._non_user_roots = _non_user_module_roots()
Comment thread
mscolnick marked this conversation as resolved.
Outdated

# Timestamp existing modules
self.check(modules=sys.modules, reload=False)

def _is_user_module(self, module: types.ModuleType) -> bool:
"""True for modules whose source lives outside stdlib/site-packages.

Editable installs (e.g. `pip install -e .`) point `__file__` at the
source tree, so they are correctly classified as user code.
"""
f = safe_getattr(module, "__file__", None)
if not f:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

false positive on c libraries? Unsure, but I think so. Maybe that's fine

return False
path = os.path.normcase(os.path.realpath(f))
return not path.startswith(self._non_user_roots)

def filename_and_mtime(
self, module: types.ModuleType
) -> ModuleMTime | None:
Expand Down Expand Up @@ -206,12 +251,23 @@ def cell_uses_stale_modules(self, cell: CellImpl) -> bool:
)

def check(
self, modules: dict[str, types.ModuleType], reload: bool
self,
modules: dict[str, types.ModuleType],
reload: bool,
*,
skip_non_user_modules: bool = False,
) -> set[types.ModuleType]:
"""Check timestamps of modules, optionally reload them.

Also patches existing objects with hot-reloaded ones.

When `skip_non_user_modules` is True, modules whose `__file__` is
under stdlib/site-packages are added to a persistent skip set and
won't be stat-ed on this or future calls. Intended for the per-cell
hot path. Once populated, the skip set short-circuits every caller
— including the background `ModuleWatcher` — so edits inside
installed packages are not hot-reloaded.

Returns a set of modules that were found to have been modified.
"""

Expand All @@ -225,9 +281,14 @@ def check(
# materialize the module keys, since we'll be reloading while
# iterating
for modname in list(modules.keys()):
if modname in self._skip:
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
Outdated
continue
m = modules.get(modname, None)
if m is None:
continue
if skip_non_user_modules and not self._is_user_module(m):
Comment thread
mscolnick marked this conversation as resolved.
Outdated
self._skip.add(modname)
continue

module_mtime = self.filename_and_mtime(m)
if module_mtime is None:
Expand Down
10 changes: 9 additions & 1 deletion marimo/_runtime/reload/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,18 @@ def cell_scope(self) -> Iterator[None]:
yield
return
snapshot = set(sys.modules)
self._reloader.check(modules=sys.modules, reload=True)
# Entry: skip stdlib/site-packages so cells don't pay for stat-ing
# them. This is the perf-critical call.
self._reloader.check(
modules=sys.modules, reload=True, skip_non_user_modules=True
)
Comment thread
mscolnick marked this conversation as resolved.
try:
yield
finally:
# Exit: record mtimes for modules the cell just imported. Don't
# skip here — `new_modules` is small (typically 0-3) and we need
# an mtime baseline for newly-imported installed packages so the
# next edit isn't silently treated as the initial state.
new_modules = set(sys.modules) - snapshot
self._reloader.check(
modules={m: sys.modules[m] for m in new_modules},
Expand Down
140 changes: 140 additions & 0 deletions tests/_runtime/reload/test_autoreload.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import gc
import importlib
import os
import pathlib
import sys
import textwrap
Expand Down Expand Up @@ -425,6 +426,145 @@ def test_check_reload_clears_stale_modules(
assert len(reloader.stale_modules) == 0


class TestSkipCache:
def test_is_user_module_stdlib(self):
reloader = ModuleReloader()
assert reloader._is_user_module(sys.modules["os"]) is False
assert reloader._is_user_module(sys.modules["pathlib"]) is False

def test_is_user_module_builtin_has_no_file(self):
reloader = ModuleReloader()
assert reloader._is_user_module(sys.modules["sys"]) is False
assert reloader._is_user_module(sys.modules["builtins"]) is False

def test_is_user_module_user_code(
self, tmp_path: pathlib.Path, py_modname: str
):
sys.path.append(str(tmp_path))
py_file = tmp_path / pathlib.Path(py_modname + ".py")
py_file.write_text("x = 1")
mod = importlib.import_module(py_modname)
reloader = ModuleReloader()
assert reloader._is_user_module(mod) is True

def test_default_check_does_not_populate_skip(self):
# The watcher path must keep scanning everything so that edits inside
# installed packages remain detectable.
reloader = ModuleReloader()
assert reloader._skip == set()

def test_hot_path_populates_skip_with_stdlib(self):
reloader = ModuleReloader()
reloader.check(sys.modules, reload=False, skip_non_user_modules=True)
assert "os" in reloader._skip
assert "pathlib" in reloader._skip

def test_user_module_not_skipped_on_hot_path(
self, tmp_path: pathlib.Path, py_modname: str
):
sys.path.append(str(tmp_path))
py_file = tmp_path / pathlib.Path(py_modname + ".py")
py_file.write_text("x = 1")
importlib.import_module(py_modname)
reloader = ModuleReloader()
reloader.check(sys.modules, reload=False, skip_non_user_modules=True)
assert py_modname not in reloader._skip

def test_skipped_modules_are_not_restated(self, monkeypatch):
reloader = ModuleReloader()
reloader.check(sys.modules, reload=False, skip_non_user_modules=True)
assert "os" in reloader._skip

calls: list[str] = []
orig = reloader.filename_and_mtime

def spy(module):
calls.append(getattr(module, "__name__", "?"))
return orig(module)

monkeypatch.setattr(reloader, "filename_and_mtime", spy)
reloader.check(sys.modules, reload=False, skip_non_user_modules=True)
assert "os" not in calls
assert "pathlib" not in calls

def test_default_check_does_not_populate_skip_on_its_own(
self, tmp_path: pathlib.Path, py_modname: str, monkeypatch
):
# The watcher uses `skip_non_user_modules=False`. It must not add
# anything to `_skip` on its own — populating the cache is the hot
# path's responsibility.
sys.path.append(str(tmp_path))
(tmp_path / pathlib.Path(py_modname + ".py")).write_text("x = 1")
importlib.import_module(py_modname)

reloader = ModuleReloader()
tmp_root = os.path.normcase(os.path.realpath(str(tmp_path))) + os.sep
monkeypatch.setattr(
reloader,
"_non_user_roots",
(tmp_root,) + reloader._non_user_roots,
)
reloader.check(sys.modules, reload=False)
assert py_modname not in reloader._skip

def test_hot_path_population_short_circuits_subsequent_calls(
self, tmp_path: pathlib.Path, py_modname: str, monkeypatch
):
# Once the hot path has classified a module as non-user, every later
# `check()` skips it — including the default (watcher) path. This is
# a documented tradeoff: it means edits inside an installed package
# are not hot-reloaded once a cell has run.
sys.path.append(str(tmp_path))
py_file = tmp_path / pathlib.Path(py_modname + ".py")
py_file.write_text("x = 1")
mod = importlib.import_module(py_modname)

reloader = ModuleReloader()
tmp_root = os.path.normcase(os.path.realpath(str(tmp_path))) + os.sep
monkeypatch.setattr(
reloader,
"_non_user_roots",
(tmp_root,) + reloader._non_user_roots,
)
assert reloader._is_user_module(mod) is False

reloader.check(sys.modules, reload=False, skip_non_user_modules=True)
assert py_modname in reloader._skip

update_file(py_file, "x = 2")
assert not any(
m is mod for m in reloader.check(sys.modules, reload=False)
)

def test_user_module_reload_still_works(
self, tmp_path: pathlib.Path, py_modname: str
):
sys.path.append(str(tmp_path))
py_file = tmp_path / pathlib.Path(py_modname + ".py")
py_file.write_text(
textwrap.dedent(
"""
def foo():
return 1
"""
)
)
mod = importlib.import_module(py_modname)
reloader = ModuleReloader()
reloader.check(sys.modules, reload=False)
assert mod.foo() == 1

update_file(
py_file,
"""
def foo():
return 2
""",
)
reloader.check(sys.modules, reload=True)
assert mod.foo() == 2


class TestUpdateFunctions:
"""Tests for update_* functions"""

Expand Down
Loading