-
Notifications
You must be signed in to change notification settings - Fork 117
feat(stdlib): add stream_with_chunking() with per-chunk validation (#901) #942
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 18 commits
8128dfa
f26cce7
93e7587
a5d358c
39f18a4
36173cb
ea6bdb0
35df77f
61448a9
def10b6
da41a06
74c009d
3fb501e
5850f92
4f508fd
5075a47
f0f93b3
18bfe02
7fc40a4
d8018dd
bf9a62b
9a715d6
f3e3501
2f2e352
66260fe
2cac22c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| # pytest: ollama, e2e | ||
|
|
||
| """Streaming generation with per-chunk validation using stream_with_chunking(). | ||
|
|
||
| Demonstrates: | ||
| - Subclassing Requirement to override stream_validate() for early-exit checks | ||
| - Calling stream_with_chunking() with sentence-level chunking | ||
| - Consuming validated chunks via astream() as they arrive | ||
| - Awaiting full completion with acomplete() to access final_validations and full_text | ||
| """ | ||
|
|
||
| import asyncio | ||
|
|
||
| from mellea.core.backend import Backend | ||
| from mellea.core.base import Context | ||
| from mellea.core.requirement import ( | ||
| PartialValidationResult, | ||
| Requirement, | ||
| ValidationResult, | ||
| ) | ||
| from mellea.stdlib.components import Instruction | ||
| from mellea.stdlib.streaming import stream_with_chunking | ||
|
|
||
|
|
||
| class MaxSentencesReq(Requirement): | ||
| """Fails if the model generates more than *limit* sentences mid-stream. | ||
|
|
||
| Each ``stream_validate`` call receives one complete sentence from the | ||
| :class:`~mellea.stdlib.chunking.SentenceChunker`. The running count is | ||
| maintained on ``self`` — this is the standard pattern for requirements | ||
| that need context beyond a single chunk. | ||
| """ | ||
|
|
||
| def __init__(self, limit: int) -> None: | ||
| super().__init__() | ||
| self._limit = limit | ||
| self._count = 0 | ||
|
|
||
| def format_for_llm(self) -> str: | ||
| return f"The response must be at most {self._limit} sentences long." | ||
|
|
||
| async def stream_validate( | ||
| self, chunk: str, *, backend: Backend, ctx: Context | ||
| ) -> PartialValidationResult: | ||
| self._count += 1 | ||
| if self._count > self._limit: | ||
| return PartialValidationResult( | ||
| "fail", | ||
| reason=f"Response exceeded {self._limit} sentence limit mid-stream", | ||
| ) | ||
| return PartialValidationResult("unknown") | ||
|
jakelorocco marked this conversation as resolved.
|
||
|
|
||
| async def validate( | ||
| self, | ||
| backend: Backend, | ||
| ctx: Context, | ||
| *, | ||
| format: type | None = None, | ||
| model_options: dict | None = None, | ||
| ) -> ValidationResult: | ||
| return ValidationResult(result=True) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think validate and stream_validate should return equivalent results for most requirements.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good call — the always- |
||
|
|
||
|
|
||
| async def main() -> None: | ||
| from mellea.stdlib.session import start_session | ||
|
|
||
| m = start_session() | ||
| backend = m.backend | ||
| ctx = m.ctx | ||
|
|
||
| action = Instruction( | ||
| "Write a short paragraph about the water cycle in exactly two sentences." | ||
| ) | ||
| req = MaxSentencesReq(limit=3) | ||
|
|
||
| result = await stream_with_chunking( | ||
| action, backend, ctx, quick_check_requirements=[req], chunking="sentence" | ||
| ) | ||
|
|
||
| print("Streaming chunks as they arrive:") | ||
| async for chunk in result.astream(): | ||
| print(f" CHUNK: {chunk!r}") | ||
|
|
||
| await result.acomplete() | ||
|
|
||
| print(f"\nCompleted normally: {result.completed}") | ||
| print(f"Full text: {result.full_text!r}") | ||
|
|
||
| if result.streaming_failures: | ||
| for _req, pvr in result.streaming_failures: | ||
| print(f"Streaming failure: {pvr.reason}") | ||
|
|
||
| if result.final_validations: | ||
| for vr in result.final_validations: | ||
| print(f"Final validation: {'PASS' if vr.as_bool() else 'FAIL'}") | ||
|
|
||
|
|
||
| asyncio.run(main()) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -320,6 +320,7 @@ def __init__( | |
|
|
||
| # Set computed to True if a value is passed in. | ||
| self._computed: bool = True if value is not None else False | ||
| self._cancelled: bool = False | ||
|
|
||
| # Additional fields that should be standardized across apis. | ||
| self.tool_calls = tool_calls | ||
|
|
@@ -364,6 +365,86 @@ def _record_ttfb(self) -> None: | |
| ).total_seconds() * 1000 | ||
| self._first_chunk_received = True | ||
|
|
||
| async def cancel_generation(self, error: Exception | None = None) -> None: | ||
| """Cancel an in-progress streaming generation, drain the queue, and close any open telemetry span. | ||
|
|
||
| Safe to call at any point during streaming. After this method returns, | ||
| ``is_computed()`` is ``True`` and ``value`` contains whatever text was | ||
| accumulated before cancellation. Calling on an already-computed MOT | ||
| is a no-op. | ||
|
planetf1 marked this conversation as resolved.
|
||
|
|
||
| Draining the internal queue after cancellation is necessary to release | ||
| any ``asyncio.Queue.put()`` call that the generation task was blocked on | ||
| (queue maxsize=20). | ||
|
|
||
| Args: | ||
| error: Optional cause attributed to the open telemetry span. When | ||
| provided, this exception is recorded via ``set_span_error`` so | ||
| the span reflects the actual reason for cancellation (e.g. the | ||
| requirement failure or an unhandled exception from a streaming | ||
| validator). When ``None``, a generic | ||
| ``RuntimeError("Generation cancelled")`` is recorded. | ||
| """ | ||
| if self._computed: | ||
| return | ||
|
|
||
| def _drain() -> None: | ||
| while not self._async_queue.empty(): | ||
| try: | ||
| self._async_queue.get_nowait() | ||
| except asyncio.QueueEmpty: | ||
| break | ||
|
|
||
| if self._generate is not None and not self._generate.done(): | ||
| self._generate.cancel() | ||
|
|
||
| if self._generate_extra is not None and not self._generate_extra.done(): | ||
| self._generate_extra.cancel() | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This cancels the asyncio task wrapper, but for the HF backend Can we add a cooperative cancellation hook for backend producers, at least for HF? For example, HF could install a
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Confirmed. Addressed in d8018dd by adding a generic cooperative-cancel hook on
Chose the generic hook rather than an HF-specific cancel event because a future litellm or vLLM streaming path has the same shape (blocking work wrapped in Tests in
|
||
|
|
||
| # Drain before awaiting — unblocks any put() the task is stuck on. | ||
| _drain() | ||
|
|
||
| if self._generate is not None: | ||
| try: | ||
| await self._generate | ||
| except (asyncio.CancelledError, Exception): | ||
| pass | ||
|
|
||
| if self._generate_extra is not None: | ||
| try: | ||
| await self._generate_extra | ||
| except (asyncio.CancelledError, Exception): | ||
| pass | ||
|
|
||
| # Drain again for any final item the task put before terminating. | ||
| _drain() | ||
|
|
||
| span = self._meta.pop("_telemetry_span", None) | ||
| if span is not None: | ||
| from ..telemetry import end_backend_span, set_span_error | ||
|
|
||
| recorded: Exception = ( | ||
| error if error is not None else RuntimeError("Generation cancelled") | ||
| ) | ||
| set_span_error(span, recorded) | ||
| end_backend_span(span) | ||
|
|
||
| if self._underlying_value is None: | ||
| self._underlying_value = "" | ||
| self._cancelled = True | ||
| self._computed = True | ||
|
|
||
| @property | ||
| def cancelled(self) -> bool: | ||
| """``True`` if :meth:`cancel_generation` ran to completion on this MOT. | ||
|
|
||
| A normally-completed MOT leaves this ``False``; only an actual | ||
| cancellation via :meth:`cancel_generation` flips it. Consumers holding | ||
| a computed MOT can use this to distinguish a genuine result from one | ||
| cut short (for example by a streaming requirement failure). | ||
| """ | ||
| return self._cancelled | ||
|
|
||
| def _copy_from(self, other: ModelOutputThunk) -> None: | ||
| """Copy computed-output fields from *other* into *self*. | ||
|
|
||
|
|
@@ -378,6 +459,7 @@ def _copy_from(self, other: ModelOutputThunk) -> None: | |
| self._thinking = other._thinking | ||
| self.generation = other.generation | ||
| self._generate_log = other._generate_log | ||
| self._cancelled = other._cancelled | ||
|
|
||
| def is_computed(self) -> bool: | ||
| """Returns true only if this Thunk has already been filled. | ||
|
|
@@ -606,6 +688,7 @@ def __copy__(self) -> ModelOutputThunk: | |
| copied.parsed_repr = copied # type: ignore | ||
|
|
||
| copied._computed = self._computed | ||
| copied._cancelled = self._cancelled | ||
| copied._thinking = self._thinking | ||
| copied._action = self._action | ||
| copied._context = self._context | ||
|
|
@@ -634,6 +717,7 @@ def __deepcopy__(self, memo: dict) -> ModelOutputThunk: | |
| deepcopied._meta = deepcopy(self._meta) | ||
| deepcopied.tool_calls = deepcopy(self.tool_calls) | ||
| deepcopied._computed = self._computed | ||
| deepcopied._cancelled = self._cancelled | ||
| deepcopied._thinking = self._thinking | ||
| deepcopied._action = deepcopy(self._action) | ||
| deepcopied._context = copy( | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.