diff --git a/docs/docs/concepts/requirements-system.md b/docs/docs/concepts/requirements-system.md index 41ff5e692..494f709da 100644 --- a/docs/docs/concepts/requirements-system.md +++ b/docs/docs/concepts/requirements-system.md @@ -301,3 +301,30 @@ requirements = [ All requirements are validated after each generation attempt. The repair request lists every requirement that failed, not just the first one, so the model can address all issues in a single repair pass. + +## Streaming validation + +`stream_validate()` is the streaming counterpart to `validate()`. It is called +once per semantic chunk as tokens arrive from the model, before the full output +is available. Requirements that need to detect problems early — too many +sentences, a prohibited keyword in the first paragraph, unexpected JSON +structure mid-output — override `stream_validate()` to express that logic. + +`stream_validate()` returns a `PartialValidationResult` with a tri-state `success` +field: + +- `"unknown"` — no conclusion yet; the chunk is passed to the consumer and + `validate()` will be called at stream end. +- `"pass"` — the chunk looks valid so far; it is passed to the consumer and + `validate()` is still called at stream end (a streaming pass is informational, + not final). +- `"fail"` — the stream is cancelled immediately; no further chunks reach the + consumer; `validate()` is skipped for this requirement. + +State isolation is per-clone: `stream_with_chunking()` copies each requirement +with `copy()` before starting the orchestrator, so the original objects are never +mutated. Requirements that accumulate state across chunks (e.g. a running word +count) should reassign mutable containers rather than mutate in place, since +clones share the original's `__dict__` values at copy time. + +> **See also:** [Streaming with per-chunk validation](../how-to/use-async-and-streaming#streaming-with-per-chunk-validation) diff --git a/docs/docs/docs.json b/docs/docs/docs.json index 3a4465615..e3ae4f0e8 100644 --- a/docs/docs/docs.json +++ b/docs/docs/docs.json @@ -35,7 +35,8 @@ "tutorials/02-streaming-and-async", "tutorials/03-using-generative-stubs", "tutorials/04-making-agents-reliable", - "tutorials/05-mifying-legacy-code" + "tutorials/05-mifying-legacy-code", + "tutorials/06-streaming-validation" ] }, { diff --git a/docs/docs/examples/index.md b/docs/docs/examples/index.md index 7b03a85fa..a4c0128f7 100644 --- a/docs/docs/examples/index.md +++ b/docs/docs/examples/index.md @@ -33,6 +33,7 @@ to run. | `context/` | Context inspection, sampling with context trees, parallel context branches | | `sessions/` | Custom session types and backend selection | | `async/` | How to utilize basic async capabilities | +| `streaming/` | `stream_with_chunking()` with per-chunk validation, typed event vocabulary, early-exit on fail | ### Data and documents diff --git a/docs/docs/how-to/use-async-and-streaming.md b/docs/docs/how-to/use-async-and-streaming.md index 084a41269..853e99f65 100644 --- a/docs/docs/how-to/use-async-and-streaming.md +++ b/docs/docs/how-to/use-async-and-streaming.md @@ -175,6 +175,128 @@ asyncio.run(sequential_chat()) For parallel generation, use `SimpleContext`. +## Streaming with per-chunk validation + +`stream_with_chunking()` adds per-chunk validation to a streaming generation. +It splits the accumulated text into semantic units (sentences, words, or +paragraphs), calls `stream_validate()` on each chunk in parallel, and can +exit early if any requirement returns `"fail"` — preventing the consumer from +seeing invalid content mid-stream. + +The primary way to observe a `stream_with_chunking()` run is via typed +`StreamEvent` objects from `result.events()`: + +```python +# Requires: mellea +# Returns: None +import asyncio + +from mellea.core.backend import Backend +from mellea.core.base import Context +from mellea.core.requirement import PartialValidationResult, Requirement, ValidationResult +from mellea.stdlib.components import Instruction +from mellea.stdlib.streaming import ( + ChunkEvent, + CompletedEvent, + FullValidationEvent, + QuickCheckEvent, + StreamingDoneEvent, + stream_with_chunking, +) + + +class MaxSentencesReq(Requirement): + """Fails if the model generates more than *limit* sentences.""" + + def __init__(self, limit: int) -> None: + super().__init__() + self._limit = limit + self._count = 0 + + def format_for_llm(self) -> str: + return f"The response must be at most {self._limit} sentences." + + async def stream_validate( + self, chunk: str, *, backend: Backend, ctx: Context + ) -> PartialValidationResult: + self._count += 1 + if self._count > self._limit: + return PartialValidationResult("fail", reason="Too many sentences") + return PartialValidationResult("unknown") + + async def validate( + self, backend: Backend, ctx: Context, *, format=None, model_options=None + ) -> ValidationResult: + return ValidationResult(result=True) + + +async def main() -> None: + from mellea.stdlib.session import start_session + + m = start_session() + action = Instruction("Write a two-sentence summary of the water cycle.") + req = MaxSentencesReq(limit=3) + + result = await stream_with_chunking( + action, m.backend, m.ctx, requirements=[req], chunking="sentence" + ) + + async for event in result.events(): + match event: + case ChunkEvent(): + print(f" chunk[{event.chunk_index}]: {event.text!r}") + case QuickCheckEvent(passed=False): + print(f" FAIL at chunk {event.chunk_index}: {event.results}") + case StreamingDoneEvent(): + print(f" stream done — {len(event.full_text)} chars") + case FullValidationEvent(): + print(f" final: {'pass' if event.passed else 'fail'}") + case CompletedEvent(): + print(f" completed — success={event.success}") + case _: + pass # ErrorEvent and other future types + + await result.acomplete() + print(f"completed={result.completed}, failures={len(result.streaming_failures)}") + + +asyncio.run(main()) +``` + +If you only need the raw validated text without event metadata, use +`result.astream()` instead: + +```python +result = await stream_with_chunking( + action, m.backend, m.ctx, requirements=[req], chunking="sentence" +) +async for chunk in result.astream(): + print(chunk) +await result.acomplete() +``` + +Both `astream()` (raw chunks) and `events()` are available on the same result +object. They use independent queues, so you can run them concurrently with +`asyncio.gather`. Both are **single-consumer** — a second iteration on either +will block indefinitely. + +### The `stream_validate` tri-state + +Each call to `stream_validate` returns a `PartialValidationResult` with one of +three values: + +| Value | Meaning | +| ----- | ------- | +| `"unknown"` | No conclusion yet — wait for the full output before judging. | +| `"pass"` | This chunk is valid so far (informational; does not skip final `validate()`). | +| `"fail"` | Invalid — cancel the stream immediately and record a streaming failure. | + +After a natural stream end, `validate()` is called on every non-`"fail"` +requirement (both `"pass"` and `"unknown"`). This means `"pass"` from +`stream_validate` does **not** replace the final `validate()` call. + +> **See also:** [The Requirements System — Streaming validation](../concepts/requirements-system#streaming-validation) + --- **See also:** [Tutorial 02: Streaming and Async](../tutorials/02-streaming-and-async) | [act() and aact()](../how-to/act-and-aact) diff --git a/docs/docs/tutorials/06-streaming-validation.md b/docs/docs/tutorials/06-streaming-validation.md new file mode 100644 index 000000000..753731ea5 --- /dev/null +++ b/docs/docs/tutorials/06-streaming-validation.md @@ -0,0 +1,616 @@ +--- +canonical: "https://docs.mellea.ai/tutorials/06-streaming-validation" +title: "Tutorial: Streaming Validation" +description: "Validate LLM output chunk by chunk as it streams — detect policy violations the moment they appear and cancel generation before invalid content reaches your users." +# diataxis: tutorial +--- + +Post-generation validation waits until the model has finished writing before +checking the output. That is fine for short responses, but wastes time and +compute when a violation appears in the first sentence of a ten-paragraph +reply. Streaming validation moves the check into the generation loop: each +chunk is validated as soon as it arrives, and generation is cancelled the +moment a requirement fails. + +By the end you will have covered: + +- `stream_with_chunking()` — the streaming validation entry point +- The typed event vocabulary (`ChunkEvent`, `QuickCheckEvent`, …) from `result.events()` +- Early-exit cancellation and reading `streaming_failures` +- Choosing between `"word"`, `"sentence"`, and `"paragraph"` chunking +- Subclassing `ChunkingStrategy` to define a custom split boundary +- `result.astream()` for consumers that only need the validated chunks + +**Prerequisites:** [Tutorial 02](./02-streaming-and-async) (async and streaming), +[Tutorial 04](./04-making-agents-reliable) (requirements and validation), +`pip install mellea`, Ollama running locally with `granite4.1:3b` downloaded. + +--- + +## Step 1: Your first streaming validation call + +`stream_with_chunking()` returns a `StreamChunkingResult` immediately. The +orchestrator runs in the background, splitting accumulated text into chunks and +calling `stream_validate()` on each one. Consume events with `result.events()`, +then call `result.acomplete()` to wait for the orchestrator to finish and raise +any exception it stored. + +```python +# Requires: mellea +# Returns: None +import asyncio +import re + +from mellea.core.backend import Backend +from mellea.core.base import Context +from mellea.core.requirement import PartialValidationResult, Requirement, ValidationResult +from mellea.stdlib.components import Instruction +from mellea.stdlib.streaming import ( + ChunkEvent, + CompletedEvent, + FullValidationEvent, + QuickCheckEvent, + StreamingDoneEvent, + stream_with_chunking, +) + +_SENTENCE_END = re.compile(r"[.!?]+") + + +class MaxSentencesReq(Requirement): + """Fails the stream if the model writes more sentences than *limit*.""" + + def __init__(self, limit: int) -> None: + super().__init__() + self._limit = limit + self._count = 0 + + def format_for_llm(self) -> str: + return f"The response must be at most {self._limit} sentences." + + async def stream_validate( + self, chunk: str, *, backend: Backend, ctx: Context + ) -> PartialValidationResult: + self._count += len(_SENTENCE_END.findall(chunk)) + if self._count > self._limit: + return PartialValidationResult( + "fail", reason=f"Exceeded {self._limit}-sentence limit" + ) + return PartialValidationResult("unknown") + + async def validate( + self, backend: Backend, ctx: Context, *, format=None, model_options=None + ) -> ValidationResult: + return ValidationResult(result=self._count <= self._limit) + + +async def main() -> None: + from mellea.stdlib.session import start_session + + m = start_session() + + result = await stream_with_chunking( + Instruction("Write a two-sentence summary of how photosynthesis works."), + m.backend, + m.ctx, + requirements=[MaxSentencesReq(limit=3)], + chunking="sentence", + ) + + async for event in result.events(): + match event: + case ChunkEvent(): + print(f" chunk[{event.chunk_index}]: {event.text!r}") + case QuickCheckEvent(passed=False): + print(f" FAIL at chunk {event.chunk_index}: {event.results[0].reason}") + case StreamingDoneEvent(): + print(f" stream done — {len(event.full_text)} chars") + case FullValidationEvent(): + print(f" final validation: {'pass' if event.passed else 'fail'}") + case CompletedEvent(): + print(f" completed — success={event.success}") + case _: + pass + + await result.acomplete() + print(f"\nFull text: {result.full_text!r}") + + +asyncio.run(main()) +``` + +```text Sample output + chunk[0]: 'Photosynthesis is the process by which plants use sunlight, water, and carbon dioxide to produce glucose and oxygen.' + chunk[1]: 'This reaction takes place in the chloroplasts and is essential to nearly all life on Earth.' + stream done — 222 chars + final validation: pass + completed — success=True + +Full text: 'Photosynthesis is the process by which plants use sunlight, water, and carbon dioxide to produce glucose and oxygen. This reaction takes place in the chloroplasts and is essential to nearly all life on Earth.' +``` + +> **Note:** LLM output is non-deterministic. Your result will vary in wording. + +Three things to notice: + +- `stream_with_chunking()` is called with `await` but returns immediately — the + orchestrator runs as a background task. +- `result.events()` is an async iterator that yields one event per semantic + unit. The loop ends when the `CompletedEvent` is delivered. +- `result.acomplete()` must be called after the event loop drains to propagate + any orchestrator exception and to ensure the background task has fully settled. + +--- + +## Step 2: Early exit on failure + +When `stream_validate()` returns `"fail"`, the orchestrator cancels the backend +immediately and stops the stream. No further chunks are delivered, and the +failure is recorded in `result.streaming_failures`. + +Lower the sentence limit so the model is likely to exceed it: + +```python +# Requires: mellea +# Returns: None +import asyncio +import re + +from mellea.core.backend import Backend +from mellea.core.base import Context +from mellea.core.requirement import PartialValidationResult, Requirement, ValidationResult +from mellea.stdlib.components import Instruction +from mellea.stdlib.streaming import ChunkEvent, CompletedEvent, QuickCheckEvent, stream_with_chunking + +_SENTENCE_END = re.compile(r"[.!?]+") + + +class MaxSentencesReq(Requirement): + def __init__(self, limit: int) -> None: + super().__init__() + self._limit = limit + self._count = 0 + + def format_for_llm(self) -> str: + return f"The response must be at most {self._limit} sentences." + + async def stream_validate( + self, chunk: str, *, backend: Backend, ctx: Context + ) -> PartialValidationResult: + self._count += len(_SENTENCE_END.findall(chunk)) + if self._count > self._limit: + return PartialValidationResult( + "fail", reason=f"Exceeded {self._limit}-sentence limit" + ) + return PartialValidationResult("unknown") + + async def validate( + self, backend: Backend, ctx: Context, *, format=None, model_options=None + ) -> ValidationResult: + return ValidationResult(result=self._count <= self._limit) + + +async def main() -> None: + from mellea.stdlib.session import start_session + + m = start_session() + + # Ask for five sentences but cap the requirement at two. + # The stream should be cancelled after the third sentence arrives. + result = await stream_with_chunking( + Instruction("Write five sentences about the history of the internet."), + m.backend, + m.ctx, + requirements=[MaxSentencesReq(limit=2)], + chunking="sentence", + ) + + async for event in result.events(): + match event: + case ChunkEvent(): + print(f" chunk[{event.chunk_index}]: {event.text[:60]!r}...") + case QuickCheckEvent(passed=False): + print(f" CANCELLED at chunk {event.chunk_index}") + case CompletedEvent(): + print(f" completed — success={event.success}") + case _: + pass + + await result.acomplete() + + if result.streaming_failures: + req, pvr = result.streaming_failures[0] + print(f"\nStreaming failure: {pvr.reason}") + print(f"Text at cancellation:\n{result.full_text!r}") + else: + print(f"\nFull text: {result.full_text!r}") + + +asyncio.run(main()) +``` + +```text Sample output + chunk[0]: 'The internet began as ARPANET, a U.S. Defense Department pr'... + chunk[1]: 'In the 1980s, the network expanded beyond government use and'... + chunk[2]: 'Tim Berners-Lee invented the World Wide Web in 1989, transfo'... + CANCELLED at chunk 2 + completed — success=False + +Streaming failure: Exceeded 2-sentence limit +Text at cancellation: +'The internet began as ARPANET, a U.S. Defense Department project in the late 1960s. In the 1980s, the network expanded beyond government use and began connecting universities and research centres. Tim Berners-Lee invented the World Wide Web in 1989...' +``` + +> **Note:** Whether the stream is cancelled depends on whether the model +> exceeds the limit. If the model happens to comply, `streaming_failures` will +> be empty and `result.completed` will be `True`. + +`result.full_text` always contains the text accumulated up to the point where +generation stopped — useful for debugging what the model produced before the +requirement failed. + +--- + +## Step 3: Choosing a chunking strategy + +The built-in strategies cover a coarse-to-fine spectrum: + +| Alias | Splits on | Good for | +| --- | --- | --- | +| `"word"` | Whitespace | Token-local checks: forbidden words, numeric limits | +| `"sentence"` | `.`, `!`, `?` followed by whitespace | Grammar, coherence, per-sentence content rules | +| `"paragraph"` | Two or more consecutive newlines | Topic coherence, citation presence, heading structure | + +The trade-off is **latency vs context**. Word chunking fires after every token — +maximum reaction speed, but each chunk carries only a single word. Paragraph +chunking waits for blank lines — full paragraph context for the validator, but +detection is later and may happen after the model has produced a large amount +of invalid content. + +To see the granularity difference concretely, switch to word chunking and print +every fifth word — so you can count how many more validation events fire compared +to Step 1's two sentences: + +```python +# Requires: mellea +# Returns: None +import asyncio + +from mellea.core.backend import Backend +from mellea.core.base import Context +from mellea.core.requirement import PartialValidationResult, Requirement, ValidationResult +from mellea.stdlib.components import Instruction +from mellea.stdlib.streaming import ChunkEvent, CompletedEvent, QuickCheckEvent, stream_with_chunking + +_FORBIDDEN = {"deprecated", "legacy", "obsolete"} + + +class ForbiddenWordReq(Requirement): + """Cancels the stream the moment any forbidden word appears.""" + + def format_for_llm(self) -> str: + return f"Do not use any of the following words: {', '.join(sorted(_FORBIDDEN))}." + + async def stream_validate( + self, chunk: str, *, backend: Backend, ctx: Context + ) -> PartialValidationResult: + word = chunk.strip().lower().strip(".,!?;:\"'") + if word in _FORBIDDEN: + return PartialValidationResult("fail", reason=f"Forbidden word: {chunk.strip()!r}") + return PartialValidationResult("unknown") + + async def validate( + self, backend: Backend, ctx: Context, *, format=None, model_options=None + ) -> ValidationResult: + return ValidationResult(result=True) + + +async def main() -> None: + from mellea.stdlib.session import start_session + + m = start_session() + + result = await stream_with_chunking( + Instruction( + "Describe three advantages of cloud-native development in two sentences." + ), + m.backend, + m.ctx, + requirements=[ForbiddenWordReq()], + chunking="word", + ) + + word_count = 0 + async for event in result.events(): + match event: + case ChunkEvent(): + word_count += 1 + # Print every fifth word to show how many events fire. + if word_count % 5 == 1: + print(f" word {word_count:>3}: {event.text!r}") + case QuickCheckEvent(passed=False): + print(f" CANCELLED at word {event.chunk_index}: {event.results[0].reason}") + case CompletedEvent(): + status = "CANCELLED" if not event.success else "ok" + print(f" {status} — {word_count} word events total") + case _: + pass + + await result.acomplete() + + if result.streaming_failures: + print(f"Failure: {result.streaming_failures[0][1].reason}") + else: + print(f"Full text: {result.full_text!r}") + + +asyncio.run(main()) +``` + +```text Sample output + word 1: 'Cloud-native' + word 6: 'resilient' + word 11: 'and' + word 16: 'allows' + word 21: 'horizontally,' + word 26: 'costs,' + word 31: 'deployments,' + word 36: 'services.' + ok — 38 word events total +Full text: 'Cloud-native development enables scalable, resilient ...' +``` + +> **Note:** LLM output is non-deterministic. Your result will vary in wording. + +The same two-sentence response that produced **2** `ChunkEvent` items with sentence +chunking now produces **38**. The validator fires on every word — maximum reaction +speed at the cost of per-chunk context. + +If a forbidden word appears, the stream stops at that word and no further +`ChunkEvent` items are emitted. To see early exit in action, change `_FORBIDDEN` +to include a common English word like `"and"` or `"the"`. + +--- + +## Step 4: Raw chunk access with `astream()` + +If you only need the validated chunks and do not want event metadata, use +`result.astream()` instead of `result.events()`. It yields the text of each +validated chunk as a plain string — useful for streaming output directly to a +UI buffer or building the response incrementally without a `match` dispatch: + +```python +# Requires: mellea +# Returns: None +import asyncio +import re + +from mellea.core.backend import Backend +from mellea.core.base import Context +from mellea.core.requirement import PartialValidationResult, Requirement, ValidationResult +from mellea.stdlib.components import Instruction +from mellea.stdlib.streaming import stream_with_chunking + +_SENTENCE_END = re.compile(r"[.!?]+") + + +class MaxSentencesReq(Requirement): + def __init__(self, limit: int) -> None: + super().__init__() + self._limit = limit + self._count = 0 + + def format_for_llm(self) -> str: + return f"The response must be at most {self._limit} sentences." + + async def stream_validate( + self, chunk: str, *, backend: Backend, ctx: Context + ) -> PartialValidationResult: + self._count += len(_SENTENCE_END.findall(chunk)) + if self._count > self._limit: + return PartialValidationResult("fail", reason=f"Exceeded {self._limit}-sentence limit") + return PartialValidationResult("unknown") + + async def validate( + self, backend: Backend, ctx: Context, *, format=None, model_options=None + ) -> ValidationResult: + return ValidationResult(result=self._count <= self._limit) + + +async def main() -> None: + from mellea.stdlib.session import start_session + + m = start_session() + + result = await stream_with_chunking( + Instruction("Write a two-sentence summary of the water cycle."), + m.backend, + m.ctx, + requirements=[MaxSentencesReq(limit=3)], + chunking="sentence", + ) + + # astream() yields only validated chunk text — no event wrapper. + async for chunk in result.astream(): + print(chunk, end=" ", flush=True) + print() + + await result.acomplete() + print(f"completed={result.completed}") + + +asyncio.run(main()) +``` + +```text Sample output +Water evaporates from oceans and lakes, rises into the atmosphere, and +condenses into clouds. Precipitation falls back to Earth as rain or +snow, replenishing rivers, lakes, and groundwater. +completed=True +``` + +> **Note:** LLM output is non-deterministic. Your result will vary in wording. + +`astream()` and `events()` are independent — both are available on the same +result object and can even be consumed concurrently with `asyncio.gather`. Each +is **single-consumer**: calling either iterator a second time raises +`RuntimeError`. If you need chunks after the fact, capture them to a list +during iteration or read `result.full_text` after `acomplete()`. + +--- + +## Step 5: A custom chunking strategy + +The built-in strategies cover the most common boundaries. For structured output +— numbered lists, code blocks, CSV rows — you can subclass `ChunkingStrategy` +and define your own split boundary. + +Two methods to implement: + +- **`split(accumulated_text)`** — called on every new token delta. Return all + complete chunks found so far; withhold any trailing fragment. Must be + stateless: it receives the full accumulated text each time, not a delta. +- **`flush(accumulated_text)`** — called once at natural end of stream. Release + the withheld trailing fragment, or return `[]` to discard it. + +Here is a `LineChunker` that splits on single newlines — natural for numbered +list output where each line is one item: + +```python +# Requires: mellea +# Returns: None +import asyncio +import re + +from mellea.core.backend import Backend +from mellea.core.base import Context +from mellea.core.requirement import PartialValidationResult, Requirement, ValidationResult +from mellea.stdlib.chunking import ChunkingStrategy +from mellea.stdlib.components import Instruction +from mellea.stdlib.streaming import ChunkEvent, CompletedEvent, QuickCheckEvent, stream_with_chunking + +_NUMBERED_LINE = re.compile(r"^\s*\d+[\.\)]\s") + + +class LineChunker(ChunkingStrategy): + """Emits one complete line per chunk, splitting on single newlines.""" + + def split(self, accumulated_text: str) -> list[str]: + if "\n" not in accumulated_text: + return [] + last_nl = accumulated_text.rfind("\n") + return [line for line in accumulated_text[:last_nl].split("\n") if line.strip()] + + def flush(self, accumulated_text: str) -> list[str]: + if not accumulated_text: + return [] + last_nl = accumulated_text.rfind("\n") + trailing = ( + accumulated_text if last_nl == -1 else accumulated_text[last_nl + 1 :] + ).strip() + return [trailing] if trailing else [] + + +class NumberedLineReq(Requirement): + """Cancels the stream if any line does not begin with a number.""" + + def format_for_llm(self) -> str: + return "Every line must begin with a number followed by a period (e.g. '1. ')." + + async def stream_validate( + self, chunk: str, *, backend: Backend, ctx: Context + ) -> PartialValidationResult: + if not _NUMBERED_LINE.match(chunk): + return PartialValidationResult( + "fail", reason=f"Line does not start with a number: {chunk.strip()!r}" + ) + return PartialValidationResult("unknown") + + async def validate( + self, backend: Backend, ctx: Context, *, format=None, model_options=None + ) -> ValidationResult: + # All format checking happens during streaming. Lines that reach validate() + # are guaranteed to have passed stream_validate() already. + return ValidationResult(result=True) + + +async def main() -> None: + from mellea.stdlib.session import start_session + + m = start_session() + + result = await stream_with_chunking( + Instruction( + "List five world capitals, one per line, numbered 1 through 5. " + "Use the format: '1. City'. Output only the numbered list, nothing else." + ), + m.backend, + m.ctx, + requirements=[NumberedLineReq()], + chunking=LineChunker(), + ) + + async for event in result.events(): + match event: + case ChunkEvent(): + print(f" line[{event.chunk_index}]: {event.text.strip()!r}") + case QuickCheckEvent(passed=False): + print(f" FAIL: {event.results[0].reason}") + case CompletedEvent(): + print(f" completed — success={event.success}") + case _: + pass + + await result.acomplete() + + +asyncio.run(main()) +``` + +```text Sample output + line[0]: '1. London' + line[1]: '2. Paris' + line[2]: '3. Tokyo' + line[3]: '4. Ottawa' + line[4]: '5. Canberra' + completed — success=True +``` + +> **Note:** LLM output is non-deterministic. Your result will vary in wording. + +`validate()` on `NumberedLineReq` always returns `True` because all format +checking happens during streaming. If any line fails, the stream is cancelled +before reaching `validate()`. Lines that do reach it have already passed +`stream_validate()`. This pattern — enforce in `stream_validate`, pass in +`validate` — is common for requirements whose invariant is a property of +individual chunks rather than the full output. + +Pass a `ChunkingStrategy` **instance** (not a string alias) to use a custom +chunker. The built-in chunkers (`WordChunker`, `SentenceChunker`, +`ParagraphChunker`) are also available as instances if you need to pass one +explicitly or subclass to override `flush()`. + +> **See also:** [`docs/examples/streaming/custom_chunking.py`](../examples/index) +> for an annotated version of this pattern with a more detailed `split()`/`flush()` +> contract walkthrough. + +--- + +## What you built + +| Concept | What it gives you | +| --- | --- | +| `stream_with_chunking()` + `requirements=` | Per-chunk validation with automatic early exit | +| `result.events()` | Typed event stream — observe every chunk, validation result, and lifecycle signal | +| `QuickCheckEvent(passed=False)` | Detect the moment a requirement fails, mid-stream | +| `result.streaming_failures` | List of `(requirement, PartialValidationResult)` pairs for failed checks | +| `"word"` / `"sentence"` / `"paragraph"` | Built-in chunking strategies trading reaction speed for context | +| `ChunkingStrategy` subclass | Custom split boundaries for structured output (lists, code, CSV) | +| `result.astream()` | Raw validated chunks without event metadata | + +--- + +> **See also:** +> [How-to: Streaming with per-chunk validation](../how-to/use-async-and-streaming#streaming-with-per-chunk-validation) | +> [Concepts: The Requirements System — Streaming validation](../concepts/requirements-system#streaming-validation) | +> [Examples: streaming/](../examples/index) diff --git a/docs/examples/streaming/custom_chunking.py b/docs/examples/streaming/custom_chunking.py new file mode 100644 index 000000000..8ea34ed7b --- /dev/null +++ b/docs/examples/streaming/custom_chunking.py @@ -0,0 +1,183 @@ +# pytest: ollama, e2e + +"""Streaming generation with a custom ChunkingStrategy subclass. + +Demonstrates: +- Subclassing :class:`~mellea.stdlib.chunking.ChunkingStrategy` to define a + new splitting boundary +- Implementing ``split()`` (stateless, idempotent) and ``flush()`` (end-of-stream + release of any withheld trailing fragment) +- Using the custom chunker with ``stream_with_chunking()`` in place of a string alias +- Validating line-by-line output from a numbered-list prompt + +``LineChunker`` splits on single newlines (``\\n``), emitting one line per +``stream_validate`` call. It sits between :class:`~mellea.stdlib.chunking.WordChunker` +(one word) and :class:`~mellea.stdlib.chunking.SentenceChunker` (one sentence) in +granularity, and is a natural fit for list-formatted model output. + +Extension pattern: + 1. Subclass ``ChunkingStrategy``. + 2. Implement ``split(accumulated_text)`` — return all complete chunks found in + the accumulated text so far; withhold any trailing fragment. The method is + called on every new token delta, so it must be stateless and idempotent. + 3. Override ``flush(accumulated_text)`` to release the withheld trailing fragment + when the stream ends naturally. The default base implementation returns ``[]`` + (fragment discarded); override it when the trailing fragment is semantically + significant. +""" + +import asyncio +import re + +from mellea.core.backend import Backend +from mellea.core.base import Context +from mellea.core.requirement import ( + PartialValidationResult, + Requirement, + ValidationResult, +) +from mellea.stdlib.chunking import ChunkingStrategy +from mellea.stdlib.components import Instruction +from mellea.stdlib.streaming import ( + ChunkEvent, + CompletedEvent, + FullValidationEvent, + QuickCheckEvent, + StreamingDoneEvent, + stream_with_chunking, +) + +# Matches a leading list marker: "1.", "1)", "1 .", or a bare number followed +# by a space — covers common model output formats. +_NUMBERED_LINE = re.compile(r"^\s*\d+[\.\)]\s") + + +class LineChunker(ChunkingStrategy): + """Splits accumulated text on single newlines, emitting one line per chunk. + + The line after the last ``\\n`` is withheld as a trailing fragment until + the stream ends and :meth:`flush` is called. Blank lines are skipped — + they carry no content for a line-level validator. + + This chunker is a good fit for numbered-list output, code listings, and + any structured response where the model uses line breaks as separators + rather than sentence-ending punctuation or double newlines. + """ + + def split(self, accumulated_text: str) -> list[str]: + """Return all complete lines (up to the last newline). + + Args: + accumulated_text: The full text accumulated so far. + + Returns: + Non-empty lines found before the last newline character. + The text after the last newline is withheld as a trailing fragment. + """ + if "\n" not in accumulated_text: + return [] + last_nl = accumulated_text.rfind("\n") + complete_section = accumulated_text[:last_nl] + return [line for line in complete_section.split("\n") if line.strip()] + + def flush(self, accumulated_text: str) -> list[str]: + """Release the trailing line fragment at end of stream. + + Args: + accumulated_text: The full accumulated text at stream end. + + Returns: + The text after the last newline as a single-element list (stripped), + or an empty list if the text ends with a newline or is empty. + """ + if not accumulated_text: + return [] + last_nl = accumulated_text.rfind("\n") + trailing = ( + accumulated_text if last_nl == -1 else accumulated_text[last_nl + 1 :] + ).strip() + return [trailing] if trailing else [] + + +class NumberedLineReq(Requirement): + """Fails the stream if any line does not start with a list number. + + Each ``stream_validate`` call receives one complete line (from + :class:`LineChunker`). This requirement enforces that every line follows + the ``N. item`` format, catching unstructured paragraphs or stray headers + that sneak into what should be a clean numbered list. + """ + + def format_for_llm(self) -> str: + return "Every line must begin with a number followed by a period (e.g. '1. ')." + + async def stream_validate( + self, chunk: str, *, backend: Backend, ctx: Context + ) -> PartialValidationResult: + if not _NUMBERED_LINE.match(chunk): + return PartialValidationResult( + "fail", reason=f"Line does not start with a number: {chunk.strip()!r}" + ) + return PartialValidationResult("unknown") + + async def validate( + self, + backend: Backend, + ctx: Context, + *, + format: type | None = None, + model_options: dict | None = None, + ) -> ValidationResult: + return ValidationResult(result=True) + + +async def main() -> None: + from mellea.stdlib.session import start_session + + m = start_session() + backend = m.backend + ctx = m.ctx + + action = Instruction( + "List five world capitals, one per line, numbered 1 through 5. " + "Use the format: '1. City'. Output only the numbered list, nothing else." + ) + chunker = LineChunker() + req = NumberedLineReq() + + result = await stream_with_chunking( + action, backend, ctx, requirements=[req], chunking=chunker + ) + + print("Streaming events as they arrive (one ChunkEvent per line):") + async for event in result.events(): + match event: + case ChunkEvent(): + print(f" LINE[{event.chunk_index}]: {event.text!r}") + case QuickCheckEvent(passed=False): + print( + f" QUICK_CHECK[line {event.chunk_index}]: FAIL — " + f"{event.results[0].reason if event.results else 'unknown'}" + ) + case QuickCheckEvent(): + print(f" QUICK_CHECK[line {event.chunk_index}]: pass") + case StreamingDoneEvent(): + print(f" STREAMING_DONE: {len(event.full_text)} chars accumulated") + case FullValidationEvent(): + print(f" FULL_VALIDATION: {'PASS' if event.passed else 'FAIL'}") + case CompletedEvent(): + print(f" COMPLETED: success={event.success}") + case _: + pass + + await result.acomplete() + + print(f"\nCompleted normally: {result.completed}") + if result.streaming_failures: + for _req, pvr in result.streaming_failures: + print(f"Streaming failure: {pvr.reason}") + else: + print(f"Full text:\n{result.full_text}") + + +asyncio.run(main()) diff --git a/docs/examples/streaming/paragraph_chunking.py b/docs/examples/streaming/paragraph_chunking.py new file mode 100644 index 000000000..0c223b05c --- /dev/null +++ b/docs/examples/streaming/paragraph_chunking.py @@ -0,0 +1,142 @@ +# pytest: ollama, e2e + +"""Streaming generation with per-paragraph validation using ParagraphChunker. + +Demonstrates: +- Using the ``"paragraph"`` chunking alias for coarse-grained, structure-aware + validation +- A paragraph-length gate that cancels generation if any paragraph is too long +- How ParagraphChunker withholds text until a blank line (``\\n\\n``) is seen, + then emits the entire paragraph as a single chunk +- The latency trade-off vs. SentenceChunker: fewer, larger chunks mean lower + validation overhead but later detection + +ParagraphChunker splits on two or more consecutive newlines. Unlike +SentenceChunker, it waits for the model to produce a blank line before +emitting anything — so if the model writes everything as one long paragraph +the stream completes before any chunk is emitted. Use ParagraphChunker when +the validation logic requires full paragraph context: topic coherence, +heading structure, citation presence, or overall paragraph quality. +""" + +import asyncio + +from mellea.core.backend import Backend +from mellea.core.base import Context +from mellea.core.requirement import ( + PartialValidationResult, + Requirement, + ValidationResult, +) +from mellea.stdlib.components import Instruction +from mellea.stdlib.streaming import ( + ChunkEvent, + CompletedEvent, + FullValidationEvent, + QuickCheckEvent, + StreamingDoneEvent, + stream_with_chunking, +) + +_MAX_PARAGRAPH_WORDS = 60 + + +class ParagraphLengthReq(Requirement): + """Fails the stream if any paragraph exceeds a word-count limit. + + Each ``stream_validate`` call receives one complete paragraph (from + :class:`~mellea.stdlib.chunking.ParagraphChunker`). The validator counts + words and immediately fails the stream if the paragraph is too long. This + lets you enforce a maximum paragraph length at generation time rather than + post-processing. + """ + + def __init__(self, max_words: int) -> None: + super().__init__() + self._max_words = max_words + self._para_index = 0 + + def format_for_llm(self) -> str: + return f"Each paragraph must contain at most {self._max_words} words." + + async def stream_validate( + self, chunk: str, *, backend: Backend, ctx: Context + ) -> PartialValidationResult: + self._para_index += 1 + word_count = len(chunk.split()) + if word_count > self._max_words: + return PartialValidationResult( + "fail", + reason=( + f"Paragraph {self._para_index} has {word_count} words " + f"(limit: {self._max_words})" + ), + ) + return PartialValidationResult("unknown") + + async def validate( + self, + backend: Backend, + ctx: Context, + *, + format: type | None = None, + model_options: dict | None = None, + ) -> ValidationResult: + return ValidationResult(result=True) + + +async def main() -> None: + from mellea.stdlib.session import start_session + + m = start_session() + backend = m.backend + ctx = m.ctx + + action = Instruction( + "Write a two-paragraph explanation of how the internet works. " + "Separate the two paragraphs with a blank line. " + f"Keep each paragraph to at most {_MAX_PARAGRAPH_WORDS} words." + ) + req = ParagraphLengthReq(max_words=_MAX_PARAGRAPH_WORDS) + + result = await stream_with_chunking( + action, backend, ctx, requirements=[req], chunking="paragraph" + ) + + print("Streaming events as they arrive (one ChunkEvent per paragraph):") + async for event in result.events(): + match event: + case ChunkEvent(): + word_count = len(event.text.split()) + preview = event.text[:80].replace("\n", "↵") + print( + f" PARAGRAPH[{event.chunk_index}]: {word_count} words — " + f"{preview!r}..." + ) + case QuickCheckEvent(passed=False): + print( + f" QUICK_CHECK[para {event.chunk_index}]: FAIL — " + f"{event.results[0].reason if event.results else 'unknown'}" + ) + case QuickCheckEvent(): + print(f" QUICK_CHECK[para {event.chunk_index}]: pass") + case StreamingDoneEvent(): + print(f" STREAMING_DONE: {len(event.full_text)} chars accumulated") + case FullValidationEvent(): + print(f" FULL_VALIDATION: {'PASS' if event.passed else 'FAIL'}") + case CompletedEvent(): + print(f" COMPLETED: success={event.success}") + case _: + pass + + await result.acomplete() + + print(f"\nCompleted normally: {result.completed}") + if result.streaming_failures: + for _req, pvr in result.streaming_failures: + print(f"Streaming failure: {pvr.reason}") + else: + print(f"Full text:\n{result.full_text}") + + +asyncio.run(main()) diff --git a/docs/examples/streaming/streaming_chunking.py b/docs/examples/streaming/streaming_chunking.py index 81056abaf..adc1afd55 100644 --- a/docs/examples/streaming/streaming_chunking.py +++ b/docs/examples/streaming/streaming_chunking.py @@ -5,7 +5,7 @@ Demonstrates: - Subclassing Requirement to override stream_validate() for early-exit checks - Calling stream_with_chunking() with sentence-level chunking -- Consuming validated chunks via astream() as they arrive +- Observing the full event vocabulary via events() as they arrive - Awaiting full completion with acomplete() to access final_validations and full_text """ @@ -20,7 +20,14 @@ ValidationResult, ) from mellea.stdlib.components import Instruction -from mellea.stdlib.streaming import stream_with_chunking +from mellea.stdlib.streaming import ( + ChunkEvent, + CompletedEvent, + FullValidationEvent, + QuickCheckEvent, + StreamingDoneEvent, + stream_with_chunking, +) # Crude sentence-terminator detector. A run of ``.``/``!``/``?`` counts once # (so "..." and "!!!" are a single terminator). Good enough for an example; @@ -89,9 +96,26 @@ async def main() -> None: action, backend, ctx, requirements=[req], chunking="sentence" ) - print("Streaming chunks as they arrive:") - async for chunk in result.astream(): - print(f" CHUNK: {chunk!r}") + print("Streaming events as they arrive:") + async for event in result.events(): + match event: + case ChunkEvent(): + print(f" CHUNK[{event.chunk_index}]: {event.text!r}") + case QuickCheckEvent(passed=False): + print( + f" QUICK_CHECK[{event.chunk_index}]: FAIL — " + f"{event.results[0].reason if event.results else 'unknown reason'}" + ) + case QuickCheckEvent(): + print(f" QUICK_CHECK[{event.chunk_index}]: pass") + case StreamingDoneEvent(): + print(f" STREAMING_DONE: {len(event.full_text)} chars accumulated") + case FullValidationEvent(): + print(f" FULL_VALIDATION: {'PASS' if event.passed else 'FAIL'}") + case CompletedEvent(): + print(f" COMPLETED: success={event.success}") + case _: + pass # RetryEvent and any future event types await result.acomplete() diff --git a/docs/examples/streaming/word_chunking.py b/docs/examples/streaming/word_chunking.py new file mode 100644 index 000000000..9eab50344 --- /dev/null +++ b/docs/examples/streaming/word_chunking.py @@ -0,0 +1,135 @@ +# pytest: ollama, e2e + +"""Streaming generation with per-word validation using WordChunker. + +Demonstrates: +- Using the ``"word"`` chunking alias for the finest-grained validation +- Detecting a forbidden word the moment it appears in the stream +- Early-exit cancelling generation before the consumer sees the bad word +- How WordChunker compares to SentenceChunker in reaction time + +WordChunker splits on whitespace, so each ``stream_validate`` call receives +exactly one word. This is the highest-sensitivity strategy: validation fires +before the model has finished even the current clause, letting you catch +prohibited content with minimal output produced. + +The trade-off vs. SentenceChunker: validators that need sentence-level context +(grammar, coherence) cannot operate correctly at word granularity because each +chunk carries only a single token. Use WordChunker when the check is +token-local — forbidden words, length budgets, numeric thresholds. +""" + +import asyncio + +from mellea.core.backend import Backend +from mellea.core.base import Context +from mellea.core.requirement import ( + PartialValidationResult, + Requirement, + ValidationResult, +) +from mellea.stdlib.components import Instruction +from mellea.stdlib.streaming import ( + ChunkEvent, + CompletedEvent, + FullValidationEvent, + QuickCheckEvent, + StreamingDoneEvent, + stream_with_chunking, +) + +# Words that must not appear in the model's response. +_FORBIDDEN = {"competitor", "CompetitorX", "legacy", "inferior", "obsolete"} + + +class ForbiddenWordReq(Requirement): + """Fails the stream immediately if a forbidden word appears. + + Each ``stream_validate`` call receives a single word (from + :class:`~mellea.stdlib.chunking.WordChunker`). The check is O(1) + per word — set membership test — so it adds negligible latency. + """ + + def __init__(self, forbidden: set[str]) -> None: + super().__init__() + self._forbidden_display = sorted(forbidden) + self._forbidden = {w.lower() for w in forbidden} + + def format_for_llm(self) -> str: + return f"Do not use any of the following words: {', '.join(self._forbidden_display)}." + + async def stream_validate( + self, chunk: str, *, backend: Backend, ctx: Context + ) -> PartialValidationResult: + word = chunk.strip().lower().strip(".,!?;:\"'") + if word in self._forbidden: + return PartialValidationResult( + "fail", reason=f"Forbidden word detected: {chunk.strip()!r}" + ) + return PartialValidationResult("unknown") + + async def validate( + self, + backend: Backend, + ctx: Context, + *, + format: type | None = None, + model_options: dict | None = None, + ) -> ValidationResult: + return ValidationResult(result=True) + + +async def main() -> None: + from mellea.stdlib.session import start_session + + m = start_session() + backend = m.backend + ctx = m.ctx + + action = Instruction( + "Describe three key advantages of cloud-native software development " + "in two or three sentences." + ) + req = ForbiddenWordReq(forbidden=_FORBIDDEN) + + result = await stream_with_chunking( + action, backend, ctx, requirements=[req], chunking="word" + ) + + print("Streaming events as they arrive (one per word):") + word_count = 0 + async for event in result.events(): + match event: + case ChunkEvent(): + word_count += 1 + # Only print every 5th word to keep output readable + if word_count % 5 == 1: + print(f" ...word {word_count}: {event.text!r}") + case QuickCheckEvent(passed=False): + print( + f" QUICK_CHECK[word {event.chunk_index}]: FAIL — " + f"{event.results[0].reason if event.results else 'unknown'}" + ) + case StreamingDoneEvent(): + print( + f" STREAMING_DONE: {word_count} words, {len(event.full_text)} chars" + ) + case FullValidationEvent(): + print(f" FULL_VALIDATION: {'PASS' if event.passed else 'FAIL'}") + case CompletedEvent(): + print(f" COMPLETED: success={event.success}") + case _: + pass + + await result.acomplete() + + print(f"\nCompleted normally: {result.completed}") + if result.streaming_failures: + for _req, pvr in result.streaming_failures: + print(f"Streaming failure: {pvr.reason}") + print(f"Text at cancellation: {result.full_text!r}") + else: + print(f"Full text: {result.full_text!r}") + + +asyncio.run(main()) diff --git a/mellea/stdlib/__init__.py b/mellea/stdlib/__init__.py index 7a30fdd53..c517fc759 100644 --- a/mellea/stdlib/__init__.py +++ b/mellea/stdlib/__init__.py @@ -13,17 +13,35 @@ ``mellea.stdlib.chunking`` and re-exported here for convenience. The core streaming orchestration primitive :func:`~mellea.stdlib.streaming.stream_with_chunking` and its result type :class:`~mellea.stdlib.streaming.StreamChunkingResult` are also -re-exported here. +re-exported here, alongside the full :class:`~mellea.stdlib.streaming.StreamEvent` +vocabulary for typed event observation. """ from .chunking import ChunkingStrategy, ParagraphChunker, SentenceChunker, WordChunker -from .streaming import StreamChunkingResult, stream_with_chunking +from .streaming import ( + ChunkEvent, + CompletedEvent, + ErrorEvent, + FullValidationEvent, + QuickCheckEvent, + StreamChunkingResult, + StreamEvent, + StreamingDoneEvent, + stream_with_chunking, +) __all__ = [ + "ChunkEvent", "ChunkingStrategy", + "CompletedEvent", + "ErrorEvent", + "FullValidationEvent", "ParagraphChunker", + "QuickCheckEvent", "SentenceChunker", "StreamChunkingResult", + "StreamEvent", + "StreamingDoneEvent", "WordChunker", "stream_with_chunking", ] diff --git a/mellea/stdlib/chunking.py b/mellea/stdlib/chunking.py index 462af81a8..7f426b8af 100644 --- a/mellea/stdlib/chunking.py +++ b/mellea/stdlib/chunking.py @@ -14,6 +14,15 @@ class ChunkingStrategy(ABC): that has not yet reached a chunk boundary is withheld — it is not included in the returned list. Each call is stateless and idempotent given the same input. + **Performance:** ``split()`` is called on every streaming delta, re-scanning + the full accumulated text each time (O(n) in total accumulated length per + call). The orchestrator tracks ``prev_chunk_count`` to extract only the new + chunks. This keeps the chunker stateless and removes the need for ``reset()`` + or deep-copy support, at the cost of re-scanning text already seen. For + typical model outputs (a few KB) the cost is negligible; for very long + streams, a stateful chunker that only processes the new delta would be more + efficient. + End-of-stream contract: ``split()`` always withholds the trailing fragment. When the stream terminates, callers are responsible for processing any remainder: take the full accumulated text, identify everything after the last returned @@ -31,7 +40,12 @@ def split(self, accumulated_text: str) -> list[str]: Args: accumulated_text: The full text accumulated so far, including all - previously seen tokens and the latest delta. + previously seen tokens and the latest delta. Implementations + that scan this string are O(n) in accumulated length per call. + Stateful implementations that only process the new delta are + possible but must never mutate state on ``self`` in place — + use reassignment (``self._buf = self._buf + [x]``) so that + ``copy()``-based cloning in the orchestrator works correctly. Returns: A list of complete chunks. If no chunk boundary has been reached yet, diff --git a/mellea/stdlib/streaming.py b/mellea/stdlib/streaming.py index dfdbcb232..ee3d9b76f 100644 --- a/mellea/stdlib/streaming.py +++ b/mellea/stdlib/streaming.py @@ -5,11 +5,17 @@ :class:`~mellea.stdlib.chunking.ChunkingStrategy` to produce semantic chunks, and runs :meth:`~mellea.core.requirement.Requirement.stream_validate` on each chunk in parallel. Higher-level streaming APIs build on this function. + +The orchestrator emits typed :class:`StreamEvent` objects that consumers can +observe via :meth:`StreamChunkingResult.events`. Raw validated chunks remain +available via :meth:`StreamChunkingResult.astream`. """ import asyncio +import time from collections.abc import AsyncIterator, Sequence from copy import copy +from dataclasses import dataclass, field from typing import Any from ..backends.model_options import ModelOption @@ -17,6 +23,14 @@ from ..core.base import CBlock, Component, Context, ModelOutputThunk from ..core.requirement import PartialValidationResult, Requirement, ValidationResult from ..core.utils import MelleaLogger +from ..telemetry.metrics import ( + classify_error, + record_error, + record_requirement_check, + record_requirement_failure, + record_sampling_outcome, +) +from ..telemetry.tracing import set_span_error, set_span_status_error, trace_application from .chunking import ChunkingStrategy, ParagraphChunker, SentenceChunker, WordChunker _CHUNKING_ALIASES: dict[str, type[ChunkingStrategy]] = { @@ -25,14 +39,175 @@ "paragraph": ParagraphChunker, } +# --------------------------------------------------------------------------- +# Streaming event types +# --------------------------------------------------------------------------- + + +@dataclass +class StreamEvent: + """Base class for all streaming events emitted by :func:`stream_with_chunking`. + + The ``timestamp`` field is auto-populated at instantiation time; callers + do not set it. Because ``timestamp`` has ``init=False`` it is never part + of ``__init__``, so subclasses may declare additional fields in any order + without conflict. Any new ``init=False`` fields on subclasses must also + use ``field(..., init=False)``. + + Attributes: + timestamp: Unix timestamp (seconds) at the moment the event was created. + """ + + timestamp: float = field(default_factory=time.time, init=False) + + +@dataclass +class ChunkEvent(StreamEvent): + """Emitted after each validated chunk is delivered to the consumer. + + Fired after all active requirements' ``stream_validate`` calls return + non-``"fail"`` for this chunk and the chunk has been placed on the + consumer queue. + + Args: + text: The chunk text that was validated and emitted. + chunk_index: Zero-based position of this chunk in the stream. + attempt: Sampling attempt number (always ``1`` in v1). + """ + + text: str + chunk_index: int + attempt: int + + +@dataclass +class QuickCheckEvent(StreamEvent): + """Emitted after each per-chunk streaming validation batch. + + One event per chunk, covering all active requirements in parallel. + Not emitted when there are no ``requirements``. + + Args: + chunk_index: Zero-based position of the chunk that was validated. + attempt: Sampling attempt number (always ``1`` in v1). + passed: ``True`` if all active requirements returned non-``"fail"`` + for this chunk. + results: :class:`~mellea.core.requirement.PartialValidationResult` + from each active requirement, in the same order as the active + slice of ``requirements``. + """ + + chunk_index: int + attempt: int + passed: bool + results: list[PartialValidationResult] + + +@dataclass +class StreamingDoneEvent(StreamEvent): + """Emitted after all chunks have been validated and delivered to the consumer. + + Fired after the regular token stream and any trailing fragment released by + :meth:`~mellea.stdlib.chunking.ChunkingStrategy.flush` have both been + processed. Only emitted on natural completion — not on early exit (a + requirement returned ``"fail"``) or on exception. + + Args: + attempt: Sampling attempt number (always ``1`` in v1). + full_text: Complete accumulated text at stream end. + """ + + attempt: int + full_text: str + + +@dataclass +class FullValidationEvent(StreamEvent): + """Emitted after the final :meth:`~mellea.core.requirement.Requirement.validate` calls complete. + + Only emitted when at least one requirement did not fail during streaming + and the stream completed naturally. Not emitted on early exit. + + Args: + attempt: Sampling attempt number (always ``1`` in v1). + passed: ``True`` if all final + :class:`~mellea.core.requirement.ValidationResult` objects passed. + results: :class:`~mellea.core.requirement.ValidationResult` from each + non-failed requirement, in requirement order. + """ + + attempt: int + passed: bool + results: list[ValidationResult] + + +@dataclass +class RetryEvent(StreamEvent): + """Reserved for future use. + + Defined for API completeness — ``RetryEvent`` is not emitted by the + v1 orchestrator because v1 retry is caller-driven re-invocation of + :func:`stream_with_chunking`. When orchestrator-side retry is added, + this event will fire before each re-attempt. + + Args: + attempt: Attempt number being started (1-based). + reason: Human-readable reason for the retry. + """ + + attempt: int + reason: str + + +@dataclass +class CompletedEvent(StreamEvent): + """Emitted when the orchestrator exits, including early-exit cases. + + Always the last event before :meth:`StreamChunkingResult.events` + terminates. ``success`` reflects :attr:`StreamChunkingResult.completed`. + + Args: + success: ``True`` if the stream completed normally (no ``"fail"`` + result and no unhandled exception); ``False`` otherwise. + full_text: Complete accumulated text. On early exit or exception, + reflects whatever was accumulated before cancellation. + attempts_used: Number of orchestrator invocations (always ``1`` in v1). + """ + + success: bool + full_text: str + attempts_used: int + + +@dataclass +class ErrorEvent(StreamEvent): + """Emitted when an unhandled exception occurs in the orchestrator. + + Args: + exception_type: Python class name of the exception + (e.g. ``"ValueError"``). + detail: String representation of the exception. If + ``cancel_generation()`` also raised during cleanup, the cleanup + error is appended. + """ + + exception_type: str + detail: str + + +# --------------------------------------------------------------------------- +# Result container +# --------------------------------------------------------------------------- + class StreamChunkingResult: """Result of a :func:`stream_with_chunking` operation. Provides async iteration over validated text chunks as they complete - (:meth:`astream`), a blocking :meth:`acomplete` for awaiting the full - result including final validation, and :attr:`as_thunk` for wrapping the - output as a :class:`~mellea.core.base.ModelOutputThunk`. + (:meth:`astream`), typed :class:`StreamEvent` objects via :meth:`events`, + a blocking :meth:`acomplete` for awaiting the full result including final + validation, and :attr:`as_thunk` for wrapping the output as a + :class:`~mellea.core.base.ModelOutputThunk`. Instances are created by :func:`stream_with_chunking`; do not instantiate directly. @@ -45,7 +220,10 @@ class StreamChunkingResult: Attributes: completed: ``False`` if the stream exited early because a requirement returned ``"fail"`` during streaming; ``True`` otherwise. - full_text: The complete generated text accumulated during streaming. + full_text: The generated text available after streaming completes. + On natural completion, the full accumulated text. On early exit + (a requirement returned ``"fail"``), only the validated and emitted + portion — i.e. what consumers received via :meth:`astream`. Available after :meth:`acomplete` returns. final_validations: :class:`~mellea.core.requirement.ValidationResult` objects from the final :meth:`~mellea.core.requirement.Requirement.validate` @@ -60,6 +238,10 @@ def __init__(self, mot: ModelOutputThunk, ctx: Context) -> None: self._mot = mot self._ctx = ctx self._chunk_queue: asyncio.Queue[str | None | Exception] = asyncio.Queue() + # If no consumer calls events(), events accumulate in this queue until + # the result object is garbage-collected. That is intentional — event + # production is unconditional; consumption is opt-in. + self._event_queue: asyncio.Queue[StreamEvent | None] = asyncio.Queue() self._orchestration_task: asyncio.Task[None] | None = None self._done = asyncio.Event() # Stashed so acomplete() surfaces orchestrator failures even when the @@ -73,6 +255,7 @@ def __init__(self, mot: ModelOutputThunk, ctx: Context) -> None: # None, and silently skips it — leaving the caller with zero chunks # and no error. self._exception_surfaced: bool = False + self._events_consumed: bool = False self.completed: bool = True self.full_text: str = "" @@ -115,6 +298,66 @@ async def astream(self) -> AsyncIterator[str]: raise item yield item + async def events(self) -> AsyncIterator[StreamEvent]: + """Yield typed streaming events as they are emitted by the orchestrator. + + Each yielded object is a :class:`StreamEvent` subclass describing a + point in the orchestration lifecycle. Consumers can dispatch on type: + + .. code-block:: python + + async for event in result.events(): + match event: + case ChunkEvent(): + print(f"chunk {event.chunk_index}: {event.text!r}") + case QuickCheckEvent(passed=False): + print(f"chunk {event.chunk_index} failed validation") + case CompletedEvent(): + print(f"done — success={event.success}") + + Typical event order (natural completion with requirements): + + 1. :class:`QuickCheckEvent` / :class:`ChunkEvent` pairs, one per chunk + (validation fires first; the chunk is released to the consumer only + after passing). Includes any trailing fragment released by the + chunking strategy's ``flush()`` method. + 2. :class:`StreamingDoneEvent` — all chunks (including flush) delivered. + 3. :class:`FullValidationEvent` — final ``validate()`` calls returned. + 4. :class:`CompletedEvent` — orchestrator is exiting. + + On early exit: :class:`QuickCheckEvent` (``passed=False``) is the + last validation event, followed by :class:`CompletedEvent`. No + :class:`StreamingDoneEvent` or :class:`FullValidationEvent` is emitted. + + On exception: :class:`ErrorEvent` followed by :class:`CompletedEvent`. + + **Single-consumer.** Events are delivered via a queue that this method + drains; calling ``events()`` a second time raises :exc:`RuntimeError`. + + Yields: + StreamEvent: A typed event from the orchestrator. + + Raises: + RuntimeError: If called more than once on the same result. + + Note: + ``events()`` itself never raises from the event stream. If the + orchestrator encounters an unhandled exception, an + :class:`ErrorEvent` is emitted and iteration ends normally. + Exceptions surface to the caller via :meth:`astream` (as a + re-raised exception) or :meth:`acomplete`. + """ + if self._events_consumed: + raise RuntimeError( + "events() is single-consumer; this iterator has already been drained" + ) + self._events_consumed = True + while True: + item = await self._event_queue.get() + if item is None: + return + yield item + async def acomplete(self) -> None: """Await full completion, including final validation. @@ -124,7 +367,12 @@ async def acomplete(self) -> None: exhaustion, this call is effectively a no-op. Raises: - Exception: Propagates any error from the orchestration task. + Exception: Propagates the orchestrator exception if :meth:`astream` + has not yet consumed it (raise-once — only one of ``astream`` + or ``acomplete`` raises, whichever drains the failure marker + first). + asyncio.CancelledError: If the orchestration task was externally + cancelled (e.g. via :func:`asyncio.wait_for` timeout). """ await self._done.wait() # Raise-once: if astream() already surfaced the exception, skip. @@ -155,8 +403,8 @@ def as_thunk(self) -> ModelOutputThunk[str]: Returns a new thunk with ``value`` set to :attr:`full_text` and generation metadata copied from the original MOT. Safe to call on - early-exit results; ``value`` will reflect whatever was accumulated - before cancellation. + early-exit results; ``value`` reflects the validated and emitted + portion (same as :attr:`full_text` — see its docstring). Note: On early exit, ``cancel_generation()`` forces the MOT into a @@ -187,6 +435,11 @@ def as_thunk(self) -> ModelOutputThunk[str]: return thunk +# --------------------------------------------------------------------------- +# Orchestrator +# --------------------------------------------------------------------------- + + async def _orchestrate_streaming( result: StreamChunkingResult, mot: ModelOutputThunk, @@ -196,154 +449,257 @@ async def _orchestrate_streaming( val_backend: Backend, ) -> None: accumulated = "" - emitted_end = 0 # byte offset into accumulated after the last emitted chunk + emitted_end = 0 # byte offset in accumulated after the last emitted chunk prev_chunk_count = 0 failed_indices: set[int] = set() early_exit = False + chunk_index = 0 + + with trace_application("stream_with_chunking") as span: + + async def _process_chunk(c: str, ci: int) -> bool: + """Validate *c*, emit events, push to consumer queue. + + Returns ``True`` if a ``"fail"`` was recorded (caller should + trigger early exit), ``False`` if the chunk was validated and + emitted successfully. + """ + active = [ + (i, req) for i, req in enumerate(cloned_reqs) if i not in failed_indices + ] + pvrs: list[PartialValidationResult] = [] + if active: + pvrs = list( + await asyncio.gather( + *[ + req.stream_validate(c, backend=val_backend, ctx=ctx) + for _, req in active + ] + ) + ) + for (idx, req), pvr in zip(active, pvrs): + if pvr.success == "fail": + failed_indices.add(idx) + result.streaming_failures.append((req, pvr)) + + any_fail = any(pvr.success == "fail" for pvr in pvrs) + qc_event = QuickCheckEvent( + chunk_index=ci, attempt=1, passed=not any_fail, results=pvrs + ) + await result._event_queue.put(qc_event) + if span is not None: + span.add_event( + "quick_check", + { + "chunk_index": ci, + "passed": not any_fail, + "requirement_count": len(active), + }, + ) + for (_, req), pvr in zip(active, pvrs): + record_requirement_check(type(req).__name__) + if pvr.success == "fail": + record_requirement_failure(type(req).__name__, "") + + if failed_indices: + return True + + await result._chunk_queue.put(c) + chunk_ev = ChunkEvent(text=c, chunk_index=ci, attempt=1) + await result._event_queue.put(chunk_ev) + if span is not None: + span.add_event("chunk", {"chunk_index": ci, "text_length": len(c)}) + return False - async def _validate_and_emit(c: str) -> bool: - """Run stream_validate on chunk c across active requirements. - - Returns True if a failure was recorded (caller should early-exit), - False otherwise (chunk was emitted to the consumer queue). - """ - active = [ - (i, req) for i, req in enumerate(cloned_reqs) if i not in failed_indices - ] - if active: - async with asyncio.TaskGroup() as tg: - _tasks = [ - tg.create_task(req.stream_validate(c, backend=val_backend, ctx=ctx)) - for _, req in active - ] - pvrs: list[PartialValidationResult] = [t.result() for t in _tasks] - for (idx, req), pvr in zip(active, pvrs): - if pvr.success == "fail": - failed_indices.add(idx) - result.streaming_failures.append((req, pvr)) - - if failed_indices: - return True - - await result._chunk_queue.put(c) - return False - - try: - while not mot.is_computed(): - try: - delta = await mot.astream() - except RuntimeError: - # Expected race: mot.is_computed() was False at the top of the - # loop but the stream finished before we re-entered astream(). - # Any other RuntimeError is a real bug and must propagate. - if mot.is_computed(): - break - raise - - accumulated += delta - chunks = chunking.split(accumulated) - new_chunks = chunks[prev_chunk_count:] - prev_chunk_count = len(chunks) - - for c in new_chunks: - failed = await _validate_and_emit(c) - if failed: - early_exit = True - result.completed = False - await mot.cancel_generation() - break - pos = accumulated.find(c, emitted_end) - if pos >= 0: - emitted_end = pos + len(c) - - if early_exit: - break # break the while loop; cancel_generation() already set _computed=True - - # Stream ended naturally: flush any withheld trailing fragment and - # run stream_validate on it. Skipped on early exit — the generation - # was cancelled, the trailing fragment is incomplete. - if not early_exit: - for c in chunking.flush(accumulated): - failed = await _validate_and_emit(c) - if failed: - early_exit = True - result.completed = False + try: + while not mot.is_computed(): + try: + delta = await mot.astream() + except RuntimeError: + # Expected race: mot.is_computed() was False at the top of the + # loop but the stream finished before we re-entered astream(). + # Any other RuntimeError is a real bug and must propagate. + if mot.is_computed(): + break + raise + + accumulated += delta + chunks = chunking.split(accumulated) + new_chunks = chunks[prev_chunk_count:] + prev_chunk_count = len(chunks) + + for c in new_chunks: + failed = await _process_chunk(c, chunk_index) + if failed: + early_exit = True + result.completed = False + await mot.cancel_generation() + if span is not None: + reason = result.streaming_failures[-1][1].reason or "" + set_span_status_error( + span, f"Streaming validation failed: {reason}" + ) + break + pos = accumulated.find(c, emitted_end) + if pos >= 0: + emitted_end = pos + len(c) + chunk_index += 1 + + if early_exit: break - pos = accumulated.find(c, emitted_end) - if pos >= 0: - emitted_end = pos + len(c) - - # On early exit, full_text is the prefix of accumulated up to and - # including the last emitted chunk — preserving original inter-chunk - # spacing from the token stream (chunk concatenation would strip it). - # On natural completion, accumulated is used directly. - result.full_text = accumulated[:emitted_end] if early_exit else accumulated - - non_failed = [ - req for i, req in enumerate(cloned_reqs) if i not in failed_indices - ] - if non_failed and not early_exit: - async with asyncio.TaskGroup() as tg: - _final_tasks = [ - tg.create_task(req.validate(val_backend, ctx)) for req in non_failed + + # Stream ended naturally: flush any withheld trailing fragment, then + # emit StreamingDoneEvent once all chunks (regular + flush) have been + # validated and delivered. If a flush chunk fails, early_exit is + # set and StreamingDoneEvent is suppressed (same contract as the + # regular early-exit path). Skipped entirely on early exit. + if not early_exit: + for c in chunking.flush(accumulated): + failed = await _process_chunk(c, chunk_index) + if failed: + early_exit = True + result.completed = False + if span is not None: + reason = result.streaming_failures[-1][1].reason or "" + set_span_status_error( + span, f"Streaming validation failed on flush: {reason}" + ) + break + pos = accumulated.find(c, emitted_end) + if pos >= 0: + emitted_end = pos + len(c) + chunk_index += 1 + + if not early_exit: + streaming_done = StreamingDoneEvent( + attempt=1, full_text=accumulated + ) + await result._event_queue.put(streaming_done) + if span is not None: + span.add_event( + "streaming_done", {"full_text_length": len(accumulated)} + ) + + # On early exit, full_text is the portion of accumulated that was + # actually validated and emitted to the consumer. On natural + # completion, the full accumulated text is used. + result.full_text = accumulated[:emitted_end] if early_exit else accumulated + + if not early_exit: + non_failed = [ + req for i, req in enumerate(cloned_reqs) if i not in failed_indices ] - result.final_validations = [t.result() for t in _final_tasks] - - except Exception as exc: - # Orchestrator is leaving — we must stop the backend producer too, - # otherwise mot._async_queue (maxsize=20) fills and the feeder task - # blocks indefinitely. The spec (#891, #901) calls this out for the - # "fail" path; the same reasoning applies to any unplanned exit. - # Pass `exc` so the backend telemetry span records the real cause - # rather than a generic "Generation cancelled". - # TaskGroup wraps failures in ExceptionGroup; unwrap so telemetry and - # the chunk queue see the original exception, not the wrapper. - # ExceptionGroup (not BaseExceptionGroup) guarantees Exception elements. - if isinstance(exc, ExceptionGroup) and exc.exceptions: - reported_exc: Exception = exc.exceptions[0] - if len(exc.exceptions) > 1: - MelleaLogger.get_logger().warning( - "stream_with_chunking: %d validator(s) failed simultaneously; " - "reporting first, suppressing rest: %r", - len(exc.exceptions) - 1, - exc.exceptions[1:], + if non_failed: + vrs: list[ValidationResult] = list( + await asyncio.gather( + *[req.validate(val_backend, ctx) for req in non_failed] + ) + ) + result.final_validations = vrs + all_passed = all(vr.as_bool() for vr in vrs) + full_val_ev = FullValidationEvent( + attempt=1, passed=all_passed, results=list(vrs) + ) + await result._event_queue.put(full_val_ev) + if span is not None: + span.add_event( + "full_validation", + { + "passed": all_passed, + "requirement_count": len(non_failed), + }, + ) + + except Exception as exc: + # Stash the exception before any await so acomplete() can always + # surface it even if a subsequent await is interrupted by an + # external CancelledError. + result._orchestration_exception = exc + # Mark as failed immediately — before any event is enqueued — so + # that CompletedEvent.success and result.completed are consistent + # if the consumer observes them during ErrorEvent processing. + result.completed = False + result.full_text = accumulated # best-effort partial capture + # Only cancel generation if the stream hasn't already completed + # (e.g. an exception from the final validate() call arrives after + # the token stream ended naturally — cancelling an already-computed + # MOT is a no-op at best and misleading in telemetry). + if not mot.is_computed(): + try: + await mot.cancel_generation(error=exc) + error_detail = str(exc) + except Exception as cleanup_exc: + # Never let cleanup mask the original exception. + error_detail = f"{exc!r} (cancel cleanup raised: {cleanup_exc!r})" + MelleaLogger.get_logger().debug( + "stream_with_chunking: cancel_generation() raised during " + "exception cleanup (original: %r, cleanup: %r)", + exc, + cleanup_exc, + ) + else: + error_detail = str(exc) + error_ev = ErrorEvent( + exception_type=type(exc).__name__, detail=error_detail + ) + await result._event_queue.put(error_ev) + if span is not None: + span.add_event( + "error", + { + "exception_type": error_ev.exception_type, + "detail": error_ev.detail, + }, ) - else: - reported_exc = exc - try: - await mot.cancel_generation(error=reported_exc) - except Exception as cleanup_exc: - # Never let cleanup mask the original exception: log loudly and - # continue to surface `exc` to the consumer. - # TODO(#902): replace this log with an ErrorEvent emission. - MelleaLogger.get_logger().warning( - "stream_with_chunking: cancel_generation() raised during " - "exception cleanup (original: %r, cleanup: %r)", - reported_exc, - cleanup_exc, + set_span_error(span, exc) + record_error( + error_type=classify_error(exc), + model=result._mot.generation.model or "unknown", + provider=result._mot.generation.provider or "unknown", + exception_class=type(exc).__name__, ) - result.completed = False - result._orchestration_exception = reported_exc - await result._chunk_queue.put(reported_exc) - finally: - # CancelledError (BaseException, not Exception) bypasses the except - # block above, so cancel_generation() may not have been called. - # Catch only Exception here so CancelledError / KeyboardInterrupt / - # SystemExit still propagate to the caller. - if not mot.is_computed(): - try: - await mot.cancel_generation() - except Exception: - pass - # put_nowait + set() are synchronous — no await point, so they cannot - # be interrupted by task cancellation. Consumers waiting on - # _done.wait() are always released, even if the task was cancelled - # mid-cleanup. The queue is unbounded, so QueueFull cannot occur. - try: + await result._chunk_queue.put(exc) + finally: + # CancelledError (BaseException, not Exception) bypasses the except + # block above, so cancel_generation() may not have been called. + # Guard here ensures the backend producer is always stopped, even on + # external task cancellation (e.g. asyncio.wait_for timeout). + # Also mark completion as failed for any BaseException path (e.g. + # CancelledError) that bypassed the except block — otherwise + # result.completed stays True and CompletedEvent / metrics lie. + if not mot.is_computed(): + result.completed = False + try: + await mot.cancel_generation() + except BaseException: + pass + + completed_ev = CompletedEvent( + success=result.completed, full_text=result.full_text, attempts_used=1 + ) + # Use put_nowait for the terminal bookkeeping: both queues are + # unbounded so this can never raise QueueFull, and it eliminates + # the await points that could be interrupted by a pending + # CancelledError before _done.set() runs. + result._event_queue.put_nowait(completed_ev) + if span is not None: + span.add_event( + "completed", + { + "success": result.completed, + "full_text_length": len(result.full_text), + }, + ) + record_sampling_outcome("stream_with_chunking", success=result.completed) + result._chunk_queue.put_nowait(None) - except asyncio.QueueFull: - pass - result._done.set() + result._event_queue.put_nowait(None) + result._done.set() + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- async def stream_with_chunking( @@ -392,6 +748,10 @@ async def stream_with_chunking( begins, so the originals are never mutated and a raising ``__copy__`` cannot leak an in-flight backend task. + The orchestrator emits typed :class:`StreamEvent` objects throughout + execution. Consume them via :meth:`StreamChunkingResult.events` in + parallel with or instead of :meth:`StreamChunkingResult.astream`. + Requirements that need context beyond the current chunk should accumulate it themselves across ``stream_validate`` calls (e.g. ``self._seen = self._seen + chunk``). They must not read ``mot.astream()`` @@ -413,7 +773,8 @@ async def stream_with_chunking( Note: v1 retry is simple re-invocation of this function. Plugin hooks (``SAMPLING_LOOP_START``, ``SAMPLING_REPAIR``, etc.) do not fire - on retries — use the ``#902`` event types for observability instead. + during streaming — use :meth:`StreamChunkingResult.events` for + observability instead. Args: action: The component or content block to generate from. @@ -431,7 +792,8 @@ async def stream_with_chunking( Returns: StreamChunkingResult: A result object providing :meth:`~StreamChunkingResult.astream` - for incremental chunk consumption and + for incremental chunk consumption, :meth:`~StreamChunkingResult.events` for + typed streaming events, and :meth:`~StreamChunkingResult.acomplete` for blocking until done. Raises: diff --git a/mellea/telemetry/tracing.py b/mellea/telemetry/tracing.py index 48a3bf5e6..fc459f3ef 100644 --- a/mellea/telemetry/tracing.py +++ b/mellea/telemetry/tracing.py @@ -280,6 +280,22 @@ def set_span_error(span: Any, exception: Exception) -> None: span.set_status(trace.Status(trace.StatusCode.ERROR, str(exception))) # type: ignore +def set_span_status_error(span: Any, description: str) -> None: + """Mark a span as ERROR without recording a phantom exception event. + + Use this for validation failures and other non-exception error conditions + where the span should be marked failed but no exception was actually raised. + Calling ``set_span_error`` in these cases would create a misleading recorded + exception event in OTEL traces. + + Args: + span: The span object (may be None if tracing is disabled) + description: Human-readable reason for the failure. + """ + if span is not None and _OTEL_AVAILABLE: + span.set_status(trace.Status(trace.StatusCode.ERROR, description)) # type: ignore + + __all__ = [ "add_span_event", "end_backend_span", @@ -288,6 +304,7 @@ def set_span_error(span: Any, exception: Exception) -> None: "is_content_tracing_enabled", "set_span_attribute", "set_span_error", + "set_span_status_error", "start_backend_span", "trace_application", "trace_backend", diff --git a/test/stdlib/test_streaming.py b/test/stdlib/test_streaming.py index ad042df56..f5e3af3c1 100644 --- a/test/stdlib/test_streaming.py +++ b/test/stdlib/test_streaming.py @@ -7,7 +7,9 @@ """ import asyncio +import time from typing import Any +from unittest.mock import patch import pytest @@ -19,7 +21,17 @@ ValidationResult, ) from mellea.stdlib.context import SimpleContext -from mellea.stdlib.streaming import stream_with_chunking +from mellea.stdlib.streaming import ( + ChunkEvent, + CompletedEvent, + ErrorEvent, + FullValidationEvent, + QuickCheckEvent, + RetryEvent, + StreamEvent, + StreamingDoneEvent, + stream_with_chunking, +) # --------------------------------------------------------------------------- # StreamingMockBackend @@ -578,6 +590,27 @@ async def test_no_requirements_streams_without_validation() -> None: assert result.streaming_failures == [] +@pytest.mark.asyncio +async def test_no_requirements_events_omits_full_validation_event() -> None: + """With no requirements, events() emits StreamingDoneEvent but + NOT FullValidationEvent — there is nothing to validate at stream end.""" + response = "Chunk one. Chunk two. " + backend = StreamingMockBackend(response, token_size=3) + + result = await stream_with_chunking( + _action(), backend, _ctx(), requirements=None, chunking="sentence" + ) + await result.acomplete() + + evts = [e async for e in result.events()] + types = [type(e) for e in evts] + + assert StreamingDoneEvent in types + assert FullValidationEvent not in types + assert isinstance(evts[-1], CompletedEvent) + assert evts[-1].success is True + + @pytest.mark.asyncio async def test_multiple_chunks_in_one_batch_with_mid_batch_fail() -> None: """When one astream() delta produces several complete chunks and one in @@ -765,6 +798,14 @@ async def validate( assert ok_result.as_thunk.is_computed() is True +@pytest.mark.asyncio +async def test_unknown_chunking_alias_raises_value_error() -> None: + """An unrecognised chunking alias raises ValueError before any backend call.""" + backend = StreamingMockBackend("hello world") + with pytest.raises(ValueError, match="unknown_alias"): + await stream_with_chunking(_action(), backend, _ctx(), chunking="unknown_alias") + + @pytest.mark.asyncio async def test_exception_in_stream_validate_cancels_generation() -> None: """Verifies the orchestrator's exception-path cleanup: if stream_validate @@ -1200,6 +1241,133 @@ async def test_stream_with_chunking_requirement_copy_contract( # --------------------------------------------------------------------------- # Fix 3 — TaskGroup cancels peer validators on first failure # --------------------------------------------------------------------------- +# Event type construction +# --------------------------------------------------------------------------- + + +def test_stream_event_types_have_auto_timestamp() -> None: + """All seven event types set timestamp automatically; callers do not pass it.""" + before = time.time() + all_events = [ + ChunkEvent(text="hello", chunk_index=0, attempt=1), + QuickCheckEvent( + chunk_index=0, + attempt=1, + passed=True, + results=[PartialValidationResult("unknown")], + ), + StreamingDoneEvent(attempt=1, full_text="hello"), + FullValidationEvent( + attempt=1, passed=True, results=[ValidationResult(result=True)] + ), + RetryEvent(attempt=2, reason="too long"), + CompletedEvent(success=True, full_text="hello", attempts_used=1), + ErrorEvent(exception_type="ValueError", detail="boom"), + ] + after = time.time() + + for ev in all_events: + assert isinstance(ev, StreamEvent) + assert before <= ev.timestamp <= after, ( + f"{type(ev).__name__} timestamp out of range" + ) + + +# --------------------------------------------------------------------------- +# Event emission — happy path +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_event_emission_order_happy_path() -> None: + """Happy path: QuickCheckEvent/ChunkEvent pairs, then StreamingDoneEvent, + FullValidationEvent, CompletedEvent(success=True).""" + response = "First sentence. Second sentence. " + backend = StreamingMockBackend(response, token_size=4) + req = AlwaysUnknownReq() + + result = await stream_with_chunking( + _action(), backend, _ctx(), requirements=[req], chunking="sentence" + ) + await result.acomplete() + + evts: list[StreamEvent] = [e async for e in result.events()] + + assert isinstance(evts[-1], CompletedEvent) + assert evts[-1].success is True + assert evts[-1].attempts_used == 1 + + types = [type(e) for e in evts] + assert StreamingDoneEvent in types + assert types.index(StreamingDoneEvent) < types.index(CompletedEvent) + assert FullValidationEvent in types + assert types.index(FullValidationEvent) > types.index(StreamingDoneEvent) + + chunk_events = [e for e in evts if isinstance(e, ChunkEvent)] + qc_events = [e for e in evts if isinstance(e, QuickCheckEvent)] + assert len(chunk_events) == 2 + assert len(qc_events) == 2 + assert [e.chunk_index for e in chunk_events] == [0, 1] + assert [e.chunk_index for e in qc_events] == [0, 1] + assert all(e.passed for e in qc_events) + + # QuickCheckEvent fires before ChunkEvent within each pair: validation must + # complete before the chunk is released to the consumer queue. + for ci in range(2): + qc_pos = evts.index(qc_events[ci]) + ch_pos = evts.index(chunk_events[ci]) + assert qc_pos < ch_pos, f"chunk {ci}: QuickCheckEvent must precede ChunkEvent" + + +@pytest.mark.asyncio +async def test_streaming_done_event_carries_full_text() -> None: + """StreamingDoneEvent.full_text matches full_text on the result.""" + response = "One sentence. Two sentences. " + backend = StreamingMockBackend(response, token_size=5) + + result = await stream_with_chunking(_action(), backend, _ctx(), chunking="sentence") + await result.acomplete() + + evts = [e async for e in result.events()] + done_events = [e for e in evts if isinstance(e, StreamingDoneEvent)] + assert len(done_events) == 1 + assert done_events[0].full_text == result.full_text + + +# --------------------------------------------------------------------------- +# Event emission — early exit +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_event_emission_on_early_exit() -> None: + """Early exit: QuickCheckEvent(passed=False) present; no StreamingDoneEvent + or FullValidationEvent; CompletedEvent(success=False).""" + response = "word " * 30 + backend = StreamingMockBackend(response, token_size=3) + req = FailAfterWordsReq(threshold=2) + + result = await stream_with_chunking( + _action(), backend, _ctx(), requirements=[req], chunking="word" + ) + await result.acomplete() + + evts = [e async for e in result.events()] + + assert isinstance(evts[-1], CompletedEvent) + assert evts[-1].success is False + + types = [type(e) for e in evts] + assert FullValidationEvent not in types + assert StreamingDoneEvent not in types + + fail_qc = [e for e in evts if isinstance(e, QuickCheckEvent) and not e.passed] + assert len(fail_qc) >= 1 + + +# --------------------------------------------------------------------------- +# Event emission — exception path +# --------------------------------------------------------------------------- @pytest.mark.asyncio @@ -1297,3 +1465,255 @@ async def generate_from_raw( with pytest.raises(RuntimeError, match="already-computed MOT"): await stream_with_chunking(_action(), PrecomputedBackend(), _ctx()) + + +@pytest.mark.asyncio +async def test_error_event_on_stream_validate_exception() -> None: + """When stream_validate raises, ErrorEvent is emitted and CompletedEvent follows.""" + + class RaisingReq2(Requirement): + def format_for_llm(self) -> str: + return "raises" + + async def stream_validate( + self, chunk: str, *, backend: Any, ctx: Any + ) -> PartialValidationResult: + raise RuntimeError("test-error") + + async def validate( + self, + backend: Any, + ctx: Any, + *, + format: Any = None, + model_options: Any = None, + ) -> ValidationResult: + return ValidationResult(result=True) + + backend = StreamingMockBackend("hello world", token_size=5) + result = await stream_with_chunking( + _action(), backend, _ctx(), requirements=[RaisingReq2()], chunking="word" + ) + with pytest.raises(RuntimeError, match="test-error"): + async for _c in result.astream(): + pass + await asyncio.wait_for(result.acomplete(), timeout=5.0) + + evts = [e async for e in result.events()] + + error_events = [e for e in evts if isinstance(e, ErrorEvent)] + assert len(error_events) == 1 + assert error_events[0].exception_type == "RuntimeError" + assert "test-error" in error_events[0].detail + + assert isinstance(evts[-1], CompletedEvent) + assert evts[-1].success is False + + +# --------------------------------------------------------------------------- +# Metric helper calls +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_record_requirement_check_called_per_chunk() -> None: + """record_requirement_check is called once per chunk per active requirement.""" + response = "One. Two. " + backend = StreamingMockBackend(response, token_size=3) + req = AlwaysUnknownReq() + + with patch("mellea.stdlib.streaming.record_requirement_check") as mock_check: + result = await stream_with_chunking( + _action(), backend, _ctx(), requirements=[req], chunking="sentence" + ) + await result.acomplete() + + assert mock_check.call_count == 2 + for call in mock_check.call_args_list: + assert call.args[0] == "AlwaysUnknownReq" + + +@pytest.mark.asyncio +async def test_record_requirement_failure_called_on_fail() -> None: + """record_requirement_failure is called with class name and reason on fail.""" + response = "word " * 10 + backend = StreamingMockBackend(response, token_size=3) + req = FailAfterWordsReq(threshold=2) + + with patch("mellea.stdlib.streaming.record_requirement_failure") as mock_fail: + result = await stream_with_chunking( + _action(), backend, _ctx(), requirements=[req], chunking="word" + ) + await result.acomplete() + + assert mock_fail.call_count >= 1 + first_call = mock_fail.call_args_list[0] + assert first_call.args[0] == "FailAfterWordsReq" + assert first_call.args[1] == "" # reason not included in metric (cardinality) + + +@pytest.mark.asyncio +async def test_record_sampling_outcome_success() -> None: + """record_sampling_outcome called with success=True on normal completion.""" + response = "One sentence. " + backend = StreamingMockBackend(response, token_size=4) + + with patch("mellea.stdlib.streaming.record_sampling_outcome") as mock_outcome: + result = await stream_with_chunking( + _action(), backend, _ctx(), chunking="sentence" + ) + await result.acomplete() + + mock_outcome.assert_called_once_with("stream_with_chunking", success=True) + + +@pytest.mark.asyncio +async def test_record_sampling_outcome_failure_on_early_exit() -> None: + """record_sampling_outcome called with success=False on early exit.""" + response = "word " * 20 + backend = StreamingMockBackend(response, token_size=3) + req = FailAfterWordsReq(threshold=1) + + with patch("mellea.stdlib.streaming.record_sampling_outcome") as mock_outcome: + result = await stream_with_chunking( + _action(), backend, _ctx(), requirements=[req], chunking="word" + ) + await result.acomplete() + + mock_outcome.assert_called_once_with("stream_with_chunking", success=False) + + +# --------------------------------------------------------------------------- +# Concurrent astream() + events() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_concurrent_astream_and_events() -> None: + """astream() and events() can be consumed concurrently without interference.""" + response = "Alpha. Beta. Gamma. " + backend = StreamingMockBackend(response, token_size=4) + req = AlwaysUnknownReq() + + result = await stream_with_chunking( + _action(), backend, _ctx(), requirements=[req], chunking="sentence" + ) + + async def drain_chunks() -> list[str]: + return [c async for c in result.astream()] + + async def drain_events() -> list[StreamEvent]: + return [e async for e in result.events()] + + chunks, evts = await asyncio.gather(drain_chunks(), drain_events()) + await result.acomplete() + + assert len(chunks) == 3 + assert isinstance(evts[-1], CompletedEvent) + assert evts[-1].success is True + + chunk_evts = [e for e in evts if isinstance(e, ChunkEvent)] + assert [e.chunk_index for e in chunk_evts] == list(range(len(chunks))) + + +# --------------------------------------------------------------------------- +# events() single-consumer guard +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_events_single_consumer_guard_raises_on_second_call() -> None: + """events() raises RuntimeError if called a second time on the same result.""" + response = "One sentence. " + backend = StreamingMockBackend(response, token_size=4) + + result = await stream_with_chunking(_action(), backend, _ctx(), chunking="sentence") + await result.acomplete() + + # First drain — OK. + async for _ in result.events(): + pass + + # Second call must raise immediately. + with pytest.raises(RuntimeError, match="single-consumer"): + async for _ in result.events(): + pass + + +# --------------------------------------------------------------------------- +# CancelledError path +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_cancelled_task_sets_completed_false() -> None: + """External task cancellation must leave result.completed=False. + + CancelledError is a BaseException and bypasses except Exception, so + the finally block is responsible for setting result.completed=False. + Regression: without the fix, result.completed stays True and + CompletedEvent / record_sampling_outcome lie to callers. + + Uses a backend whose token feed blocks on an asyncio.Event that is + never set, guaranteeing the orchestrator is suspended at astream() + when the task is cancelled. + + Requires ``await asyncio.sleep(0)`` before ``cancel()`` — see inline + comment. Python 3.12's C Task implementation skips the coroutine body + entirely (including finally blocks) when cancelled before the first + ``coro.send(None)``. + """ + gate = asyncio.Event() # never set — feed task blocks indefinitely + feed_task: asyncio.Task[None] | None = None + + async def _blocking_feed(mot: ModelOutputThunk) -> None: + await gate.wait() + + class BlockingBackend(Backend): + async def _generate_from_context( + self, action: Any, ctx: Any, **kwargs: Any + ) -> tuple[ModelOutputThunk, Any]: + nonlocal feed_task + mot = _make_mot() + feed_task = asyncio.create_task(_blocking_feed(mot)) + return mot, ctx.add(action).add(mot) + + async def generate_from_raw(self, *args: Any, **kwargs: Any) -> Any: + raise NotImplementedError + + result = await stream_with_chunking( + _action(), BlockingBackend(), _ctx(), chunking="word" + ) + assert result._orchestration_task is not None + + # Yield once so the orchestration task starts and reaches its first real + # await (Queue.get inside astream). Without this, the task is cancelled + # before coro.send(None) is ever called, and Python skips the coroutine + # body entirely — the finally block never runs. + await asyncio.sleep(0) + + result._orchestration_task.cancel() + + try: + await result._orchestration_task + except BaseException: + pass + + # Primary assertion: completed must be False after external cancellation. + assert result.completed is False + + # The finally block must have run to completion: _done must be set and + # acomplete() must not hang. This is the actual failure mode the fix + # guards against — if _done is never set, acomplete() blocks forever. + # External cancellation surfaces as CancelledError (raise-once contract). + assert result._done.is_set() + with pytest.raises(asyncio.CancelledError): + await asyncio.wait_for(result.acomplete(), timeout=2.0) + + # Clean up the blocking feed task to avoid "Task destroyed while pending". + if feed_task is not None: + feed_task.cancel() + try: + await feed_task + except BaseException: + pass