diff --git a/marimo/_save/save.py b/marimo/_save/save.py index 017a470a52d..6a59cb6fe32 100644 --- a/marimo/_save/save.py +++ b/marimo/_save/save.py @@ -1157,7 +1157,7 @@ def factorial(n): *args, pin_modules=pin_modules, loader=MemoryLoader.partial(max_size=maxsize), - _frame_offset=2, + _frame_offset=2 if callable(arg) else 1, **kwargs, ), ) diff --git a/tests/_save/test_cache.py b/tests/_save/test_cache.py index 1406a14e893..8c9db07faa8 100644 --- a/tests/_save/test_cache.py +++ b/tests/_save/test_cache.py @@ -1081,6 +1081,36 @@ def fib(n): ) assert k.globals["b"] == 55 + async def test_lru_cache_with_maxsize_persists_across_cell_reruns( + self, k: Kernel, exec_req: ExecReqProvider + ) -> None: + cached_cell = exec_req.get( + """ + @lru_cache(maxsize=128) + def foo(): + print("ran") + + foo() + foo() + """ + ) + + await k.run( + [ + exec_req.get( + """ + from marimo._save.save import lru_cache + """ + ), + cached_cell, + ] + ) + await k.run([cached_cell]) + + assert not k.stderr.messages, k.stderr + assert k.stdout.messages.count("ran") == 1 + assert k.globals["foo"].hits == 3 + async def test_persistent_cache( self, k: Kernel, exec_req: ExecReqProvider ) -> None: