From 822270747b5d5c084f44b81144341062d9e1b431 Mon Sep 17 00:00:00 2001 From: Alex M Date: Fri, 1 May 2026 18:56:52 +0000 Subject: [PATCH 1/6] feat: cross-loader tool-calling toolkit + transformers wiring Adds modelship.openai.tool_calling, a small package that turns raw chat-completion text into OpenAI-shape tool_calls. Loaders whose engines already emit structured calls (vLLM, llama.cpp via a function-calling chat handler) keep their native path; loaders that emit raw text (Transformers today, plugin-wrapped engines later) call into the toolkit. Includes: - ToolCallParser ABC + ParsedToolCalls result type - Hermes-style {...} parser (Hermes-2, Qwen2.5, many community fine-tunes) - name -> parser registry with register_parser hook for plugin code - resolve_tools_for_request applying OpenAI tool_choice semantics (none / auto / required / specific function) Wires the Transformers chat path to it: when tools are active, pre-renders the prompt via apply_chat_template(tools=...) and parses output through the configured parser, setting finish_reason="tool_calls" and populating ChatMessage.tool_calls. Streaming buffers tokens while tools are active and emits a single resolved delta at the end so we never stream a fragment of a tool-call marker as if it were prose. Also fixes ChatCompletionRequest.tool_choice default from "none" to None: per the OpenAI spec, "auto" is the default when tools are present. The previous default suppressed tools whenever a client omitted tool_choice, including via the llama.cpp passthrough. Tests: - 28 unit tests covering parser shape, registry behavior, tool_choice resolution, and the serving_chat tool path against a faked HF pipeline - Integration test deploying Qwen/Qwen2.5-0.5B-Instruct via the transformers loader and round-tripping a get_weather tool call --- .../infer/transformers/openai/serving_chat.py | 143 +++++++++++++-- modelship/openai/protocol.py | 2 +- modelship/openai/tool_calling/__init__.py | 21 +++ modelship/openai/tool_calling/input.py | 63 +++++++ .../openai/tool_calling/parsers/__init__.py | 4 + modelship/openai/tool_calling/parsers/base.py | 42 +++++ .../openai/tool_calling/parsers/hermes.py | 49 ++++++ modelship/openai/tool_calling/registry.py | 34 ++++ tests/test_integration.py | 60 ++++++- tests/test_tool_calling.py | 153 ++++++++++++++++ tests/test_transformers_chat_tools.py | 163 ++++++++++++++++++ 11 files changed, 717 insertions(+), 17 deletions(-) create mode 100644 modelship/openai/tool_calling/__init__.py create mode 100644 modelship/openai/tool_calling/input.py create mode 100644 modelship/openai/tool_calling/parsers/__init__.py create mode 100644 modelship/openai/tool_calling/parsers/base.py create mode 100644 modelship/openai/tool_calling/parsers/hermes.py create mode 100644 modelship/openai/tool_calling/registry.py create mode 100644 tests/test_tool_calling.py create mode 100644 tests/test_transformers_chat_tools.py diff --git a/modelship/infer/transformers/openai/serving_chat.py b/modelship/infer/transformers/openai/serving_chat.py index 35f7d26..333e2d5 100644 --- a/modelship/infer/transformers/openai/serving_chat.py +++ b/modelship/infer/transformers/openai/serving_chat.py @@ -3,6 +3,7 @@ import time from collections.abc import AsyncGenerator from threading import Thread +from typing import Any from transformers import Pipeline, PreTrainedTokenizerBase, TextIteratorStreamer @@ -23,6 +24,7 @@ UsageInfo, create_error_response, ) +from modelship.openai.tool_calling import ParsedToolCalls, get_parser, resolve_tools_for_request from modelship.utils import base_request_id logger = get_logger("infer.transformers.chat") @@ -45,12 +47,16 @@ def __init__( assert pipeline.tokenizer is not None, "text-generation pipeline must have a tokenizer" self.tokenizer: PreTrainedTokenizerBase = pipeline.tokenizer self._lock = asyncio.Lock() + # Validate the configured parser at startup so misconfiguration surfaces + # before the first request rather than mid-generation. + get_parser(self.config.tool_call_parser) async def warmup(self) -> None: logger.info("Warming up chat model: %s", self.model_name) await self.run_in_executor( self._run, [{"role": "user", "content": "warmup"}], + None, 1, ) logger.info("Warmup chat done for %s", self.model_name) @@ -74,26 +80,29 @@ async def create_chat_completion( logger.warning("chat request %s rejected: %s", request_id, e) return create_error_response(e) + tools = resolve_tools_for_request(request.tools, request.tool_choice) + max_tokens = request.max_tokens if max_tokens is None and request.max_completion_tokens is not None: max_tokens = request.max_completion_tokens if request.stream: include_usage = bool(request.stream_options and request.stream_options.include_usage) - return self._locked_stream(request_id, messages, max_tokens, include_usage=include_usage) + return self._locked_stream(request_id, messages, tools, max_tokens, include_usage=include_usage) async with self._lock: try: - result = await self.run_in_executor(self._run, messages, max_tokens) + result = await self.run_in_executor(self._run, messages, tools, max_tokens) except Exception: logger.exception("chat completion inference failed for %s", request_id) return create_error_response("chat completion inference failed") - prompt_tokens = self._count_prompt_tokens(messages) - generated = result[0]["generated_text"] - completion_text = generated[-1]["content"] if isinstance(generated, list) else generated + prompt_tokens = self._count_prompt_tokens(messages, tools) + completion_text = self._extract_completion_text(result) completion_tokens = len(self.tokenizer.encode(completion_text, add_special_tokens=False)) - finish_reason = "length" if (max_tokens is not None and completion_tokens >= max_tokens) else "stop" + + parsed = self._parse_tool_calls(completion_text) if tools else ParsedToolCalls(completion_text, []) + finish_reason = self._finish_reason(parsed, completion_tokens, max_tokens) response = ChatCompletionResponse( id=request_id, @@ -101,7 +110,11 @@ async def create_chat_completion( choices=[ ChatCompletionResponseChoice( index=0, - message=ChatMessage(role="assistant", content=completion_text), + message=ChatMessage( + role="assistant", + content=parsed.content, + tool_calls=parsed.tool_calls, + ), finish_reason=finish_reason, ) ], @@ -114,41 +127,81 @@ async def create_chat_completion( ) logger.log( TRACE, - "chat response %s: text=%r, prompt_tokens=%d, completion_tokens=%d", + "chat response %s: text=%r, tool_calls=%d, prompt_tokens=%d, completion_tokens=%d", request_id, completion_text, + len(parsed.tool_calls), prompt_tokens, completion_tokens, ) return response - def _count_prompt_tokens(self, messages: list[dict]) -> int: + def _parse_tool_calls(self, text: str) -> ParsedToolCalls: + return get_parser(self.config.tool_call_parser).parse(text) + + @staticmethod + def _finish_reason(parsed: ParsedToolCalls, completion_tokens: int, max_tokens: int | None) -> str: + if parsed.has_tool_calls: + return "tool_calls" + if max_tokens is not None and completion_tokens >= max_tokens: + return "length" + return "stop" + + @staticmethod + def _extract_completion_text(result: list) -> str: + generated = result[0]["generated_text"] + if isinstance(generated, list): + return generated[-1]["content"] + return generated + + def _count_prompt_tokens(self, messages: list[dict], tools: list[dict[str, Any]] | None) -> int: # apply_chat_template returns a string by default (character count!) — force tokenize=True. - token_ids = self.tokenizer.apply_chat_template(messages, tokenize=True) + kwargs: dict[str, Any] = {"tokenize": True} + if tools: + kwargs["tools"] = tools + token_ids = self.tokenizer.apply_chat_template(messages, **kwargs) return len(token_ids) - def _run(self, messages: list[dict], max_tokens: int | None) -> list: + def _render_prompt(self, messages: list[dict], tools: list[dict[str, Any]]) -> str: + rendered = self.tokenizer.apply_chat_template( + messages, + tools=tools, # type: ignore[arg-type] + tokenize=False, + add_generation_prompt=True, + ) + assert isinstance(rendered, str), "apply_chat_template(tokenize=False) must return str" + return rendered + + def _run(self, messages: list[dict], tools: list[dict[str, Any]] | None, max_tokens: int | None) -> list: kwargs = {**self.config.pipeline_kwargs} if max_tokens is not None: kwargs["max_new_tokens"] = max_tokens + if tools: + # The standard text-generation pipeline does not forward `tools` to + # `apply_chat_template`, so we render the prompt ourselves and feed + # it as a plain string. + prompt = self._render_prompt(messages, tools) + return self.pipeline(prompt, return_full_text=False, **kwargs) # type: ignore[return-value] return self.pipeline(messages, return_full_text=False, **kwargs) # type: ignore[return-value] async def _locked_stream( self, request_id: str, messages: list[dict], + tools: list[dict[str, Any]] | None, max_tokens: int | None, *, include_usage: bool, ) -> AsyncGenerator[str, None]: async with self._lock: - async for chunk in self._stream(request_id, messages, max_tokens, include_usage=include_usage): + async for chunk in self._stream(request_id, messages, tools, max_tokens, include_usage=include_usage): yield chunk async def _stream( self, request_id: str, messages: list[dict], + tools: list[dict[str, Any]] | None, max_tokens: int | None, *, include_usage: bool, @@ -159,9 +212,14 @@ async def _stream( if max_tokens is not None: kwargs["max_new_tokens"] = max_tokens + if tools: + prompt: Any = self._render_prompt(messages, tools) + else: + prompt = messages + thread = Thread( target=self.pipeline, - args=(messages,), + args=(prompt,), kwargs={ "streamer": streamer, "return_full_text": False, @@ -191,6 +249,12 @@ async def _stream( if not text_chunk: continue accumulated.append(text_chunk) + # When tools are in play we cannot stream content incrementally + # without risking emitting fragments of a `` block as + # if they were assistant prose. Buffer until generation is done + # and emit the resolved shape as a single delta below. + if tools: + continue yield self._encode_chunk( ChatCompletionStreamResponse( id=request_id, @@ -208,7 +272,15 @@ async def _stream( completion_text = "".join(accumulated) completion_tokens = len(self.tokenizer.encode(completion_text, add_special_tokens=False)) - finish_reason = "length" if (max_tokens is not None and completion_tokens >= max_tokens) else "stop" + + if tools: + parsed = self._parse_tool_calls(completion_text) + async for chunk in self._emit_buffered_final_delta(request_id, parsed): + yield chunk + else: + parsed = ParsedToolCalls(completion_text, []) + + finish_reason = self._finish_reason(parsed, completion_tokens, max_tokens) yield self._encode_chunk( ChatCompletionStreamResponse( @@ -226,7 +298,7 @@ async def _stream( ) if include_usage: - prompt_tokens = self._count_prompt_tokens(messages) + prompt_tokens = self._count_prompt_tokens(messages, tools) yield self._encode_chunk( ChatCompletionStreamResponse( id=request_id, @@ -245,6 +317,47 @@ async def _stream( finally: thread.join() + async def _emit_buffered_final_delta(self, request_id: str, parsed: ParsedToolCalls) -> AsyncGenerator[str, None]: + if parsed.has_tool_calls: + from modelship.openai.protocol import DeltaFunctionCall, DeltaToolCall + + deltas = [ + DeltaToolCall( + index=i, + id=tc.id, + type="function", + function=DeltaFunctionCall(name=tc.function.name, arguments=tc.function.arguments), + ) + for i, tc in enumerate(parsed.tool_calls) + ] + yield self._encode_chunk( + ChatCompletionStreamResponse( + id=request_id, + model=self.model_name, + choices=[ + ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(tool_calls=deltas), + ) + ], + created=int(time.time()), + ) + ) + elif parsed.content: + yield self._encode_chunk( + ChatCompletionStreamResponse( + id=request_id, + model=self.model_name, + choices=[ + ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(content=parsed.content), + ) + ], + created=int(time.time()), + ) + ) + @staticmethod def _encode_chunk(chunk: ChatCompletionStreamResponse) -> str: return f"data: {json.dumps(chunk.model_dump(mode='json'))}\n\n" diff --git a/modelship/openai/protocol.py b/modelship/openai/protocol.py index d09afe7..d4142b1 100644 --- a/modelship/openai/protocol.py +++ b/modelship/openai/protocol.py @@ -180,7 +180,7 @@ class ChatCompletionRequest(OpenAIBaseModel): temperature: float | None = None top_p: float | None = None tools: list[dict[str, Any]] | None = None - tool_choice: str | dict[str, Any] | None = "none" + tool_choice: str | dict[str, Any] | None = None reasoning_effort: Literal["none", "low", "medium", "high"] | None = None parallel_tool_calls: bool | None = True user: str | None = None diff --git a/modelship/openai/tool_calling/__init__.py b/modelship/openai/tool_calling/__init__.py new file mode 100644 index 0000000..3a132fe --- /dev/null +++ b/modelship/openai/tool_calling/__init__.py @@ -0,0 +1,21 @@ +"""Cross-loader tool-calling toolkit. + +Loaders without native tool-call support (Transformers today, plugin-wrapped +raw-text engines tomorrow) use the parsers and helpers in this package to +turn raw model output into OpenAI-shape ``tool_calls``. Loaders whose engines +already emit structured tool calls (vLLM, llama.cpp via a function-calling +chat handler) bypass it. +""" + +from modelship.openai.tool_calling.input import resolve_tools_for_request +from modelship.openai.tool_calling.parsers import ParsedToolCalls, ToolCallParser +from modelship.openai.tool_calling.registry import available_parsers, get_parser, register_parser + +__all__ = [ + "ParsedToolCalls", + "ToolCallParser", + "available_parsers", + "get_parser", + "register_parser", + "resolve_tools_for_request", +] diff --git a/modelship/openai/tool_calling/input.py b/modelship/openai/tool_calling/input.py new file mode 100644 index 0000000..a245c0c --- /dev/null +++ b/modelship/openai/tool_calling/input.py @@ -0,0 +1,63 @@ +"""Input-side helpers for tool calling. + +Loaders that hand a chat template a list of OpenAI messages plus a list of +tool schemas use these helpers to interpret the request's ``tool_choice`` +(``"none"`` suppresses tools, ``"required"`` / specific-function downgrade +to ``"auto"`` with a warning) and to validate that a parser exists for the +configured family before generation starts. +""" + +from __future__ import annotations + +from typing import Any + +from modelship.logging import get_logger + +logger = get_logger("openai.tool_calling.input") + + +def resolve_tools_for_request( + tools: list[dict[str, Any]] | None, + tool_choice: str | dict[str, Any] | None, +) -> list[dict[str, Any]] | None: + """Apply OpenAI ``tool_choice`` semantics to the request's ``tools`` list. + + Returns the list of tools to render into the prompt, or ``None`` when + the request should be served without any tool-calling affordance. + + - ``tool_choice == "none"`` — suppress tools entirely. + - ``tool_choice == "auto"`` (or unset) — pass all tools through. + - ``tool_choice == "required"`` — pass tools through, log that we cannot + strictly enforce a tool call without constrained decoding. + - ``tool_choice == {"type": "function", "function": {"name": "X"}}`` — + filter ``tools`` to that single function and warn that the call cannot + be strictly enforced. + """ + if not tools: + return None + if tool_choice in (None, "auto"): + return tools + if tool_choice == "none": + return None + if tool_choice == "required": + logger.warning( + "tool_choice='required' requested but this loader cannot enforce a tool call; " + "passing all tools to the model and trusting it to call one" + ) + return tools + if isinstance(tool_choice, dict): + fn = tool_choice.get("function") or {} + name = fn.get("name") if isinstance(fn, dict) else None + if isinstance(name, str) and name: + filtered = [t for t in tools if (t.get("function") or {}).get("name") == name] + if not filtered: + logger.warning("tool_choice names function %r which is not in the request's tools list", name) + return tools + logger.warning( + "tool_choice forcing function %r is not strictly enforced by this loader; " + "passing only that tool to the model", + name, + ) + return filtered + logger.warning("unrecognized tool_choice value %r; falling back to 'auto' semantics", tool_choice) + return tools diff --git a/modelship/openai/tool_calling/parsers/__init__.py b/modelship/openai/tool_calling/parsers/__init__.py new file mode 100644 index 0000000..ae63d47 --- /dev/null +++ b/modelship/openai/tool_calling/parsers/__init__.py @@ -0,0 +1,4 @@ +from modelship.openai.tool_calling.parsers.base import ParsedToolCalls, ToolCallParser +from modelship.openai.tool_calling.parsers.hermes import HermesToolCallParser + +__all__ = ["HermesToolCallParser", "ParsedToolCalls", "ToolCallParser"] diff --git a/modelship/openai/tool_calling/parsers/base.py b/modelship/openai/tool_calling/parsers/base.py new file mode 100644 index 0000000..bb26cb6 --- /dev/null +++ b/modelship/openai/tool_calling/parsers/base.py @@ -0,0 +1,42 @@ +"""Base class for model-family-specific tool-call output parsers.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass + +from modelship.openai.protocol import ToolCall + + +@dataclass(frozen=True) +class ParsedToolCalls: + """Result of running a parser over a model's raw chat-completion text. + + ``content`` carries the residual non-tool-call text once any tool-call + markers are stripped. It is ``None`` when tool calls were extracted *and* + the residual is empty, matching OpenAI's behavior of nulling ``content`` + alongside ``tool_calls``. + """ + + content: str | None + tool_calls: list[ToolCall] + + @property + def has_tool_calls(self) -> bool: + return bool(self.tool_calls) + + +class ToolCallParser(ABC): + """Convert a raw text generation into OpenAI-shape tool calls. + + Each subclass targets one model-family output convention (Hermes XML tags, + Llama 3.1 ``<|python_tag|>``, Mistral ``[TOOL_CALLS]``, …). Implementations + must be pure functions of the input text — no model state, no side effects + — so the same parser can be reused across loaders, deployments, and tests. + """ + + name: str + + @abstractmethod + def parse(self, text: str) -> ParsedToolCalls: + """Extract tool calls and residual content from ``text``.""" diff --git a/modelship/openai/tool_calling/parsers/hermes.py b/modelship/openai/tool_calling/parsers/hermes.py new file mode 100644 index 0000000..342722a --- /dev/null +++ b/modelship/openai/tool_calling/parsers/hermes.py @@ -0,0 +1,49 @@ +"""Hermes-style ``{json}`` parser. + +Used by Hermes-2-Pro, Qwen2.5-Instruct, and a large family of NousResearch / +community fine-tunes whose chat templates wrap each tool call in the literal +tags ```` / ```` around a JSON object of the shape +``{"name": "...", "arguments": {...}}``. +""" + +from __future__ import annotations + +import json +import re + +from modelship.openai.protocol import FunctionCall, ToolCall +from modelship.openai.tool_calling.parsers.base import ParsedToolCalls, ToolCallParser + +_TOOL_CALL_RE = re.compile(r"(.*?)", re.DOTALL) + + +class HermesToolCallParser(ToolCallParser): + name = "hermes" + + def parse(self, text: str) -> ParsedToolCalls: + matches = list(_TOOL_CALL_RE.finditer(text)) + if not matches: + return ParsedToolCalls(content=text, tool_calls=[]) + + tool_calls: list[ToolCall] = [] + for m in matches: + try: + payload = json.loads(m.group(1).strip()) + except json.JSONDecodeError: + # Malformed block — leave the text in residual, skip this call. + continue + if not isinstance(payload, dict): + continue + name = payload.get("name") + if not isinstance(name, str) or not name: + continue + arguments = payload.get("arguments", {}) + if not isinstance(arguments, str): + arguments = json.dumps(arguments) + tool_calls.append(ToolCall(function=FunctionCall(name=name, arguments=arguments))) + + if not tool_calls: + return ParsedToolCalls(content=text, tool_calls=[]) + + residual = _TOOL_CALL_RE.sub("", text).strip() + return ParsedToolCalls(content=residual or None, tool_calls=tool_calls) diff --git a/modelship/openai/tool_calling/registry.py b/modelship/openai/tool_calling/registry.py new file mode 100644 index 0000000..68b8e28 --- /dev/null +++ b/modelship/openai/tool_calling/registry.py @@ -0,0 +1,34 @@ +"""Registry of named tool-call parsers, dispatched by configuration string. + +The registry is the single seam between a loader and the per-family parsers. +Loaders that emit raw text (Transformers, plugin-wrapped engines) look up a +parser by the name configured on the deployment and feed it the model's +output. Loaders with native tool-call support (vLLM, llama.cpp via a +function-calling chat handler) bypass the registry entirely. +""" + +from __future__ import annotations + +from modelship.openai.tool_calling.parsers import HermesToolCallParser, ToolCallParser + +_PARSERS: dict[str, ToolCallParser] = { + HermesToolCallParser.name: HermesToolCallParser(), +} + + +def get_parser(name: str) -> ToolCallParser: + """Return the parser registered under ``name`` or raise ``ValueError``.""" + try: + return _PARSERS[name] + except KeyError: + available = ", ".join(sorted(_PARSERS)) or "(none)" + raise ValueError(f"unknown tool_call_parser {name!r}; available: {available}") from None + + +def available_parsers() -> list[str]: + return sorted(_PARSERS) + + +def register_parser(parser: ToolCallParser) -> None: + """Register an additional parser. Intended for tests and plugin code.""" + _PARSERS[parser.name] = parser diff --git a/tests/test_integration.py b/tests/test_integration.py index 47f276c..9433136 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -9,7 +9,14 @@ OPENAI_API_BASE = "http://localhost:8000/v1" -EXPECTED_MODELS = {"chat-capable", "chat-limited", "embed-model", "stt-model", "tts-model"} +EXPECTED_MODELS = { + "chat-capable", + "chat-limited", + "chat-transformers", + "embed-model", + "stt-model", + "tts-model", +} @pytest.fixture(scope="session") @@ -45,6 +52,20 @@ def mship_cluster(tmp_path_factory): "loader": "llama_cpp", "num_cpus": 1, }, + { + # Same Qwen2.5-Instruct family as `chat-capable` so we exercise + # the transformers loader against a model trained to emit + # Hermes-style `{...}` markers. + "name": "chat-transformers", + "model": "Qwen/Qwen2.5-0.5B-Instruct", + "usecase": "generate", + "loader": "transformers", + "num_cpus": 2, + "transformers_config": { + "device": "cpu", + "torch_dtype": "float32", + }, + }, { "name": "embed-model", "model": "nomic-ai/nomic-embed-text-v1.5", @@ -140,6 +161,7 @@ def test_list_models(client): model_ids = [m.id for m in models.data] assert "chat-capable" in model_ids assert "chat-limited" in model_ids + assert "chat-transformers" in model_ids assert "embed-model" in model_ids assert "stt-model" in model_ids assert "tts-model" in model_ids @@ -191,6 +213,42 @@ def test_tool_calling_success(client): assert completion.choices[0].message.tool_calls[0].function.name == "get_weather" +@pytest.mark.integration +def test_tool_calling_transformers_loader(client): + """Round-trip a Hermes-style tool call through the transformers loader. + + Uses the same Qwen2.5-0.5B-Instruct weights as the vLLM `chat-capable` + deployment but goes through the modelship-side tool-calling toolkit + (apply_chat_template(tools=...) on input, hermes parser on output). + """ + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather for a city", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + }, + } + ] + completion = client.chat.completions.create( + model="chat-transformers", + messages=[{"role": "user", "content": "What is the weather in Paris?"}], + tools=tools, + tool_choice="auto", + max_tokens=128, + ) + tool_calls = completion.choices[0].message.tool_calls + assert tool_calls, f"expected a tool call, got content={completion.choices[0].message.content!r}" + assert tool_calls[0].function.name == "get_weather" + assert "Paris" in tool_calls[0].function.arguments + assert completion.choices[0].finish_reason == "tool_calls" + + @pytest.mark.integration def test_tool_calling_unsupported_loader(client): """Verifies that loaders without tool support (like llama_cpp) don't return tool calls.""" diff --git a/tests/test_tool_calling.py b/tests/test_tool_calling.py new file mode 100644 index 0000000..e5aa61f --- /dev/null +++ b/tests/test_tool_calling.py @@ -0,0 +1,153 @@ +"""Tests for the cross-loader tool-calling toolkit.""" + +from __future__ import annotations + +import json +from typing import ClassVar + +import pytest + +from modelship.openai.tool_calling import ( + available_parsers, + get_parser, + register_parser, + resolve_tools_for_request, +) +from modelship.openai.tool_calling.parsers import HermesToolCallParser, ParsedToolCalls, ToolCallParser + + +class TestRegistry: + def test_default_registry_includes_hermes(self): + assert "hermes" in available_parsers() + + def test_get_parser_returns_singleton(self): + a = get_parser("hermes") + b = get_parser("hermes") + assert a is b + + def test_unknown_parser_raises_with_available_list(self): + with pytest.raises(ValueError, match="hermes"): + get_parser("does-not-exist") + + def test_register_parser_makes_it_findable(self): + class Stub(ToolCallParser): + name = "stub-test-parser" + + def parse(self, text: str) -> ParsedToolCalls: + return ParsedToolCalls(content=text, tool_calls=[]) + + register_parser(Stub()) + try: + assert get_parser("stub-test-parser").name == "stub-test-parser" + finally: + # Clean up so other tests don't see the stub. + from modelship.openai.tool_calling import registry + + registry._PARSERS.pop("stub-test-parser", None) + + +class TestHermesParser: + parser = HermesToolCallParser() + + def test_no_tool_calls_returns_text_unchanged(self): + result = self.parser.parse("just a regular response") + assert result.tool_calls == [] + assert result.content == "just a regular response" + assert result.has_tool_calls is False + + def test_single_tool_call(self): + text = '{"name": "get_weather", "arguments": {"city": "Paris"}}' + result = self.parser.parse(text) + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].function.name == "get_weather" + assert json.loads(result.tool_calls[0].function.arguments) == {"city": "Paris"} + assert result.content is None + + def test_multiple_tool_calls(self): + text = ( + '{"name": "a", "arguments": {"x": 1}}' + '{"name": "b", "arguments": {"y": 2}}' + ) + result = self.parser.parse(text) + assert [tc.function.name for tc in result.tool_calls] == ["a", "b"] + + def test_tool_call_with_residual_text(self): + text = 'Sure, calling that.\n{"name": "ping", "arguments": {}}' + result = self.parser.parse(text) + assert len(result.tool_calls) == 1 + assert result.content == "Sure, calling that." + + def test_string_arguments_passed_through(self): + text = '{"name": "x", "arguments": "{\\"a\\": 1}"}' + result = self.parser.parse(text) + assert result.tool_calls[0].function.arguments == '{"a": 1}' + + def test_object_arguments_serialized_to_json(self): + text = '{"name": "x", "arguments": {"a": 1, "b": [2, 3]}}' + result = self.parser.parse(text) + assert json.loads(result.tool_calls[0].function.arguments) == {"a": 1, "b": [2, 3]} + + def test_malformed_json_block_drops_call_and_falls_back_to_content(self): + text = "{not valid json}" + result = self.parser.parse(text) + assert result.tool_calls == [] + # Malformed block stays in content as-is. + assert result.content == text + + def test_missing_name_drops_call(self): + text = '{"arguments": {}}' + result = self.parser.parse(text) + assert result.tool_calls == [] + + def test_empty_name_drops_call(self): + text = '{"name": "", "arguments": {}}' + result = self.parser.parse(text) + assert result.tool_calls == [] + + def test_each_tool_call_gets_unique_id(self): + text = ( + '{"name": "a", "arguments": {}}{"name": "b", "arguments": {}}' + ) + result = self.parser.parse(text) + assert result.tool_calls[0].id != result.tool_calls[1].id + + +class TestResolveToolsForRequest: + tools: ClassVar = [ + {"type": "function", "function": {"name": "alpha"}}, + {"type": "function", "function": {"name": "beta"}}, + ] + + def test_no_tools_returns_none(self): + assert resolve_tools_for_request(None, "auto") is None + assert resolve_tools_for_request([], "auto") is None + + def test_auto_passes_through(self): + assert resolve_tools_for_request(self.tools, "auto") == self.tools + + def test_unset_tool_choice_passes_through(self): + assert resolve_tools_for_request(self.tools, None) == self.tools + + def test_none_suppresses_tools(self): + assert resolve_tools_for_request(self.tools, "none") is None + + def test_required_passes_through(self): + # We cannot strictly enforce a tool call without constrained decoding, + # so "required" downgrades to "auto" semantics (with a logged warning). + assert resolve_tools_for_request(self.tools, "required") == self.tools + + def test_specific_function_filters_to_that_tool(self): + result = resolve_tools_for_request(self.tools, {"type": "function", "function": {"name": "beta"}}) + assert result is not None + assert len(result) == 1 + assert result[0]["function"]["name"] == "beta" + + def test_unknown_function_falls_back_to_all(self): + # If the named function isn't in the tools list, fall back to passing + # them all through rather than emitting an empty list. + result = resolve_tools_for_request(self.tools, {"type": "function", "function": {"name": "missing"}}) + assert result == self.tools + + def test_unrecognized_choice_falls_back_to_all(self): + result = resolve_tools_for_request(self.tools, "weird-mode") + assert result == self.tools diff --git a/tests/test_transformers_chat_tools.py b/tests/test_transformers_chat_tools.py new file mode 100644 index 0000000..003a909 --- /dev/null +++ b/tests/test_transformers_chat_tools.py @@ -0,0 +1,163 @@ +"""Tests for the Transformers chat path's tool-call handling. + +These tests bypass the real HF pipeline by injecting a callable that returns +a canned generation, so they run offline and do not touch any model weights. +""" + +from __future__ import annotations + +import json +from typing import Any + +import pytest + +from modelship.infer.infer_config import RawRequestProxy, TransformersConfig +from modelship.infer.transformers.capabilities import TransformersCapabilities +from modelship.infer.transformers.openai.serving_chat import OpenAIServingChat +from modelship.openai.protocol import ChatCompletionRequest, ChatCompletionResponse + + +class _FakeTokenizer: + def encode(self, text: str, add_special_tokens: bool = False) -> list[int]: + return [0] * len(text.split()) + + def apply_chat_template(self, messages: list[dict], **kwargs: Any) -> Any: + prompt = "\n".join(f"{m['role']}: {m.get('content', '')}" for m in messages) + if "tools" in kwargs: + prompt = f"[TOOLS:{len(kwargs['tools'])}]\n" + prompt + if kwargs.get("tokenize"): + return [0] * len(prompt.split()) + return prompt + + +class _FakePipeline: + """Stand-in for ``transformers.Pipeline`` that records calls and replays canned output.""" + + def __init__(self, generated_text: str): + self.tokenizer = _FakeTokenizer() + self.task = "text-generation" + self.generated_text = generated_text + self.last_input: Any = None + self.last_kwargs: dict[str, Any] = {} + + def __call__(self, inputs: Any, **kwargs: Any) -> list[dict]: + self.last_input = inputs + self.last_kwargs = kwargs + return [{"generated_text": self.generated_text}] + + +def _make_serving(generated: str) -> tuple[OpenAIServingChat, _FakePipeline]: + pipe = _FakePipeline(generated) + serving = OpenAIServingChat( + pipeline=pipe, # type: ignore[arg-type] + model_name="test-model", + config=TransformersConfig(), + capabilities=TransformersCapabilities(supports_image=False, supports_audio=False), + ) + return serving, pipe + + +def _raw_request() -> RawRequestProxy: + return RawRequestProxy(None, {}) + + +@pytest.mark.asyncio +async def test_response_without_tools_carries_content_only(): + serving, _ = _make_serving("hello there") + req = ChatCompletionRequest(messages=[{"role": "user", "content": "hi"}], stream=False) + resp = await serving.create_chat_completion(req, _raw_request()) + + assert isinstance(resp, ChatCompletionResponse) + msg = resp.choices[0].message + assert msg.content == "hello there" + assert msg.tool_calls == [] + assert resp.choices[0].finish_reason == "stop" + + +@pytest.mark.asyncio +async def test_tools_only_render_when_requested(): + serving, pipe = _make_serving("hello") + req = ChatCompletionRequest(messages=[{"role": "user", "content": "hi"}], stream=False) + await serving.create_chat_completion(req, _raw_request()) + + # Without `tools` in the request, the pipeline receives the message list + # directly — no pre-rendered prompt. + assert isinstance(pipe.last_input, list) + + +@pytest.mark.asyncio +async def test_tools_in_request_pre_renders_prompt_and_parses_tool_call(): + raw = '{"name": "get_weather", "arguments": {"city": "Paris"}}' + serving, pipe = _make_serving(raw) + req = ChatCompletionRequest( + messages=[{"role": "user", "content": "weather in paris?"}], + tools=[ + { + "type": "function", + "function": {"name": "get_weather", "parameters": {"type": "object"}}, + } + ], + tool_choice="auto", + stream=False, + ) + resp = await serving.create_chat_completion(req, _raw_request()) + + assert isinstance(resp, ChatCompletionResponse) + # Pre-rendered prompt is a string carrying the tool marker injected by our fake template. + assert isinstance(pipe.last_input, str) + assert pipe.last_input.startswith("[TOOLS:1]") + + msg = resp.choices[0].message + assert msg.content is None + assert len(msg.tool_calls) == 1 + assert msg.tool_calls[0].function.name == "get_weather" + assert json.loads(msg.tool_calls[0].function.arguments) == {"city": "Paris"} + assert resp.choices[0].finish_reason == "tool_calls" + + +@pytest.mark.asyncio +async def test_tool_choice_none_skips_tool_rendering(): + serving, pipe = _make_serving("regular reply") + req = ChatCompletionRequest( + messages=[{"role": "user", "content": "hi"}], + tools=[{"type": "function", "function": {"name": "noop"}}], + tool_choice="none", + stream=False, + ) + resp = await serving.create_chat_completion(req, _raw_request()) + + # tool_choice="none" — pipeline should receive messages, not a rendered prompt. + assert isinstance(pipe.last_input, list) + assert isinstance(resp, ChatCompletionResponse) + assert resp.choices[0].message.content == "regular reply" + assert resp.choices[0].message.tool_calls == [] + assert resp.choices[0].finish_reason == "stop" + + +@pytest.mark.asyncio +async def test_tool_call_with_trailing_text_preserves_content(): + raw = 'Calling now.\n{"name": "ping", "arguments": {}}' + serving, _ = _make_serving(raw) + req = ChatCompletionRequest( + messages=[{"role": "user", "content": "ping?"}], + tools=[{"type": "function", "function": {"name": "ping"}}], + stream=False, + ) + resp = await serving.create_chat_completion(req, _raw_request()) + + assert isinstance(resp, ChatCompletionResponse) + msg = resp.choices[0].message + assert msg.content == "Calling now." + assert len(msg.tool_calls) == 1 + + +@pytest.mark.asyncio +async def test_unknown_parser_at_init_raises(): + pipe = _FakePipeline("anything") + with pytest.raises(ValueError, match="unknown tool_call_parser"): + OpenAIServingChat( + pipeline=pipe, # type: ignore[arg-type] + model_name="test-model", + config=TransformersConfig(tool_call_parser="not-a-real-parser"), + capabilities=TransformersCapabilities(supports_image=False, supports_audio=False), + ) From 80b42617cc9f19db640ce737a9b4cfac23947467 Mon Sep 17 00:00:00 2001 From: Alex M Date: Mon, 4 May 2026 16:47:16 +0000 Subject: [PATCH 2/6] feat: incremental streaming for tool-call parsing Replaces the buffer-until-done streaming path with a vLLM-style stateful diff loop so the client receives content tokens and tool-call argument fragments as fast as the model emits them, instead of seeing nothing until generation finishes. ToolCallParser is reshaped around three knobs per family: ``start_marker`` / ``end_marker`` and two extractors, ``extract_partial_name`` and ``extract_partial_args``. A new ``ToolCallStreamer`` instance is created per request and holds the high-water-marks ``_sent_content_idx`` / ``_sent_name[i]`` / ``_sent_args[i]``. On each ``extract_streaming(current_text)`` call it re-derives the content stream view (text with tool-call regions excised) and per-block fragments, then diffs against state and returns a ``DeltaMessage | None`` carrying just the new bytes. Tests: - 9 new TestToolCallStreamer cases covering pure-content streaming, marker-prefix tail held back until disambiguated/finalize, name emitted before args, args streamed incrementally across many small chunks (concatenated they form valid JSON), multiple tool calls get distinct indices, content resumes after a tool call, partial name held until its closing quote, unterminated block doesn't crash on finalize. - Two existing parser tests updated where vLLM-style semantics differ from the old block-level parser (raw-bytes args passthrough; blocks with no extractable name silently dropped). - Integration: ``test_tool_calling_streaming_transformers_loader`` and ``test_tool_calling_streaming_vllm_loader`` exercise streaming + tool calling end to end through the gateway. The transformers test asserts the function name arrives in exactly one delta, arguments arrive in >= 2 deltas (the key invariant proving the diff loop is actually diffing rather than buffering), and the rebuilt args parse as JSON. --- docs/model-configuration.md | 20 ++ .../infer/transformers/openai/serving_chat.py | 110 +++----- modelship/openai/tool_calling/__init__.py | 3 +- .../openai/tool_calling/parsers/__init__.py | 4 +- modelship/openai/tool_calling/parsers/base.py | 261 +++++++++++++++++- .../openai/tool_calling/parsers/hermes.py | 61 ++-- tests/test_integration.py | 141 ++++++++++ tests/test_tool_calling.py | 141 +++++++++- 8 files changed, 608 insertions(+), 133 deletions(-) diff --git a/docs/model-configuration.md b/docs/model-configuration.md index 6ef64dd..3024ee7 100644 --- a/docs/model-configuration.md +++ b/docs/model-configuration.md @@ -219,6 +219,7 @@ The `transformers` loader uses PyTorch with HuggingFace Transformers. Supports c | `trust_remote_code` | bool | `false` | Allow remote code execution | | `model_kwargs` | object | `{}` | Extra keyword arguments passed to the model constructor | | `pipeline_kwargs` | object | `{}` | Extra keyword arguments passed to the pipeline at inference time | +| `tool_call_parser` | string | `hermes` | Parser used to turn raw model output into OpenAI `tool_calls`. Currently supported: `hermes` (Hermes-2-Pro / Qwen2.5-Instruct / many community fine-tunes that emit `{...}` markers). | ### Chat / Text Generation (CPU) @@ -233,6 +234,25 @@ models: device: "cpu" ``` +### Chat with Tool Calling (CPU) + +The transformers loader renders `tools` into the prompt via the model's chat +template and parses the output back into OpenAI `tool_calls`. The model must +have been trained on a Hermes-style tool format (Qwen2.5-Instruct, Hermes-2, +many community fine-tunes); the parser is selected via `tool_call_parser`. + +```yaml +models: + - name: qwen-tools + model: Qwen/Qwen2.5-0.5B-Instruct + usecase: generate + loader: transformers + num_cpus: 2 + transformers_config: + device: "cpu" + tool_call_parser: hermes # this is the default; shown for clarity +``` + ### Speech-to-Text (CPU) Audio is automatically decoded and resampled to the model's expected sample rate (e.g. 16kHz for Whisper). diff --git a/modelship/infer/transformers/openai/serving_chat.py b/modelship/infer/transformers/openai/serving_chat.py index 333e2d5..6cf4c94 100644 --- a/modelship/infer/transformers/openai/serving_chat.py +++ b/modelship/infer/transformers/openai/serving_chat.py @@ -24,7 +24,7 @@ UsageInfo, create_error_response, ) -from modelship.openai.tool_calling import ParsedToolCalls, get_parser, resolve_tools_for_request +from modelship.openai.tool_calling import ParsedToolCalls, ToolCallStreamer, get_parser, resolve_tools_for_request from modelship.utils import base_request_id logger = get_logger("infer.transformers.chat") @@ -214,8 +214,10 @@ async def _stream( if tools: prompt: Any = self._render_prompt(messages, tools) + tool_call_streamer: ToolCallStreamer | None = ToolCallStreamer(get_parser(self.config.tool_call_parser)) else: prompt = messages + tool_call_streamer = None thread = Thread( target=self.pipeline, @@ -231,55 +233,34 @@ async def _stream( accumulated: list[str] = [] try: # Per OpenAI spec, the first delta carries `role` only. - yield self._encode_chunk( - ChatCompletionStreamResponse( - id=request_id, - model=self.model_name, - choices=[ - ChatCompletionResponseStreamChoice( - index=0, - delta=DeltaMessage(role="assistant"), - ) - ], - created=int(time.time()), - ) - ) + yield self._delta_chunk(request_id, DeltaMessage(role="assistant")) for text_chunk in streamer: if not text_chunk: continue accumulated.append(text_chunk) - # When tools are in play we cannot stream content incrementally - # without risking emitting fragments of a `` block as - # if they were assistant prose. Buffer until generation is done - # and emit the resolved shape as a single delta below. - if tools: + if tool_call_streamer is None: + yield self._delta_chunk(request_id, DeltaMessage(content=text_chunk)) + await asyncio.sleep(0) continue - yield self._encode_chunk( - ChatCompletionStreamResponse( - id=request_id, - model=self.model_name, - choices=[ - ChatCompletionResponseStreamChoice( - index=0, - delta=DeltaMessage(content=text_chunk), - ) - ], - created=int(time.time()), - ) - ) + # Re-parse the cumulative text and emit any new content / tool-call + # fragments. Bounded held-back tail (length of the start marker) so + # the client never sees a half-formed `` opening tag. + delta = tool_call_streamer.extract_streaming("".join(accumulated)) + if delta is not None: + yield self._delta_chunk(request_id, delta) await asyncio.sleep(0) - completion_text = "".join(accumulated) - completion_tokens = len(self.tokenizer.encode(completion_text, add_special_tokens=False)) - - if tools: - parsed = self._parse_tool_calls(completion_text) - async for chunk in self._emit_buffered_final_delta(request_id, parsed): - yield chunk + if tool_call_streamer is not None: + final = tool_call_streamer.finalize() + if final is not None: + yield self._delta_chunk(request_id, final) + parsed = tool_call_streamer.result else: - parsed = ParsedToolCalls(completion_text, []) + parsed = ParsedToolCalls("".join(accumulated), []) + completion_text = "".join(accumulated) + completion_tokens = len(self.tokenizer.encode(completion_text, add_special_tokens=False)) finish_reason = self._finish_reason(parsed, completion_tokens, max_tokens) yield self._encode_chunk( @@ -317,46 +298,17 @@ async def _stream( finally: thread.join() - async def _emit_buffered_final_delta(self, request_id: str, parsed: ParsedToolCalls) -> AsyncGenerator[str, None]: - if parsed.has_tool_calls: - from modelship.openai.protocol import DeltaFunctionCall, DeltaToolCall - - deltas = [ - DeltaToolCall( - index=i, - id=tc.id, - type="function", - function=DeltaFunctionCall(name=tc.function.name, arguments=tc.function.arguments), - ) - for i, tc in enumerate(parsed.tool_calls) - ] - yield self._encode_chunk( - ChatCompletionStreamResponse( - id=request_id, - model=self.model_name, - choices=[ - ChatCompletionResponseStreamChoice( - index=0, - delta=DeltaMessage(tool_calls=deltas), - ) - ], - created=int(time.time()), - ) - ) - elif parsed.content: - yield self._encode_chunk( - ChatCompletionStreamResponse( - id=request_id, - model=self.model_name, - choices=[ - ChatCompletionResponseStreamChoice( - index=0, - delta=DeltaMessage(content=parsed.content), - ) - ], - created=int(time.time()), - ) + def _delta_chunk(self, request_id: str, delta: DeltaMessage) -> str: + return self._encode_chunk( + ChatCompletionStreamResponse( + id=request_id, + model=self.model_name, + choices=[ + ChatCompletionResponseStreamChoice(index=0, delta=delta), + ], + created=int(time.time()), ) + ) @staticmethod def _encode_chunk(chunk: ChatCompletionStreamResponse) -> str: diff --git a/modelship/openai/tool_calling/__init__.py b/modelship/openai/tool_calling/__init__.py index 3a132fe..2e448ec 100644 --- a/modelship/openai/tool_calling/__init__.py +++ b/modelship/openai/tool_calling/__init__.py @@ -8,12 +8,13 @@ """ from modelship.openai.tool_calling.input import resolve_tools_for_request -from modelship.openai.tool_calling.parsers import ParsedToolCalls, ToolCallParser +from modelship.openai.tool_calling.parsers import ParsedToolCalls, ToolCallParser, ToolCallStreamer from modelship.openai.tool_calling.registry import available_parsers, get_parser, register_parser __all__ = [ "ParsedToolCalls", "ToolCallParser", + "ToolCallStreamer", "available_parsers", "get_parser", "register_parser", diff --git a/modelship/openai/tool_calling/parsers/__init__.py b/modelship/openai/tool_calling/parsers/__init__.py index ae63d47..1df7236 100644 --- a/modelship/openai/tool_calling/parsers/__init__.py +++ b/modelship/openai/tool_calling/parsers/__init__.py @@ -1,4 +1,4 @@ -from modelship.openai.tool_calling.parsers.base import ParsedToolCalls, ToolCallParser +from modelship.openai.tool_calling.parsers.base import ParsedToolCalls, ToolCallParser, ToolCallStreamer from modelship.openai.tool_calling.parsers.hermes import HermesToolCallParser -__all__ = ["HermesToolCallParser", "ParsedToolCalls", "ToolCallParser"] +__all__ = ["HermesToolCallParser", "ParsedToolCalls", "ToolCallParser", "ToolCallStreamer"] diff --git a/modelship/openai/tool_calling/parsers/base.py b/modelship/openai/tool_calling/parsers/base.py index bb26cb6..678ff23 100644 --- a/modelship/openai/tool_calling/parsers/base.py +++ b/modelship/openai/tool_calling/parsers/base.py @@ -1,16 +1,32 @@ -"""Base class for model-family-specific tool-call output parsers.""" +"""Base class for model-family-specific tool-call output parsers. + +Parsers in this codebase are *marker-based*: each family wraps tool-call +JSON in a fixed pair of literal strings (```` / ```` +for Hermes, ``[TOOL_CALLS]`` / closing token for Mistral, ...). A subclass +declares the marker pair plus two small extractors that pick a function +name and an arguments substring out of the (possibly partial) JSON between +the markers. Both the streaming and non-streaming paths run the same +:class:`ToolCallStreamer` so behavior cannot drift between them. +""" from __future__ import annotations from abc import ABC, abstractmethod from dataclasses import dataclass -from modelship.openai.protocol import ToolCall +from modelship.openai.protocol import ( + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + FunctionCall, + ToolCall, + random_uuid, +) @dataclass(frozen=True) class ParsedToolCalls: - """Result of running a parser over a model's raw chat-completion text. + """Aggregate result of parsing a model's full chat-completion text. ``content`` carries the residual non-tool-call text once any tool-call markers are stripped. It is ``None`` when tool calls were extracted *and* @@ -27,16 +43,243 @@ def has_tool_calls(self) -> bool: class ToolCallParser(ABC): - """Convert a raw text generation into OpenAI-shape tool calls. + """Family-specific knobs the streamer needs to drive its diff loop. - Each subclass targets one model-family output convention (Hermes XML tags, - Llama 3.1 ``<|python_tag|>``, Mistral ``[TOOL_CALLS]``, …). Implementations - must be pure functions of the input text — no model state, no side effects - — so the same parser can be reused across loaders, deployments, and tests. + Subclasses set ``start_marker`` / ``end_marker`` and implement the two + extractors. They never touch streaming state — that lives on + :class:`ToolCallStreamer`, which can be instantiated once per request. """ name: str + start_marker: str + end_marker: str @abstractmethod + def extract_partial_name(self, partial_payload: str) -> str | None: + """Return the function name if a complete quoted name is visible yet, else ``None``. + + Called every delta until it yields a non-``None`` result; once a name + has been emitted to the client, the streamer stops asking and starts + forwarding arguments bytes for that tool. + """ + + @abstractmethod + def extract_partial_args(self, partial_payload: str) -> str | None: + """Return the arguments substring as the client should see it so far. + + The streamer takes a length-diff of successive returns and forwards + only the new bytes. Implementations must withhold any trailing bytes + that could plausibly be the envelope closer landing ahead of the + family's end marker — otherwise those bytes leak into the args + stream and the client receives malformed JSON. + """ + def parse(self, text: str) -> ParsedToolCalls: - """Extract tool calls and residual content from ``text``.""" + streamer = ToolCallStreamer(self) + streamer.extract_streaming(text) + streamer.finalize() + return streamer.result + + +class ToolCallStreamer: + """Per-request, stateful tool-call extractor. + + Mirrors vLLM's approach: hold a small amount of "what we've already sent" + state, re-parse the cumulative ``current_text`` on every delta, and diff. + Returns either a :class:`DeltaMessage` carrying any newly-emittable + content / tool-call fragments or ``None`` if there is nothing to send yet. + + State held per request: + + - ``_sent_content_idx`` — number of content-stream chars already shipped. + The "content stream" view is the original text with every (complete or + open) tool-call region excised. + - ``_sent_name`` / ``_sent_id`` per tool index — whether the function name + and id deltas have been emitted. + - ``_sent_args`` per tool index — number of arguments chars already + shipped (the suffix-diff cursor). + - ``_finalized_calls`` — :class:`ToolCall` objects accumulated for blocks + that have closed, used to populate :attr:`result` for the non-streaming + path and for the final ``finish_reason``. + """ + + def __init__(self, parser: ToolCallParser): + self._parser = parser + self._start = parser.start_marker + self._end = parser.end_marker + self._sent_content_idx = 0 + self._sent_name: list[bool] = [] + self._sent_id: list[str] = [] + self._sent_args: list[str] = [] + self._finalized_calls: list[ToolCall] = [] + self._content_parts_len = 0 # bookkeeping for `result.content` + self._last_text = "" + + def extract_streaming(self, current_text: str) -> DeltaMessage | None: + """Run one diff pass against ``current_text`` and return any new deltas.""" + self._last_text = current_text + content_delta = self._emit_new_content(current_text, hold_marker_tail=True) + tool_call_deltas = self._emit_new_tool_call_fragments(current_text) + if not content_delta and not tool_call_deltas: + return None + return DeltaMessage(content=content_delta, tool_calls=tool_call_deltas) + + def finalize(self) -> DeltaMessage | None: + """Flush any held-back content tail once no more text is coming.""" + content_delta = self._emit_new_content(self._last_text, hold_marker_tail=False) + if content_delta is None: + return None + return DeltaMessage(content=content_delta) + + @property + def result(self) -> ParsedToolCalls: + """Final view, suitable for the non-streaming response shape.""" + # The content-view we accumulated as we streamed; reconstruct from the cursor. + view = self._build_content_view(self._last_text, hold_marker_tail=False) + content = (view.strip() or None) if self._finalized_calls else (view or None) + return ParsedToolCalls(content=content, tool_calls=list(self._finalized_calls)) + + # ------------------------------------------------------------------ + # Content stream + # ------------------------------------------------------------------ + + def _emit_new_content(self, current_text: str, *, hold_marker_tail: bool) -> str | None: + view = self._build_content_view(current_text, hold_marker_tail=hold_marker_tail) + if len(view) <= self._sent_content_idx: + return None + new = view[self._sent_content_idx :] + self._sent_content_idx = len(view) + return new or None + + def _build_content_view(self, text: str, *, hold_marker_tail: bool) -> str: + """Build the content-stream view: the original text with tool-call regions excised. + + When ``hold_marker_tail`` is true (mid-stream), withhold a trailing + suffix that could be the start of a new ``start_marker``, so the + client never sees half of an opening tag. At finalize time we know + no more text is coming, so the held-back tail is safe to flush. + """ + parts: list[str] = [] + pos = 0 + while pos < len(text): + start = text.find(self._start, pos) + if start < 0: + remainder = text[pos:] + if hold_marker_tail: + safe = _safe_outside_flush_index(remainder, self._start) + parts.append(remainder[:safe]) + else: + parts.append(remainder) + break + parts.append(text[pos:start]) + payload_start = start + len(self._start) + end = text.find(self._end, payload_start) + if end < 0: + # Open block, not yet closed — nothing more to append to content. + break + pos = end + len(self._end) + return "".join(parts) + + # ------------------------------------------------------------------ + # Tool-call fragments + # ------------------------------------------------------------------ + + def _emit_new_tool_call_fragments(self, current_text: str) -> list[DeltaToolCall]: + deltas: list[DeltaToolCall] = [] + for i, (payload, is_complete) in enumerate(self._iter_tool_call_blocks(current_text)): + self._ensure_slot(i) + + if not self._sent_name[i]: + name = self._parser.extract_partial_name(payload) + if name is None: + # Per OpenAI streaming convention the name is sent first; + # don't advance to a later block until this one has one. + break + tool_id = f"chatcmpl-tool-{random_uuid()}" + self._sent_name[i] = True + self._sent_id[i] = tool_id + deltas.append( + DeltaToolCall( + index=i, + id=tool_id, + type="function", + function=DeltaFunctionCall(name=name), + ) + ) + + args = self._parser.extract_partial_args(payload) + if args is not None and len(args) > len(self._sent_args[i]): + diff = args[len(self._sent_args[i]) :] + self._sent_args[i] = args + deltas.append( + DeltaToolCall( + index=i, + function=DeltaFunctionCall(arguments=diff), + ) + ) + + if is_complete and i == len(self._finalized_calls) and self._sent_name[i]: + self._finalized_calls.append( + ToolCall( + id=self._sent_id[i], + type="function", + function=FunctionCall( + name=self._extract_committed_name(i), + arguments=self._sent_args[i], + ), + ) + ) + return deltas + + def _ensure_slot(self, i: int) -> None: + while len(self._sent_name) <= i: + self._sent_name.append(False) + self._sent_id.append("") + self._sent_args.append("") + + def _extract_committed_name(self, i: int) -> str: + # We only stash the bool that the name was sent, not the value, so + # re-derive from the current payload (cheap, the regex is small). + for j, (payload, _) in enumerate(self._iter_tool_call_blocks(self._last_text)): + if j == i: + name = self._parser.extract_partial_name(payload) + return name or "" + return "" + + def _iter_tool_call_blocks(self, text: str): + """Yield ``(partial_payload, is_complete)`` for each tool-call region in order. + + For the still-open final block, the partial payload has any tail + suffix that could be the start of ``end_marker`` withheld, so the + client never sees a fragment of the closing tag forwarded as + argument bytes. + """ + pos = 0 + while True: + start = text.find(self._start, pos) + if start < 0: + return + payload_start = start + len(self._start) + end = text.find(self._end, payload_start) + if end < 0: + partial = text[payload_start:] + safe = _safe_outside_flush_index(partial, self._end) + yield partial[:safe], False + return + yield text[payload_start:end], True + pos = end + len(self._end) + + +def _safe_outside_flush_index(buf: str, start_marker: str) -> int: + """Index up to which ``buf`` can be flushed without risking a split marker. + + ``buf`` is known not to contain the full marker. The unsafe tail is the + longest proper-prefix overlap between ``buf`` and ``start_marker``: if + the next chunk completes that prefix, we'd have to retract bytes we + already streamed. Holding them back avoids the retraction. + """ + max_overlap = min(len(buf), len(start_marker) - 1) + for k in range(max_overlap, 0, -1): + if buf.endswith(start_marker[:k]): + return len(buf) - k + return len(buf) diff --git a/modelship/openai/tool_calling/parsers/hermes.py b/modelship/openai/tool_calling/parsers/hermes.py index 342722a..97487db 100644 --- a/modelship/openai/tool_calling/parsers/hermes.py +++ b/modelship/openai/tool_calling/parsers/hermes.py @@ -8,42 +8,37 @@ from __future__ import annotations -import json import re -from modelship.openai.protocol import FunctionCall, ToolCall -from modelship.openai.tool_calling.parsers.base import ParsedToolCalls, ToolCallParser - -_TOOL_CALL_RE = re.compile(r"(.*?)", re.DOTALL) +from modelship.openai.tool_calling.parsers.base import ToolCallParser class HermesToolCallParser(ToolCallParser): name = "hermes" - - def parse(self, text: str) -> ParsedToolCalls: - matches = list(_TOOL_CALL_RE.finditer(text)) - if not matches: - return ParsedToolCalls(content=text, tool_calls=[]) - - tool_calls: list[ToolCall] = [] - for m in matches: - try: - payload = json.loads(m.group(1).strip()) - except json.JSONDecodeError: - # Malformed block — leave the text in residual, skip this call. - continue - if not isinstance(payload, dict): - continue - name = payload.get("name") - if not isinstance(name, str) or not name: - continue - arguments = payload.get("arguments", {}) - if not isinstance(arguments, str): - arguments = json.dumps(arguments) - tool_calls.append(ToolCall(function=FunctionCall(name=name, arguments=arguments))) - - if not tool_calls: - return ParsedToolCalls(content=text, tool_calls=[]) - - residual = _TOOL_CALL_RE.sub("", text).strip() - return ParsedToolCalls(content=residual or None, tool_calls=tool_calls) + start_marker = "" + end_marker = "" + + _NAME_RE = re.compile(r'"name"\s*:\s*"([^"]+)"') + _ARGS_RE = re.compile(r'"arguments"\s*:\s*') + + def extract_partial_name(self, partial_payload: str) -> str | None: + m = self._NAME_RE.search(partial_payload) + return m.group(1) if m else None + + def extract_partial_args(self, partial_payload: str) -> str | None: + m = self._ARGS_RE.search(partial_payload) + if m is None: + return None + args = partial_payload[m.end() :].rstrip() + if args.endswith("}"): + # The block envelope is `{"name":"x","arguments":}`. The + # closing brace of the envelope arrives in the byte stream before + # `` does, so we cannot tell whether any given + # trailing `}` belongs to the args object or to the envelope. + # Withholding one trailing `}` keeps the args stream well-formed: + # if the model goes on to emit more args bytes, the held brace is + # recovered on the next pass; if instead it goes on to emit + # ``, the held brace was the envelope closer and + # discarding it was correct. + args = args[:-1].rstrip() + return args or None diff --git a/tests/test_integration.py b/tests/test_integration.py index 9433136..dc317bc 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,3 +1,4 @@ +import json import subprocess import time @@ -249,6 +250,146 @@ def test_tool_calling_transformers_loader(client): assert completion.choices[0].finish_reason == "tool_calls" +def _collect_streaming_tool_call(stream) -> dict: + """Drain an OpenAI streaming response and rebuild the assistant message. + + Returns a dict with: ``content`` (concatenated content deltas), + ``tool_calls`` (per-index dict of ``{id, name, arguments}`` — arguments + concatenated across all fragments), ``finish_reason``, ``name_deltas`` + and ``args_deltas`` (counts, used to assert that streaming was actually + incremental rather than a single buffered emission). + """ + content_parts: list[str] = [] + tool_calls: dict[int, dict] = {} + finish_reason: str | None = None + name_deltas = 0 + args_deltas = 0 + chunks_with_tool_calls = 0 + + for chunk in stream: + choice = chunk.choices[0] + delta = choice.delta + if delta.content: + content_parts.append(delta.content) + if delta.tool_calls: + chunks_with_tool_calls += 1 + for tc in delta.tool_calls: + slot = tool_calls.setdefault(tc.index, {"id": None, "name": None, "arguments": ""}) + if tc.id is not None: + slot["id"] = tc.id + if tc.function and tc.function.name: + slot["name"] = tc.function.name + name_deltas += 1 + if tc.function and tc.function.arguments: + slot["arguments"] += tc.function.arguments + args_deltas += 1 + if choice.finish_reason is not None: + finish_reason = choice.finish_reason + + return { + "content": "".join(content_parts), + "tool_calls": tool_calls, + "finish_reason": finish_reason, + "name_deltas": name_deltas, + "args_deltas": args_deltas, + "chunks_with_tool_calls": chunks_with_tool_calls, + } + + +@pytest.mark.integration +def test_tool_calling_streaming_transformers_loader(client): + """Stream a tool call through the transformers loader and verify the + delta sequence matches the OpenAI streaming contract. + + Asserts: + - the function name arrives in exactly one delta; + - arguments arrive across multiple deltas (incremental, not buffered); + - concatenated arguments form valid JSON containing the expected key; + - the final delta carries ``finish_reason="tool_calls"``. + """ + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather for a city", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + }, + } + ] + stream = client.chat.completions.create( + model="chat-transformers", + messages=[{"role": "user", "content": "What is the weather in Paris?"}], + tools=tools, + tool_choice="auto", + max_tokens=128, + stream=True, + ) + + collected = _collect_streaming_tool_call(stream) + + assert collected["tool_calls"], f"expected at least one streamed tool call; got content={collected['content']!r}" + call_0 = collected["tool_calls"][0] + assert call_0["id"], "expected an id on the first tool-call delta" + assert call_0["name"] == "get_weather" + # Name must be sent exactly once (not on every delta). + assert collected["name_deltas"] == 1, f"expected one name delta, got {collected['name_deltas']}" + # Arguments must arrive incrementally across multiple deltas — that's the + # whole point of switching from block-level buffering to vLLM-style + # diff streaming. Exact count depends on the model, but it must be > 1. + assert collected["args_deltas"] >= 2, ( + f"expected arguments to stream incrementally, got {collected['args_deltas']} args delta(s)" + ) + # Concatenated args must form valid JSON containing the city. + parsed_args = json.loads(call_0["arguments"]) + assert parsed_args.get("city") + assert "Paris" in parsed_args["city"] + assert collected["finish_reason"] == "tool_calls" + + +@pytest.mark.integration +def test_tool_calling_streaming_vllm_loader(client): + """Smoke-test that vLLM streaming + tool calling still works through + the gateway. vLLM emits its own per-token deltas; we only verify that + the gateway forwards them and that the final assistant message rebuilds + correctly. + """ + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather for a city", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + }, + } + ] + stream = client.chat.completions.create( + model="chat-capable", + messages=[{"role": "user", "content": "What is the weather in Paris?"}], + tools=tools, + tool_choice="required", + stream=True, + ) + + collected = _collect_streaming_tool_call(stream) + + assert collected["tool_calls"], "vLLM should have streamed at least one tool call" + call_0 = collected["tool_calls"][0] + assert call_0["name"] == "get_weather" + parsed_args = json.loads(call_0["arguments"]) + assert "Paris" in parsed_args.get("city", "") + assert collected["finish_reason"] == "tool_calls" + + @pytest.mark.integration def test_tool_calling_unsupported_loader(client): """Verifies that loaders without tool support (like llama_cpp) don't return tool calls.""" diff --git a/tests/test_tool_calling.py b/tests/test_tool_calling.py index e5aa61f..6677e3e 100644 --- a/tests/test_tool_calling.py +++ b/tests/test_tool_calling.py @@ -8,12 +8,13 @@ import pytest from modelship.openai.tool_calling import ( + ToolCallStreamer, available_parsers, get_parser, register_parser, resolve_tools_for_request, ) -from modelship.openai.tool_calling.parsers import HermesToolCallParser, ParsedToolCalls, ToolCallParser +from modelship.openai.tool_calling.parsers import HermesToolCallParser, ToolCallParser class TestRegistry: @@ -32,9 +33,14 @@ def test_unknown_parser_raises_with_available_list(self): def test_register_parser_makes_it_findable(self): class Stub(ToolCallParser): name = "stub-test-parser" + start_marker = "<<" + end_marker = ">>" - def parse(self, text: str) -> ParsedToolCalls: - return ParsedToolCalls(content=text, tool_calls=[]) + def extract_partial_name(self, partial_payload: str) -> str | None: + return None + + def extract_partial_args(self, partial_payload: str) -> str | None: + return None register_parser(Stub()) try: @@ -77,22 +83,27 @@ def test_tool_call_with_residual_text(self): assert len(result.tool_calls) == 1 assert result.content == "Sure, calling that." - def test_string_arguments_passed_through(self): + def test_string_arguments_forwarded_verbatim(self): + # vLLM-style: the streamer forwards the raw bytes of the arguments + # value as the model emitted them, including any surrounding quotes + # if the model wrapped its arguments in a JSON string literal. The + # OpenAI streaming contract treats `arguments` as an opaque string + # the client concatenates and parses. text = '{"name": "x", "arguments": "{\\"a\\": 1}"}' result = self.parser.parse(text) - assert result.tool_calls[0].function.arguments == '{"a": 1}' + assert result.tool_calls[0].function.arguments == '"{\\"a\\": 1}"' - def test_object_arguments_serialized_to_json(self): + def test_object_arguments_passed_through(self): text = '{"name": "x", "arguments": {"a": 1, "b": [2, 3]}}' result = self.parser.parse(text) assert json.loads(result.tool_calls[0].function.arguments) == {"a": 1, "b": [2, 3]} - def test_malformed_json_block_drops_call_and_falls_back_to_content(self): + def test_block_without_extractable_name_is_dropped(self): + # When the block contains nothing the name regex can hook onto we + # silently drop it — there is nothing to tell the client about. text = "{not valid json}" result = self.parser.parse(text) assert result.tool_calls == [] - # Malformed block stays in content as-is. - assert result.content == text def test_missing_name_drops_call(self): text = '{"arguments": {}}' @@ -112,6 +123,118 @@ def test_each_tool_call_gets_unique_id(self): assert result.tool_calls[0].id != result.tool_calls[1].id +class TestToolCallStreamer: + """Drive the streamer one chunk at a time and verify the deltas. + + The cumulative-text protocol matches what serving_chat does in production: + every fed string is the *full* generated text so far, not just the latest + delta. + """ + + def _feed(self, chunks: list[str]) -> tuple[ToolCallStreamer, list]: + streamer = ToolCallStreamer(HermesToolCallParser()) + deltas = [] + cumulative = "" + for chunk in chunks: + cumulative += chunk + d = streamer.extract_streaming(cumulative) + if d is not None: + deltas.append(d) + final = streamer.finalize() + if final is not None: + deltas.append(final) + return streamer, deltas + + def test_pure_content_streams_immediately(self): + _, deltas = self._feed(["Hello", " ", "world"]) + assert "".join(d.content or "" for d in deltas) == "Hello world" + assert all(not d.tool_calls for d in deltas) + + def test_holds_back_marker_prefix_in_content_until_finalize(self): + # `<` could be the first char of ``; the streamer must not + # ship it mid-stream. Once finalize() runs (no more text coming) the + # held tail is safe to flush as content. + streamer = ToolCallStreamer(HermesToolCallParser()) + mid = streamer.extract_streaming("before <") + assert mid is not None and mid.content == "before " + final = streamer.finalize() + assert final is not None and final.content == "<" + + def test_held_tail_flushes_when_disambiguated(self): + # `{"name": "get_weather", "arguments": {"city": "Paris"}}') + _, deltas = self._feed(chunks) + + tool_deltas = [tc for d in deltas for tc in d.tool_calls] + # First tool delta: name + id, no arguments. + assert tool_deltas[0].function is not None + assert tool_deltas[0].function.name == "get_weather" + assert tool_deltas[0].id is not None + assert tool_deltas[0].function.arguments is None + # Subsequent deltas carry arguments fragments only (no name). + arg_deltas = tool_deltas[1:] + assert all(d.function and d.function.name is None for d in arg_deltas) + # Concatenated arguments form valid JSON. + joined_args = "".join(d.function.arguments or "" for d in arg_deltas) + assert json.loads(joined_args) == {"city": "Paris"} + + def test_arguments_stream_incrementally(self): + # Feed the args char-by-char; each char-after-name should generate + # an args delta of length 1 (or close to it). + prefix = '{"name": "ping", "arguments": ' + suffix = '{"x": 42}}' + _, deltas = self._feed([prefix, *list(suffix)]) + + arg_deltas = [tc for d in deltas for tc in d.tool_calls if tc.function and tc.function.arguments is not None] + assert len(arg_deltas) >= 3 # incremental, not one big shot + joined = "".join(d.function.arguments or "" for d in arg_deltas) + assert json.loads(joined) == {"x": 42} + + def test_multiple_tool_calls_get_distinct_indices(self): + text = ( + '{"name": "a", "arguments": {"x": 1}}' + '{"name": "b", "arguments": {"y": 2}}' + ) + _, deltas = self._feed([text]) + + tool_deltas = [tc for d in deltas for tc in d.tool_calls] + indices = {d.index for d in tool_deltas} + assert indices == {0, 1} + names = [d.function.name for d in tool_deltas if d.function and d.function.name] + assert names == ["a", "b"] + + def test_content_after_tool_call_resumes_streaming(self): + text = '{"name": "p", "arguments": {}} ok' + _, deltas = self._feed([text]) + joined_content = "".join(d.content or "" for d in deltas) + assert "ok" in joined_content + + def test_partial_name_held_until_closing_quote(self): + # While the model is mid-name (``"name": "get_wea`` so far), the + # streamer must NOT send a partial name — wait for the closing quote. + streamer = ToolCallStreamer(HermesToolCallParser()) + partial = '{"name": "get_wea' + d = streamer.extract_streaming(partial) + # No name yet → no tool delta. + assert d is None or all(not tc.function or not tc.function.name for tc in d.tool_calls) + + def test_unterminated_block_is_finalized_without_crash(self): + # Model stops mid tool-call. The streamer should not raise on finalize + # and should not claim a finalized ToolCall (the call wasn't closed). + text = '{"name": "incomplete", "arguments": {"a": 1' + streamer, _ = self._feed([text]) + assert streamer.result.tool_calls == [] + + class TestResolveToolsForRequest: tools: ClassVar = [ {"type": "function", "function": {"name": "alpha"}}, From 68531f38789f0ad84403979ec65e4dd1ec593f1e Mon Sep 17 00:00:00 2001 From: Alex M Date: Mon, 4 May 2026 17:01:18 +0000 Subject: [PATCH 3/6] fix: make tool-call finalization robust to skipped blocks Replaces the fragile index-equals-length condition for finalizing tool-call blocks with a dedicated `_finalized_indices` set. This ensures valid blocks are correctly finalized even if preceding blocks are malformed and skipped by the streaming parser. Also allows the parser to continue processing subsequent blocks when a malformed complete block (missing a valid function name) is encountered. --- modelship/openai/tool_calling/parsers/base.py | 13 ++++++++++--- tests/test_tool_calling.py | 10 ++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/modelship/openai/tool_calling/parsers/base.py b/modelship/openai/tool_calling/parsers/base.py index 678ff23..c9bd771 100644 --- a/modelship/openai/tool_calling/parsers/base.py +++ b/modelship/openai/tool_calling/parsers/base.py @@ -98,6 +98,7 @@ class ToolCallStreamer: and id deltas have been emitted. - ``_sent_args`` per tool index — number of arguments chars already shipped (the suffix-diff cursor). + - ``_finalized_indices`` — set of block indices that have been finalized. - ``_finalized_calls`` — :class:`ToolCall` objects accumulated for blocks that have closed, used to populate :attr:`result` for the non-streaming path and for the final ``finish_reason``. @@ -111,6 +112,7 @@ def __init__(self, parser: ToolCallParser): self._sent_name: list[bool] = [] self._sent_id: list[str] = [] self._sent_args: list[str] = [] + self._finalized_indices: set[int] = set() self._finalized_calls: list[ToolCall] = [] self._content_parts_len = 0 # bookkeeping for `result.content` self._last_text = "" @@ -192,8 +194,12 @@ def _emit_new_tool_call_fragments(self, current_text: str) -> list[DeltaToolCall if not self._sent_name[i]: name = self._parser.extract_partial_name(payload) if name is None: - # Per OpenAI streaming convention the name is sent first; - # don't advance to a later block until this one has one. + if is_complete: + # This block is finished but has no extractable name (malformed). + # Skip it so we can potentially process later blocks. + continue + # Mid-stream and no name yet; per OpenAI convention we don't + # advance to later blocks until the current one has a name. break tool_id = f"chatcmpl-tool-{random_uuid()}" self._sent_name[i] = True @@ -218,7 +224,8 @@ def _emit_new_tool_call_fragments(self, current_text: str) -> list[DeltaToolCall ) ) - if is_complete and i == len(self._finalized_calls) and self._sent_name[i]: + if is_complete and i not in self._finalized_indices and self._sent_name[i]: + self._finalized_indices.add(i) self._finalized_calls.append( ToolCall( id=self._sent_id[i], diff --git a/tests/test_tool_calling.py b/tests/test_tool_calling.py index 6677e3e..98f50b1 100644 --- a/tests/test_tool_calling.py +++ b/tests/test_tool_calling.py @@ -234,6 +234,16 @@ def test_unterminated_block_is_finalized_without_crash(self): streamer, _ = self._feed([text]) assert streamer.result.tool_calls == [] + def test_skipped_malformed_block_does_not_block_subsequent_finalization(self): + # If the first tool-call block is malformed (missing name), it should be + # skipped, but the SECOND block (if valid) should still be finalized. + text = ( + '{"arguments": {"skipped": true}}' + '{"name": "valid", "arguments": {"x": 1}}' + ) + streamer, _ = self._feed([text]) + assert [tc.function.name for tc in streamer.result.tool_calls] == ["valid"] + class TestResolveToolsForRequest: tools: ClassVar = [ From 565a3a77d6a4021098bc990d6a6efada45c19370 Mon Sep 17 00:00:00 2001 From: Alex M Date: Mon, 4 May 2026 17:06:35 +0000 Subject: [PATCH 4/6] fix: maintain consistent created timestamp in chat streaming The OpenAI specification requires the `created` timestamp to remain consistent across all chunks in a streaming response. Previously, the transformers loader recalculated the timestamp for each chunk using `int(time.time())` inside `_delta_chunk` and for the final finish/usage chunks. Now, the timestamp is calculated once at the start of `_stream` and explicitly passed to all chunk generation functions. --- .../infer/transformers/openai/serving_chat.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/modelship/infer/transformers/openai/serving_chat.py b/modelship/infer/transformers/openai/serving_chat.py index 6cf4c94..21cf923 100644 --- a/modelship/infer/transformers/openai/serving_chat.py +++ b/modelship/infer/transformers/openai/serving_chat.py @@ -206,6 +206,7 @@ async def _stream( *, include_usage: bool, ) -> AsyncGenerator[str, None]: + created_time = int(time.time()) streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) # type: ignore[arg-type] kwargs = {**self.config.pipeline_kwargs} @@ -233,14 +234,14 @@ async def _stream( accumulated: list[str] = [] try: # Per OpenAI spec, the first delta carries `role` only. - yield self._delta_chunk(request_id, DeltaMessage(role="assistant")) + yield self._delta_chunk(request_id, DeltaMessage(role="assistant"), created_time) for text_chunk in streamer: if not text_chunk: continue accumulated.append(text_chunk) if tool_call_streamer is None: - yield self._delta_chunk(request_id, DeltaMessage(content=text_chunk)) + yield self._delta_chunk(request_id, DeltaMessage(content=text_chunk), created_time) await asyncio.sleep(0) continue # Re-parse the cumulative text and emit any new content / tool-call @@ -248,13 +249,13 @@ async def _stream( # the client never sees a half-formed `` opening tag. delta = tool_call_streamer.extract_streaming("".join(accumulated)) if delta is not None: - yield self._delta_chunk(request_id, delta) + yield self._delta_chunk(request_id, delta, created_time) await asyncio.sleep(0) if tool_call_streamer is not None: final = tool_call_streamer.finalize() if final is not None: - yield self._delta_chunk(request_id, final) + yield self._delta_chunk(request_id, final, created_time) parsed = tool_call_streamer.result else: parsed = ParsedToolCalls("".join(accumulated), []) @@ -274,7 +275,7 @@ async def _stream( finish_reason=finish_reason, ) ], - created=int(time.time()), + created=created_time, ) ) @@ -290,7 +291,7 @@ async def _stream( completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, ), - created=int(time.time()), + created=created_time, ) ) @@ -298,7 +299,7 @@ async def _stream( finally: thread.join() - def _delta_chunk(self, request_id: str, delta: DeltaMessage) -> str: + def _delta_chunk(self, request_id: str, delta: DeltaMessage, created_time: int) -> str: return self._encode_chunk( ChatCompletionStreamResponse( id=request_id, @@ -306,7 +307,7 @@ def _delta_chunk(self, request_id: str, delta: DeltaMessage) -> str: choices=[ ChatCompletionResponseStreamChoice(index=0, delta=delta), ], - created=int(time.time()), + created=created_time, ) ) From 6f72596c96f318a096b6e24ce6dca971537a154d Mon Sep 17 00:00:00 2001 From: Alex M Date: Mon, 4 May 2026 17:07:58 +0000 Subject: [PATCH 5/6] chore: remove unused `_content_parts_len` attribute from ToolCallStreamer --- modelship/openai/tool_calling/parsers/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modelship/openai/tool_calling/parsers/base.py b/modelship/openai/tool_calling/parsers/base.py index c9bd771..e8197ed 100644 --- a/modelship/openai/tool_calling/parsers/base.py +++ b/modelship/openai/tool_calling/parsers/base.py @@ -114,7 +114,6 @@ def __init__(self, parser: ToolCallParser): self._sent_args: list[str] = [] self._finalized_indices: set[int] = set() self._finalized_calls: list[ToolCall] = [] - self._content_parts_len = 0 # bookkeeping for `result.content` self._last_text = "" def extract_streaming(self, current_text: str) -> DeltaMessage | None: From d7e9d24b0bf05ccbae18c8885d905d13f6c378a6 Mon Sep 17 00:00:00 2001 From: Alex M Date: Mon, 4 May 2026 17:22:53 +0000 Subject: [PATCH 6/6] perf: optimize transformers stream complexity Eliminate the O(N^2) complexity caused by calling `"".join(accumulated)` inside the chunk-by-chunk stream loop. Now we maintain a running cumulative string `accumulated_str` that is built via fast appends rather than allocating and joining the entire array of previously yielded tokens on every new token. --- modelship/infer/transformers/openai/serving_chat.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/modelship/infer/transformers/openai/serving_chat.py b/modelship/infer/transformers/openai/serving_chat.py index 21cf923..81fd0d4 100644 --- a/modelship/infer/transformers/openai/serving_chat.py +++ b/modelship/infer/transformers/openai/serving_chat.py @@ -232,6 +232,7 @@ async def _stream( thread.start() accumulated: list[str] = [] + accumulated_str = "" try: # Per OpenAI spec, the first delta carries `role` only. yield self._delta_chunk(request_id, DeltaMessage(role="assistant"), created_time) @@ -240,6 +241,7 @@ async def _stream( if not text_chunk: continue accumulated.append(text_chunk) + accumulated_str += text_chunk if tool_call_streamer is None: yield self._delta_chunk(request_id, DeltaMessage(content=text_chunk), created_time) await asyncio.sleep(0) @@ -247,7 +249,7 @@ async def _stream( # Re-parse the cumulative text and emit any new content / tool-call # fragments. Bounded held-back tail (length of the start marker) so # the client never sees a half-formed `` opening tag. - delta = tool_call_streamer.extract_streaming("".join(accumulated)) + delta = tool_call_streamer.extract_streaming(accumulated_str) if delta is not None: yield self._delta_chunk(request_id, delta, created_time) await asyncio.sleep(0) @@ -258,10 +260,9 @@ async def _stream( yield self._delta_chunk(request_id, final, created_time) parsed = tool_call_streamer.result else: - parsed = ParsedToolCalls("".join(accumulated), []) + parsed = ParsedToolCalls(accumulated_str, []) - completion_text = "".join(accumulated) - completion_tokens = len(self.tokenizer.encode(completion_text, add_special_tokens=False)) + completion_tokens = len(self.tokenizer.encode(accumulated_str, add_special_tokens=False)) finish_reason = self._finish_reason(parsed, completion_tokens, max_tokens) yield self._encode_chunk(