Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions marimo/_plugins/stateless/status/_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,6 @@ def __init__(
self.completion_subtitle = completion_subtitle
self.remove_on_exit = remove_on_exit
self.disabled = disabled
self.step: int = 1
self.collection = collection
self._is_async = isinstance(collection, AsyncIterable)

Expand All @@ -404,9 +403,6 @@ def __init__(
"A `total` must be provided."
)

if isinstance(collection, range):
self.step = cast(range, collection).step

elif total is None:
raise ValueError(
"`total` is required when using as a context manager"
Expand Down Expand Up @@ -438,7 +434,7 @@ def __iter__(self) -> Iterator[S]:
for item in cast(Iterable[S], self.collection):
yield item
if not self.disabled:
self.progress.update(increment=self.step)
self.progress.update(increment=1)
finally:
self._finish()

Expand All @@ -458,7 +454,7 @@ async def __aiter__(self) -> AsyncIterator[S]:
async for item in cast(AsyncIterable[S], self.collection):
yield item
if not self.disabled:
self.progress.update(increment=self.step)
self.progress.update(increment=1)
finally:
self._finish()

Expand Down
87 changes: 87 additions & 0 deletions tests/_plugins/stateless/status/test_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from marimo._plugins.stateless.status._progress import (
_Progress,
ProgressBar,
progress_bar,
spinner,
)
Expand Down Expand Up @@ -247,3 +248,89 @@ async def test_progress_async_for_loop_without_collection_error():
):
async for _ in progress_bar(total=1):
pass


# Bug fix: mo.status.progress_bar should use the step property of 'range'
# https://github.com/marimo-team/marimo/issues/9575


@patch("marimo._runtime.output._output.flush")
@patch("marimo._runtime.output._output.append")
def test_progress_bar_range_with_step(mock_append, mock_flush):
"""progress_bar with range(0, 10, 2) should yield 5 items and increment by 1."""
del mock_flush

captured = [None]

def capture_progress(obj):
if isinstance(obj, ProgressBar):
captured[0] = obj

mock_append.side_effect = capture_progress

result = list(progress_bar(range(0, 10, 2)))
assert result == [0, 2, 4, 6, 8]
assert captured[0] is not None
# Each iteration should increment by 1, not by the range step
assert captured[0].current == 5


@patch("marimo._runtime.output._output.flush")
@patch("marimo._runtime.output._output.append")
def test_progress_bar_range_no_step(mock_append, mock_flush):
"""progress_bar with range(10) should still work correctly."""
del mock_flush

captured = [None]

def capture_progress(obj):
if isinstance(obj, ProgressBar):
captured[0] = obj

mock_append.side_effect = capture_progress

result = list(progress_bar(range(10)))
assert result == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
assert captured[0] is not None
assert captured[0].current == 10


@patch("marimo._runtime.output._output.flush")
@patch("marimo._runtime.output._output.append")
def test_progress_bar_range_custom_step(mock_append, mock_flush):
"""progress_bar with range(5, 20, 3) should yield 5 items."""
del mock_flush

captured = [None]

def capture_progress(obj):
if isinstance(obj, ProgressBar):
captured[0] = obj

mock_append.side_effect = capture_progress

result = list(progress_bar(range(5, 20, 3)))
assert result == [5, 8, 11, 14, 17]
assert captured[0] is not None
assert captured[0].current == 5


@patch("marimo._runtime.output._output.flush")
@patch("marimo._runtime.output._output.append")
def test_progress_bar_range_negative_step(mock_append, mock_flush):
"""progress_bar with range(10, 0, -2) should yield 5 items."""
del mock_flush

captured = [None]

def capture_progress(obj):
if isinstance(obj, ProgressBar):
captured[0] = obj

mock_append.side_effect = capture_progress

result = list(progress_bar(range(10, 0, -2)))
assert result == [10, 8, 6, 4, 2]
assert captured[0] is not None
assert captured[0].current == 5

Loading