diff --git a/AGENTS.md b/AGENTS.md index cc068f107..170dbca22 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -126,6 +126,11 @@ Use the tool's common name (e.g., GitHub Copilot, Cursor, etc.). 3. New functions typed with concise docstrings? 4. Unit tests added for new functionality? 5. Avoided over-engineering? +6. If the diff adds `raise` statements to library code (`mellea/` but not `test/`), run the docstring quality gate before pushing: + ```bash + uv run python tooling/docs-autogen/audit_coverage.py --docs-dir docs/docs/api --quality --fail-on-quality --threshold 100 --orphans + ``` + Every new `raise` in a public function requires a matching `Raises:` entry — the `build-and-validate` CI job enforces this with `--fail-on-quality`. ## 11. Writing Tests diff --git a/docs/examples/streaming/streaming_chunking.py b/docs/examples/streaming/streaming_chunking.py new file mode 100644 index 000000000..81056abaf --- /dev/null +++ b/docs/examples/streaming/streaming_chunking.py @@ -0,0 +1,110 @@ +# pytest: ollama, e2e + +"""Streaming generation with per-chunk validation using stream_with_chunking(). + +Demonstrates: +- Subclassing Requirement to override stream_validate() for early-exit checks +- Calling stream_with_chunking() with sentence-level chunking +- Consuming validated chunks via astream() as they arrive +- Awaiting full completion with acomplete() to access final_validations and full_text +""" + +import asyncio +import re + +from mellea.core.backend import Backend +from mellea.core.base import Context +from mellea.core.requirement import ( + PartialValidationResult, + Requirement, + ValidationResult, +) +from mellea.stdlib.components import Instruction +from mellea.stdlib.streaming import stream_with_chunking + +# Crude sentence-terminator detector. A run of ``.``/``!``/``?`` counts once +# (so "..." and "!!!" are a single terminator). Good enough for an example; +# production code might use spaCy/NLTK for proper sentence segmentation. +_SENTENCE_END = re.compile(r"[.!?]+") + + +class MaxSentencesReq(Requirement): + """Fails if the model generates more than *limit* sentences mid-stream. + + Counts sentence terminators in the chunk *text* rather than counting + ``stream_validate`` calls. This makes the requirement **chunker-agnostic**: + the same instance behaves correctly with sentence, word, or paragraph + chunking, because the semantics depend on content, not on the chunker's + structural decisions. + + When writing your own streaming requirements, prefer this content-driven + pattern over coupling the requirement to a specific chunker. Reach for + chunker-coupled logic only when the requirement is genuinely a property + of chunk boundaries (e.g. "no chunk longer than N tokens"). + """ + + def __init__(self, limit: int) -> None: + super().__init__() + self._limit = limit + self._count = 0 + + def format_for_llm(self) -> str: + return f"The response must be at most {self._limit} sentences long." + + async def stream_validate( + self, chunk: str, *, backend: Backend, ctx: Context + ) -> PartialValidationResult: + self._count += len(_SENTENCE_END.findall(chunk)) + if self._count > self._limit: + return PartialValidationResult( + "fail", + reason=f"Response exceeded {self._limit} sentence limit mid-stream", + ) + return PartialValidationResult("unknown") + + async def validate( + self, + backend: Backend, + ctx: Context, + *, + format: type | None = None, + model_options: dict | None = None, + ) -> ValidationResult: + return ValidationResult(result=self._count <= self._limit) + + +async def main() -> None: + from mellea.stdlib.session import start_session + + m = start_session() + backend = m.backend + ctx = m.ctx + + action = Instruction( + "Write a short paragraph about the water cycle in exactly two sentences." + ) + req = MaxSentencesReq(limit=3) + + result = await stream_with_chunking( + action, backend, ctx, requirements=[req], chunking="sentence" + ) + + print("Streaming chunks as they arrive:") + async for chunk in result.astream(): + print(f" CHUNK: {chunk!r}") + + await result.acomplete() + + print(f"\nCompleted normally: {result.completed}") + print(f"Full text: {result.full_text!r}") + + if result.streaming_failures: + for _req, pvr in result.streaming_failures: + print(f"Streaming failure: {pvr.reason}") + + if result.final_validations: + for vr in result.final_validations: + print(f"Final validation: {'PASS' if vr.as_bool() else 'FAIL'}") + + +asyncio.run(main()) diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index 1ccb9abf7..337c27266 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -22,6 +22,10 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.cache_utils import DynamicCache from transformers.generation.logits_process import LogitsProcessorList + from transformers.generation.stopping_criteria import ( + StoppingCriteria, + StoppingCriteriaList, + ) from transformers.generation.streamers import AsyncTextIteratorStreamer from transformers.generation.utils import GenerateDecoderOnlyOutput from transformers.modeling_utils import PreTrainedModel @@ -74,6 +78,45 @@ ) from .utils import to_chat, to_tool_calls + +class _EventStoppingCriteria(StoppingCriteria): + """StoppingCriteria that signals the model to stop when a threading.Event is set. + + Used by LocalHFBackend to implement cooperative cancellation: when + ``cancel_generation`` is called, it sets the backing event via + ``_cancel_hook`` before cancelling the asyncio task, giving the HF + ``model.generate`` thread a chance to exit cleanly rather than running + to completion. + """ + + def __init__(self, event: threading.Event) -> None: + self._event = event + + def __call__(self, input_ids: Any, scores: Any, **kwargs: Any) -> bool: # type: ignore[override] + return self._event.is_set() + + +def _install_cancel_stopping_criteria( + generate_options: dict[str, Any], streaming_kwargs: dict[str, Any] +) -> threading.Event: + """Wire a cooperative-cancel event into the generate call's stopping criteria. + + Pops any caller-supplied ``stopping_criteria`` from *generate_options* (to + avoid passing it twice via both ``**generate_options`` and + ``**streaming_kwargs``), prepends an :class:`_EventStoppingCriteria` backed + by a fresh ``threading.Event``, and stores the merged list in + *streaming_kwargs*. Returns the event so the caller can arm + ``output._cancel_hook = event.set``. + """ + cancel_event = threading.Event() + user_sc = generate_options.pop("stopping_criteria", None) + streaming_kwargs["stopping_criteria"] = StoppingCriteriaList( + [_EventStoppingCriteria(cancel_event)] + + (list(user_sc) if user_sc is not None else []) + ) + return cancel_event + + """A configuration type for the unhappy path: Tokenizer * Model * torch device string Huggingface backends can initialize themselves from a model string if the transformers `Auto*` classes can be used. Therefore, a TransformersTorchConfig usually isn't required. However, sometimes a model needs special care to instantiate properly, or a custom device type needs to bse used. Instead of trying to do a lot of partial magic, we basically have two modaliites: either the constructor can figure out everything from the model_id, or the user has to provide an entire config. @@ -839,6 +882,15 @@ async def _generate_from_context_with_kv_cache( # Filter out chat template-only options before passing to generate() generate_options = self._filter_chat_template_only_options(model_options) + # Only install cooperative-cancel plumbing on the streaming path. + # Non-streaming calls have no orchestrator calling cancel_generation(), + # so the hook would be dead code and the StoppingCriteria would silently + # wrap any user-supplied stopping_criteria on every decode step. + if stream: + _cancel_event = _install_cancel_stopping_criteria( + generate_options, streaming_kwargs + ) + linearized_ctx = ctx.view_for_generation() assert linearized_ctx is not None _input_text, input_ids, merged_cache, attention_mask = ( @@ -867,6 +919,10 @@ async def _generate_from_context_with_kv_cache( ) output = ModelOutputThunk(None) + # Arm the cancel hook before creating tasks so a cancel racing + # task creation still finds the hook set. + if stream: + output._cancel_hook = _cancel_event.set output._start = datetime.datetime.now() output._context = ctx.view_for_generation() output._action = action @@ -1002,6 +1058,15 @@ async def _generate_from_context_standard( # Filter out chat template-only options before passing to generate() generate_options = self._filter_chat_template_only_options(model_options) + # Only install cooperative-cancel plumbing on the streaming path. + # Non-streaming calls have no orchestrator calling cancel_generation(), + # so the hook would be dead code and the StoppingCriteria would silently + # wrap any user-supplied stopping_criteria on every decode step. + if stream: + _cancel_event = _install_cancel_stopping_criteria( + generate_options, streaming_kwargs + ) + chat_response = asyncio.to_thread( self._generate_with_adapter_lock, "", # Empty for no adapters. @@ -1016,6 +1081,10 @@ async def _generate_from_context_standard( ) output = ModelOutputThunk(None) + # Arm the cancel hook before creating tasks so a cancel racing + # task creation still finds the hook set. + if stream: + output._cancel_hook = _cancel_event.set output._start = datetime.datetime.now() output._context = ctx.view_for_generation() output._action = action diff --git a/mellea/core/base.py b/mellea/core/base.py index 2028008d9..3ea17e088 100644 --- a/mellea/core/base.py +++ b/mellea/core/base.py @@ -17,6 +17,7 @@ import binascii import datetime import enum +import logging from collections.abc import Callable, Coroutine, Iterable, Mapping from copy import copy, deepcopy from dataclasses import dataclass @@ -320,6 +321,7 @@ def __init__( # Set computed to True if a value is passed in. self._computed: bool = True if value is not None else False + self._cancelled: bool = False # Additional fields that should be standardized across apis. self.tool_calls = tool_calls @@ -344,6 +346,14 @@ def __init__( self._generate_extra: asyncio.Task[Any] | None = ( None # Currently only used by hf. ) + # Optional cooperative-cancel hook called before asyncio task cancellation. + # Backends that run generation in a thread (e.g. HuggingFace via + # asyncio.to_thread) set this to a non-blocking callable (e.g. + # threading.Event.set) so the thread receives a stop signal before the + # task wrapper is cancelled. Must be non-blocking; exceptions are logged + # and suppressed. Copied MOTs reset this to None — each computation owns + # its own thread signal. + self._cancel_hook: Callable[[], None] | None = None self._process: Callable[[ModelOutputThunk, Any], Coroutine] | None = None self._post_process: Callable[[ModelOutputThunk], Coroutine] | None = None self._on_computed: Callable[[ModelOutputThunk], Coroutine] | None = None @@ -364,6 +374,115 @@ def _record_ttfb(self) -> None: ).total_seconds() * 1000 self._first_chunk_received = True + async def cancel_generation(self, error: Exception | None = None) -> None: + """Cancel an in-progress streaming generation, drain the queue, and close any open telemetry span. + + Safe to call at any point during streaming. After this method returns, + ``is_computed()`` is ``True`` and ``value`` contains whatever text was + accumulated before cancellation. Calling on an already-computed MOT + is a no-op. + + Draining the internal queue after cancellation is necessary to release + any ``asyncio.Queue.put()`` call that the generation task was blocked on + (queue maxsize=20). + + Args: + error: Optional cause attributed to the open telemetry span. When + provided, this exception is recorded via ``set_span_error`` so + the span reflects the actual reason for cancellation (e.g. the + requirement failure or an unhandled exception from a streaming + validator). When ``None``, a generic + ``RuntimeError("Generation cancelled")`` is recorded. + + Raises: + asyncio.CancelledError: Re-raised when the *calling* task itself is + being cancelled (``asyncio.current_task().cancelling() > 0``). + This prevents external cancellation (e.g. ``asyncio.wait_for`` + timeout) from being silently absorbed while awaiting the inner + generation task. + """ + if self._computed: + return + + def _drain() -> None: + while not self._async_queue.empty(): + try: + self._async_queue.get_nowait() + except asyncio.QueueEmpty: + break + + # Signal any backend thread before cancelling the asyncio task wrapper + # so the thread can stop cooperatively instead of running to completion. + if self._cancel_hook is not None: + try: + self._cancel_hook() + except Exception as hook_exc: + logging.getLogger(__name__).warning( + "cancel_generation: _cancel_hook raised (suppressed): %r", hook_exc + ) + + if self._generate is not None and not self._generate.done(): + self._generate.cancel() + + if self._generate_extra is not None and not self._generate_extra.done(): + self._generate_extra.cancel() + + # Drain before awaiting — unblocks any put() the task is stuck on. + _drain() + + if self._generate is not None: + try: + await self._generate + except asyncio.CancelledError: + # Re-raise if the *outer* task is being cancelled (Python 3.11+ + # task.cancelling() > 0) so we don't silently absorb external + # cancellation. For the inner task's own CancelledError (the + # expected result of .cancel() above), cancelling() is 0. + cur = asyncio.current_task() + if cur is not None and cur.cancelling() > 0: + raise + except Exception: + pass + + if self._generate_extra is not None: + try: + await self._generate_extra + except asyncio.CancelledError: + cur = asyncio.current_task() + if cur is not None and cur.cancelling() > 0: + raise + except Exception: + pass + + # Drain again for any final item the task put before terminating. + _drain() + + span = self._meta.pop("_telemetry_span", None) + if span is not None: + from ..telemetry import end_backend_span, set_span_error + + recorded: Exception = ( + error if error is not None else RuntimeError("Generation cancelled") + ) + set_span_error(span, recorded) + end_backend_span(span) + + if self._underlying_value is None: + self._underlying_value = "" + self._cancelled = True + self._computed = True + + @property + def cancelled(self) -> bool: + """``True`` if :meth:`cancel_generation` ran to completion on this MOT. + + A normally-completed MOT leaves this ``False``; only an actual + cancellation via :meth:`cancel_generation` flips it. Consumers holding + a computed MOT can use this to distinguish a genuine result from one + cut short (for example by a streaming requirement failure). + """ + return self._cancelled + def _copy_from(self, other: ModelOutputThunk) -> None: """Copy computed-output fields from *other* into *self*. @@ -378,6 +497,10 @@ def _copy_from(self, other: ModelOutputThunk) -> None: self._thinking = other._thinking self.generation = other.generation self._generate_log = other._generate_log + self._cancelled = other._cancelled + # _cancel_hook is deliberately not copied: _copy_from swaps output state, + # not backend-thread plumbing, which is tied to the original computation. + self._cancel_hook = None def is_computed(self) -> bool: """Returns true only if this Thunk has already been filled. @@ -606,6 +729,10 @@ def __copy__(self) -> ModelOutputThunk: copied.parsed_repr = copied # type: ignore copied._computed = self._computed + copied._cancelled = self._cancelled + # _cancel_hook is not forwarded: a copied MOT is a distinct computation + # and must not share the original's backend thread signal. + copied._cancel_hook = None copied._thinking = self._thinking copied._action = self._action copied._context = self._context @@ -634,6 +761,10 @@ def __deepcopy__(self, memo: dict) -> ModelOutputThunk: deepcopied._meta = deepcopy(self._meta) deepcopied.tool_calls = deepcopy(self.tool_calls) deepcopied._computed = self._computed + deepcopied._cancelled = self._cancelled + # _cancel_hook is not forwarded: a deepcopied MOT is a distinct computation + # and must not share the original's backend thread signal. + deepcopied._cancel_hook = None deepcopied._thinking = self._thinking deepcopied._action = deepcopy(self._action) deepcopied._context = copy( diff --git a/mellea/core/requirement.py b/mellea/core/requirement.py index e8c55564b..141af22b0 100644 --- a/mellea/core/requirement.py +++ b/mellea/core/requirement.py @@ -300,7 +300,12 @@ async def stream_validate( are shared by reference under ``copy()``. Reassign rather than mutate in place (``self._buffer = self._buffer + [chunk]``, not ``self._buffer.append(chunk)``), or override ``__copy__`` for proper - isolation. + isolation. If an override raises, the enclosing + :func:`~mellea.stdlib.streaming.stream_with_chunking` call aborts before + any backend generation starts and the exception propagates unchanged. + Overrides with externally visible side effects (file writes, network + calls) should perform them only after any logic that could raise, since + the framework cannot roll them back. Implementations must not call ``mot.astream()`` or otherwise read the underlying stream; the orchestrator is the single consumer of the MOT diff --git a/mellea/stdlib/__init__.py b/mellea/stdlib/__init__.py index e4f32941b..7a30fdd53 100644 --- a/mellea/stdlib/__init__.py +++ b/mellea/stdlib/__init__.py @@ -10,9 +10,20 @@ ``mellea.stdlib.session`` — for day-to-day use. Streaming chunking strategies (for use with streaming validation) are available at -``mellea.stdlib.chunking`` and re-exported here for convenience. +``mellea.stdlib.chunking`` and re-exported here for convenience. The core streaming +orchestration primitive :func:`~mellea.stdlib.streaming.stream_with_chunking` and +its result type :class:`~mellea.stdlib.streaming.StreamChunkingResult` are also +re-exported here. """ from .chunking import ChunkingStrategy, ParagraphChunker, SentenceChunker, WordChunker +from .streaming import StreamChunkingResult, stream_with_chunking -__all__ = ["ChunkingStrategy", "ParagraphChunker", "SentenceChunker", "WordChunker"] +__all__ = [ + "ChunkingStrategy", + "ParagraphChunker", + "SentenceChunker", + "StreamChunkingResult", + "WordChunker", + "stream_with_chunking", +] diff --git a/mellea/stdlib/chunking.py b/mellea/stdlib/chunking.py index 6b9091780..462af81a8 100644 --- a/mellea/stdlib/chunking.py +++ b/mellea/stdlib/chunking.py @@ -19,6 +19,10 @@ class ChunkingStrategy(ABC): take the full accumulated text, identify everything after the last returned chunk boundary, and handle it appropriately (e.g. pass to a final validator or discard). + + Note: this ABC operates on text streams only. Multi-modal output (audio + segments, image regions) is not supported — the ``accumulated_text: str`` + signatures on ``split`` and ``flush`` preclude it. """ @abstractmethod @@ -35,6 +39,27 @@ def split(self, accumulated_text: str) -> list[str]: """ ... + def flush(self, accumulated_text: str) -> list[str]: + """Return any trailing fragment that ``split`` withheld. + + Called once by the orchestrator after the stream has ended naturally + (not on early-exit cancellation). Gives the chunker a chance to + release the final fragment that did not reach a terminator. + + The default implementation returns an empty list — the trailing + fragment is discarded. Built-in chunkers override this to return + the withheld fragment as a single-element list when non-empty. + + Args: + accumulated_text: The full accumulated text at stream end. + + Returns: + The trailing fragment as ``[fragment]`` if it should be treated + as a final chunk, or an empty list to discard it. + """ + _ = accumulated_text + return [] + # Sentence boundary: sentence-ending punctuation, optionally followed by a closing # quote or paren, then whitespace. @@ -94,6 +119,36 @@ def split(self, accumulated_text: str) -> list[str]: return chunks + def flush(self, accumulated_text: str) -> list[str]: + """Return the trailing sentence fragment (if any) as a final chunk. + + Trailing whitespace on the fragment is non-semantic for sentence + boundaries and is dropped via ``rstrip``. Leading whitespace is + already removed by the loop's ``lstrip`` on each advance, so no + ``lstrip`` is needed here. The result is the fragment's content + only, consistent with how :meth:`split` returns sentences without + trailing whitespace. + + Args: + accumulated_text: The full accumulated text at stream end. + + Returns: + A single-element list containing the trailing sentence fragment + with leading and trailing whitespace stripped, or an empty list + when there is no fragment (all content ended in a sentence + boundary or the input is empty/whitespace-only). + """ + if not accumulated_text: + return [] + remaining = accumulated_text + while True: + match = _SENTENCE_BOUNDARY.search(remaining) + if match is None: + break + remaining = remaining[match.end() :].lstrip() + trailing = remaining.rstrip() + return [trailing] if trailing else [] + class WordChunker(ChunkingStrategy): """Splits accumulated text on whitespace boundaries. @@ -134,6 +189,32 @@ def split(self, accumulated_text: str) -> list[str]: return parts + def flush(self, accumulated_text: str) -> list[str]: + """Return the trailing word fragment (if any) as a final chunk. + + The trailing fragment is the text after the last whitespace run when + the accumulated text does not end with whitespace. When it does end + with whitespace, every word is already complete and no fragment is + released. + + Args: + accumulated_text: The full accumulated text at stream end. + + Returns: + A single-element list containing the trailing word fragment, or + an empty list when the input ends with whitespace (every word + already complete) or is empty. + """ + if not accumulated_text: + return [] + if accumulated_text[-1].isspace(): + return [] + parts = _WHITESPACE.split(accumulated_text) + for part in reversed(parts): + if part: + return [part] + return [] + class ParagraphChunker(ChunkingStrategy): r"""Splits accumulated text on double-newline paragraph boundaries. @@ -168,3 +249,29 @@ def split(self, accumulated_text: str) -> list[str]: # _PARA_BOUNDARY.split on leading \n\n produces an empty first element. return [p for p in parts if p] + + def flush(self, accumulated_text: str) -> list[str]: + r"""Return the trailing paragraph fragment (if any) as a final chunk. + + Unlike :class:`SentenceChunker.flush`, the fragment is returned + byte-for-byte without stripping. Internal whitespace — including + a trailing single ``\n`` — can be semantically meaningful inside + a paragraph (e.g. a list item or a deliberate line break), and a + consumer validating paragraph content should see the fragment as + it was withheld. + + Args: + accumulated_text: The full accumulated text at stream end. + + Returns: + A single-element list containing the trailing paragraph fragment + byte-for-byte, or an empty list when the input ends with a + paragraph boundary (``\n\n`` or more) or is empty. + """ + if not accumulated_text: + return [] + if _PARA_BOUNDARY_END.search(accumulated_text): + return [] + parts = _PARA_BOUNDARY.split(accumulated_text) + trailing = parts[-1] if parts else "" + return [trailing] if trailing else [] diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py new file mode 100644 index 000000000..dfdbcb232 --- /dev/null +++ b/mellea/stdlib/streaming.py @@ -0,0 +1,494 @@ +"""Streaming generation with per-chunk validation. + +Provides :func:`stream_with_chunking`, the core orchestration primitive that +consumes a streaming :class:`~mellea.core.base.ModelOutputThunk`, applies a +:class:`~mellea.stdlib.chunking.ChunkingStrategy` to produce semantic chunks, +and runs :meth:`~mellea.core.requirement.Requirement.stream_validate` on each +chunk in parallel. Higher-level streaming APIs build on this function. +""" + +import asyncio +from collections.abc import AsyncIterator, Sequence +from copy import copy +from typing import Any + +from ..backends.model_options import ModelOption +from ..core.backend import Backend +from ..core.base import CBlock, Component, Context, ModelOutputThunk +from ..core.requirement import PartialValidationResult, Requirement, ValidationResult +from ..core.utils import MelleaLogger +from .chunking import ChunkingStrategy, ParagraphChunker, SentenceChunker, WordChunker + +_CHUNKING_ALIASES: dict[str, type[ChunkingStrategy]] = { + "sentence": SentenceChunker, + "word": WordChunker, + "paragraph": ParagraphChunker, +} + + +class StreamChunkingResult: + """Result of a :func:`stream_with_chunking` operation. + + Provides async iteration over validated text chunks as they complete + (:meth:`astream`), a blocking :meth:`acomplete` for awaiting the full + result including final validation, and :attr:`as_thunk` for wrapping the + output as a :class:`~mellea.core.base.ModelOutputThunk`. + + Instances are created by :func:`stream_with_chunking`; do not instantiate + directly. + + Args: + mot: The :class:`~mellea.core.base.ModelOutputThunk` from the backend + generation call. + ctx: The generation context returned alongside the MOT. + + Attributes: + completed: ``False`` if the stream exited early because a requirement + returned ``"fail"`` during streaming; ``True`` otherwise. + full_text: The complete generated text accumulated during streaming. + Available after :meth:`acomplete` returns. + final_validations: :class:`~mellea.core.requirement.ValidationResult` + objects from the final :meth:`~mellea.core.requirement.Requirement.validate` + calls on all non-failed requirements. Available after + :meth:`acomplete` returns. + streaming_failures: ``(Requirement, PartialValidationResult)`` pairs + for every requirement that returned ``"fail"`` during streaming. + """ + + def __init__(self, mot: ModelOutputThunk, ctx: Context) -> None: + """Initialise with the MOT and context from the backend call.""" + self._mot = mot + self._ctx = ctx + self._chunk_queue: asyncio.Queue[str | None | Exception] = asyncio.Queue() + self._orchestration_task: asyncio.Task[None] | None = None + self._done = asyncio.Event() + # Stashed so acomplete() surfaces orchestrator failures even when the + # consumer never iterates astream(). Cleared once consumed by + # whichever of the two reads it first. + self._orchestration_exception: BaseException | None = None + # Tracks whether the exception has already been surfaced to the caller + # (by astream OR acomplete). A separate flag rather than reusing the + # stash slot avoids the race where acomplete() clears the stash, a + # subsequent astream() dequeues the exception item, sees the stash is + # None, and silently skips it — leaving the caller with zero chunks + # and no error. + self._exception_surfaced: bool = False + + self.completed: bool = True + self.full_text: str = "" + self.final_validations: list[ValidationResult] = [] + self.streaming_failures: list[tuple[Requirement, PartialValidationResult]] = [] + + async def astream(self) -> AsyncIterator[str]: + """Yield validated text chunks as they complete. + + Each yielded string is a chunk that has passed per-chunk streaming + validation (or the stream had no requirements). Iteration ends when + all chunks have been yielded, whether the stream completed normally or + was cancelled early on a ``"fail"`` result. + + **Single-consumer.** Chunks are delivered via an + :class:`asyncio.Queue` that this method drains; calling + ``astream()`` a second time on the same result blocks indefinitely + because the queue is empty and the terminating ``None`` sentinel + has already been consumed. If you need the chunks after + iteration, capture them into a list during the first pass or use + :attr:`full_text` after :meth:`acomplete`. + + Yields: + str: A validated text chunk from the chunking strategy. + + Raises: + Exception: Propagates any error from the background orchestration + task. + """ + while True: + item = await self._chunk_queue.get() + if item is None: + return + if isinstance(item, Exception): + if self._exception_surfaced: + # Already surfaced by acomplete(); don't raise twice. + continue + self._exception_surfaced = True + self._orchestration_exception = None + raise item + yield item + + async def acomplete(self) -> None: + """Await full completion, including final validation. + + After this method returns, :attr:`full_text`, :attr:`completed`, + :attr:`final_validations`, and :attr:`streaming_failures` are all + populated. If :meth:`astream` has already been consumed to + exhaustion, this call is effectively a no-op. + + Raises: + Exception: Propagates any error from the orchestration task. + """ + await self._done.wait() + # Raise-once: if astream() already surfaced the exception, skip. + exc = self._orchestration_exception + if exc is not None and not self._exception_surfaced: + self._exception_surfaced = True + self._orchestration_exception = None + raise exc + if self._orchestration_task is not None and self._orchestration_task.done(): + # Raise-once: a prior call already surfaced the exception. + if self._exception_surfaced: + return + # ``task.exception()`` raises CancelledError on a cancelled task + # (rather than returning it), so check cancelled status first. + # This branch covers BaseException paths that bypass the + # ``except Exception`` handler in ``_orchestrate_streaming``. + if self._orchestration_task.cancelled(): + self._exception_surfaced = True + raise asyncio.CancelledError() + task_exc = self._orchestration_task.exception() + if task_exc is not None: + self._exception_surfaced = True + raise task_exc + + @property + def as_thunk(self) -> ModelOutputThunk[str]: + """Wrap the output as a computed :class:`~mellea.core.base.ModelOutputThunk`. + + Returns a new thunk with ``value`` set to :attr:`full_text` and + generation metadata copied from the original MOT. Safe to call on + early-exit results; ``value`` will reflect whatever was accumulated + before cancellation. + + Note: + On early exit, ``cancel_generation()`` forces the MOT into a + computed state without running the backend's + ``post_processing()``. ``value`` and ``streaming`` are + reliable. ``parsed_repr`` is set to the raw text (same as + ``value``) — consistent with normal completion for plain-text + outputs, but for typed outputs the backend-parsed representation + will not be available. Telemetry fields (``generation.usage``, + ``generation.ttfb_ms``, etc.) may be ``None`` or reflect the + partial state at cancellation time; usage totals are not + recoverable. + + Returns: + ModelOutputThunk[str]: A computed thunk containing the streamed output. + + Raises: + RuntimeError: If called before :meth:`acomplete` has returned. + """ + if not self._done.is_set(): + raise RuntimeError( + "as_thunk accessed before acomplete() — await acomplete() first" + ) + thunk = ModelOutputThunk(value=self.full_text) + thunk._cancelled = self._mot._cancelled + thunk.generation = copy(self._mot.generation) + thunk.parsed_repr = thunk.value # type: ignore[assignment] + return thunk + + +async def _orchestrate_streaming( + result: StreamChunkingResult, + mot: ModelOutputThunk, + ctx: Context, + cloned_reqs: list[Requirement], + chunking: ChunkingStrategy, + val_backend: Backend, +) -> None: + accumulated = "" + emitted_end = 0 # byte offset into accumulated after the last emitted chunk + prev_chunk_count = 0 + failed_indices: set[int] = set() + early_exit = False + + async def _validate_and_emit(c: str) -> bool: + """Run stream_validate on chunk c across active requirements. + + Returns True if a failure was recorded (caller should early-exit), + False otherwise (chunk was emitted to the consumer queue). + """ + active = [ + (i, req) for i, req in enumerate(cloned_reqs) if i not in failed_indices + ] + if active: + async with asyncio.TaskGroup() as tg: + _tasks = [ + tg.create_task(req.stream_validate(c, backend=val_backend, ctx=ctx)) + for _, req in active + ] + pvrs: list[PartialValidationResult] = [t.result() for t in _tasks] + for (idx, req), pvr in zip(active, pvrs): + if pvr.success == "fail": + failed_indices.add(idx) + result.streaming_failures.append((req, pvr)) + + if failed_indices: + return True + + await result._chunk_queue.put(c) + return False + + try: + while not mot.is_computed(): + try: + delta = await mot.astream() + except RuntimeError: + # Expected race: mot.is_computed() was False at the top of the + # loop but the stream finished before we re-entered astream(). + # Any other RuntimeError is a real bug and must propagate. + if mot.is_computed(): + break + raise + + accumulated += delta + chunks = chunking.split(accumulated) + new_chunks = chunks[prev_chunk_count:] + prev_chunk_count = len(chunks) + + for c in new_chunks: + failed = await _validate_and_emit(c) + if failed: + early_exit = True + result.completed = False + await mot.cancel_generation() + break + pos = accumulated.find(c, emitted_end) + if pos >= 0: + emitted_end = pos + len(c) + + if early_exit: + break # break the while loop; cancel_generation() already set _computed=True + + # Stream ended naturally: flush any withheld trailing fragment and + # run stream_validate on it. Skipped on early exit — the generation + # was cancelled, the trailing fragment is incomplete. + if not early_exit: + for c in chunking.flush(accumulated): + failed = await _validate_and_emit(c) + if failed: + early_exit = True + result.completed = False + break + pos = accumulated.find(c, emitted_end) + if pos >= 0: + emitted_end = pos + len(c) + + # On early exit, full_text is the prefix of accumulated up to and + # including the last emitted chunk — preserving original inter-chunk + # spacing from the token stream (chunk concatenation would strip it). + # On natural completion, accumulated is used directly. + result.full_text = accumulated[:emitted_end] if early_exit else accumulated + + non_failed = [ + req for i, req in enumerate(cloned_reqs) if i not in failed_indices + ] + if non_failed and not early_exit: + async with asyncio.TaskGroup() as tg: + _final_tasks = [ + tg.create_task(req.validate(val_backend, ctx)) for req in non_failed + ] + result.final_validations = [t.result() for t in _final_tasks] + + except Exception as exc: + # Orchestrator is leaving — we must stop the backend producer too, + # otherwise mot._async_queue (maxsize=20) fills and the feeder task + # blocks indefinitely. The spec (#891, #901) calls this out for the + # "fail" path; the same reasoning applies to any unplanned exit. + # Pass `exc` so the backend telemetry span records the real cause + # rather than a generic "Generation cancelled". + # TaskGroup wraps failures in ExceptionGroup; unwrap so telemetry and + # the chunk queue see the original exception, not the wrapper. + # ExceptionGroup (not BaseExceptionGroup) guarantees Exception elements. + if isinstance(exc, ExceptionGroup) and exc.exceptions: + reported_exc: Exception = exc.exceptions[0] + if len(exc.exceptions) > 1: + MelleaLogger.get_logger().warning( + "stream_with_chunking: %d validator(s) failed simultaneously; " + "reporting first, suppressing rest: %r", + len(exc.exceptions) - 1, + exc.exceptions[1:], + ) + else: + reported_exc = exc + try: + await mot.cancel_generation(error=reported_exc) + except Exception as cleanup_exc: + # Never let cleanup mask the original exception: log loudly and + # continue to surface `exc` to the consumer. + # TODO(#902): replace this log with an ErrorEvent emission. + MelleaLogger.get_logger().warning( + "stream_with_chunking: cancel_generation() raised during " + "exception cleanup (original: %r, cleanup: %r)", + reported_exc, + cleanup_exc, + ) + result.completed = False + result._orchestration_exception = reported_exc + await result._chunk_queue.put(reported_exc) + finally: + # CancelledError (BaseException, not Exception) bypasses the except + # block above, so cancel_generation() may not have been called. + # Catch only Exception here so CancelledError / KeyboardInterrupt / + # SystemExit still propagate to the caller. + if not mot.is_computed(): + try: + await mot.cancel_generation() + except Exception: + pass + # put_nowait + set() are synchronous — no await point, so they cannot + # be interrupted by task cancellation. Consumers waiting on + # _done.wait() are always released, even if the task was cancelled + # mid-cleanup. The queue is unbounded, so QueueFull cannot occur. + try: + result._chunk_queue.put_nowait(None) + except asyncio.QueueFull: + pass + result._done.set() + + +async def stream_with_chunking( + action: Component[Any] | CBlock, + backend: Backend, + ctx: Context, + *, + requirements: Sequence[Requirement] | None = None, + chunking: str | ChunkingStrategy = "sentence", + validation_backend: Backend | None = None, +) -> StreamChunkingResult: + """Generate a streaming response with per-chunk validation. + + Starts a backend generation with streaming enabled, consumes the + :class:`~mellea.core.base.ModelOutputThunk`'s async stream in a single + background task, splits the accumulated text using *chunking*, and runs + :meth:`~mellea.core.requirement.Requirement.stream_validate` on each new + chunk in parallel across all requirements. + + For each new complete chunk produced by the chunking strategy, + ``stream_validate`` is called once per active requirement (in parallel + via :func:`asyncio.gather`), receiving that single chunk. Multiple + chunks produced from one ``astream()`` iteration are validated + sequentially in order, so early exit on a ``"fail"`` result prevents + later chunks in the same batch from being validated or emitted to the + consumer. + + If any requirement returns ``"fail"``, the generation is cancelled + immediately (via + :meth:`~mellea.core.base.ModelOutputThunk.cancel_generation`) and + :attr:`StreamChunkingResult.completed` is set to ``False``. The + failing chunk is not emitted to the consumer; use + :attr:`StreamChunkingResult.streaming_failures` to inspect what failed. + + When the stream ends naturally, any trailing fragment withheld by the + chunking strategy (see :meth:`~mellea.stdlib.chunking.ChunkingStrategy.flush`) + is released as a final chunk and run through ``stream_validate`` on the + same terms as the regular chunks. On early exit, the trailing fragment + is discarded because the generation was cancelled mid-token. + + After the stream ends naturally, ``validate()`` is called on every + requirement that did not return ``"fail"`` — both ``"pass"`` and + ``"unknown"`` trigger final validation. On early exit, no ``validate()`` + call is made; :attr:`StreamChunkingResult.final_validations` remains + empty. Requirements are cloned (``copy(req)``) before backend generation + begins, so the originals are never mutated and a raising ``__copy__`` + cannot leak an in-flight backend task. + + Requirements that need context beyond the current chunk should + accumulate it themselves across ``stream_validate`` calls (e.g. + ``self._seen = self._seen + chunk``). They must not read ``mot.astream()`` + directly — this orchestrator is the single consumer of the MOT stream. + + Note: + Chunks are emitted to the consumer (via + :meth:`StreamChunkingResult.astream`) only after every requirement's + ``stream_validate`` has returned for that chunk. A slow validator + (for example, one that invokes an LLM) therefore adds latency to + every chunk — the consumer sees a chunk at most as quickly as the + slowest active validator. This trade is deliberate in v1: it + preserves the invariant that the consumer never sees content that + has not been validated, which matters for UIs displaying generated + text live. A future fast-path mode that emits chunks to the + consumer concurrently with validation (at the cost of that + invariant) may be added if a concrete use case calls for it. + + Note: + v1 retry is simple re-invocation of this function. Plugin hooks + (``SAMPLING_LOOP_START``, ``SAMPLING_REPAIR``, etc.) do not fire + on retries — use the ``#902`` event types for observability instead. + + Args: + action: The component or content block to generate from. + backend: The backend used for generation and final validation. + ctx: The generation context. + requirements: Sequence of requirements to validate against each chunk + during streaming. ``None`` disables streaming validation (chunks + are still produced; ``validate()`` is not called at stream end). + chunking: Chunking strategy — either a :class:`~mellea.stdlib.chunking.ChunkingStrategy` + instance or one of the string aliases ``"sentence"`` (default), + ``"word"``, or ``"paragraph"``. + validation_backend: Optional alternate backend for both + ``stream_validate`` and final ``validate`` calls. When ``None``, + *backend* is used for validation. + + Returns: + StreamChunkingResult: A result object providing :meth:`~StreamChunkingResult.astream` + for incremental chunk consumption and + :meth:`~StreamChunkingResult.acomplete` for blocking until done. + + Raises: + ValueError: If *chunking* is a string that does not match any known + alias (``"sentence"``, ``"word"``, ``"paragraph"``). + RuntimeError: If the backend returns an already-computed + :class:`~mellea.core.base.ModelOutputThunk` instead of a streaming + one. This indicates the backend is not honouring + ``ModelOption.STREAM``. + + Note: + Any exception raised by ``copy(req)`` on a ``requirements`` entry + propagates to the caller; no backend generation is started in that + case. See :class:`~mellea.core.Requirement` for the ``__copy__`` + override contract. + """ + if isinstance(chunking, str): + cls = _CHUNKING_ALIASES.get(chunking) + if cls is None: + raise ValueError( + f"Unknown chunking alias {chunking!r}. Choose from: {list(_CHUNKING_ALIASES)}" + ) + chunking = cls() + + opts: dict[str, Any] = {ModelOption.STREAM: True} + + # Clone requirements before starting backend generation so that a raising + # __copy__ (an advertised extension point on Requirement) cannot leave the + # backend feeder task wedged against a full _async_queue with no consumer. + cloned_reqs = [copy(req) for req in (requirements or [])] + val_backend = validation_backend if validation_backend is not None else backend + + mot, gen_ctx = await backend.generate_from_context(action, ctx, model_options=opts) + if mot.is_computed(): + raise RuntimeError( + "stream_with_chunking() requires a streaming backend; the backend returned " + "an already-computed MOT. Ensure the backend honours ModelOption.STREAM." + ) + try: + result = StreamChunkingResult(mot, gen_ctx) + coro = _orchestrate_streaming( + result, mot, gen_ctx, cloned_reqs, chunking, val_backend + ) + try: + result._orchestration_task = asyncio.create_task(coro) + except BaseException: + coro.close() # prevent "coroutine was never awaited" RuntimeWarning + raise + except BaseException: + try: + await mot.cancel_generation() + except Exception as cleanup_exc: + MelleaLogger.get_logger().warning( + "stream_with_chunking: cancel_generation() raised during " + "setup-path cleanup (cleanup: %r)", + cleanup_exc, + ) + raise + + return result diff --git a/test/core/test_base.py b/test/core/test_base.py index 32ad9ab10..213a16e6e 100644 --- a/test/core/test_base.py +++ b/test/core/test_base.py @@ -191,3 +191,129 @@ def test_mot_deep_copy_clones_generation(): if __name__ == "__main__": pytest.main([__file__]) + + +# --------------------------------------------------------------------------- +# Fix 2 — cancel_generation invokes _cancel_hook before task cancellation +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_cancel_generation_invokes_cancel_hook_before_task_cancel() -> None: + """Fix 2: _cancel_hook fires and cancel_generation() returns promptly. + + Simulates a backend thread that blocks for 5 s unless the hook sets the + event. Without the hook, cancel_generation() would only observe the asyncio + task as CancelledError but the thread would keep running — on a slow box + that can mean the task wrapper hangs past the 1 s timeout here. With the + hook, the event is set first, the thread unblocks, and the whole path + completes within the timeout. + """ + import asyncio + import threading + + hook_called = threading.Event() + thread_released = threading.Event() + + def hook() -> None: + hook_called.set() + thread_released.set() + + mot = ModelOutputThunk(value=None) + mot._cancel_hook = hook # type: ignore[attr-defined] + + # Task that blocks in a thread until thread_released is set. + async def spin() -> None: + await asyncio.to_thread(thread_released.wait, 5.0) + + mot._generate = asyncio.create_task(spin()) # type: ignore[attr-defined] + await asyncio.sleep(0) # let the task reach to_thread + + # Must return within 1 s; without the hook it would hang ~5 s. + await asyncio.wait_for(mot.cancel_generation(), timeout=1.0) # type: ignore[attr-defined] + + assert hook_called.is_set(), "_cancel_hook was never called" + assert mot._cancelled is True # type: ignore[attr-defined] + + +def test_cancel_hook_not_forwarded_by_copy_methods() -> None: + """Fix 2: copied MOTs must not inherit _cancel_hook (distinct computation).""" + import copy as copy_mod + + def _hook() -> None: + pass + + mot = ModelOutputThunk(value="x") + mot._cancel_hook = _hook # type: ignore[attr-defined] + + shallow = copy_mod.copy(mot) + assert shallow._cancel_hook is None, "__copy__ must reset _cancel_hook to None" # type: ignore[attr-defined] + + deep = copy_mod.deepcopy(mot) + assert deep._cancel_hook is None, "__deepcopy__ must reset _cancel_hook to None" # type: ignore[attr-defined] + + target = ModelOutputThunk(value="original") + target._copy_from(mot) # type: ignore[attr-defined] + assert target._cancel_hook is None, "_copy_from must reset _cancel_hook to None" # type: ignore[attr-defined] + + +@pytest.mark.asyncio +async def test_cancel_generation_hook_exception_is_suppressed() -> None: + """Fix 2: a faulty _cancel_hook must not mask cancel_generation itself.""" + import asyncio + + def _bad_hook() -> None: + raise RuntimeError("hook exploded") + + mot = ModelOutputThunk(value=None) + mot._cancel_hook = _bad_hook # type: ignore[attr-defined] + + # No _generate task — cancel_generation still runs the hook path. + # The hook raises, but cancel_generation must complete without propagating. + await mot.cancel_generation() # type: ignore[attr-defined] + assert mot._cancelled is True # type: ignore[attr-defined] + + +@pytest.mark.asyncio +async def test_cancel_generation_propagates_outer_cancellation() -> None: + """Outer cancellation of the cancel_generation() task must re-raise CancelledError. + + When cancel_generation() is awaiting self._generate and the *cancel_generation* + task is itself cancelled from outside, cur.cancelling() > 0 and the + CancelledError must propagate — not be swallowed by the bare ``pass`` path. + """ + import asyncio + + inner_cancelled = asyncio.Event() + + async def _absorbs_first_cancel() -> None: + try: + await asyncio.sleep(60) + except asyncio.CancelledError: + # Signal that cancel_generation() has called .cancel() and is + # now blocked at ``await self._generate``. + inner_cancelled.set() + # Absorb this cancel so cancel_generation() stays at the await. + await asyncio.sleep(60) + + mot = ModelOutputThunk(value=None) + mot._generate = asyncio.create_task(_absorbs_first_cancel()) # type: ignore[attr-defined] + await asyncio.sleep(0) + + cg_task = asyncio.create_task(mot.cancel_generation()) # type: ignore[attr-defined] + # Wait until _generate has absorbed cancel_generation()'s .cancel() call — + # at that point cg_task is blocked at ``await self._generate``. + await asyncio.wait_for(inner_cancelled.wait(), timeout=2.0) + + # Cancel cancel_generation() from outside (simulates asyncio.wait_for timeout + # or an outer TaskGroup cancelling this coroutine). + cg_task.cancel() + with pytest.raises(asyncio.CancelledError): + await asyncio.wait_for(cg_task, timeout=2.0) + + # Cleanup: stop the still-running _generate task. + mot._generate.cancel() # type: ignore[attr-defined] + try: + await asyncio.wait_for(mot._generate, timeout=1.0) # type: ignore[attr-defined] + except (TimeoutError, asyncio.CancelledError): + pass diff --git a/test/stdlib/test_chunking.py b/test/stdlib/test_chunking.py index fbaf727a2..7b965350f 100644 --- a/test/stdlib/test_chunking.py +++ b/test/stdlib/test_chunking.py @@ -242,3 +242,79 @@ def test_paragraph_chunker_incremental_simulation(): "First paragraph.", "Second paragraph.", ] + + +# --------------------------------------------------------------------------- +# flush() — trailing-fragment release at end of stream +# --------------------------------------------------------------------------- + + +def test_default_flush_returns_empty_list(): + """The ABC default discards the trailing fragment.""" + + class Minimal(ChunkingStrategy): + def split(self, accumulated_text: str) -> list[str]: + _ = accumulated_text + return [] + + assert Minimal().flush("anything at all") == [] + assert Minimal().flush("") == [] + + +def test_sentence_chunker_flush_empty(): + assert SentenceChunker().flush("") == [] + + +def test_sentence_chunker_flush_only_complete(): + """All text ends in a complete sentence with trailing whitespace → no fragment.""" + assert SentenceChunker().flush("One. Two. ") == [] + + +def test_sentence_chunker_flush_trailing_fragment(): + """Final sentence without trailing whitespace is released by flush.""" + assert SentenceChunker().flush("One. Two without period") == ["Two without period"] + + +def test_sentence_chunker_flush_terminated_no_trailing_space(): + """Final sentence with terminator but no trailing whitespace is a fragment + under split() semantics and gets released by flush().""" + assert SentenceChunker().flush("One. Two.") == ["Two."] + + +def test_sentence_chunker_flush_single_sentence_no_terminator(): + assert SentenceChunker().flush("Incomplete sentence") == ["Incomplete sentence"] + + +def test_word_chunker_flush_empty(): + assert WordChunker().flush("") == [] + + +def test_word_chunker_flush_trailing_whitespace(): + """Trailing whitespace means all words are complete → no fragment.""" + assert WordChunker().flush("one two three ") == [] + + +def test_word_chunker_flush_trailing_fragment(): + assert WordChunker().flush("one two three") == ["three"] + + +def test_word_chunker_flush_single_word(): + assert WordChunker().flush("solo") == ["solo"] + + +def test_paragraph_chunker_flush_empty(): + assert ParagraphChunker().flush("") == [] + + +def test_paragraph_chunker_flush_only_complete(): + assert ParagraphChunker().flush("Para one.\n\nPara two.\n\n") == [] + + +def test_paragraph_chunker_flush_trailing_fragment(): + assert ParagraphChunker().flush("Para one.\n\nPara two (no sep)") == [ + "Para two (no sep)" + ] + + +def test_paragraph_chunker_flush_single_paragraph_no_separator(): + assert ParagraphChunker().flush("Only paragraph") == ["Only paragraph"] diff --git a/test/stdlib/test_streaming.py b/test/stdlib/test_streaming.py new file mode 100644 index 000000000..ad042df56 --- /dev/null +++ b/test/stdlib/test_streaming.py @@ -0,0 +1,1299 @@ +"""Tests for stream_with_chunking() and StreamChunkingResult. + +Uses StreamingMockBackend — a deterministic test double that feeds tokens from a +fixed response string into a MOT queue without network or LLM calls. + +All tests are unit tests (no @pytest.mark.ollama needed). +""" + +import asyncio +from typing import Any + +import pytest + +from mellea.core.backend import Backend +from mellea.core.base import CBlock, Context, GenerateType, ModelOutputThunk +from mellea.core.requirement import ( + PartialValidationResult, + Requirement, + ValidationResult, +) +from mellea.stdlib.context import SimpleContext +from mellea.stdlib.streaming import stream_with_chunking + +# --------------------------------------------------------------------------- +# StreamingMockBackend +# --------------------------------------------------------------------------- + + +async def _mock_process(mot: ModelOutputThunk, chunk: Any) -> None: + if mot._underlying_value is None: + mot._underlying_value = "" + if chunk is not None: + mot._underlying_value += chunk + + +async def _mock_post_process(_mot: ModelOutputThunk) -> None: + pass + + +def _make_mot() -> ModelOutputThunk: + mot = ModelOutputThunk(value=None) + mot._action = CBlock("mock_action") + mot._generate_type = GenerateType.ASYNC + mot._process = _mock_process + mot._post_process = _mock_post_process + mot._chunk_size = 0 + return mot + + +async def _feed_tokens(mot: ModelOutputThunk, response: str, token_size: int) -> None: + i = 0 + while i < len(response): + token = response[i : i + token_size] + await mot._async_queue.put(token) + await asyncio.sleep(0) + i += token_size + await mot._async_queue.put(None) + + +class StreamingMockBackend(Backend): + """Test double that streams a fixed response one token at a time. + + ``token_size`` controls how many characters constitute one token. + Validation calls (via ``stream_validate`` / ``validate``) are delegated + to the requirements themselves — this backend does not perform any real + inference. + """ + + def __init__(self, response: str, token_size: int = 1) -> None: + self._response = response + self._token_size = token_size + + async def _generate_from_context( + self, + action: Any, + ctx: Context, + *, + format: Any = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> tuple[ModelOutputThunk, Context]: + _ = format, model_options, tool_calls + mot = _make_mot() + task = asyncio.create_task(_feed_tokens(mot, self._response, self._token_size)) + _ = task + new_ctx = ctx.add(action).add(mot) + return mot, new_ctx + + async def generate_from_raw( + self, actions: Any, ctx: Any, **kwargs: Any + ) -> list[ModelOutputThunk]: + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Requirement test doubles +# --------------------------------------------------------------------------- + + +class AlwaysUnknownReq(Requirement): + """stream_validate always returns 'unknown'; validate returns True.""" + + def format_for_llm(self) -> str: + return "always unknown" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + return PartialValidationResult("unknown") + + async def validate( + self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None + ) -> ValidationResult: + return ValidationResult(result=True) + + +class FailAfterWordsReq(Requirement): + """Returns 'fail' once the cumulative word count reaches *threshold*. + + Each call to ``stream_validate`` receives a single chunk (delta) from the + chunking strategy; the running total is maintained on the instance. + """ + + def __init__(self, threshold: int) -> None: + super().__init__() + self._threshold = threshold + self._word_count = 0 + + def format_for_llm(self) -> str: + return f"fail after {self._threshold} words" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + self._word_count += len(chunk.split()) + if self._word_count >= self._threshold: + return PartialValidationResult("fail", reason="too many words") + return PartialValidationResult("unknown") + + async def validate( + self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None + ) -> ValidationResult: + return ValidationResult(result=True) + + +class BackendRecordingReq(Requirement): + """Records which backend was passed to stream_validate and validate.""" + + def __init__(self) -> None: + super().__init__() + self.seen_backends: list[Any] = [] + + def __copy__(self) -> "BackendRecordingReq": + clone = BackendRecordingReq() + clone.seen_backends = [] # fresh list — do not share with original + return clone + + def format_for_llm(self) -> str: + return "backend recorder" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = chunk + self.seen_backends.append(backend) + return PartialValidationResult("unknown") + + async def validate( + self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None + ) -> ValidationResult: + self.seen_backends.append(backend) + return ValidationResult(result=True) + + +class MutationDetectorReq(Requirement): + """Tracks how many times stream_validate was called on this instance.""" + + def __init__(self) -> None: + super().__init__() + self._call_count = 0 + + def format_for_llm(self) -> str: + return "mutation detector" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = chunk, backend, ctx + self._call_count += 1 + return PartialValidationResult("unknown") + + async def validate( + self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None + ) -> ValidationResult: + return ValidationResult(result=True) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _ctx() -> SimpleContext: + return SimpleContext() + + +def _action() -> CBlock: + return CBlock("prompt") + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_normal_completion_calls_validate_at_stream_end() -> None: + """All 'unknown' requirements → validate() called at stream end; completed=True.""" + response = "Hello world. How are you. " + backend = StreamingMockBackend(response, token_size=3) + req = AlwaysUnknownReq() + + result = await stream_with_chunking( + _action(), backend, _ctx(), requirements=[req], chunking="sentence" + ) + await result.acomplete() + + assert result.completed is True + assert result.full_text == response + assert len(result.final_validations) == 1 + assert result.final_validations[0].as_bool() is True + assert result.streaming_failures == [] + + +@pytest.mark.asyncio +async def test_early_exit_on_fail() -> None: + """Requirement fails mid-stream → completed=False, streaming_failures populated.""" + # 5 words to trigger failure + response = "one two three four five six seven eight. " + backend = StreamingMockBackend(response, token_size=2) + req = FailAfterWordsReq(threshold=4) + + result = await stream_with_chunking( + _action(), backend, _ctx(), requirements=[req], chunking="word" + ) + await result.acomplete() + + assert result.completed is False + assert len(result.streaming_failures) == 1 + _req, pvr = result.streaming_failures[0] + assert pvr.success == "fail" + assert pvr.reason == "too many words" + # final_validations should be empty — final validate() skipped on early exit + assert result.final_validations == [] + + +@pytest.mark.asyncio +async def test_clone_isolation_across_retries() -> None: + """Originals must not be mutated; two invocations are independent.""" + response = "Sentence one. Sentence two. " + req = MutationDetectorReq() + original_reqs = [req] + + backend = StreamingMockBackend(response, token_size=4) + + r1 = await stream_with_chunking( + _action(), backend, _ctx(), requirements=original_reqs, chunking="sentence" + ) + await r1.acomplete() + + r2 = await stream_with_chunking( + _action(), backend, _ctx(), requirements=original_reqs, chunking="sentence" + ) + await r2.acomplete() + + # Original requirement must never have been called — only clones are used + assert req._call_count == 0 + + +@pytest.mark.asyncio +async def test_validation_backend_routing() -> None: + """stream_validate and validate receive validation_backend, not the main backend.""" + response = "One sentence. Two sentences. " + main_backend = StreamingMockBackend(response, token_size=3) + val_backend = StreamingMockBackend("unused", token_size=1) + + req = BackendRecordingReq() + + # Capture the cloned requirement so we can inspect which backends it saw. + captured: list[BackendRecordingReq] = [] + original_copy = BackendRecordingReq.__copy__ + + def _capturing_copy(self: BackendRecordingReq) -> BackendRecordingReq: + clone = original_copy(self) + captured.append(clone) + return clone + + BackendRecordingReq.__copy__ = _capturing_copy # type: ignore[method-assign] + try: + result = await stream_with_chunking( + _action(), + main_backend, + _ctx(), + requirements=[req], + chunking="sentence", + validation_backend=val_backend, + ) + await result.acomplete() + finally: + BackendRecordingReq.__copy__ = original_copy # type: ignore[method-assign] + + assert result.completed is True + # The original was never called — only clones are used. + assert req.seen_backends == [] + # The clone must have seen val_backend for every call (stream_validate + validate), + # never main_backend. This is the actual routing assertion. + assert len(captured) == 1 + assert len(captured[0].seen_backends) > 0 + assert all(b is val_backend for b in captured[0].seen_backends) + + +@pytest.mark.asyncio +async def test_early_exit_does_not_deadlock() -> None: + """Early failure with a high-throughput stream must not hang.""" + long_response = "word " * 200 + backend = StreamingMockBackend(long_response, token_size=5) + req = FailAfterWordsReq(threshold=3) + + result = await stream_with_chunking( + _action(), backend, _ctx(), requirements=[req], chunking="word" + ) + # 5-second timeout — should complete in milliseconds on success + await asyncio.wait_for(result.acomplete(), timeout=5.0) + + assert result.completed is False + + +@pytest.mark.asyncio +async def test_as_thunk_correctness() -> None: + """as_thunk is computed, value matches full_text, generation metadata preserved.""" + response = "This is a test sentence. " + backend = StreamingMockBackend(response, token_size=4) + + result = await stream_with_chunking(_action(), backend, _ctx(), chunking="sentence") + await result.acomplete() + + thunk = result.as_thunk + assert thunk.is_computed() + assert thunk.value == result.full_text == response + + +@pytest.mark.asyncio +async def test_as_thunk_raises_before_acomplete() -> None: + """as_thunk raises RuntimeError if accessed before acomplete().""" + response = "Some text. " + backend = StreamingMockBackend(response, token_size=2) + + result = await stream_with_chunking(_action(), backend, _ctx(), chunking="sentence") + + with pytest.raises(RuntimeError, match="acomplete"): + _ = result.as_thunk + + +@pytest.mark.asyncio +async def test_astream_yields_individual_chunks() -> None: + """Consumer via astream() receives individual chunks, not accumulated text.""" + response = "First sentence. Second sentence. Third sentence. " + backend = StreamingMockBackend(response, token_size=5) + + result = await stream_with_chunking(_action(), backend, _ctx(), chunking="sentence") + + chunks: list[str] = [] + async for chunk in result.astream(): + chunks.append(chunk) + + await result.acomplete() + + # Each chunk must be a complete sentence (not the accumulated text) + assert len(chunks) == 3 + for chunk in chunks: + assert chunk.endswith(".") + # Chunks don't include inter-sentence spaces; joined with a space they appear in full_text + assert " ".join(chunks) in result.full_text + + +@pytest.mark.asyncio +async def test_stream_validate_receives_individual_chunks() -> None: + """stream_validate is called once per chunk with the chunk itself, not accumulated text.""" + + class ChunkRecordingReq(Requirement): + def __init__(self) -> None: + self.seen_chunks: list[str] = [] + + def __copy__(self) -> "ChunkRecordingReq": + clone = ChunkRecordingReq() + clone.seen_chunks = [] + return clone + + def format_for_llm(self) -> str: + return "chunk recorder" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + self.seen_chunks.append(chunk) + return PartialValidationResult("unknown") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + return ValidationResult(result=True) + + response = "First sentence. Second sentence. Third sentence. " + backend = StreamingMockBackend(response, token_size=4) + req = ChunkRecordingReq() + + # Capture the cloned requirement used by the orchestrator via a side channel. + captured: list[ChunkRecordingReq] = [] + original_copy = ChunkRecordingReq.__copy__ + + def _capturing_copy(self: ChunkRecordingReq) -> ChunkRecordingReq: + clone = original_copy(self) + captured.append(clone) + return clone + + ChunkRecordingReq.__copy__ = _capturing_copy # type: ignore[method-assign] + try: + result = await stream_with_chunking( + _action(), backend, _ctx(), requirements=[req], chunking="sentence" + ) + await result.acomplete() + finally: + ChunkRecordingReq.__copy__ = original_copy # type: ignore[method-assign] + + assert len(captured) == 1 + seen = captured[0].seen_chunks + # Exact match: three separate calls, one per complete sentence, + # each call receiving that sentence and nothing more. Under the old + # accumulated-text semantics, seen would have been + # ["First sentence.", "First sentence. Second sentence.", ...] — + # exact match against the per-chunk list is the direct regression guard. + assert seen == ["First sentence.", "Second sentence.", "Third sentence."] + + +@pytest.mark.asyncio +async def test_trailing_fragment_is_flushed_to_consumer() -> None: + """Response without trailing whitespace: final sentence reaches astream() and stream_validate.""" + + class ChunkRecordingReq(Requirement): + def __init__(self) -> None: + self.seen_chunks: list[str] = [] + + def __copy__(self) -> "ChunkRecordingReq": + clone = ChunkRecordingReq() + clone.seen_chunks = [] + return clone + + def format_for_llm(self) -> str: + return "chunk recorder" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + self.seen_chunks.append(chunk) + return PartialValidationResult("unknown") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + return ValidationResult(result=True) + + # No trailing whitespace after the final sentence — SentenceChunker withholds it. + response = "First sentence. Second sentence." + backend = StreamingMockBackend(response, token_size=4) + req = ChunkRecordingReq() + + captured: list[ChunkRecordingReq] = [] + original_copy = ChunkRecordingReq.__copy__ + + def _capturing_copy(self: ChunkRecordingReq) -> ChunkRecordingReq: + clone = original_copy(self) + captured.append(clone) + return clone + + ChunkRecordingReq.__copy__ = _capturing_copy # type: ignore[method-assign] + try: + result = await stream_with_chunking( + _action(), backend, _ctx(), requirements=[req], chunking="sentence" + ) + yielded: list[str] = [] + async for chunk in result.astream(): + yielded.append(chunk) + await result.acomplete() + finally: + ChunkRecordingReq.__copy__ = original_copy # type: ignore[method-assign] + + # Both sentences reach the consumer, including the terminating one without trailing whitespace. + assert yielded == ["First sentence.", "Second sentence."] + # stream_validate was called on both — the flush path is not a shortcut. + assert captured[0].seen_chunks == ["First sentence.", "Second sentence."] + assert result.completed is True + + +@pytest.mark.asyncio +async def test_early_exit_on_trailing_fragment() -> None: + """A fail on the flushed fragment records a streaming failure and skips final validate().""" + + class FailOnSecondSentence(Requirement): + def __init__(self) -> None: + self._count = 0 + + def format_for_llm(self) -> str: + return "fail on second sentence" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = chunk, backend, ctx + self._count += 1 + if self._count >= 2: + return PartialValidationResult("fail", reason="second sentence hit") + return PartialValidationResult("unknown") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + return ValidationResult(result=True) + + response = "First sentence. Second sentence." + backend = StreamingMockBackend(response, token_size=4) + req = FailOnSecondSentence() + + result = await stream_with_chunking( + _action(), backend, _ctx(), requirements=[req], chunking="sentence" + ) + yielded: list[str] = [] + async for chunk in result.astream(): + yielded.append(chunk) + await result.acomplete() + + assert result.completed is False + assert len(result.streaming_failures) == 1 + # First sentence was emitted; second (the flushed fragment) failed and wasn't emitted. + assert yielded == ["First sentence."] + # Early exit on fail skips final validate(). + assert result.final_validations == [] + + +@pytest.mark.asyncio +async def test_no_requirements_streams_without_validation() -> None: + """requirements=None → chunks produced, no validate() called.""" + response = "Chunk one. Chunk two. Chunk three. " + backend = StreamingMockBackend(response, token_size=3) + + result = await stream_with_chunking( + _action(), backend, _ctx(), requirements=None, chunking="sentence" + ) + await result.acomplete() + + assert result.completed is True + assert result.full_text == response + assert result.final_validations == [] + assert result.streaming_failures == [] + + +@pytest.mark.asyncio +async def test_multiple_chunks_in_one_batch_with_mid_batch_fail() -> None: + """When one astream() delta produces several complete chunks and one in + the middle fails, earlier chunks emit, failing chunk is recorded, later + chunks are neither validated nor emitted.""" + + captured: list[Any] = [] + + class FailOnNthChunk(Requirement): + def __init__(self, n: int) -> None: + self._n = n + self._calls = 0 + self.seen: list[str] = [] + + def __copy__(self) -> "FailOnNthChunk": + clone = FailOnNthChunk(self._n) + captured.append(clone) + return clone + + def format_for_llm(self) -> str: + return f"fail on chunk {self._n}" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = backend, ctx + self._calls += 1 + self.seen.append(chunk) + if self._calls == self._n: + return PartialValidationResult("fail", reason=f"n={self._n}") + return PartialValidationResult("unknown") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + _ = backend, ctx, format, model_options + return ValidationResult(result=True) + + # token_size larger than the whole response → one astream() delta delivers + # the full text, so chunking.split produces 4 sentences in a single batch. + response = "One. Two. Three. Four. " + backend = StreamingMockBackend(response, token_size=100) + req = FailOnNthChunk(n=2) + + result = await stream_with_chunking( + _action(), backend, _ctx(), requirements=[req], chunking="sentence" + ) + yielded: list[str] = [] + async for c in result.astream(): + yielded.append(c) + await result.acomplete() + + assert result.completed is False + assert len(result.streaming_failures) == 1 + # Chunk 1 was validated and emitted; chunk 2 was validated and failed + # (NOT emitted); chunks 3 and 4 were NEITHER validated NOR emitted. + assert yielded == ["One."] + assert len(captured) == 1 + assert captured[0].seen == ["One.", "Two."] + assert captured[0]._calls == 2 + + +@pytest.mark.asyncio +async def test_cancel_generation_invoked_on_fail() -> None: + """Early exit on 'fail' must call mot.cancel_generation() — the spec reason + is that asyncio.Queue(maxsize=20) will block the producer if the consumer + stops without cancelling.""" + + from mellea.core.base import ModelOutputThunk + + response = "word " * 50 + backend = StreamingMockBackend(response, token_size=3) + + class FailOnFirstChunk(Requirement): + def format_for_llm(self) -> str: + return "fail immediately" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = chunk, backend, ctx + return PartialValidationResult("fail", reason="nope") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + _ = backend, ctx, format, model_options + return ValidationResult(result=True) + + call_count = 0 + real_cancel = ModelOutputThunk.cancel_generation + + async def spy_cancel( + self: ModelOutputThunk, error: Exception | None = None + ) -> None: + nonlocal call_count + call_count += 1 + await real_cancel(self, error) + + ModelOutputThunk.cancel_generation = spy_cancel # type: ignore[method-assign] + try: + result = await stream_with_chunking( + _action(), + backend, + _ctx(), + requirements=[FailOnFirstChunk()], + chunking="word", + ) + await asyncio.wait_for(result.acomplete(), timeout=5.0) + finally: + ModelOutputThunk.cancel_generation = real_cancel # type: ignore[method-assign] + + assert result.completed is False + assert call_count >= 1 + + +@pytest.mark.asyncio +async def test_cancelled_flag_reflects_cancellation_state() -> None: + """The ``cancelled`` property on ModelOutputThunk distinguishes an early-exit + cancellation from a normal completion and propagates through ``as_thunk``.""" + + # Early exit → cancelled is True, is_computed True, propagates through as_thunk. + fail_response = "word " * 50 + fail_backend = StreamingMockBackend(fail_response, token_size=3) + + class FailImmediately(Requirement): + def format_for_llm(self) -> str: + return "fail immediately" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = chunk, backend, ctx + return PartialValidationResult("fail", reason="nope") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + _ = backend, ctx, format, model_options + return ValidationResult(result=True) + + fail_result = await stream_with_chunking( + _action(), + fail_backend, + _ctx(), + requirements=[FailImmediately()], + chunking="word", + ) + await asyncio.wait_for(fail_result.acomplete(), timeout=5.0) + + assert fail_result.completed is False + assert fail_result.as_thunk.cancelled is True + assert fail_result.as_thunk.is_computed() is True + + # Normal completion → cancelled is False. + ok_response = "Hello world. How are you. " + ok_backend = StreamingMockBackend(ok_response, token_size=3) + + ok_result = await stream_with_chunking( + _action(), + ok_backend, + _ctx(), + requirements=[AlwaysUnknownReq()], + chunking="sentence", + ) + await ok_result.acomplete() + + assert ok_result.completed is True + assert ok_result.as_thunk.cancelled is False + assert ok_result.as_thunk.is_computed() is True + + +@pytest.mark.asyncio +async def test_exception_in_stream_validate_cancels_generation() -> None: + """Verifies the orchestrator's exception-path cleanup: if stream_validate + raises, cancel_generation() is called and the exception surfaces to the + consumer via astream()/acomplete() without hanging. + + This covers the cancel-on-exception path and the no-hang guarantee. + It does not directly exercise the worst-case "producer already blocked on + full queue" scenario (here the fail happens on chunk 1 so the queue never + fills); the cancel_generation drain logic is covered by its own tests in + test/core/. + """ + + from mellea.core.base import ModelOutputThunk + + class RaisingReq(Requirement): + def format_for_llm(self) -> str: + return "raises" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = chunk, backend, ctx + raise ValueError("boom") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + _ = backend, ctx, format, model_options + return ValidationResult(result=True) + + response = "word " * 50 # enough to fill maxsize=20 queue without cleanup + backend = StreamingMockBackend(response, token_size=3) + + call_count = 0 + real_cancel = ModelOutputThunk.cancel_generation + + async def spy_cancel( + self: ModelOutputThunk, error: Exception | None = None + ) -> None: + nonlocal call_count + call_count += 1 + await real_cancel(self, error) + + ModelOutputThunk.cancel_generation = spy_cancel # type: ignore[method-assign] + try: + result = await stream_with_chunking( + _action(), backend, _ctx(), requirements=[RaisingReq()], chunking="word" + ) + with pytest.raises(ValueError, match="boom"): + async for _chunk in result.astream(): + pass + # acomplete must complete (not hang) even though the orchestration + # task raised, because cancel_generation was called in the except path. + await asyncio.wait_for(result.acomplete(), timeout=5.0) + finally: + ModelOutputThunk.cancel_generation = real_cancel # type: ignore[method-assign] + + assert result.completed is False + assert call_count >= 1 + + +@pytest.mark.asyncio +async def test_acomplete_surfaces_exception_without_astream() -> None: + """acomplete() must surface orchestrator exceptions even when the + consumer never iterates astream(). + + The alternative — only delivering the exception through the chunk queue + — silently swallows validator failures for callers who skip astream(). + """ + + class RaisingReq(Requirement): + def format_for_llm(self) -> str: + return "raises" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + _ = chunk, backend, ctx + raise ValueError("surfaced-without-astream") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + _ = backend, ctx, format, model_options + return ValidationResult(result=True) + + response = "word " * 50 + backend = StreamingMockBackend(response, token_size=3) + + result = await stream_with_chunking( + _action(), backend, _ctx(), requirements=[RaisingReq()], chunking="word" + ) + # Deliberately skip astream(). wait_for bounds any hang. + with pytest.raises(ValueError, match="surfaced-without-astream"): + await asyncio.wait_for(result.acomplete(), timeout=5.0) + + assert result.completed is False + # Raise-once: a second acomplete() must not re-raise. + await asyncio.wait_for(result.acomplete(), timeout=5.0) + + +@pytest.mark.asyncio +async def test_external_task_cancellation_releases_consumers() -> None: + """External cancellation of the orchestration task must still set _done. + + If the finally cleanup itself contains an ``await`` (e.g. awaiting a + terminator put into the chunk queue), CancelledError re-raises at that + await and ``_done.set()`` never runs — any consumer blocked on + ``acomplete()`` hangs forever. The cleanup must therefore end with + synchronous operations only. + """ + response = "word " * 200 # long enough that streaming is still in progress + backend = StreamingMockBackend(response, token_size=2) + + result = await stream_with_chunking( + _action(), backend, _ctx(), requirements=[AlwaysUnknownReq()], chunking="word" + ) + + assert result._orchestration_task is not None + # Yield once so the orchestration task enters its main loop before we + # cancel it. + await asyncio.sleep(0.01) + + # Same mechanism asyncio.wait_for uses on timeout. + result._orchestration_task.cancel() + + # _done must be set by the finally cleanup. A hang would time out here. + await asyncio.wait_for(result._done.wait(), timeout=2.0) + assert result._done.is_set() + + # acomplete() surfaces the CancelledError via task.exception() and must + # not hang. + with pytest.raises(asyncio.CancelledError): + await asyncio.wait_for(result.acomplete(), timeout=2.0) + + +@pytest.mark.asyncio +async def test_external_cancellation_acomplete_raise_once() -> None: + """Raise-once contract holds for the task-fallback path on external cancel. + + CancelledError bypasses the orchestrator's ``except Exception`` handler, + so ``_orchestration_exception`` is never set. ``acomplete()`` surfaces the + cancel via ``self._orchestration_task.exception()`` instead — and that + branch must also flip ``_exception_surfaced`` so a second ``acomplete()`` + call does not raise the same exception twice. + """ + response = "word " * 200 + backend = StreamingMockBackend(response, token_size=2) + + result = await stream_with_chunking( + _action(), backend, _ctx(), requirements=[AlwaysUnknownReq()], chunking="word" + ) + + assert result._orchestration_task is not None + await asyncio.sleep(0.01) + result._orchestration_task.cancel() + await asyncio.wait_for(result._done.wait(), timeout=2.0) + + with pytest.raises(asyncio.CancelledError): + await asyncio.wait_for(result.acomplete(), timeout=2.0) + + # Second call must NOT re-raise — raise-once contract. + await asyncio.wait_for(result.acomplete(), timeout=2.0) + + +@pytest.mark.asyncio +async def test_raise_once_acomplete_then_astream() -> None: + """Regression for the raise-once stash bug: acomplete() first, astream() second. + + Prior to the fix, acomplete() cleared _orchestration_exception, so a + subsequent astream() call dequeued the exception item, saw the stash was + None, silently skipped it, and returned zero chunks with no error. + """ + + class RaisingReq(Requirement): + def format_for_llm(self) -> str: + return "raises" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + raise ValueError("raise-once-regression") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + return ValidationResult(result=True) + + response = "word " * 10 + backend = StreamingMockBackend(response, token_size=3) + result = await stream_with_chunking( + _action(), backend, _ctx(), requirements=[RaisingReq()], chunking="word" + ) + + # acomplete() sees the exception first and raises it. + with pytest.raises(ValueError, match="raise-once-regression"): + await asyncio.wait_for(result.acomplete(), timeout=5.0) + + # astream() must NOT re-raise (raise-once semantics). Because the + # exception fired before any chunk was emitted, the queue contains + # [exc, None]. With the separate _exception_surfaced flag, astream() + # correctly skips the exception item and terminates cleanly. Without + # the flag the behaviour is the same, but the guard conflates + # "already surfaced" with "stash was never set" — the flag makes the + # intent unambiguous. + chunks: list[str] = [] + async for chunk in result.astream(): + chunks.append(chunk) + assert chunks == [] # no partial chunks before the exception + + +@pytest.mark.asyncio +async def test_full_text_contains_only_validated_chunks_on_early_exit() -> None: + """full_text must equal exactly what was emitted to the consumer on early exit. + + When one astream() delta produces N chunks and chunk K fails, full_text + must contain chunks 0..K-1 only — not the failed chunk or any unvalidated + chunks after it in the same delta. + """ + + class FailOnNthChunkText(Requirement): + def __init__(self, n: int) -> None: + self._n = n + self._calls = 0 + + def __copy__(self) -> "FailOnNthChunkText": + return FailOnNthChunkText(self._n) + + def format_for_llm(self) -> str: + return f"fail on chunk {self._n}" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + self._calls += 1 + if self._calls == self._n: + return PartialValidationResult("fail") + return PartialValidationResult("unknown") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + return ValidationResult(result=True) + + # token_size > full response → single delta with 4 sentences; fail on chunk 2. + response = "One. Two. Three. Four. " + backend = StreamingMockBackend(response, token_size=100) + req = FailOnNthChunkText(n=2) + + result = await stream_with_chunking( + _action(), backend, _ctx(), requirements=[req], chunking="sentence" + ) + yielded: list[str] = [] + async for chunk in result.astream(): + yielded.append(chunk) + await result.acomplete() + + assert result.completed is False + # Consumer received only chunk 1. + assert yielded == ["One."] + # full_text must match what the consumer received — not the raw delta. + assert result.full_text == "One." + # as_thunk.value must agree with full_text. + assert result.as_thunk.value == result.full_text + + # Fail on chunk 3: two chunks emitted before early exit. full_text must + # preserve the original inter-sentence spacing from the token stream, not + # the stripped chunk concatenation ("One.Two." would be wrong). + backend2 = StreamingMockBackend(response, token_size=100) + req2 = FailOnNthChunkText(n=3) + result2 = await stream_with_chunking( + _action(), backend2, _ctx(), requirements=[req2], chunking="sentence" + ) + yielded2: list[str] = [] + async for chunk in result2.astream(): + yielded2.append(chunk) + await result2.acomplete() + + assert result2.completed is False + assert yielded2 == ["One.", "Two."] + assert result2.full_text == "One. Two." + assert result2.as_thunk.value == result2.full_text + + +@pytest.mark.asyncio +async def test_cancelled_flag_propagates_through_copy_methods() -> None: + """_cancelled must survive __copy__, __deepcopy__, and _copy_from.""" + from copy import deepcopy + + mot = ModelOutputThunk(value="result") + mot._cancelled = True + + # __copy__ + shallow = mot.__copy__() + assert shallow._cancelled is True, "__copy__ must propagate _cancelled" + + # __deepcopy__ + deep = deepcopy(mot) + assert deep._cancelled is True, "__deepcopy__ must propagate _cancelled" + + # _copy_from + target = ModelOutputThunk(value="original") + assert target._cancelled is False + target._copy_from(mot) + assert target._cancelled is True, "_copy_from must propagate _cancelled" + + # Sanity: default-constructed MOT has _cancelled=False. + fresh = ModelOutputThunk(value="x") + assert fresh._cancelled is False + + +# --------------------------------------------------------------------------- +# Fix 1 — setup-path backend leak: copy(req) before generate_from_context +# --------------------------------------------------------------------------- + + +class _PlainReq(Requirement): + """Default shallow copy — cannot raise.""" + + def format_for_llm(self) -> str: + return "plain" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + return PartialValidationResult("unknown") + + async def validate( + self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None + ) -> ValidationResult: + return ValidationResult(result=True) + + +class _RaisingCopyReq(Requirement): + """__copy__ raises — simulates a user-defined Requirement with a faulty override.""" + + def __copy__(self) -> "_RaisingCopyReq": + raise ValueError("copy boom") + + def format_for_llm(self) -> str: + return "raising copy" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + return PartialValidationResult("unknown") + + async def validate( + self, backend: Any, ctx: Any, *, format: Any = None, model_options: Any = None + ) -> ValidationResult: + return ValidationResult(result=True) + + +class _InstrumentedBackend(StreamingMockBackend): + """Counts generate_from_context calls and exposes the last MOT produced.""" + + def __init__(self, response: str, token_size: int = 1) -> None: + super().__init__(response, token_size) + self.generate_from_context_call_count = 0 + self.last_mot: ModelOutputThunk | None = None + + async def _generate_from_context( + self, + action: Any, + ctx: Any, + *, + format: Any = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> tuple[ModelOutputThunk, Any]: + self.generate_from_context_call_count += 1 + mot, new_ctx = await super()._generate_from_context( + action, + ctx, + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) + self.last_mot = mot + return mot, new_ctx + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "req_cls,expect_raise", [(_PlainReq, False), (_RaisingCopyReq, True)] +) +async def test_stream_with_chunking_requirement_copy_contract( + req_cls: type, expect_raise: bool +) -> None: + """Fix 1: copy(req) runs before generate_from_context. + + On __copy__ failure the backend is never started (call_count == 0). + On success the backend is called exactly once. + """ + backend = _InstrumentedBackend("Hello world. ", token_size=2) + req = req_cls() + if expect_raise: + with pytest.raises(ValueError, match="copy boom"): + await stream_with_chunking(_action(), backend, _ctx(), requirements=[req]) + # Hard invariant: reorder ensures backend never starts on copy failure. + assert backend.generate_from_context_call_count == 0 + assert backend.last_mot is None + else: + result = await stream_with_chunking( + _action(), backend, _ctx(), requirements=[req] + ) + await result.acomplete() + assert backend.generate_from_context_call_count == 1 + assert backend.last_mot is not None + + +# --------------------------------------------------------------------------- +# Fix 3 — TaskGroup cancels peer validators on first failure +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_stream_with_chunking_cancels_peer_validators() -> None: + """Fix 3: a failing stream_validate causes TaskGroup to cancel peer validators. + + One requirement raises immediately in stream_validate; the second sleeps + for 5 s and sets a flag on completion. Without TaskGroup the slow sibling + runs detached; with it the cancellation is observed. + """ + reached_final_stage = asyncio.Event() + + class _RaisingReq(Requirement): + def format_for_llm(self) -> str: + return "raiser" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + raise RuntimeError("validator failed") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + return ValidationResult(result=False) + + class _SlowReq(Requirement): + def format_for_llm(self) -> str: + return "slow" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + try: + await asyncio.sleep(5.0) + reached_final_stage.set() + return PartialValidationResult("pass") + except asyncio.CancelledError: + raise # propagate so TaskGroup knows we were cancelled + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + return ValidationResult(result=True) + + backend = StreamingMockBackend("Hello world. ", token_size=2) + result = await stream_with_chunking( + _action(), backend, _ctx(), requirements=[_RaisingReq(), _SlowReq()] + ) + with pytest.raises(RuntimeError, match="validator failed"): + await result.acomplete() + + # Give the loop a tick; the slow sibling must NOT have run to completion. + await asyncio.sleep(0.05) + assert not reached_final_stage.is_set(), ( + "slow sibling was not cancelled by TaskGroup" + ) + + +@pytest.mark.asyncio +async def test_stream_with_chunking_rejects_precomputed_mot() -> None: + """Backend returning an already-computed MOT raises RuntimeError immediately. + + stream_with_chunking() requires streaming; a pre-computed MOT would cause + the orchestrator loop to skip entirely, producing empty output and silently + passing all final validators against an empty string. + """ + + class PrecomputedBackend(Backend): + async def _generate_from_context( + self, + action: Any, + ctx: Any, + *, + format: Any = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> tuple[ModelOutputThunk, Any]: + return ModelOutputThunk(value="already done"), ctx + + async def generate_from_raw( + self, actions: Any, ctx: Any, **kwargs: Any + ) -> list[ModelOutputThunk]: + raise NotImplementedError + + with pytest.raises(RuntimeError, match="already-computed MOT"): + await stream_with_chunking(_action(), PrecomputedBackend(), _ctx())