diff --git a/marimo/_runtime/reload/autoreload.py b/marimo/_runtime/reload/autoreload.py index ada9c1f819c..5292a86a60d 100644 --- a/marimo/_runtime/reload/autoreload.py +++ b/marimo/_runtime/reload/autoreload.py @@ -9,11 +9,13 @@ from __future__ import annotations +import functools import gc import io import modulefinder import os import sys +import sysconfig import threading import traceback import types @@ -73,6 +75,32 @@ def safe_hasattr(obj: M, attr: str) -> bool: return False +@functools.cache +def _non_user_module_roots() -> tuple[str, ...]: + """Filesystem prefixes that hold stdlib + site-packages modules. + + Each entry is normalized and terminated with a separator so that a raw + prefix check on a normalized path matches whole directory boundaries + (e.g. `/usr/lib/python3.13/` does not match `/usr/lib/python3.13-mine/`). + """ + roots: set[str] = set() + for key in ("stdlib", "platstdlib", "purelib", "platlib"): + p = sysconfig.get_path(key) + if p: + roots.add(p) + # Fallback for builds where sysconfig's stdlib path is missing or + # differs from the runtime location of the stdlib. + roots.add(os.path.dirname(os.__file__)) + + normalized: set[str] = set() + for r in roots: + n = os.path.normcase(os.path.realpath(r)) + if not n.endswith(os.sep): + n += os.sep + normalized.add(n) + return tuple(normalized) + + def modules_imported_by_cell( cell: CellImpl, sys_modules: dict[str, types.ModuleType] ) -> set[str]: @@ -160,10 +188,29 @@ def __init__(self) -> None: # for thread-safety self.lock = threading.Lock() self._module_dependency_finder = ModuleDependencyFinder() + # modname -> cached `__file__` for modules classified as non-user. + # Populated by every `check()` call (memoizing `_is_user_module`); + # consumed only when `skip_non_user_modules=True`. Stored value is + # used to invalidate the entry if `sys.modules[modname]` is later + # rebound to a module with a different `__file__` (e.g. a user + # module shadowing an installed package). + self._skip: dict[str, str | None] = {} # Timestamp existing modules self.check(modules=sys.modules, reload=False) + def _is_user_module(self, module: types.ModuleType) -> bool: + """True for modules whose source lives outside stdlib/site-packages. + + Editable installs (e.g. `pip install -e .`) point `__file__` at the + source tree, so they are correctly classified as user code. + """ + f = safe_getattr(module, "__file__", None) + if not f: + return False + path = os.path.normcase(os.path.realpath(f)) + return not path.startswith(_non_user_module_roots()) + def filename_and_mtime( self, module: types.ModuleType ) -> ModuleMTime | None: @@ -206,12 +253,24 @@ def cell_uses_stale_modules(self, cell: CellImpl) -> bool: ) def check( - self, modules: dict[str, types.ModuleType], reload: bool + self, + modules: dict[str, types.ModuleType], + reload: bool, + *, + skip_non_user_modules: bool = False, ) -> set[types.ModuleType]: """Check timestamps of modules, optionally reload them. Also patches existing objects with hot-reloaded ones. + When `skip_non_user_modules` is True, modules whose `__file__` is + under stdlib/site-packages are skipped — intended for the per-cell + hot path. The background `ModuleWatcher` leaves it False so it still + stats every module on its 1s loop, which is what keeps edits inside + installed packages detectable. Both paths populate the same skip + cache, so the hot path benefits from classifications the watcher + has already done. + Returns a set of modules that were found to have been modified. """ @@ -228,6 +287,30 @@ def check( m = modules.get(modname, None) if m is None: continue + # Classify (memoized via `_skip`). The hot path uses the + # cache to short-circuit; the watcher always falls through + # to the stat check so that edits inside installed packages + # are still picked up. The cached entry stores `__file__`, + # so a module rebound to a new location gets reclassified. + current_file = safe_getattr(m, "__file__", None) + if modname in self._skip: + if self._skip[modname] == current_file: + is_non_user = True + else: + # Rebound to a different file — drop all cached + # state for this name so the new module starts + # from a clean mtime baseline. + del self._skip[modname] + self.modules_mtimes.pop(modname, None) + self.stale_modules.discard(modname) + is_non_user = False + else: + is_non_user = False + if not is_non_user and not self._is_user_module(m): + self._skip[modname] = current_file + is_non_user = True + if is_non_user and skip_non_user_modules: + continue module_mtime = self.filename_and_mtime(m) if module_mtime is None: diff --git a/marimo/_runtime/reload/manager.py b/marimo/_runtime/reload/manager.py index 7d99638608b..ebb72e0b901 100644 --- a/marimo/_runtime/reload/manager.py +++ b/marimo/_runtime/reload/manager.py @@ -85,10 +85,18 @@ def cell_scope(self) -> Iterator[None]: yield return snapshot = set(sys.modules) - self._reloader.check(modules=sys.modules, reload=True) + # Entry: skip stdlib/site-packages so cells don't pay for stat-ing + # them. This is the perf-critical call. + self._reloader.check( + modules=sys.modules, reload=True, skip_non_user_modules=True + ) try: yield finally: + # Exit: record mtimes for modules the cell just imported. Don't + # skip here — `new_modules` is small (typically 0-3) and we need + # an mtime baseline for newly-imported installed packages so the + # next edit isn't silently treated as the initial state. new_modules = set(sys.modules) - snapshot self._reloader.check( modules={m: sys.modules[m] for m in new_modules}, diff --git a/tests/_runtime/reload/test_autoreload.py b/tests/_runtime/reload/test_autoreload.py index eabb9425131..d30c43cec56 100644 --- a/tests/_runtime/reload/test_autoreload.py +++ b/tests/_runtime/reload/test_autoreload.py @@ -2,6 +2,7 @@ import gc import importlib +import os import pathlib import sys import textwrap @@ -425,6 +426,163 @@ def test_check_reload_clears_stale_modules( assert len(reloader.stale_modules) == 0 +class TestSkipCache: + def test_is_user_module_stdlib(self): + reloader = ModuleReloader() + assert reloader._is_user_module(sys.modules["os"]) is False + assert reloader._is_user_module(sys.modules["pathlib"]) is False + + def test_is_user_module_builtin_has_no_file(self): + reloader = ModuleReloader() + assert reloader._is_user_module(sys.modules["sys"]) is False + assert reloader._is_user_module(sys.modules["builtins"]) is False + + def test_is_user_module_user_code( + self, tmp_path: pathlib.Path, py_modname: str + ): + sys.path.append(str(tmp_path)) + py_file = tmp_path / pathlib.Path(py_modname + ".py") + py_file.write_text("x = 1") + mod = importlib.import_module(py_modname) + reloader = ModuleReloader() + assert reloader._is_user_module(mod) is True + + def test_both_paths_populate_skip(self): + # The cache is shared memoization for the classification step; + # whichever path sees a module first records the verdict. + reloader = ModuleReloader() + reloader.check(sys.modules, reload=False) + assert "os" in reloader._skip + + reloader2 = ModuleReloader() + reloader2.check(sys.modules, reload=False, skip_non_user_modules=True) + assert "os" in reloader2._skip + + def test_user_module_not_skipped_on_hot_path( + self, tmp_path: pathlib.Path, py_modname: str + ): + sys.path.append(str(tmp_path)) + py_file = tmp_path / pathlib.Path(py_modname + ".py") + py_file.write_text("x = 1") + importlib.import_module(py_modname) + reloader = ModuleReloader() + reloader.check(sys.modules, reload=False, skip_non_user_modules=True) + assert py_modname not in reloader._skip + + def test_skipped_modules_are_not_restated(self, monkeypatch): + reloader = ModuleReloader() + reloader.check(sys.modules, reload=False, skip_non_user_modules=True) + assert "os" in reloader._skip + + calls: list[str] = [] + orig = reloader.filename_and_mtime + + def spy(module): + calls.append(getattr(module, "__name__", "?")) + return orig(module) + + monkeypatch.setattr(reloader, "filename_and_mtime", spy) + reloader.check(sys.modules, reload=False, skip_non_user_modules=True) + assert "os" not in calls + assert "pathlib" not in calls + + def test_watcher_path_still_sees_installed_packages( + self, tmp_path: pathlib.Path, py_modname: str, monkeypatch + ): + # Regression guard: even after the hot path has classified a module + # as non-user (and cached it in `_skip`), the watcher's + # `skip_non_user_modules=False` call must still stat it and detect + # edits. Without this, `auto_reload` users editing files inside an + # installed package would silently stop getting hot reloads. + import marimo._runtime.reload.autoreload as autoreload_mod + + sys.path.append(str(tmp_path)) + py_file = tmp_path / pathlib.Path(py_modname + ".py") + py_file.write_text("x = 1") + mod = importlib.import_module(py_modname) + + real_roots = autoreload_mod._non_user_module_roots() + tmp_root = os.path.normcase(os.path.realpath(str(tmp_path))) + os.sep + monkeypatch.setattr( + autoreload_mod, + "_non_user_module_roots", + lambda: (tmp_root,) + real_roots, + ) + + reloader = ModuleReloader() + assert reloader._is_user_module(mod) is False + + # Hot path classifies and caches. + reloader.check(sys.modules, reload=False, skip_non_user_modules=True) + assert py_modname in reloader._skip + + # Watcher path falls through to the stat check despite the cache. + update_file(py_file, "x = 2") + assert any(m is mod for m in reloader.check(sys.modules, reload=False)) + + def test_skip_cache_invalidates_when_module_rebound( + self, tmp_path: pathlib.Path, py_modname: str + ): + # If `sys.modules[modname]` is rebound to a module with a different + # `__file__` (e.g. a user file shadows an installed package), the + # cached non-user verdict must not stick — and any cached mtime + # from the old module must be cleared too, otherwise an older user + # file would be silently treated as unchanged. + sys.path.append(str(tmp_path)) + user_file = tmp_path / pathlib.Path(py_modname + ".py") + user_file.write_text("x = 1") + user_mod = importlib.import_module(py_modname) + + # Plant a fake "installed" version under the same name first. + fake_installed = types.ModuleType(py_modname) + fake_installed.__file__ = os.path.join( + os.path.dirname(os.__file__), py_modname + ".py" + ) + sys.modules[py_modname] = fake_installed + + reloader = ModuleReloader() + # Watcher-style call: classifies as non-user AND records a + # (synthetic) far-future mtime so we can detect stale-cache leakage. + reloader.check(sys.modules, reload=False) + assert py_modname in reloader._skip + reloader.modules_mtimes[py_modname] = 1e12 + + # Rebind to the real user module. + sys.modules[py_modname] = user_mod + reloader.check(sys.modules, reload=False, skip_non_user_modules=True) + assert py_modname not in reloader._skip + # Stale mtime is cleared so the next edit isn't masked by it. + assert reloader.modules_mtimes.get(py_modname, 0) < 1e12 + + def test_user_module_reload_still_works( + self, tmp_path: pathlib.Path, py_modname: str + ): + sys.path.append(str(tmp_path)) + py_file = tmp_path / pathlib.Path(py_modname + ".py") + py_file.write_text( + textwrap.dedent( + """ + def foo(): + return 1 + """ + ) + ) + mod = importlib.import_module(py_modname) + reloader = ModuleReloader() + reloader.check(sys.modules, reload=False) + assert mod.foo() == 1 + + update_file( + py_file, + """ + def foo(): + return 2 + """, + ) + reloader.check(sys.modules, reload=True) + assert mod.foo() == 2 + + class TestUpdateFunctions: """Tests for update_* functions"""