diff --git a/marimo/_plugins/stateless/status/_progress.py b/marimo/_plugins/stateless/status/_progress.py index 82e6ed60ea8..17549ec321e 100644 --- a/marimo/_plugins/stateless/status/_progress.py +++ b/marimo/_plugins/stateless/status/_progress.py @@ -389,7 +389,6 @@ def __init__( self.completion_subtitle = completion_subtitle self.remove_on_exit = remove_on_exit self.disabled = disabled - self.step: int = 1 self.collection = collection self._is_async = isinstance(collection, AsyncIterable) @@ -404,9 +403,6 @@ def __init__( "A `total` must be provided." ) - if isinstance(collection, range): - self.step = cast(range, collection).step - elif total is None: raise ValueError( "`total` is required when using as a context manager" @@ -438,7 +434,7 @@ def __iter__(self) -> Iterator[S]: for item in cast(Iterable[S], self.collection): yield item if not self.disabled: - self.progress.update(increment=self.step) + self.progress.update(increment=1) finally: self._finish() @@ -458,7 +454,7 @@ async def __aiter__(self) -> AsyncIterator[S]: async for item in cast(AsyncIterable[S], self.collection): yield item if not self.disabled: - self.progress.update(increment=self.step) + self.progress.update(increment=1) finally: self._finish() diff --git a/tests/_plugins/stateless/status/test_progress.py b/tests/_plugins/stateless/status/test_progress.py index 2c4978e5d04..dd27b30f88c 100644 --- a/tests/_plugins/stateless/status/test_progress.py +++ b/tests/_plugins/stateless/status/test_progress.py @@ -8,6 +8,7 @@ import pytest from marimo._plugins.stateless.status._progress import ( + ProgressBar, _Progress, progress_bar, spinner, @@ -247,3 +248,88 @@ async def test_progress_async_for_loop_without_collection_error(): ): async for _ in progress_bar(total=1): pass + + +# Bug fix: mo.status.progress_bar should use the step property of 'range' +# https://github.com/marimo-team/marimo/issues/9575 + + +@patch("marimo._runtime.output._output.flush") +@patch("marimo._runtime.output._output.append") +def test_progress_bar_range_with_step(mock_append, mock_flush): + """progress_bar with range(0, 10, 2) should yield 5 items and increment by 1.""" + del mock_flush + + captured = [None] + + def capture_progress(obj): + if isinstance(obj, ProgressBar): + captured[0] = obj + + mock_append.side_effect = capture_progress + + result = list(progress_bar(range(0, 10, 2))) + assert result == [0, 2, 4, 6, 8] + assert captured[0] is not None + # Each iteration should increment by 1, not by the range step + assert captured[0].current == 5 + + +@patch("marimo._runtime.output._output.flush") +@patch("marimo._runtime.output._output.append") +def test_progress_bar_range_no_step(mock_append, mock_flush): + """progress_bar with range(10) should still work correctly.""" + del mock_flush + + captured = [None] + + def capture_progress(obj): + if isinstance(obj, ProgressBar): + captured[0] = obj + + mock_append.side_effect = capture_progress + + result = list(progress_bar(range(10))) + assert result == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + assert captured[0] is not None + assert captured[0].current == 10 + + +@patch("marimo._runtime.output._output.flush") +@patch("marimo._runtime.output._output.append") +def test_progress_bar_range_custom_step(mock_append, mock_flush): + """progress_bar with range(5, 20, 3) should yield 5 items.""" + del mock_flush + + captured = [None] + + def capture_progress(obj): + if isinstance(obj, ProgressBar): + captured[0] = obj + + mock_append.side_effect = capture_progress + + result = list(progress_bar(range(5, 20, 3))) + assert result == [5, 8, 11, 14, 17] + assert captured[0] is not None + assert captured[0].current == 5 + + +@patch("marimo._runtime.output._output.flush") +@patch("marimo._runtime.output._output.append") +def test_progress_bar_range_negative_step(mock_append, mock_flush): + """progress_bar with range(10, 0, -2) should yield 5 items.""" + del mock_flush + + captured = [None] + + def capture_progress(obj): + if isinstance(obj, ProgressBar): + captured[0] = obj + + mock_append.side_effect = capture_progress + + result = list(progress_bar(range(10, 0, -2))) + assert result == [10, 8, 6, 4, 2] + assert captured[0] is not None + assert captured[0].current == 5