Skip to content
32 changes: 32 additions & 0 deletions mellea/core/requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,38 @@ async def validate(
context=val_ctx,
)

async def stream_validate(
self, chunk: str, *, backend: Backend, ctx: Context
) -> PartialValidationResult:
"""Hook for per-chunk streaming validation.

The default implementation returns ``PartialValidationResult("unknown")``
— meaning insufficient data to decide yet. Subclasses override this method
to inspect the accumulated chunk and return ``"pass"`` or ``"fail"`` early.

Implementations may accumulate state on ``self`` across calls within a
single attempt. The orchestrator clones the requirement (``copy(req)``)
before each attempt, so state does not bleed across retries.

Shallow-copy caveat: mutable container fields (e.g. ``self._buffer = []``)
are shared by reference under ``copy()``. Reassign rather than mutate in
place (``self._buffer = self._buffer + [chunk]``, not
``self._buffer.append(chunk)``), or override ``__copy__`` for proper
isolation.

Args:
chunk: The accumulated model output so far (not just the latest token).
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This worries me. @nrfulton, your initial proposal was to have the requirement only see new chunks. This would show all accumulated chunks so far.

This forces all streaming requirements to be stateful; all requirements must now keep track of what chunks they have processed. The alternative would be to only provide new chunks to requirements; then streaming validation would be stateless except when needed. Requirements can choose whether they need to store and process multiple chunks, or just check each chunk independently.

I guess the checking each chunk independently is unlikely to be helpful, so I can see why accumulating and forcing requirements to track their own progress doesn't actually add much complexity.

If so, I think we should actually pre-define functions to help with this (either as a new class of requirements or functions that implementors can draw on).

Also, if we are passing the accumulated chunk through, I almost think we should just pass in some point-in-time copy of the model output thunk. Ie one that doesn't get streamed the new chunks but has all the data fields from the point-in-time it was copied at.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest reverting to the simplicity of the original proposal. The change was made incorrectly when reviewing the initial PR. The initial approach is simple - and requirements can maintain state if the need to work on accumulated chunks. If we need to support accumulated content generally we can consider it in a later phase.

Which works best will vary by use case. Per-chunk like works better for

  • checkinf for forbidden words/phrases
  • ensuring a paragraph or sentence is coherent
  • structural checks - code fencing
  • format validation

especially when the MoT manages the semantic chunking (later)

There will be cases where accumulation is better -- these are probably more complex checks - does a story line flow, are we taking the response in an unexpected direction, do we have a complete enough response

Importantly if we stick to per-chunk we could still implement this second approach - albeit not as cleanly.

So in summary - I'll revert to original -- but if you now think that's wrong we can adjust?

backend: The inference backend, available for backend-assisted checks.
ctx: The current generation context.

Returns:
PartialValidationResult: ``"unknown"`` by default. Subclasses may return
``"pass"`` (constraint satisfied so far) or ``"fail"`` (constraint violated,
streaming should be aborted). ``"pass"`` does not short-circuit the final
``validate()`` call; the orchestrator decides whether to skip it.
"""
return PartialValidationResult("unknown")

def parts(self) -> list[Component | CBlock]:
"""Returns all of the constituent parts of a Requirement.

Expand Down
166 changes: 166 additions & 0 deletions test/core/test_stream_validate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""Unit tests for Requirement.stream_validate() hook."""

import inspect
from copy import copy

import pytest

from mellea.core import Backend, Context, PartialValidationResult, Requirement


@pytest.mark.asyncio
async def test_default_returns_unknown():
req = Requirement(description="some requirement")
result = await req.stream_validate("some chunk", backend=None, ctx=None) # type: ignore[arg-type]
assert result.success == "unknown"


@pytest.mark.asyncio
async def test_default_returns_partial_validation_result_instance():
req = Requirement()
result = await req.stream_validate("chunk", backend=None, ctx=None) # type: ignore[arg-type]
assert isinstance(result, PartialValidationResult)


def test_stream_validate_is_coroutine():
req = Requirement()
assert inspect.iscoroutinefunction(req.stream_validate)


@pytest.mark.asyncio
async def test_subclass_can_return_pass():
class PassRequirement(Requirement):
async def stream_validate(
self, chunk: str, *, backend: Backend, ctx: Context
) -> PartialValidationResult:
return PartialValidationResult("pass")

req = PassRequirement(description="always passes")
result = await req.stream_validate("any chunk", backend=None, ctx=None) # type: ignore[arg-type]
assert result.success == "pass"


@pytest.mark.asyncio
async def test_subclass_can_return_fail():
class FailRequirement(Requirement):
async def stream_validate(
self, chunk: str, *, backend: Backend, ctx: Context
) -> PartialValidationResult:
if "bad" in chunk:
return PartialValidationResult("fail", reason="bad word detected")
return PartialValidationResult("unknown")

req = FailRequirement(description="no bad words")
result = await req.stream_validate("this is bad content", backend=None, ctx=None) # type: ignore[arg-type]
assert result.success == "fail"
assert result.reason == "bad word detected"

result_unknown = await req.stream_validate("good content", backend=None, ctx=None) # type: ignore[arg-type]
assert result_unknown.success == "unknown"


@pytest.mark.asyncio
async def test_does_not_mutate_requirement():
req = Requirement(description="original description")
original_description = req.description
original_output = req._output
original_validation_fn = req.validation_fn

await req.stream_validate("some chunk", backend=None, ctx=None) # type: ignore[arg-type]

assert req.description == original_description
assert req._output == original_output
assert req.validation_fn == original_validation_fn


@pytest.mark.asyncio
async def test_stream_validate_idempotent():
req = Requirement(description="repeated calls")
result1 = await req.stream_validate("chunk one", backend=None, ctx=None) # type: ignore[arg-type]
result2 = await req.stream_validate("chunk two", backend=None, ctx=None) # type: ignore[arg-type]
assert result1.success == "unknown"
assert result2.success == "unknown"
assert req._output is None


@pytest.mark.asyncio
async def test_stateful_subclass_accumulates_state():
"""Stateful subclass correctly accumulates state across stream_validate calls.

Uses delta extraction (via _seen_len) to count only new bullet points per call —
a pattern that genuinely requires state from prior calls.
"""

class BulletCounter(Requirement):
def __init__(self) -> None:
super().__init__(description="no more than 3 bullets")
self._seen_len = 0
self._bullet_count = 0

async def stream_validate(
self, chunk: str, *, backend: Backend, ctx: Context
) -> PartialValidationResult:
delta = chunk[self._seen_len :]
self._seen_len = len(chunk)
self._bullet_count += delta.count("\n-")
if self._bullet_count > 3:
return PartialValidationResult(
"fail", reason=f"{self._bullet_count} bullets exceeds limit"
)
return PartialValidationResult("unknown")

req = BulletCounter()
assert req._bullet_count == 0

await req.stream_validate("intro text", backend=None, ctx=None) # type: ignore[arg-type]
assert req._bullet_count == 0

await req.stream_validate("intro text\n- one\n- two", backend=None, ctx=None) # type: ignore[arg-type]
assert req._bullet_count == 2 # delta added 2 new bullets

result = await req.stream_validate(
"intro text\n- one\n- two\n- three\n- four",
backend=None, # type: ignore[arg-type]
ctx=None, # type: ignore[arg-type]
)
assert req._bullet_count == 4 # delta added 2 more
assert result.success == "fail"
assert result.reason is not None and "4" in result.reason


@pytest.mark.asyncio
async def test_stateful_subclass_clone_isolation():
"""Orchestrator clone pattern: copy() before each attempt gives a fresh independent clone.

The orchestrator holds the original requirement and never calls stream_validate on it
directly. Before each attempt it clones the original; each clone starts from the
original's (zero) state and advances independently.
"""

class CallCounter(Requirement):
def __init__(self) -> None:
super().__init__(description="call counter")
self._calls = 0

async def stream_validate(
self, chunk: str, *, backend: Backend, ctx: Context
) -> PartialValidationResult:
self._calls += 1
return PartialValidationResult("unknown")

req = CallCounter() # original — never used directly by the orchestrator

# Attempt 1
attempt1 = copy(req)
assert attempt1._calls == 0
await attempt1.stream_validate("a", backend=None, ctx=None) # type: ignore[arg-type]
await attempt1.stream_validate("b", backend=None, ctx=None) # type: ignore[arg-type]
assert attempt1._calls == 2

# Attempt 2 (retry) — fresh clone from the same original
attempt2 = copy(req)
assert attempt2._calls == 0 # starts clean, not carrying attempt1's state
await attempt2.stream_validate("c", backend=None, ctx=None) # type: ignore[arg-type]
assert attempt2._calls == 1

assert req._calls == 0 # original never mutated
Loading