diff --git a/docs/model-configuration.md b/docs/model-configuration.md
index 6ef64dd..3024ee7 100644
--- a/docs/model-configuration.md
+++ b/docs/model-configuration.md
@@ -219,6 +219,7 @@ The `transformers` loader uses PyTorch with HuggingFace Transformers. Supports c
| `trust_remote_code` | bool | `false` | Allow remote code execution |
| `model_kwargs` | object | `{}` | Extra keyword arguments passed to the model constructor |
| `pipeline_kwargs` | object | `{}` | Extra keyword arguments passed to the pipeline at inference time |
+| `tool_call_parser` | string | `hermes` | Parser used to turn raw model output into OpenAI `tool_calls`. Currently supported: `hermes` (Hermes-2-Pro / Qwen2.5-Instruct / many community fine-tunes that emit `{...}` markers). |
### Chat / Text Generation (CPU)
@@ -233,6 +234,25 @@ models:
device: "cpu"
```
+### Chat with Tool Calling (CPU)
+
+The transformers loader renders `tools` into the prompt via the model's chat
+template and parses the output back into OpenAI `tool_calls`. The model must
+have been trained on a Hermes-style tool format (Qwen2.5-Instruct, Hermes-2,
+many community fine-tunes); the parser is selected via `tool_call_parser`.
+
+```yaml
+models:
+ - name: qwen-tools
+ model: Qwen/Qwen2.5-0.5B-Instruct
+ usecase: generate
+ loader: transformers
+ num_cpus: 2
+ transformers_config:
+ device: "cpu"
+ tool_call_parser: hermes # this is the default; shown for clarity
+```
+
### Speech-to-Text (CPU)
Audio is automatically decoded and resampled to the model's expected sample rate (e.g. 16kHz for Whisper).
diff --git a/modelship/infer/transformers/openai/serving_chat.py b/modelship/infer/transformers/openai/serving_chat.py
index 35f7d26..81fd0d4 100644
--- a/modelship/infer/transformers/openai/serving_chat.py
+++ b/modelship/infer/transformers/openai/serving_chat.py
@@ -3,6 +3,7 @@
import time
from collections.abc import AsyncGenerator
from threading import Thread
+from typing import Any
from transformers import Pipeline, PreTrainedTokenizerBase, TextIteratorStreamer
@@ -23,6 +24,7 @@
UsageInfo,
create_error_response,
)
+from modelship.openai.tool_calling import ParsedToolCalls, ToolCallStreamer, get_parser, resolve_tools_for_request
from modelship.utils import base_request_id
logger = get_logger("infer.transformers.chat")
@@ -45,12 +47,16 @@ def __init__(
assert pipeline.tokenizer is not None, "text-generation pipeline must have a tokenizer"
self.tokenizer: PreTrainedTokenizerBase = pipeline.tokenizer
self._lock = asyncio.Lock()
+ # Validate the configured parser at startup so misconfiguration surfaces
+ # before the first request rather than mid-generation.
+ get_parser(self.config.tool_call_parser)
async def warmup(self) -> None:
logger.info("Warming up chat model: %s", self.model_name)
await self.run_in_executor(
self._run,
[{"role": "user", "content": "warmup"}],
+ None,
1,
)
logger.info("Warmup chat done for %s", self.model_name)
@@ -74,26 +80,29 @@ async def create_chat_completion(
logger.warning("chat request %s rejected: %s", request_id, e)
return create_error_response(e)
+ tools = resolve_tools_for_request(request.tools, request.tool_choice)
+
max_tokens = request.max_tokens
if max_tokens is None and request.max_completion_tokens is not None:
max_tokens = request.max_completion_tokens
if request.stream:
include_usage = bool(request.stream_options and request.stream_options.include_usage)
- return self._locked_stream(request_id, messages, max_tokens, include_usage=include_usage)
+ return self._locked_stream(request_id, messages, tools, max_tokens, include_usage=include_usage)
async with self._lock:
try:
- result = await self.run_in_executor(self._run, messages, max_tokens)
+ result = await self.run_in_executor(self._run, messages, tools, max_tokens)
except Exception:
logger.exception("chat completion inference failed for %s", request_id)
return create_error_response("chat completion inference failed")
- prompt_tokens = self._count_prompt_tokens(messages)
- generated = result[0]["generated_text"]
- completion_text = generated[-1]["content"] if isinstance(generated, list) else generated
+ prompt_tokens = self._count_prompt_tokens(messages, tools)
+ completion_text = self._extract_completion_text(result)
completion_tokens = len(self.tokenizer.encode(completion_text, add_special_tokens=False))
- finish_reason = "length" if (max_tokens is not None and completion_tokens >= max_tokens) else "stop"
+
+ parsed = self._parse_tool_calls(completion_text) if tools else ParsedToolCalls(completion_text, [])
+ finish_reason = self._finish_reason(parsed, completion_tokens, max_tokens)
response = ChatCompletionResponse(
id=request_id,
@@ -101,7 +110,11 @@ async def create_chat_completion(
choices=[
ChatCompletionResponseChoice(
index=0,
- message=ChatMessage(role="assistant", content=completion_text),
+ message=ChatMessage(
+ role="assistant",
+ content=parsed.content,
+ tool_calls=parsed.tool_calls,
+ ),
finish_reason=finish_reason,
)
],
@@ -114,54 +127,102 @@ async def create_chat_completion(
)
logger.log(
TRACE,
- "chat response %s: text=%r, prompt_tokens=%d, completion_tokens=%d",
+ "chat response %s: text=%r, tool_calls=%d, prompt_tokens=%d, completion_tokens=%d",
request_id,
completion_text,
+ len(parsed.tool_calls),
prompt_tokens,
completion_tokens,
)
return response
- def _count_prompt_tokens(self, messages: list[dict]) -> int:
+ def _parse_tool_calls(self, text: str) -> ParsedToolCalls:
+ return get_parser(self.config.tool_call_parser).parse(text)
+
+ @staticmethod
+ def _finish_reason(parsed: ParsedToolCalls, completion_tokens: int, max_tokens: int | None) -> str:
+ if parsed.has_tool_calls:
+ return "tool_calls"
+ if max_tokens is not None and completion_tokens >= max_tokens:
+ return "length"
+ return "stop"
+
+ @staticmethod
+ def _extract_completion_text(result: list) -> str:
+ generated = result[0]["generated_text"]
+ if isinstance(generated, list):
+ return generated[-1]["content"]
+ return generated
+
+ def _count_prompt_tokens(self, messages: list[dict], tools: list[dict[str, Any]] | None) -> int:
# apply_chat_template returns a string by default (character count!) — force tokenize=True.
- token_ids = self.tokenizer.apply_chat_template(messages, tokenize=True)
+ kwargs: dict[str, Any] = {"tokenize": True}
+ if tools:
+ kwargs["tools"] = tools
+ token_ids = self.tokenizer.apply_chat_template(messages, **kwargs)
return len(token_ids)
- def _run(self, messages: list[dict], max_tokens: int | None) -> list:
+ def _render_prompt(self, messages: list[dict], tools: list[dict[str, Any]]) -> str:
+ rendered = self.tokenizer.apply_chat_template(
+ messages,
+ tools=tools, # type: ignore[arg-type]
+ tokenize=False,
+ add_generation_prompt=True,
+ )
+ assert isinstance(rendered, str), "apply_chat_template(tokenize=False) must return str"
+ return rendered
+
+ def _run(self, messages: list[dict], tools: list[dict[str, Any]] | None, max_tokens: int | None) -> list:
kwargs = {**self.config.pipeline_kwargs}
if max_tokens is not None:
kwargs["max_new_tokens"] = max_tokens
+ if tools:
+ # The standard text-generation pipeline does not forward `tools` to
+ # `apply_chat_template`, so we render the prompt ourselves and feed
+ # it as a plain string.
+ prompt = self._render_prompt(messages, tools)
+ return self.pipeline(prompt, return_full_text=False, **kwargs) # type: ignore[return-value]
return self.pipeline(messages, return_full_text=False, **kwargs) # type: ignore[return-value]
async def _locked_stream(
self,
request_id: str,
messages: list[dict],
+ tools: list[dict[str, Any]] | None,
max_tokens: int | None,
*,
include_usage: bool,
) -> AsyncGenerator[str, None]:
async with self._lock:
- async for chunk in self._stream(request_id, messages, max_tokens, include_usage=include_usage):
+ async for chunk in self._stream(request_id, messages, tools, max_tokens, include_usage=include_usage):
yield chunk
async def _stream(
self,
request_id: str,
messages: list[dict],
+ tools: list[dict[str, Any]] | None,
max_tokens: int | None,
*,
include_usage: bool,
) -> AsyncGenerator[str, None]:
+ created_time = int(time.time())
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) # type: ignore[arg-type]
kwargs = {**self.config.pipeline_kwargs}
if max_tokens is not None:
kwargs["max_new_tokens"] = max_tokens
+ if tools:
+ prompt: Any = self._render_prompt(messages, tools)
+ tool_call_streamer: ToolCallStreamer | None = ToolCallStreamer(get_parser(self.config.tool_call_parser))
+ else:
+ prompt = messages
+ tool_call_streamer = None
+
thread = Thread(
target=self.pipeline,
- args=(messages,),
+ args=(prompt,),
kwargs={
"streamer": streamer,
"return_full_text": False,
@@ -171,44 +232,38 @@ async def _stream(
thread.start()
accumulated: list[str] = []
+ accumulated_str = ""
try:
# Per OpenAI spec, the first delta carries `role` only.
- yield self._encode_chunk(
- ChatCompletionStreamResponse(
- id=request_id,
- model=self.model_name,
- choices=[
- ChatCompletionResponseStreamChoice(
- index=0,
- delta=DeltaMessage(role="assistant"),
- )
- ],
- created=int(time.time()),
- )
- )
+ yield self._delta_chunk(request_id, DeltaMessage(role="assistant"), created_time)
for text_chunk in streamer:
if not text_chunk:
continue
accumulated.append(text_chunk)
- yield self._encode_chunk(
- ChatCompletionStreamResponse(
- id=request_id,
- model=self.model_name,
- choices=[
- ChatCompletionResponseStreamChoice(
- index=0,
- delta=DeltaMessage(content=text_chunk),
- )
- ],
- created=int(time.time()),
- )
- )
+ accumulated_str += text_chunk
+ if tool_call_streamer is None:
+ yield self._delta_chunk(request_id, DeltaMessage(content=text_chunk), created_time)
+ await asyncio.sleep(0)
+ continue
+ # Re-parse the cumulative text and emit any new content / tool-call
+ # fragments. Bounded held-back tail (length of the start marker) so
+ # the client never sees a half-formed `` opening tag.
+ delta = tool_call_streamer.extract_streaming(accumulated_str)
+ if delta is not None:
+ yield self._delta_chunk(request_id, delta, created_time)
await asyncio.sleep(0)
- completion_text = "".join(accumulated)
- completion_tokens = len(self.tokenizer.encode(completion_text, add_special_tokens=False))
- finish_reason = "length" if (max_tokens is not None and completion_tokens >= max_tokens) else "stop"
+ if tool_call_streamer is not None:
+ final = tool_call_streamer.finalize()
+ if final is not None:
+ yield self._delta_chunk(request_id, final, created_time)
+ parsed = tool_call_streamer.result
+ else:
+ parsed = ParsedToolCalls(accumulated_str, [])
+
+ completion_tokens = len(self.tokenizer.encode(accumulated_str, add_special_tokens=False))
+ finish_reason = self._finish_reason(parsed, completion_tokens, max_tokens)
yield self._encode_chunk(
ChatCompletionStreamResponse(
@@ -221,12 +276,12 @@ async def _stream(
finish_reason=finish_reason,
)
],
- created=int(time.time()),
+ created=created_time,
)
)
if include_usage:
- prompt_tokens = self._count_prompt_tokens(messages)
+ prompt_tokens = self._count_prompt_tokens(messages, tools)
yield self._encode_chunk(
ChatCompletionStreamResponse(
id=request_id,
@@ -237,7 +292,7 @@ async def _stream(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
- created=int(time.time()),
+ created=created_time,
)
)
@@ -245,6 +300,18 @@ async def _stream(
finally:
thread.join()
+ def _delta_chunk(self, request_id: str, delta: DeltaMessage, created_time: int) -> str:
+ return self._encode_chunk(
+ ChatCompletionStreamResponse(
+ id=request_id,
+ model=self.model_name,
+ choices=[
+ ChatCompletionResponseStreamChoice(index=0, delta=delta),
+ ],
+ created=created_time,
+ )
+ )
+
@staticmethod
def _encode_chunk(chunk: ChatCompletionStreamResponse) -> str:
return f"data: {json.dumps(chunk.model_dump(mode='json'))}\n\n"
diff --git a/modelship/openai/protocol.py b/modelship/openai/protocol.py
index d09afe7..d4142b1 100644
--- a/modelship/openai/protocol.py
+++ b/modelship/openai/protocol.py
@@ -180,7 +180,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
temperature: float | None = None
top_p: float | None = None
tools: list[dict[str, Any]] | None = None
- tool_choice: str | dict[str, Any] | None = "none"
+ tool_choice: str | dict[str, Any] | None = None
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None
parallel_tool_calls: bool | None = True
user: str | None = None
diff --git a/modelship/openai/tool_calling/__init__.py b/modelship/openai/tool_calling/__init__.py
new file mode 100644
index 0000000..2e448ec
--- /dev/null
+++ b/modelship/openai/tool_calling/__init__.py
@@ -0,0 +1,22 @@
+"""Cross-loader tool-calling toolkit.
+
+Loaders without native tool-call support (Transformers today, plugin-wrapped
+raw-text engines tomorrow) use the parsers and helpers in this package to
+turn raw model output into OpenAI-shape ``tool_calls``. Loaders whose engines
+already emit structured tool calls (vLLM, llama.cpp via a function-calling
+chat handler) bypass it.
+"""
+
+from modelship.openai.tool_calling.input import resolve_tools_for_request
+from modelship.openai.tool_calling.parsers import ParsedToolCalls, ToolCallParser, ToolCallStreamer
+from modelship.openai.tool_calling.registry import available_parsers, get_parser, register_parser
+
+__all__ = [
+ "ParsedToolCalls",
+ "ToolCallParser",
+ "ToolCallStreamer",
+ "available_parsers",
+ "get_parser",
+ "register_parser",
+ "resolve_tools_for_request",
+]
diff --git a/modelship/openai/tool_calling/input.py b/modelship/openai/tool_calling/input.py
new file mode 100644
index 0000000..a245c0c
--- /dev/null
+++ b/modelship/openai/tool_calling/input.py
@@ -0,0 +1,63 @@
+"""Input-side helpers for tool calling.
+
+Loaders that hand a chat template a list of OpenAI messages plus a list of
+tool schemas use these helpers to interpret the request's ``tool_choice``
+(``"none"`` suppresses tools, ``"required"`` / specific-function downgrade
+to ``"auto"`` with a warning) and to validate that a parser exists for the
+configured family before generation starts.
+"""
+
+from __future__ import annotations
+
+from typing import Any
+
+from modelship.logging import get_logger
+
+logger = get_logger("openai.tool_calling.input")
+
+
+def resolve_tools_for_request(
+ tools: list[dict[str, Any]] | None,
+ tool_choice: str | dict[str, Any] | None,
+) -> list[dict[str, Any]] | None:
+ """Apply OpenAI ``tool_choice`` semantics to the request's ``tools`` list.
+
+ Returns the list of tools to render into the prompt, or ``None`` when
+ the request should be served without any tool-calling affordance.
+
+ - ``tool_choice == "none"`` — suppress tools entirely.
+ - ``tool_choice == "auto"`` (or unset) — pass all tools through.
+ - ``tool_choice == "required"`` — pass tools through, log that we cannot
+ strictly enforce a tool call without constrained decoding.
+ - ``tool_choice == {"type": "function", "function": {"name": "X"}}`` —
+ filter ``tools`` to that single function and warn that the call cannot
+ be strictly enforced.
+ """
+ if not tools:
+ return None
+ if tool_choice in (None, "auto"):
+ return tools
+ if tool_choice == "none":
+ return None
+ if tool_choice == "required":
+ logger.warning(
+ "tool_choice='required' requested but this loader cannot enforce a tool call; "
+ "passing all tools to the model and trusting it to call one"
+ )
+ return tools
+ if isinstance(tool_choice, dict):
+ fn = tool_choice.get("function") or {}
+ name = fn.get("name") if isinstance(fn, dict) else None
+ if isinstance(name, str) and name:
+ filtered = [t for t in tools if (t.get("function") or {}).get("name") == name]
+ if not filtered:
+ logger.warning("tool_choice names function %r which is not in the request's tools list", name)
+ return tools
+ logger.warning(
+ "tool_choice forcing function %r is not strictly enforced by this loader; "
+ "passing only that tool to the model",
+ name,
+ )
+ return filtered
+ logger.warning("unrecognized tool_choice value %r; falling back to 'auto' semantics", tool_choice)
+ return tools
diff --git a/modelship/openai/tool_calling/parsers/__init__.py b/modelship/openai/tool_calling/parsers/__init__.py
new file mode 100644
index 0000000..1df7236
--- /dev/null
+++ b/modelship/openai/tool_calling/parsers/__init__.py
@@ -0,0 +1,4 @@
+from modelship.openai.tool_calling.parsers.base import ParsedToolCalls, ToolCallParser, ToolCallStreamer
+from modelship.openai.tool_calling.parsers.hermes import HermesToolCallParser
+
+__all__ = ["HermesToolCallParser", "ParsedToolCalls", "ToolCallParser", "ToolCallStreamer"]
diff --git a/modelship/openai/tool_calling/parsers/base.py b/modelship/openai/tool_calling/parsers/base.py
new file mode 100644
index 0000000..e8197ed
--- /dev/null
+++ b/modelship/openai/tool_calling/parsers/base.py
@@ -0,0 +1,291 @@
+"""Base class for model-family-specific tool-call output parsers.
+
+Parsers in this codebase are *marker-based*: each family wraps tool-call
+JSON in a fixed pair of literal strings (```` / ````
+for Hermes, ``[TOOL_CALLS]`` / closing token for Mistral, ...). A subclass
+declares the marker pair plus two small extractors that pick a function
+name and an arguments substring out of the (possibly partial) JSON between
+the markers. Both the streaming and non-streaming paths run the same
+:class:`ToolCallStreamer` so behavior cannot drift between them.
+"""
+
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+
+from modelship.openai.protocol import (
+ DeltaFunctionCall,
+ DeltaMessage,
+ DeltaToolCall,
+ FunctionCall,
+ ToolCall,
+ random_uuid,
+)
+
+
+@dataclass(frozen=True)
+class ParsedToolCalls:
+ """Aggregate result of parsing a model's full chat-completion text.
+
+ ``content`` carries the residual non-tool-call text once any tool-call
+ markers are stripped. It is ``None`` when tool calls were extracted *and*
+ the residual is empty, matching OpenAI's behavior of nulling ``content``
+ alongside ``tool_calls``.
+ """
+
+ content: str | None
+ tool_calls: list[ToolCall]
+
+ @property
+ def has_tool_calls(self) -> bool:
+ return bool(self.tool_calls)
+
+
+class ToolCallParser(ABC):
+ """Family-specific knobs the streamer needs to drive its diff loop.
+
+ Subclasses set ``start_marker`` / ``end_marker`` and implement the two
+ extractors. They never touch streaming state — that lives on
+ :class:`ToolCallStreamer`, which can be instantiated once per request.
+ """
+
+ name: str
+ start_marker: str
+ end_marker: str
+
+ @abstractmethod
+ def extract_partial_name(self, partial_payload: str) -> str | None:
+ """Return the function name if a complete quoted name is visible yet, else ``None``.
+
+ Called every delta until it yields a non-``None`` result; once a name
+ has been emitted to the client, the streamer stops asking and starts
+ forwarding arguments bytes for that tool.
+ """
+
+ @abstractmethod
+ def extract_partial_args(self, partial_payload: str) -> str | None:
+ """Return the arguments substring as the client should see it so far.
+
+ The streamer takes a length-diff of successive returns and forwards
+ only the new bytes. Implementations must withhold any trailing bytes
+ that could plausibly be the envelope closer landing ahead of the
+ family's end marker — otherwise those bytes leak into the args
+ stream and the client receives malformed JSON.
+ """
+
+ def parse(self, text: str) -> ParsedToolCalls:
+ streamer = ToolCallStreamer(self)
+ streamer.extract_streaming(text)
+ streamer.finalize()
+ return streamer.result
+
+
+class ToolCallStreamer:
+ """Per-request, stateful tool-call extractor.
+
+ Mirrors vLLM's approach: hold a small amount of "what we've already sent"
+ state, re-parse the cumulative ``current_text`` on every delta, and diff.
+ Returns either a :class:`DeltaMessage` carrying any newly-emittable
+ content / tool-call fragments or ``None`` if there is nothing to send yet.
+
+ State held per request:
+
+ - ``_sent_content_idx`` — number of content-stream chars already shipped.
+ The "content stream" view is the original text with every (complete or
+ open) tool-call region excised.
+ - ``_sent_name`` / ``_sent_id`` per tool index — whether the function name
+ and id deltas have been emitted.
+ - ``_sent_args`` per tool index — number of arguments chars already
+ shipped (the suffix-diff cursor).
+ - ``_finalized_indices`` — set of block indices that have been finalized.
+ - ``_finalized_calls`` — :class:`ToolCall` objects accumulated for blocks
+ that have closed, used to populate :attr:`result` for the non-streaming
+ path and for the final ``finish_reason``.
+ """
+
+ def __init__(self, parser: ToolCallParser):
+ self._parser = parser
+ self._start = parser.start_marker
+ self._end = parser.end_marker
+ self._sent_content_idx = 0
+ self._sent_name: list[bool] = []
+ self._sent_id: list[str] = []
+ self._sent_args: list[str] = []
+ self._finalized_indices: set[int] = set()
+ self._finalized_calls: list[ToolCall] = []
+ self._last_text = ""
+
+ def extract_streaming(self, current_text: str) -> DeltaMessage | None:
+ """Run one diff pass against ``current_text`` and return any new deltas."""
+ self._last_text = current_text
+ content_delta = self._emit_new_content(current_text, hold_marker_tail=True)
+ tool_call_deltas = self._emit_new_tool_call_fragments(current_text)
+ if not content_delta and not tool_call_deltas:
+ return None
+ return DeltaMessage(content=content_delta, tool_calls=tool_call_deltas)
+
+ def finalize(self) -> DeltaMessage | None:
+ """Flush any held-back content tail once no more text is coming."""
+ content_delta = self._emit_new_content(self._last_text, hold_marker_tail=False)
+ if content_delta is None:
+ return None
+ return DeltaMessage(content=content_delta)
+
+ @property
+ def result(self) -> ParsedToolCalls:
+ """Final view, suitable for the non-streaming response shape."""
+ # The content-view we accumulated as we streamed; reconstruct from the cursor.
+ view = self._build_content_view(self._last_text, hold_marker_tail=False)
+ content = (view.strip() or None) if self._finalized_calls else (view or None)
+ return ParsedToolCalls(content=content, tool_calls=list(self._finalized_calls))
+
+ # ------------------------------------------------------------------
+ # Content stream
+ # ------------------------------------------------------------------
+
+ def _emit_new_content(self, current_text: str, *, hold_marker_tail: bool) -> str | None:
+ view = self._build_content_view(current_text, hold_marker_tail=hold_marker_tail)
+ if len(view) <= self._sent_content_idx:
+ return None
+ new = view[self._sent_content_idx :]
+ self._sent_content_idx = len(view)
+ return new or None
+
+ def _build_content_view(self, text: str, *, hold_marker_tail: bool) -> str:
+ """Build the content-stream view: the original text with tool-call regions excised.
+
+ When ``hold_marker_tail`` is true (mid-stream), withhold a trailing
+ suffix that could be the start of a new ``start_marker``, so the
+ client never sees half of an opening tag. At finalize time we know
+ no more text is coming, so the held-back tail is safe to flush.
+ """
+ parts: list[str] = []
+ pos = 0
+ while pos < len(text):
+ start = text.find(self._start, pos)
+ if start < 0:
+ remainder = text[pos:]
+ if hold_marker_tail:
+ safe = _safe_outside_flush_index(remainder, self._start)
+ parts.append(remainder[:safe])
+ else:
+ parts.append(remainder)
+ break
+ parts.append(text[pos:start])
+ payload_start = start + len(self._start)
+ end = text.find(self._end, payload_start)
+ if end < 0:
+ # Open block, not yet closed — nothing more to append to content.
+ break
+ pos = end + len(self._end)
+ return "".join(parts)
+
+ # ------------------------------------------------------------------
+ # Tool-call fragments
+ # ------------------------------------------------------------------
+
+ def _emit_new_tool_call_fragments(self, current_text: str) -> list[DeltaToolCall]:
+ deltas: list[DeltaToolCall] = []
+ for i, (payload, is_complete) in enumerate(self._iter_tool_call_blocks(current_text)):
+ self._ensure_slot(i)
+
+ if not self._sent_name[i]:
+ name = self._parser.extract_partial_name(payload)
+ if name is None:
+ if is_complete:
+ # This block is finished but has no extractable name (malformed).
+ # Skip it so we can potentially process later blocks.
+ continue
+ # Mid-stream and no name yet; per OpenAI convention we don't
+ # advance to later blocks until the current one has a name.
+ break
+ tool_id = f"chatcmpl-tool-{random_uuid()}"
+ self._sent_name[i] = True
+ self._sent_id[i] = tool_id
+ deltas.append(
+ DeltaToolCall(
+ index=i,
+ id=tool_id,
+ type="function",
+ function=DeltaFunctionCall(name=name),
+ )
+ )
+
+ args = self._parser.extract_partial_args(payload)
+ if args is not None and len(args) > len(self._sent_args[i]):
+ diff = args[len(self._sent_args[i]) :]
+ self._sent_args[i] = args
+ deltas.append(
+ DeltaToolCall(
+ index=i,
+ function=DeltaFunctionCall(arguments=diff),
+ )
+ )
+
+ if is_complete and i not in self._finalized_indices and self._sent_name[i]:
+ self._finalized_indices.add(i)
+ self._finalized_calls.append(
+ ToolCall(
+ id=self._sent_id[i],
+ type="function",
+ function=FunctionCall(
+ name=self._extract_committed_name(i),
+ arguments=self._sent_args[i],
+ ),
+ )
+ )
+ return deltas
+
+ def _ensure_slot(self, i: int) -> None:
+ while len(self._sent_name) <= i:
+ self._sent_name.append(False)
+ self._sent_id.append("")
+ self._sent_args.append("")
+
+ def _extract_committed_name(self, i: int) -> str:
+ # We only stash the bool that the name was sent, not the value, so
+ # re-derive from the current payload (cheap, the regex is small).
+ for j, (payload, _) in enumerate(self._iter_tool_call_blocks(self._last_text)):
+ if j == i:
+ name = self._parser.extract_partial_name(payload)
+ return name or ""
+ return ""
+
+ def _iter_tool_call_blocks(self, text: str):
+ """Yield ``(partial_payload, is_complete)`` for each tool-call region in order.
+
+ For the still-open final block, the partial payload has any tail
+ suffix that could be the start of ``end_marker`` withheld, so the
+ client never sees a fragment of the closing tag forwarded as
+ argument bytes.
+ """
+ pos = 0
+ while True:
+ start = text.find(self._start, pos)
+ if start < 0:
+ return
+ payload_start = start + len(self._start)
+ end = text.find(self._end, payload_start)
+ if end < 0:
+ partial = text[payload_start:]
+ safe = _safe_outside_flush_index(partial, self._end)
+ yield partial[:safe], False
+ return
+ yield text[payload_start:end], True
+ pos = end + len(self._end)
+
+
+def _safe_outside_flush_index(buf: str, start_marker: str) -> int:
+ """Index up to which ``buf`` can be flushed without risking a split marker.
+
+ ``buf`` is known not to contain the full marker. The unsafe tail is the
+ longest proper-prefix overlap between ``buf`` and ``start_marker``: if
+ the next chunk completes that prefix, we'd have to retract bytes we
+ already streamed. Holding them back avoids the retraction.
+ """
+ max_overlap = min(len(buf), len(start_marker) - 1)
+ for k in range(max_overlap, 0, -1):
+ if buf.endswith(start_marker[:k]):
+ return len(buf) - k
+ return len(buf)
diff --git a/modelship/openai/tool_calling/parsers/hermes.py b/modelship/openai/tool_calling/parsers/hermes.py
new file mode 100644
index 0000000..97487db
--- /dev/null
+++ b/modelship/openai/tool_calling/parsers/hermes.py
@@ -0,0 +1,44 @@
+"""Hermes-style ``{json}`` parser.
+
+Used by Hermes-2-Pro, Qwen2.5-Instruct, and a large family of NousResearch /
+community fine-tunes whose chat templates wrap each tool call in the literal
+tags ```` / ```` around a JSON object of the shape
+``{"name": "...", "arguments": {...}}``.
+"""
+
+from __future__ import annotations
+
+import re
+
+from modelship.openai.tool_calling.parsers.base import ToolCallParser
+
+
+class HermesToolCallParser(ToolCallParser):
+ name = "hermes"
+ start_marker = ""
+ end_marker = ""
+
+ _NAME_RE = re.compile(r'"name"\s*:\s*"([^"]+)"')
+ _ARGS_RE = re.compile(r'"arguments"\s*:\s*')
+
+ def extract_partial_name(self, partial_payload: str) -> str | None:
+ m = self._NAME_RE.search(partial_payload)
+ return m.group(1) if m else None
+
+ def extract_partial_args(self, partial_payload: str) -> str | None:
+ m = self._ARGS_RE.search(partial_payload)
+ if m is None:
+ return None
+ args = partial_payload[m.end() :].rstrip()
+ if args.endswith("}"):
+ # The block envelope is `{"name":"x","arguments":}`. The
+ # closing brace of the envelope arrives in the byte stream before
+ # `` does, so we cannot tell whether any given
+ # trailing `}` belongs to the args object or to the envelope.
+ # Withholding one trailing `}` keeps the args stream well-formed:
+ # if the model goes on to emit more args bytes, the held brace is
+ # recovered on the next pass; if instead it goes on to emit
+ # ``, the held brace was the envelope closer and
+ # discarding it was correct.
+ args = args[:-1].rstrip()
+ return args or None
diff --git a/modelship/openai/tool_calling/registry.py b/modelship/openai/tool_calling/registry.py
new file mode 100644
index 0000000..68b8e28
--- /dev/null
+++ b/modelship/openai/tool_calling/registry.py
@@ -0,0 +1,34 @@
+"""Registry of named tool-call parsers, dispatched by configuration string.
+
+The registry is the single seam between a loader and the per-family parsers.
+Loaders that emit raw text (Transformers, plugin-wrapped engines) look up a
+parser by the name configured on the deployment and feed it the model's
+output. Loaders with native tool-call support (vLLM, llama.cpp via a
+function-calling chat handler) bypass the registry entirely.
+"""
+
+from __future__ import annotations
+
+from modelship.openai.tool_calling.parsers import HermesToolCallParser, ToolCallParser
+
+_PARSERS: dict[str, ToolCallParser] = {
+ HermesToolCallParser.name: HermesToolCallParser(),
+}
+
+
+def get_parser(name: str) -> ToolCallParser:
+ """Return the parser registered under ``name`` or raise ``ValueError``."""
+ try:
+ return _PARSERS[name]
+ except KeyError:
+ available = ", ".join(sorted(_PARSERS)) or "(none)"
+ raise ValueError(f"unknown tool_call_parser {name!r}; available: {available}") from None
+
+
+def available_parsers() -> list[str]:
+ return sorted(_PARSERS)
+
+
+def register_parser(parser: ToolCallParser) -> None:
+ """Register an additional parser. Intended for tests and plugin code."""
+ _PARSERS[parser.name] = parser
diff --git a/tests/test_integration.py b/tests/test_integration.py
index 47f276c..dc317bc 100644
--- a/tests/test_integration.py
+++ b/tests/test_integration.py
@@ -1,3 +1,4 @@
+import json
import subprocess
import time
@@ -9,7 +10,14 @@
OPENAI_API_BASE = "http://localhost:8000/v1"
-EXPECTED_MODELS = {"chat-capable", "chat-limited", "embed-model", "stt-model", "tts-model"}
+EXPECTED_MODELS = {
+ "chat-capable",
+ "chat-limited",
+ "chat-transformers",
+ "embed-model",
+ "stt-model",
+ "tts-model",
+}
@pytest.fixture(scope="session")
@@ -45,6 +53,20 @@ def mship_cluster(tmp_path_factory):
"loader": "llama_cpp",
"num_cpus": 1,
},
+ {
+ # Same Qwen2.5-Instruct family as `chat-capable` so we exercise
+ # the transformers loader against a model trained to emit
+ # Hermes-style `{...}` markers.
+ "name": "chat-transformers",
+ "model": "Qwen/Qwen2.5-0.5B-Instruct",
+ "usecase": "generate",
+ "loader": "transformers",
+ "num_cpus": 2,
+ "transformers_config": {
+ "device": "cpu",
+ "torch_dtype": "float32",
+ },
+ },
{
"name": "embed-model",
"model": "nomic-ai/nomic-embed-text-v1.5",
@@ -140,6 +162,7 @@ def test_list_models(client):
model_ids = [m.id for m in models.data]
assert "chat-capable" in model_ids
assert "chat-limited" in model_ids
+ assert "chat-transformers" in model_ids
assert "embed-model" in model_ids
assert "stt-model" in model_ids
assert "tts-model" in model_ids
@@ -191,6 +214,182 @@ def test_tool_calling_success(client):
assert completion.choices[0].message.tool_calls[0].function.name == "get_weather"
+@pytest.mark.integration
+def test_tool_calling_transformers_loader(client):
+ """Round-trip a Hermes-style tool call through the transformers loader.
+
+ Uses the same Qwen2.5-0.5B-Instruct weights as the vLLM `chat-capable`
+ deployment but goes through the modelship-side tool-calling toolkit
+ (apply_chat_template(tools=...) on input, hermes parser on output).
+ """
+ tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "description": "Get weather for a city",
+ "parameters": {
+ "type": "object",
+ "properties": {"city": {"type": "string"}},
+ "required": ["city"],
+ },
+ },
+ }
+ ]
+ completion = client.chat.completions.create(
+ model="chat-transformers",
+ messages=[{"role": "user", "content": "What is the weather in Paris?"}],
+ tools=tools,
+ tool_choice="auto",
+ max_tokens=128,
+ )
+ tool_calls = completion.choices[0].message.tool_calls
+ assert tool_calls, f"expected a tool call, got content={completion.choices[0].message.content!r}"
+ assert tool_calls[0].function.name == "get_weather"
+ assert "Paris" in tool_calls[0].function.arguments
+ assert completion.choices[0].finish_reason == "tool_calls"
+
+
+def _collect_streaming_tool_call(stream) -> dict:
+ """Drain an OpenAI streaming response and rebuild the assistant message.
+
+ Returns a dict with: ``content`` (concatenated content deltas),
+ ``tool_calls`` (per-index dict of ``{id, name, arguments}`` — arguments
+ concatenated across all fragments), ``finish_reason``, ``name_deltas``
+ and ``args_deltas`` (counts, used to assert that streaming was actually
+ incremental rather than a single buffered emission).
+ """
+ content_parts: list[str] = []
+ tool_calls: dict[int, dict] = {}
+ finish_reason: str | None = None
+ name_deltas = 0
+ args_deltas = 0
+ chunks_with_tool_calls = 0
+
+ for chunk in stream:
+ choice = chunk.choices[0]
+ delta = choice.delta
+ if delta.content:
+ content_parts.append(delta.content)
+ if delta.tool_calls:
+ chunks_with_tool_calls += 1
+ for tc in delta.tool_calls:
+ slot = tool_calls.setdefault(tc.index, {"id": None, "name": None, "arguments": ""})
+ if tc.id is not None:
+ slot["id"] = tc.id
+ if tc.function and tc.function.name:
+ slot["name"] = tc.function.name
+ name_deltas += 1
+ if tc.function and tc.function.arguments:
+ slot["arguments"] += tc.function.arguments
+ args_deltas += 1
+ if choice.finish_reason is not None:
+ finish_reason = choice.finish_reason
+
+ return {
+ "content": "".join(content_parts),
+ "tool_calls": tool_calls,
+ "finish_reason": finish_reason,
+ "name_deltas": name_deltas,
+ "args_deltas": args_deltas,
+ "chunks_with_tool_calls": chunks_with_tool_calls,
+ }
+
+
+@pytest.mark.integration
+def test_tool_calling_streaming_transformers_loader(client):
+ """Stream a tool call through the transformers loader and verify the
+ delta sequence matches the OpenAI streaming contract.
+
+ Asserts:
+ - the function name arrives in exactly one delta;
+ - arguments arrive across multiple deltas (incremental, not buffered);
+ - concatenated arguments form valid JSON containing the expected key;
+ - the final delta carries ``finish_reason="tool_calls"``.
+ """
+ tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "description": "Get weather for a city",
+ "parameters": {
+ "type": "object",
+ "properties": {"city": {"type": "string"}},
+ "required": ["city"],
+ },
+ },
+ }
+ ]
+ stream = client.chat.completions.create(
+ model="chat-transformers",
+ messages=[{"role": "user", "content": "What is the weather in Paris?"}],
+ tools=tools,
+ tool_choice="auto",
+ max_tokens=128,
+ stream=True,
+ )
+
+ collected = _collect_streaming_tool_call(stream)
+
+ assert collected["tool_calls"], f"expected at least one streamed tool call; got content={collected['content']!r}"
+ call_0 = collected["tool_calls"][0]
+ assert call_0["id"], "expected an id on the first tool-call delta"
+ assert call_0["name"] == "get_weather"
+ # Name must be sent exactly once (not on every delta).
+ assert collected["name_deltas"] == 1, f"expected one name delta, got {collected['name_deltas']}"
+ # Arguments must arrive incrementally across multiple deltas — that's the
+ # whole point of switching from block-level buffering to vLLM-style
+ # diff streaming. Exact count depends on the model, but it must be > 1.
+ assert collected["args_deltas"] >= 2, (
+ f"expected arguments to stream incrementally, got {collected['args_deltas']} args delta(s)"
+ )
+ # Concatenated args must form valid JSON containing the city.
+ parsed_args = json.loads(call_0["arguments"])
+ assert parsed_args.get("city")
+ assert "Paris" in parsed_args["city"]
+ assert collected["finish_reason"] == "tool_calls"
+
+
+@pytest.mark.integration
+def test_tool_calling_streaming_vllm_loader(client):
+ """Smoke-test that vLLM streaming + tool calling still works through
+ the gateway. vLLM emits its own per-token deltas; we only verify that
+ the gateway forwards them and that the final assistant message rebuilds
+ correctly.
+ """
+ tools = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "description": "Get weather for a city",
+ "parameters": {
+ "type": "object",
+ "properties": {"city": {"type": "string"}},
+ "required": ["city"],
+ },
+ },
+ }
+ ]
+ stream = client.chat.completions.create(
+ model="chat-capable",
+ messages=[{"role": "user", "content": "What is the weather in Paris?"}],
+ tools=tools,
+ tool_choice="required",
+ stream=True,
+ )
+
+ collected = _collect_streaming_tool_call(stream)
+
+ assert collected["tool_calls"], "vLLM should have streamed at least one tool call"
+ call_0 = collected["tool_calls"][0]
+ assert call_0["name"] == "get_weather"
+ parsed_args = json.loads(call_0["arguments"])
+ assert "Paris" in parsed_args.get("city", "")
+ assert collected["finish_reason"] == "tool_calls"
+
+
@pytest.mark.integration
def test_tool_calling_unsupported_loader(client):
"""Verifies that loaders without tool support (like llama_cpp) don't return tool calls."""
diff --git a/tests/test_tool_calling.py b/tests/test_tool_calling.py
new file mode 100644
index 0000000..98f50b1
--- /dev/null
+++ b/tests/test_tool_calling.py
@@ -0,0 +1,286 @@
+"""Tests for the cross-loader tool-calling toolkit."""
+
+from __future__ import annotations
+
+import json
+from typing import ClassVar
+
+import pytest
+
+from modelship.openai.tool_calling import (
+ ToolCallStreamer,
+ available_parsers,
+ get_parser,
+ register_parser,
+ resolve_tools_for_request,
+)
+from modelship.openai.tool_calling.parsers import HermesToolCallParser, ToolCallParser
+
+
+class TestRegistry:
+ def test_default_registry_includes_hermes(self):
+ assert "hermes" in available_parsers()
+
+ def test_get_parser_returns_singleton(self):
+ a = get_parser("hermes")
+ b = get_parser("hermes")
+ assert a is b
+
+ def test_unknown_parser_raises_with_available_list(self):
+ with pytest.raises(ValueError, match="hermes"):
+ get_parser("does-not-exist")
+
+ def test_register_parser_makes_it_findable(self):
+ class Stub(ToolCallParser):
+ name = "stub-test-parser"
+ start_marker = "<<"
+ end_marker = ">>"
+
+ def extract_partial_name(self, partial_payload: str) -> str | None:
+ return None
+
+ def extract_partial_args(self, partial_payload: str) -> str | None:
+ return None
+
+ register_parser(Stub())
+ try:
+ assert get_parser("stub-test-parser").name == "stub-test-parser"
+ finally:
+ # Clean up so other tests don't see the stub.
+ from modelship.openai.tool_calling import registry
+
+ registry._PARSERS.pop("stub-test-parser", None)
+
+
+class TestHermesParser:
+ parser = HermesToolCallParser()
+
+ def test_no_tool_calls_returns_text_unchanged(self):
+ result = self.parser.parse("just a regular response")
+ assert result.tool_calls == []
+ assert result.content == "just a regular response"
+ assert result.has_tool_calls is False
+
+ def test_single_tool_call(self):
+ text = '{"name": "get_weather", "arguments": {"city": "Paris"}}'
+ result = self.parser.parse(text)
+ assert len(result.tool_calls) == 1
+ assert result.tool_calls[0].function.name == "get_weather"
+ assert json.loads(result.tool_calls[0].function.arguments) == {"city": "Paris"}
+ assert result.content is None
+
+ def test_multiple_tool_calls(self):
+ text = (
+ '{"name": "a", "arguments": {"x": 1}}'
+ '{"name": "b", "arguments": {"y": 2}}'
+ )
+ result = self.parser.parse(text)
+ assert [tc.function.name for tc in result.tool_calls] == ["a", "b"]
+
+ def test_tool_call_with_residual_text(self):
+ text = 'Sure, calling that.\n{"name": "ping", "arguments": {}}'
+ result = self.parser.parse(text)
+ assert len(result.tool_calls) == 1
+ assert result.content == "Sure, calling that."
+
+ def test_string_arguments_forwarded_verbatim(self):
+ # vLLM-style: the streamer forwards the raw bytes of the arguments
+ # value as the model emitted them, including any surrounding quotes
+ # if the model wrapped its arguments in a JSON string literal. The
+ # OpenAI streaming contract treats `arguments` as an opaque string
+ # the client concatenates and parses.
+ text = '{"name": "x", "arguments": "{\\"a\\": 1}"}'
+ result = self.parser.parse(text)
+ assert result.tool_calls[0].function.arguments == '"{\\"a\\": 1}"'
+
+ def test_object_arguments_passed_through(self):
+ text = '{"name": "x", "arguments": {"a": 1, "b": [2, 3]}}'
+ result = self.parser.parse(text)
+ assert json.loads(result.tool_calls[0].function.arguments) == {"a": 1, "b": [2, 3]}
+
+ def test_block_without_extractable_name_is_dropped(self):
+ # When the block contains nothing the name regex can hook onto we
+ # silently drop it — there is nothing to tell the client about.
+ text = "{not valid json}"
+ result = self.parser.parse(text)
+ assert result.tool_calls == []
+
+ def test_missing_name_drops_call(self):
+ text = '{"arguments": {}}'
+ result = self.parser.parse(text)
+ assert result.tool_calls == []
+
+ def test_empty_name_drops_call(self):
+ text = '{"name": "", "arguments": {}}'
+ result = self.parser.parse(text)
+ assert result.tool_calls == []
+
+ def test_each_tool_call_gets_unique_id(self):
+ text = (
+ '{"name": "a", "arguments": {}}{"name": "b", "arguments": {}}'
+ )
+ result = self.parser.parse(text)
+ assert result.tool_calls[0].id != result.tool_calls[1].id
+
+
+class TestToolCallStreamer:
+ """Drive the streamer one chunk at a time and verify the deltas.
+
+ The cumulative-text protocol matches what serving_chat does in production:
+ every fed string is the *full* generated text so far, not just the latest
+ delta.
+ """
+
+ def _feed(self, chunks: list[str]) -> tuple[ToolCallStreamer, list]:
+ streamer = ToolCallStreamer(HermesToolCallParser())
+ deltas = []
+ cumulative = ""
+ for chunk in chunks:
+ cumulative += chunk
+ d = streamer.extract_streaming(cumulative)
+ if d is not None:
+ deltas.append(d)
+ final = streamer.finalize()
+ if final is not None:
+ deltas.append(final)
+ return streamer, deltas
+
+ def test_pure_content_streams_immediately(self):
+ _, deltas = self._feed(["Hello", " ", "world"])
+ assert "".join(d.content or "" for d in deltas) == "Hello world"
+ assert all(not d.tool_calls for d in deltas)
+
+ def test_holds_back_marker_prefix_in_content_until_finalize(self):
+ # `<` could be the first char of ``; the streamer must not
+ # ship it mid-stream. Once finalize() runs (no more text coming) the
+ # held tail is safe to flush as content.
+ streamer = ToolCallStreamer(HermesToolCallParser())
+ mid = streamer.extract_streaming("before <")
+ assert mid is not None and mid.content == "before "
+ final = streamer.finalize()
+ assert final is not None and final.content == "<"
+
+ def test_held_tail_flushes_when_disambiguated(self):
+ # `{"name": "get_weather", "arguments": {"city": "Paris"}}')
+ _, deltas = self._feed(chunks)
+
+ tool_deltas = [tc for d in deltas for tc in d.tool_calls]
+ # First tool delta: name + id, no arguments.
+ assert tool_deltas[0].function is not None
+ assert tool_deltas[0].function.name == "get_weather"
+ assert tool_deltas[0].id is not None
+ assert tool_deltas[0].function.arguments is None
+ # Subsequent deltas carry arguments fragments only (no name).
+ arg_deltas = tool_deltas[1:]
+ assert all(d.function and d.function.name is None for d in arg_deltas)
+ # Concatenated arguments form valid JSON.
+ joined_args = "".join(d.function.arguments or "" for d in arg_deltas)
+ assert json.loads(joined_args) == {"city": "Paris"}
+
+ def test_arguments_stream_incrementally(self):
+ # Feed the args char-by-char; each char-after-name should generate
+ # an args delta of length 1 (or close to it).
+ prefix = '{"name": "ping", "arguments": '
+ suffix = '{"x": 42}}'
+ _, deltas = self._feed([prefix, *list(suffix)])
+
+ arg_deltas = [tc for d in deltas for tc in d.tool_calls if tc.function and tc.function.arguments is not None]
+ assert len(arg_deltas) >= 3 # incremental, not one big shot
+ joined = "".join(d.function.arguments or "" for d in arg_deltas)
+ assert json.loads(joined) == {"x": 42}
+
+ def test_multiple_tool_calls_get_distinct_indices(self):
+ text = (
+ '{"name": "a", "arguments": {"x": 1}}'
+ '{"name": "b", "arguments": {"y": 2}}'
+ )
+ _, deltas = self._feed([text])
+
+ tool_deltas = [tc for d in deltas for tc in d.tool_calls]
+ indices = {d.index for d in tool_deltas}
+ assert indices == {0, 1}
+ names = [d.function.name for d in tool_deltas if d.function and d.function.name]
+ assert names == ["a", "b"]
+
+ def test_content_after_tool_call_resumes_streaming(self):
+ text = '{"name": "p", "arguments": {}} ok'
+ _, deltas = self._feed([text])
+ joined_content = "".join(d.content or "" for d in deltas)
+ assert "ok" in joined_content
+
+ def test_partial_name_held_until_closing_quote(self):
+ # While the model is mid-name (``"name": "get_wea`` so far), the
+ # streamer must NOT send a partial name — wait for the closing quote.
+ streamer = ToolCallStreamer(HermesToolCallParser())
+ partial = '{"name": "get_wea'
+ d = streamer.extract_streaming(partial)
+ # No name yet → no tool delta.
+ assert d is None or all(not tc.function or not tc.function.name for tc in d.tool_calls)
+
+ def test_unterminated_block_is_finalized_without_crash(self):
+ # Model stops mid tool-call. The streamer should not raise on finalize
+ # and should not claim a finalized ToolCall (the call wasn't closed).
+ text = '{"name": "incomplete", "arguments": {"a": 1'
+ streamer, _ = self._feed([text])
+ assert streamer.result.tool_calls == []
+
+ def test_skipped_malformed_block_does_not_block_subsequent_finalization(self):
+ # If the first tool-call block is malformed (missing name), it should be
+ # skipped, but the SECOND block (if valid) should still be finalized.
+ text = (
+ '{"arguments": {"skipped": true}}'
+ '{"name": "valid", "arguments": {"x": 1}}'
+ )
+ streamer, _ = self._feed([text])
+ assert [tc.function.name for tc in streamer.result.tool_calls] == ["valid"]
+
+
+class TestResolveToolsForRequest:
+ tools: ClassVar = [
+ {"type": "function", "function": {"name": "alpha"}},
+ {"type": "function", "function": {"name": "beta"}},
+ ]
+
+ def test_no_tools_returns_none(self):
+ assert resolve_tools_for_request(None, "auto") is None
+ assert resolve_tools_for_request([], "auto") is None
+
+ def test_auto_passes_through(self):
+ assert resolve_tools_for_request(self.tools, "auto") == self.tools
+
+ def test_unset_tool_choice_passes_through(self):
+ assert resolve_tools_for_request(self.tools, None) == self.tools
+
+ def test_none_suppresses_tools(self):
+ assert resolve_tools_for_request(self.tools, "none") is None
+
+ def test_required_passes_through(self):
+ # We cannot strictly enforce a tool call without constrained decoding,
+ # so "required" downgrades to "auto" semantics (with a logged warning).
+ assert resolve_tools_for_request(self.tools, "required") == self.tools
+
+ def test_specific_function_filters_to_that_tool(self):
+ result = resolve_tools_for_request(self.tools, {"type": "function", "function": {"name": "beta"}})
+ assert result is not None
+ assert len(result) == 1
+ assert result[0]["function"]["name"] == "beta"
+
+ def test_unknown_function_falls_back_to_all(self):
+ # If the named function isn't in the tools list, fall back to passing
+ # them all through rather than emitting an empty list.
+ result = resolve_tools_for_request(self.tools, {"type": "function", "function": {"name": "missing"}})
+ assert result == self.tools
+
+ def test_unrecognized_choice_falls_back_to_all(self):
+ result = resolve_tools_for_request(self.tools, "weird-mode")
+ assert result == self.tools
diff --git a/tests/test_transformers_chat_tools.py b/tests/test_transformers_chat_tools.py
new file mode 100644
index 0000000..003a909
--- /dev/null
+++ b/tests/test_transformers_chat_tools.py
@@ -0,0 +1,163 @@
+"""Tests for the Transformers chat path's tool-call handling.
+
+These tests bypass the real HF pipeline by injecting a callable that returns
+a canned generation, so they run offline and do not touch any model weights.
+"""
+
+from __future__ import annotations
+
+import json
+from typing import Any
+
+import pytest
+
+from modelship.infer.infer_config import RawRequestProxy, TransformersConfig
+from modelship.infer.transformers.capabilities import TransformersCapabilities
+from modelship.infer.transformers.openai.serving_chat import OpenAIServingChat
+from modelship.openai.protocol import ChatCompletionRequest, ChatCompletionResponse
+
+
+class _FakeTokenizer:
+ def encode(self, text: str, add_special_tokens: bool = False) -> list[int]:
+ return [0] * len(text.split())
+
+ def apply_chat_template(self, messages: list[dict], **kwargs: Any) -> Any:
+ prompt = "\n".join(f"{m['role']}: {m.get('content', '')}" for m in messages)
+ if "tools" in kwargs:
+ prompt = f"[TOOLS:{len(kwargs['tools'])}]\n" + prompt
+ if kwargs.get("tokenize"):
+ return [0] * len(prompt.split())
+ return prompt
+
+
+class _FakePipeline:
+ """Stand-in for ``transformers.Pipeline`` that records calls and replays canned output."""
+
+ def __init__(self, generated_text: str):
+ self.tokenizer = _FakeTokenizer()
+ self.task = "text-generation"
+ self.generated_text = generated_text
+ self.last_input: Any = None
+ self.last_kwargs: dict[str, Any] = {}
+
+ def __call__(self, inputs: Any, **kwargs: Any) -> list[dict]:
+ self.last_input = inputs
+ self.last_kwargs = kwargs
+ return [{"generated_text": self.generated_text}]
+
+
+def _make_serving(generated: str) -> tuple[OpenAIServingChat, _FakePipeline]:
+ pipe = _FakePipeline(generated)
+ serving = OpenAIServingChat(
+ pipeline=pipe, # type: ignore[arg-type]
+ model_name="test-model",
+ config=TransformersConfig(),
+ capabilities=TransformersCapabilities(supports_image=False, supports_audio=False),
+ )
+ return serving, pipe
+
+
+def _raw_request() -> RawRequestProxy:
+ return RawRequestProxy(None, {})
+
+
+@pytest.mark.asyncio
+async def test_response_without_tools_carries_content_only():
+ serving, _ = _make_serving("hello there")
+ req = ChatCompletionRequest(messages=[{"role": "user", "content": "hi"}], stream=False)
+ resp = await serving.create_chat_completion(req, _raw_request())
+
+ assert isinstance(resp, ChatCompletionResponse)
+ msg = resp.choices[0].message
+ assert msg.content == "hello there"
+ assert msg.tool_calls == []
+ assert resp.choices[0].finish_reason == "stop"
+
+
+@pytest.mark.asyncio
+async def test_tools_only_render_when_requested():
+ serving, pipe = _make_serving("hello")
+ req = ChatCompletionRequest(messages=[{"role": "user", "content": "hi"}], stream=False)
+ await serving.create_chat_completion(req, _raw_request())
+
+ # Without `tools` in the request, the pipeline receives the message list
+ # directly — no pre-rendered prompt.
+ assert isinstance(pipe.last_input, list)
+
+
+@pytest.mark.asyncio
+async def test_tools_in_request_pre_renders_prompt_and_parses_tool_call():
+ raw = '{"name": "get_weather", "arguments": {"city": "Paris"}}'
+ serving, pipe = _make_serving(raw)
+ req = ChatCompletionRequest(
+ messages=[{"role": "user", "content": "weather in paris?"}],
+ tools=[
+ {
+ "type": "function",
+ "function": {"name": "get_weather", "parameters": {"type": "object"}},
+ }
+ ],
+ tool_choice="auto",
+ stream=False,
+ )
+ resp = await serving.create_chat_completion(req, _raw_request())
+
+ assert isinstance(resp, ChatCompletionResponse)
+ # Pre-rendered prompt is a string carrying the tool marker injected by our fake template.
+ assert isinstance(pipe.last_input, str)
+ assert pipe.last_input.startswith("[TOOLS:1]")
+
+ msg = resp.choices[0].message
+ assert msg.content is None
+ assert len(msg.tool_calls) == 1
+ assert msg.tool_calls[0].function.name == "get_weather"
+ assert json.loads(msg.tool_calls[0].function.arguments) == {"city": "Paris"}
+ assert resp.choices[0].finish_reason == "tool_calls"
+
+
+@pytest.mark.asyncio
+async def test_tool_choice_none_skips_tool_rendering():
+ serving, pipe = _make_serving("regular reply")
+ req = ChatCompletionRequest(
+ messages=[{"role": "user", "content": "hi"}],
+ tools=[{"type": "function", "function": {"name": "noop"}}],
+ tool_choice="none",
+ stream=False,
+ )
+ resp = await serving.create_chat_completion(req, _raw_request())
+
+ # tool_choice="none" — pipeline should receive messages, not a rendered prompt.
+ assert isinstance(pipe.last_input, list)
+ assert isinstance(resp, ChatCompletionResponse)
+ assert resp.choices[0].message.content == "regular reply"
+ assert resp.choices[0].message.tool_calls == []
+ assert resp.choices[0].finish_reason == "stop"
+
+
+@pytest.mark.asyncio
+async def test_tool_call_with_trailing_text_preserves_content():
+ raw = 'Calling now.\n{"name": "ping", "arguments": {}}'
+ serving, _ = _make_serving(raw)
+ req = ChatCompletionRequest(
+ messages=[{"role": "user", "content": "ping?"}],
+ tools=[{"type": "function", "function": {"name": "ping"}}],
+ stream=False,
+ )
+ resp = await serving.create_chat_completion(req, _raw_request())
+
+ assert isinstance(resp, ChatCompletionResponse)
+ msg = resp.choices[0].message
+ assert msg.content == "Calling now."
+ assert len(msg.tool_calls) == 1
+
+
+@pytest.mark.asyncio
+async def test_unknown_parser_at_init_raises():
+ pipe = _FakePipeline("anything")
+ with pytest.raises(ValueError, match="unknown tool_call_parser"):
+ OpenAIServingChat(
+ pipeline=pipe, # type: ignore[arg-type]
+ model_name="test-model",
+ config=TransformersConfig(tool_call_parser="not-a-real-parser"),
+ capabilities=TransformersCapabilities(supports_image=False, supports_audio=False),
+ )