diff --git a/docs/model-configuration.md b/docs/model-configuration.md index 6ef64dd..3024ee7 100644 --- a/docs/model-configuration.md +++ b/docs/model-configuration.md @@ -219,6 +219,7 @@ The `transformers` loader uses PyTorch with HuggingFace Transformers. Supports c | `trust_remote_code` | bool | `false` | Allow remote code execution | | `model_kwargs` | object | `{}` | Extra keyword arguments passed to the model constructor | | `pipeline_kwargs` | object | `{}` | Extra keyword arguments passed to the pipeline at inference time | +| `tool_call_parser` | string | `hermes` | Parser used to turn raw model output into OpenAI `tool_calls`. Currently supported: `hermes` (Hermes-2-Pro / Qwen2.5-Instruct / many community fine-tunes that emit `{...}` markers). | ### Chat / Text Generation (CPU) @@ -233,6 +234,25 @@ models: device: "cpu" ``` +### Chat with Tool Calling (CPU) + +The transformers loader renders `tools` into the prompt via the model's chat +template and parses the output back into OpenAI `tool_calls`. The model must +have been trained on a Hermes-style tool format (Qwen2.5-Instruct, Hermes-2, +many community fine-tunes); the parser is selected via `tool_call_parser`. + +```yaml +models: + - name: qwen-tools + model: Qwen/Qwen2.5-0.5B-Instruct + usecase: generate + loader: transformers + num_cpus: 2 + transformers_config: + device: "cpu" + tool_call_parser: hermes # this is the default; shown for clarity +``` + ### Speech-to-Text (CPU) Audio is automatically decoded and resampled to the model's expected sample rate (e.g. 16kHz for Whisper). diff --git a/modelship/infer/transformers/openai/serving_chat.py b/modelship/infer/transformers/openai/serving_chat.py index 35f7d26..81fd0d4 100644 --- a/modelship/infer/transformers/openai/serving_chat.py +++ b/modelship/infer/transformers/openai/serving_chat.py @@ -3,6 +3,7 @@ import time from collections.abc import AsyncGenerator from threading import Thread +from typing import Any from transformers import Pipeline, PreTrainedTokenizerBase, TextIteratorStreamer @@ -23,6 +24,7 @@ UsageInfo, create_error_response, ) +from modelship.openai.tool_calling import ParsedToolCalls, ToolCallStreamer, get_parser, resolve_tools_for_request from modelship.utils import base_request_id logger = get_logger("infer.transformers.chat") @@ -45,12 +47,16 @@ def __init__( assert pipeline.tokenizer is not None, "text-generation pipeline must have a tokenizer" self.tokenizer: PreTrainedTokenizerBase = pipeline.tokenizer self._lock = asyncio.Lock() + # Validate the configured parser at startup so misconfiguration surfaces + # before the first request rather than mid-generation. + get_parser(self.config.tool_call_parser) async def warmup(self) -> None: logger.info("Warming up chat model: %s", self.model_name) await self.run_in_executor( self._run, [{"role": "user", "content": "warmup"}], + None, 1, ) logger.info("Warmup chat done for %s", self.model_name) @@ -74,26 +80,29 @@ async def create_chat_completion( logger.warning("chat request %s rejected: %s", request_id, e) return create_error_response(e) + tools = resolve_tools_for_request(request.tools, request.tool_choice) + max_tokens = request.max_tokens if max_tokens is None and request.max_completion_tokens is not None: max_tokens = request.max_completion_tokens if request.stream: include_usage = bool(request.stream_options and request.stream_options.include_usage) - return self._locked_stream(request_id, messages, max_tokens, include_usage=include_usage) + return self._locked_stream(request_id, messages, tools, max_tokens, include_usage=include_usage) async with self._lock: try: - result = await self.run_in_executor(self._run, messages, max_tokens) + result = await self.run_in_executor(self._run, messages, tools, max_tokens) except Exception: logger.exception("chat completion inference failed for %s", request_id) return create_error_response("chat completion inference failed") - prompt_tokens = self._count_prompt_tokens(messages) - generated = result[0]["generated_text"] - completion_text = generated[-1]["content"] if isinstance(generated, list) else generated + prompt_tokens = self._count_prompt_tokens(messages, tools) + completion_text = self._extract_completion_text(result) completion_tokens = len(self.tokenizer.encode(completion_text, add_special_tokens=False)) - finish_reason = "length" if (max_tokens is not None and completion_tokens >= max_tokens) else "stop" + + parsed = self._parse_tool_calls(completion_text) if tools else ParsedToolCalls(completion_text, []) + finish_reason = self._finish_reason(parsed, completion_tokens, max_tokens) response = ChatCompletionResponse( id=request_id, @@ -101,7 +110,11 @@ async def create_chat_completion( choices=[ ChatCompletionResponseChoice( index=0, - message=ChatMessage(role="assistant", content=completion_text), + message=ChatMessage( + role="assistant", + content=parsed.content, + tool_calls=parsed.tool_calls, + ), finish_reason=finish_reason, ) ], @@ -114,54 +127,102 @@ async def create_chat_completion( ) logger.log( TRACE, - "chat response %s: text=%r, prompt_tokens=%d, completion_tokens=%d", + "chat response %s: text=%r, tool_calls=%d, prompt_tokens=%d, completion_tokens=%d", request_id, completion_text, + len(parsed.tool_calls), prompt_tokens, completion_tokens, ) return response - def _count_prompt_tokens(self, messages: list[dict]) -> int: + def _parse_tool_calls(self, text: str) -> ParsedToolCalls: + return get_parser(self.config.tool_call_parser).parse(text) + + @staticmethod + def _finish_reason(parsed: ParsedToolCalls, completion_tokens: int, max_tokens: int | None) -> str: + if parsed.has_tool_calls: + return "tool_calls" + if max_tokens is not None and completion_tokens >= max_tokens: + return "length" + return "stop" + + @staticmethod + def _extract_completion_text(result: list) -> str: + generated = result[0]["generated_text"] + if isinstance(generated, list): + return generated[-1]["content"] + return generated + + def _count_prompt_tokens(self, messages: list[dict], tools: list[dict[str, Any]] | None) -> int: # apply_chat_template returns a string by default (character count!) — force tokenize=True. - token_ids = self.tokenizer.apply_chat_template(messages, tokenize=True) + kwargs: dict[str, Any] = {"tokenize": True} + if tools: + kwargs["tools"] = tools + token_ids = self.tokenizer.apply_chat_template(messages, **kwargs) return len(token_ids) - def _run(self, messages: list[dict], max_tokens: int | None) -> list: + def _render_prompt(self, messages: list[dict], tools: list[dict[str, Any]]) -> str: + rendered = self.tokenizer.apply_chat_template( + messages, + tools=tools, # type: ignore[arg-type] + tokenize=False, + add_generation_prompt=True, + ) + assert isinstance(rendered, str), "apply_chat_template(tokenize=False) must return str" + return rendered + + def _run(self, messages: list[dict], tools: list[dict[str, Any]] | None, max_tokens: int | None) -> list: kwargs = {**self.config.pipeline_kwargs} if max_tokens is not None: kwargs["max_new_tokens"] = max_tokens + if tools: + # The standard text-generation pipeline does not forward `tools` to + # `apply_chat_template`, so we render the prompt ourselves and feed + # it as a plain string. + prompt = self._render_prompt(messages, tools) + return self.pipeline(prompt, return_full_text=False, **kwargs) # type: ignore[return-value] return self.pipeline(messages, return_full_text=False, **kwargs) # type: ignore[return-value] async def _locked_stream( self, request_id: str, messages: list[dict], + tools: list[dict[str, Any]] | None, max_tokens: int | None, *, include_usage: bool, ) -> AsyncGenerator[str, None]: async with self._lock: - async for chunk in self._stream(request_id, messages, max_tokens, include_usage=include_usage): + async for chunk in self._stream(request_id, messages, tools, max_tokens, include_usage=include_usage): yield chunk async def _stream( self, request_id: str, messages: list[dict], + tools: list[dict[str, Any]] | None, max_tokens: int | None, *, include_usage: bool, ) -> AsyncGenerator[str, None]: + created_time = int(time.time()) streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) # type: ignore[arg-type] kwargs = {**self.config.pipeline_kwargs} if max_tokens is not None: kwargs["max_new_tokens"] = max_tokens + if tools: + prompt: Any = self._render_prompt(messages, tools) + tool_call_streamer: ToolCallStreamer | None = ToolCallStreamer(get_parser(self.config.tool_call_parser)) + else: + prompt = messages + tool_call_streamer = None + thread = Thread( target=self.pipeline, - args=(messages,), + args=(prompt,), kwargs={ "streamer": streamer, "return_full_text": False, @@ -171,44 +232,38 @@ async def _stream( thread.start() accumulated: list[str] = [] + accumulated_str = "" try: # Per OpenAI spec, the first delta carries `role` only. - yield self._encode_chunk( - ChatCompletionStreamResponse( - id=request_id, - model=self.model_name, - choices=[ - ChatCompletionResponseStreamChoice( - index=0, - delta=DeltaMessage(role="assistant"), - ) - ], - created=int(time.time()), - ) - ) + yield self._delta_chunk(request_id, DeltaMessage(role="assistant"), created_time) for text_chunk in streamer: if not text_chunk: continue accumulated.append(text_chunk) - yield self._encode_chunk( - ChatCompletionStreamResponse( - id=request_id, - model=self.model_name, - choices=[ - ChatCompletionResponseStreamChoice( - index=0, - delta=DeltaMessage(content=text_chunk), - ) - ], - created=int(time.time()), - ) - ) + accumulated_str += text_chunk + if tool_call_streamer is None: + yield self._delta_chunk(request_id, DeltaMessage(content=text_chunk), created_time) + await asyncio.sleep(0) + continue + # Re-parse the cumulative text and emit any new content / tool-call + # fragments. Bounded held-back tail (length of the start marker) so + # the client never sees a half-formed `` opening tag. + delta = tool_call_streamer.extract_streaming(accumulated_str) + if delta is not None: + yield self._delta_chunk(request_id, delta, created_time) await asyncio.sleep(0) - completion_text = "".join(accumulated) - completion_tokens = len(self.tokenizer.encode(completion_text, add_special_tokens=False)) - finish_reason = "length" if (max_tokens is not None and completion_tokens >= max_tokens) else "stop" + if tool_call_streamer is not None: + final = tool_call_streamer.finalize() + if final is not None: + yield self._delta_chunk(request_id, final, created_time) + parsed = tool_call_streamer.result + else: + parsed = ParsedToolCalls(accumulated_str, []) + + completion_tokens = len(self.tokenizer.encode(accumulated_str, add_special_tokens=False)) + finish_reason = self._finish_reason(parsed, completion_tokens, max_tokens) yield self._encode_chunk( ChatCompletionStreamResponse( @@ -221,12 +276,12 @@ async def _stream( finish_reason=finish_reason, ) ], - created=int(time.time()), + created=created_time, ) ) if include_usage: - prompt_tokens = self._count_prompt_tokens(messages) + prompt_tokens = self._count_prompt_tokens(messages, tools) yield self._encode_chunk( ChatCompletionStreamResponse( id=request_id, @@ -237,7 +292,7 @@ async def _stream( completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, ), - created=int(time.time()), + created=created_time, ) ) @@ -245,6 +300,18 @@ async def _stream( finally: thread.join() + def _delta_chunk(self, request_id: str, delta: DeltaMessage, created_time: int) -> str: + return self._encode_chunk( + ChatCompletionStreamResponse( + id=request_id, + model=self.model_name, + choices=[ + ChatCompletionResponseStreamChoice(index=0, delta=delta), + ], + created=created_time, + ) + ) + @staticmethod def _encode_chunk(chunk: ChatCompletionStreamResponse) -> str: return f"data: {json.dumps(chunk.model_dump(mode='json'))}\n\n" diff --git a/modelship/openai/protocol.py b/modelship/openai/protocol.py index d09afe7..d4142b1 100644 --- a/modelship/openai/protocol.py +++ b/modelship/openai/protocol.py @@ -180,7 +180,7 @@ class ChatCompletionRequest(OpenAIBaseModel): temperature: float | None = None top_p: float | None = None tools: list[dict[str, Any]] | None = None - tool_choice: str | dict[str, Any] | None = "none" + tool_choice: str | dict[str, Any] | None = None reasoning_effort: Literal["none", "low", "medium", "high"] | None = None parallel_tool_calls: bool | None = True user: str | None = None diff --git a/modelship/openai/tool_calling/__init__.py b/modelship/openai/tool_calling/__init__.py new file mode 100644 index 0000000..2e448ec --- /dev/null +++ b/modelship/openai/tool_calling/__init__.py @@ -0,0 +1,22 @@ +"""Cross-loader tool-calling toolkit. + +Loaders without native tool-call support (Transformers today, plugin-wrapped +raw-text engines tomorrow) use the parsers and helpers in this package to +turn raw model output into OpenAI-shape ``tool_calls``. Loaders whose engines +already emit structured tool calls (vLLM, llama.cpp via a function-calling +chat handler) bypass it. +""" + +from modelship.openai.tool_calling.input import resolve_tools_for_request +from modelship.openai.tool_calling.parsers import ParsedToolCalls, ToolCallParser, ToolCallStreamer +from modelship.openai.tool_calling.registry import available_parsers, get_parser, register_parser + +__all__ = [ + "ParsedToolCalls", + "ToolCallParser", + "ToolCallStreamer", + "available_parsers", + "get_parser", + "register_parser", + "resolve_tools_for_request", +] diff --git a/modelship/openai/tool_calling/input.py b/modelship/openai/tool_calling/input.py new file mode 100644 index 0000000..a245c0c --- /dev/null +++ b/modelship/openai/tool_calling/input.py @@ -0,0 +1,63 @@ +"""Input-side helpers for tool calling. + +Loaders that hand a chat template a list of OpenAI messages plus a list of +tool schemas use these helpers to interpret the request's ``tool_choice`` +(``"none"`` suppresses tools, ``"required"`` / specific-function downgrade +to ``"auto"`` with a warning) and to validate that a parser exists for the +configured family before generation starts. +""" + +from __future__ import annotations + +from typing import Any + +from modelship.logging import get_logger + +logger = get_logger("openai.tool_calling.input") + + +def resolve_tools_for_request( + tools: list[dict[str, Any]] | None, + tool_choice: str | dict[str, Any] | None, +) -> list[dict[str, Any]] | None: + """Apply OpenAI ``tool_choice`` semantics to the request's ``tools`` list. + + Returns the list of tools to render into the prompt, or ``None`` when + the request should be served without any tool-calling affordance. + + - ``tool_choice == "none"`` — suppress tools entirely. + - ``tool_choice == "auto"`` (or unset) — pass all tools through. + - ``tool_choice == "required"`` — pass tools through, log that we cannot + strictly enforce a tool call without constrained decoding. + - ``tool_choice == {"type": "function", "function": {"name": "X"}}`` — + filter ``tools`` to that single function and warn that the call cannot + be strictly enforced. + """ + if not tools: + return None + if tool_choice in (None, "auto"): + return tools + if tool_choice == "none": + return None + if tool_choice == "required": + logger.warning( + "tool_choice='required' requested but this loader cannot enforce a tool call; " + "passing all tools to the model and trusting it to call one" + ) + return tools + if isinstance(tool_choice, dict): + fn = tool_choice.get("function") or {} + name = fn.get("name") if isinstance(fn, dict) else None + if isinstance(name, str) and name: + filtered = [t for t in tools if (t.get("function") or {}).get("name") == name] + if not filtered: + logger.warning("tool_choice names function %r which is not in the request's tools list", name) + return tools + logger.warning( + "tool_choice forcing function %r is not strictly enforced by this loader; " + "passing only that tool to the model", + name, + ) + return filtered + logger.warning("unrecognized tool_choice value %r; falling back to 'auto' semantics", tool_choice) + return tools diff --git a/modelship/openai/tool_calling/parsers/__init__.py b/modelship/openai/tool_calling/parsers/__init__.py new file mode 100644 index 0000000..1df7236 --- /dev/null +++ b/modelship/openai/tool_calling/parsers/__init__.py @@ -0,0 +1,4 @@ +from modelship.openai.tool_calling.parsers.base import ParsedToolCalls, ToolCallParser, ToolCallStreamer +from modelship.openai.tool_calling.parsers.hermes import HermesToolCallParser + +__all__ = ["HermesToolCallParser", "ParsedToolCalls", "ToolCallParser", "ToolCallStreamer"] diff --git a/modelship/openai/tool_calling/parsers/base.py b/modelship/openai/tool_calling/parsers/base.py new file mode 100644 index 0000000..e8197ed --- /dev/null +++ b/modelship/openai/tool_calling/parsers/base.py @@ -0,0 +1,291 @@ +"""Base class for model-family-specific tool-call output parsers. + +Parsers in this codebase are *marker-based*: each family wraps tool-call +JSON in a fixed pair of literal strings (```` / ```` +for Hermes, ``[TOOL_CALLS]`` / closing token for Mistral, ...). A subclass +declares the marker pair plus two small extractors that pick a function +name and an arguments substring out of the (possibly partial) JSON between +the markers. Both the streaming and non-streaming paths run the same +:class:`ToolCallStreamer` so behavior cannot drift between them. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass + +from modelship.openai.protocol import ( + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + FunctionCall, + ToolCall, + random_uuid, +) + + +@dataclass(frozen=True) +class ParsedToolCalls: + """Aggregate result of parsing a model's full chat-completion text. + + ``content`` carries the residual non-tool-call text once any tool-call + markers are stripped. It is ``None`` when tool calls were extracted *and* + the residual is empty, matching OpenAI's behavior of nulling ``content`` + alongside ``tool_calls``. + """ + + content: str | None + tool_calls: list[ToolCall] + + @property + def has_tool_calls(self) -> bool: + return bool(self.tool_calls) + + +class ToolCallParser(ABC): + """Family-specific knobs the streamer needs to drive its diff loop. + + Subclasses set ``start_marker`` / ``end_marker`` and implement the two + extractors. They never touch streaming state — that lives on + :class:`ToolCallStreamer`, which can be instantiated once per request. + """ + + name: str + start_marker: str + end_marker: str + + @abstractmethod + def extract_partial_name(self, partial_payload: str) -> str | None: + """Return the function name if a complete quoted name is visible yet, else ``None``. + + Called every delta until it yields a non-``None`` result; once a name + has been emitted to the client, the streamer stops asking and starts + forwarding arguments bytes for that tool. + """ + + @abstractmethod + def extract_partial_args(self, partial_payload: str) -> str | None: + """Return the arguments substring as the client should see it so far. + + The streamer takes a length-diff of successive returns and forwards + only the new bytes. Implementations must withhold any trailing bytes + that could plausibly be the envelope closer landing ahead of the + family's end marker — otherwise those bytes leak into the args + stream and the client receives malformed JSON. + """ + + def parse(self, text: str) -> ParsedToolCalls: + streamer = ToolCallStreamer(self) + streamer.extract_streaming(text) + streamer.finalize() + return streamer.result + + +class ToolCallStreamer: + """Per-request, stateful tool-call extractor. + + Mirrors vLLM's approach: hold a small amount of "what we've already sent" + state, re-parse the cumulative ``current_text`` on every delta, and diff. + Returns either a :class:`DeltaMessage` carrying any newly-emittable + content / tool-call fragments or ``None`` if there is nothing to send yet. + + State held per request: + + - ``_sent_content_idx`` — number of content-stream chars already shipped. + The "content stream" view is the original text with every (complete or + open) tool-call region excised. + - ``_sent_name`` / ``_sent_id`` per tool index — whether the function name + and id deltas have been emitted. + - ``_sent_args`` per tool index — number of arguments chars already + shipped (the suffix-diff cursor). + - ``_finalized_indices`` — set of block indices that have been finalized. + - ``_finalized_calls`` — :class:`ToolCall` objects accumulated for blocks + that have closed, used to populate :attr:`result` for the non-streaming + path and for the final ``finish_reason``. + """ + + def __init__(self, parser: ToolCallParser): + self._parser = parser + self._start = parser.start_marker + self._end = parser.end_marker + self._sent_content_idx = 0 + self._sent_name: list[bool] = [] + self._sent_id: list[str] = [] + self._sent_args: list[str] = [] + self._finalized_indices: set[int] = set() + self._finalized_calls: list[ToolCall] = [] + self._last_text = "" + + def extract_streaming(self, current_text: str) -> DeltaMessage | None: + """Run one diff pass against ``current_text`` and return any new deltas.""" + self._last_text = current_text + content_delta = self._emit_new_content(current_text, hold_marker_tail=True) + tool_call_deltas = self._emit_new_tool_call_fragments(current_text) + if not content_delta and not tool_call_deltas: + return None + return DeltaMessage(content=content_delta, tool_calls=tool_call_deltas) + + def finalize(self) -> DeltaMessage | None: + """Flush any held-back content tail once no more text is coming.""" + content_delta = self._emit_new_content(self._last_text, hold_marker_tail=False) + if content_delta is None: + return None + return DeltaMessage(content=content_delta) + + @property + def result(self) -> ParsedToolCalls: + """Final view, suitable for the non-streaming response shape.""" + # The content-view we accumulated as we streamed; reconstruct from the cursor. + view = self._build_content_view(self._last_text, hold_marker_tail=False) + content = (view.strip() or None) if self._finalized_calls else (view or None) + return ParsedToolCalls(content=content, tool_calls=list(self._finalized_calls)) + + # ------------------------------------------------------------------ + # Content stream + # ------------------------------------------------------------------ + + def _emit_new_content(self, current_text: str, *, hold_marker_tail: bool) -> str | None: + view = self._build_content_view(current_text, hold_marker_tail=hold_marker_tail) + if len(view) <= self._sent_content_idx: + return None + new = view[self._sent_content_idx :] + self._sent_content_idx = len(view) + return new or None + + def _build_content_view(self, text: str, *, hold_marker_tail: bool) -> str: + """Build the content-stream view: the original text with tool-call regions excised. + + When ``hold_marker_tail`` is true (mid-stream), withhold a trailing + suffix that could be the start of a new ``start_marker``, so the + client never sees half of an opening tag. At finalize time we know + no more text is coming, so the held-back tail is safe to flush. + """ + parts: list[str] = [] + pos = 0 + while pos < len(text): + start = text.find(self._start, pos) + if start < 0: + remainder = text[pos:] + if hold_marker_tail: + safe = _safe_outside_flush_index(remainder, self._start) + parts.append(remainder[:safe]) + else: + parts.append(remainder) + break + parts.append(text[pos:start]) + payload_start = start + len(self._start) + end = text.find(self._end, payload_start) + if end < 0: + # Open block, not yet closed — nothing more to append to content. + break + pos = end + len(self._end) + return "".join(parts) + + # ------------------------------------------------------------------ + # Tool-call fragments + # ------------------------------------------------------------------ + + def _emit_new_tool_call_fragments(self, current_text: str) -> list[DeltaToolCall]: + deltas: list[DeltaToolCall] = [] + for i, (payload, is_complete) in enumerate(self._iter_tool_call_blocks(current_text)): + self._ensure_slot(i) + + if not self._sent_name[i]: + name = self._parser.extract_partial_name(payload) + if name is None: + if is_complete: + # This block is finished but has no extractable name (malformed). + # Skip it so we can potentially process later blocks. + continue + # Mid-stream and no name yet; per OpenAI convention we don't + # advance to later blocks until the current one has a name. + break + tool_id = f"chatcmpl-tool-{random_uuid()}" + self._sent_name[i] = True + self._sent_id[i] = tool_id + deltas.append( + DeltaToolCall( + index=i, + id=tool_id, + type="function", + function=DeltaFunctionCall(name=name), + ) + ) + + args = self._parser.extract_partial_args(payload) + if args is not None and len(args) > len(self._sent_args[i]): + diff = args[len(self._sent_args[i]) :] + self._sent_args[i] = args + deltas.append( + DeltaToolCall( + index=i, + function=DeltaFunctionCall(arguments=diff), + ) + ) + + if is_complete and i not in self._finalized_indices and self._sent_name[i]: + self._finalized_indices.add(i) + self._finalized_calls.append( + ToolCall( + id=self._sent_id[i], + type="function", + function=FunctionCall( + name=self._extract_committed_name(i), + arguments=self._sent_args[i], + ), + ) + ) + return deltas + + def _ensure_slot(self, i: int) -> None: + while len(self._sent_name) <= i: + self._sent_name.append(False) + self._sent_id.append("") + self._sent_args.append("") + + def _extract_committed_name(self, i: int) -> str: + # We only stash the bool that the name was sent, not the value, so + # re-derive from the current payload (cheap, the regex is small). + for j, (payload, _) in enumerate(self._iter_tool_call_blocks(self._last_text)): + if j == i: + name = self._parser.extract_partial_name(payload) + return name or "" + return "" + + def _iter_tool_call_blocks(self, text: str): + """Yield ``(partial_payload, is_complete)`` for each tool-call region in order. + + For the still-open final block, the partial payload has any tail + suffix that could be the start of ``end_marker`` withheld, so the + client never sees a fragment of the closing tag forwarded as + argument bytes. + """ + pos = 0 + while True: + start = text.find(self._start, pos) + if start < 0: + return + payload_start = start + len(self._start) + end = text.find(self._end, payload_start) + if end < 0: + partial = text[payload_start:] + safe = _safe_outside_flush_index(partial, self._end) + yield partial[:safe], False + return + yield text[payload_start:end], True + pos = end + len(self._end) + + +def _safe_outside_flush_index(buf: str, start_marker: str) -> int: + """Index up to which ``buf`` can be flushed without risking a split marker. + + ``buf`` is known not to contain the full marker. The unsafe tail is the + longest proper-prefix overlap between ``buf`` and ``start_marker``: if + the next chunk completes that prefix, we'd have to retract bytes we + already streamed. Holding them back avoids the retraction. + """ + max_overlap = min(len(buf), len(start_marker) - 1) + for k in range(max_overlap, 0, -1): + if buf.endswith(start_marker[:k]): + return len(buf) - k + return len(buf) diff --git a/modelship/openai/tool_calling/parsers/hermes.py b/modelship/openai/tool_calling/parsers/hermes.py new file mode 100644 index 0000000..97487db --- /dev/null +++ b/modelship/openai/tool_calling/parsers/hermes.py @@ -0,0 +1,44 @@ +"""Hermes-style ``{json}`` parser. + +Used by Hermes-2-Pro, Qwen2.5-Instruct, and a large family of NousResearch / +community fine-tunes whose chat templates wrap each tool call in the literal +tags ```` / ```` around a JSON object of the shape +``{"name": "...", "arguments": {...}}``. +""" + +from __future__ import annotations + +import re + +from modelship.openai.tool_calling.parsers.base import ToolCallParser + + +class HermesToolCallParser(ToolCallParser): + name = "hermes" + start_marker = "" + end_marker = "" + + _NAME_RE = re.compile(r'"name"\s*:\s*"([^"]+)"') + _ARGS_RE = re.compile(r'"arguments"\s*:\s*') + + def extract_partial_name(self, partial_payload: str) -> str | None: + m = self._NAME_RE.search(partial_payload) + return m.group(1) if m else None + + def extract_partial_args(self, partial_payload: str) -> str | None: + m = self._ARGS_RE.search(partial_payload) + if m is None: + return None + args = partial_payload[m.end() :].rstrip() + if args.endswith("}"): + # The block envelope is `{"name":"x","arguments":}`. The + # closing brace of the envelope arrives in the byte stream before + # `` does, so we cannot tell whether any given + # trailing `}` belongs to the args object or to the envelope. + # Withholding one trailing `}` keeps the args stream well-formed: + # if the model goes on to emit more args bytes, the held brace is + # recovered on the next pass; if instead it goes on to emit + # ``, the held brace was the envelope closer and + # discarding it was correct. + args = args[:-1].rstrip() + return args or None diff --git a/modelship/openai/tool_calling/registry.py b/modelship/openai/tool_calling/registry.py new file mode 100644 index 0000000..68b8e28 --- /dev/null +++ b/modelship/openai/tool_calling/registry.py @@ -0,0 +1,34 @@ +"""Registry of named tool-call parsers, dispatched by configuration string. + +The registry is the single seam between a loader and the per-family parsers. +Loaders that emit raw text (Transformers, plugin-wrapped engines) look up a +parser by the name configured on the deployment and feed it the model's +output. Loaders with native tool-call support (vLLM, llama.cpp via a +function-calling chat handler) bypass the registry entirely. +""" + +from __future__ import annotations + +from modelship.openai.tool_calling.parsers import HermesToolCallParser, ToolCallParser + +_PARSERS: dict[str, ToolCallParser] = { + HermesToolCallParser.name: HermesToolCallParser(), +} + + +def get_parser(name: str) -> ToolCallParser: + """Return the parser registered under ``name`` or raise ``ValueError``.""" + try: + return _PARSERS[name] + except KeyError: + available = ", ".join(sorted(_PARSERS)) or "(none)" + raise ValueError(f"unknown tool_call_parser {name!r}; available: {available}") from None + + +def available_parsers() -> list[str]: + return sorted(_PARSERS) + + +def register_parser(parser: ToolCallParser) -> None: + """Register an additional parser. Intended for tests and plugin code.""" + _PARSERS[parser.name] = parser diff --git a/tests/test_integration.py b/tests/test_integration.py index 47f276c..dc317bc 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,3 +1,4 @@ +import json import subprocess import time @@ -9,7 +10,14 @@ OPENAI_API_BASE = "http://localhost:8000/v1" -EXPECTED_MODELS = {"chat-capable", "chat-limited", "embed-model", "stt-model", "tts-model"} +EXPECTED_MODELS = { + "chat-capable", + "chat-limited", + "chat-transformers", + "embed-model", + "stt-model", + "tts-model", +} @pytest.fixture(scope="session") @@ -45,6 +53,20 @@ def mship_cluster(tmp_path_factory): "loader": "llama_cpp", "num_cpus": 1, }, + { + # Same Qwen2.5-Instruct family as `chat-capable` so we exercise + # the transformers loader against a model trained to emit + # Hermes-style `{...}` markers. + "name": "chat-transformers", + "model": "Qwen/Qwen2.5-0.5B-Instruct", + "usecase": "generate", + "loader": "transformers", + "num_cpus": 2, + "transformers_config": { + "device": "cpu", + "torch_dtype": "float32", + }, + }, { "name": "embed-model", "model": "nomic-ai/nomic-embed-text-v1.5", @@ -140,6 +162,7 @@ def test_list_models(client): model_ids = [m.id for m in models.data] assert "chat-capable" in model_ids assert "chat-limited" in model_ids + assert "chat-transformers" in model_ids assert "embed-model" in model_ids assert "stt-model" in model_ids assert "tts-model" in model_ids @@ -191,6 +214,182 @@ def test_tool_calling_success(client): assert completion.choices[0].message.tool_calls[0].function.name == "get_weather" +@pytest.mark.integration +def test_tool_calling_transformers_loader(client): + """Round-trip a Hermes-style tool call through the transformers loader. + + Uses the same Qwen2.5-0.5B-Instruct weights as the vLLM `chat-capable` + deployment but goes through the modelship-side tool-calling toolkit + (apply_chat_template(tools=...) on input, hermes parser on output). + """ + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather for a city", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + }, + } + ] + completion = client.chat.completions.create( + model="chat-transformers", + messages=[{"role": "user", "content": "What is the weather in Paris?"}], + tools=tools, + tool_choice="auto", + max_tokens=128, + ) + tool_calls = completion.choices[0].message.tool_calls + assert tool_calls, f"expected a tool call, got content={completion.choices[0].message.content!r}" + assert tool_calls[0].function.name == "get_weather" + assert "Paris" in tool_calls[0].function.arguments + assert completion.choices[0].finish_reason == "tool_calls" + + +def _collect_streaming_tool_call(stream) -> dict: + """Drain an OpenAI streaming response and rebuild the assistant message. + + Returns a dict with: ``content`` (concatenated content deltas), + ``tool_calls`` (per-index dict of ``{id, name, arguments}`` — arguments + concatenated across all fragments), ``finish_reason``, ``name_deltas`` + and ``args_deltas`` (counts, used to assert that streaming was actually + incremental rather than a single buffered emission). + """ + content_parts: list[str] = [] + tool_calls: dict[int, dict] = {} + finish_reason: str | None = None + name_deltas = 0 + args_deltas = 0 + chunks_with_tool_calls = 0 + + for chunk in stream: + choice = chunk.choices[0] + delta = choice.delta + if delta.content: + content_parts.append(delta.content) + if delta.tool_calls: + chunks_with_tool_calls += 1 + for tc in delta.tool_calls: + slot = tool_calls.setdefault(tc.index, {"id": None, "name": None, "arguments": ""}) + if tc.id is not None: + slot["id"] = tc.id + if tc.function and tc.function.name: + slot["name"] = tc.function.name + name_deltas += 1 + if tc.function and tc.function.arguments: + slot["arguments"] += tc.function.arguments + args_deltas += 1 + if choice.finish_reason is not None: + finish_reason = choice.finish_reason + + return { + "content": "".join(content_parts), + "tool_calls": tool_calls, + "finish_reason": finish_reason, + "name_deltas": name_deltas, + "args_deltas": args_deltas, + "chunks_with_tool_calls": chunks_with_tool_calls, + } + + +@pytest.mark.integration +def test_tool_calling_streaming_transformers_loader(client): + """Stream a tool call through the transformers loader and verify the + delta sequence matches the OpenAI streaming contract. + + Asserts: + - the function name arrives in exactly one delta; + - arguments arrive across multiple deltas (incremental, not buffered); + - concatenated arguments form valid JSON containing the expected key; + - the final delta carries ``finish_reason="tool_calls"``. + """ + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather for a city", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + }, + } + ] + stream = client.chat.completions.create( + model="chat-transformers", + messages=[{"role": "user", "content": "What is the weather in Paris?"}], + tools=tools, + tool_choice="auto", + max_tokens=128, + stream=True, + ) + + collected = _collect_streaming_tool_call(stream) + + assert collected["tool_calls"], f"expected at least one streamed tool call; got content={collected['content']!r}" + call_0 = collected["tool_calls"][0] + assert call_0["id"], "expected an id on the first tool-call delta" + assert call_0["name"] == "get_weather" + # Name must be sent exactly once (not on every delta). + assert collected["name_deltas"] == 1, f"expected one name delta, got {collected['name_deltas']}" + # Arguments must arrive incrementally across multiple deltas — that's the + # whole point of switching from block-level buffering to vLLM-style + # diff streaming. Exact count depends on the model, but it must be > 1. + assert collected["args_deltas"] >= 2, ( + f"expected arguments to stream incrementally, got {collected['args_deltas']} args delta(s)" + ) + # Concatenated args must form valid JSON containing the city. + parsed_args = json.loads(call_0["arguments"]) + assert parsed_args.get("city") + assert "Paris" in parsed_args["city"] + assert collected["finish_reason"] == "tool_calls" + + +@pytest.mark.integration +def test_tool_calling_streaming_vllm_loader(client): + """Smoke-test that vLLM streaming + tool calling still works through + the gateway. vLLM emits its own per-token deltas; we only verify that + the gateway forwards them and that the final assistant message rebuilds + correctly. + """ + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather for a city", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + }, + } + ] + stream = client.chat.completions.create( + model="chat-capable", + messages=[{"role": "user", "content": "What is the weather in Paris?"}], + tools=tools, + tool_choice="required", + stream=True, + ) + + collected = _collect_streaming_tool_call(stream) + + assert collected["tool_calls"], "vLLM should have streamed at least one tool call" + call_0 = collected["tool_calls"][0] + assert call_0["name"] == "get_weather" + parsed_args = json.loads(call_0["arguments"]) + assert "Paris" in parsed_args.get("city", "") + assert collected["finish_reason"] == "tool_calls" + + @pytest.mark.integration def test_tool_calling_unsupported_loader(client): """Verifies that loaders without tool support (like llama_cpp) don't return tool calls.""" diff --git a/tests/test_tool_calling.py b/tests/test_tool_calling.py new file mode 100644 index 0000000..98f50b1 --- /dev/null +++ b/tests/test_tool_calling.py @@ -0,0 +1,286 @@ +"""Tests for the cross-loader tool-calling toolkit.""" + +from __future__ import annotations + +import json +from typing import ClassVar + +import pytest + +from modelship.openai.tool_calling import ( + ToolCallStreamer, + available_parsers, + get_parser, + register_parser, + resolve_tools_for_request, +) +from modelship.openai.tool_calling.parsers import HermesToolCallParser, ToolCallParser + + +class TestRegistry: + def test_default_registry_includes_hermes(self): + assert "hermes" in available_parsers() + + def test_get_parser_returns_singleton(self): + a = get_parser("hermes") + b = get_parser("hermes") + assert a is b + + def test_unknown_parser_raises_with_available_list(self): + with pytest.raises(ValueError, match="hermes"): + get_parser("does-not-exist") + + def test_register_parser_makes_it_findable(self): + class Stub(ToolCallParser): + name = "stub-test-parser" + start_marker = "<<" + end_marker = ">>" + + def extract_partial_name(self, partial_payload: str) -> str | None: + return None + + def extract_partial_args(self, partial_payload: str) -> str | None: + return None + + register_parser(Stub()) + try: + assert get_parser("stub-test-parser").name == "stub-test-parser" + finally: + # Clean up so other tests don't see the stub. + from modelship.openai.tool_calling import registry + + registry._PARSERS.pop("stub-test-parser", None) + + +class TestHermesParser: + parser = HermesToolCallParser() + + def test_no_tool_calls_returns_text_unchanged(self): + result = self.parser.parse("just a regular response") + assert result.tool_calls == [] + assert result.content == "just a regular response" + assert result.has_tool_calls is False + + def test_single_tool_call(self): + text = '{"name": "get_weather", "arguments": {"city": "Paris"}}' + result = self.parser.parse(text) + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].function.name == "get_weather" + assert json.loads(result.tool_calls[0].function.arguments) == {"city": "Paris"} + assert result.content is None + + def test_multiple_tool_calls(self): + text = ( + '{"name": "a", "arguments": {"x": 1}}' + '{"name": "b", "arguments": {"y": 2}}' + ) + result = self.parser.parse(text) + assert [tc.function.name for tc in result.tool_calls] == ["a", "b"] + + def test_tool_call_with_residual_text(self): + text = 'Sure, calling that.\n{"name": "ping", "arguments": {}}' + result = self.parser.parse(text) + assert len(result.tool_calls) == 1 + assert result.content == "Sure, calling that." + + def test_string_arguments_forwarded_verbatim(self): + # vLLM-style: the streamer forwards the raw bytes of the arguments + # value as the model emitted them, including any surrounding quotes + # if the model wrapped its arguments in a JSON string literal. The + # OpenAI streaming contract treats `arguments` as an opaque string + # the client concatenates and parses. + text = '{"name": "x", "arguments": "{\\"a\\": 1}"}' + result = self.parser.parse(text) + assert result.tool_calls[0].function.arguments == '"{\\"a\\": 1}"' + + def test_object_arguments_passed_through(self): + text = '{"name": "x", "arguments": {"a": 1, "b": [2, 3]}}' + result = self.parser.parse(text) + assert json.loads(result.tool_calls[0].function.arguments) == {"a": 1, "b": [2, 3]} + + def test_block_without_extractable_name_is_dropped(self): + # When the block contains nothing the name regex can hook onto we + # silently drop it — there is nothing to tell the client about. + text = "{not valid json}" + result = self.parser.parse(text) + assert result.tool_calls == [] + + def test_missing_name_drops_call(self): + text = '{"arguments": {}}' + result = self.parser.parse(text) + assert result.tool_calls == [] + + def test_empty_name_drops_call(self): + text = '{"name": "", "arguments": {}}' + result = self.parser.parse(text) + assert result.tool_calls == [] + + def test_each_tool_call_gets_unique_id(self): + text = ( + '{"name": "a", "arguments": {}}{"name": "b", "arguments": {}}' + ) + result = self.parser.parse(text) + assert result.tool_calls[0].id != result.tool_calls[1].id + + +class TestToolCallStreamer: + """Drive the streamer one chunk at a time and verify the deltas. + + The cumulative-text protocol matches what serving_chat does in production: + every fed string is the *full* generated text so far, not just the latest + delta. + """ + + def _feed(self, chunks: list[str]) -> tuple[ToolCallStreamer, list]: + streamer = ToolCallStreamer(HermesToolCallParser()) + deltas = [] + cumulative = "" + for chunk in chunks: + cumulative += chunk + d = streamer.extract_streaming(cumulative) + if d is not None: + deltas.append(d) + final = streamer.finalize() + if final is not None: + deltas.append(final) + return streamer, deltas + + def test_pure_content_streams_immediately(self): + _, deltas = self._feed(["Hello", " ", "world"]) + assert "".join(d.content or "" for d in deltas) == "Hello world" + assert all(not d.tool_calls for d in deltas) + + def test_holds_back_marker_prefix_in_content_until_finalize(self): + # `<` could be the first char of ``; the streamer must not + # ship it mid-stream. Once finalize() runs (no more text coming) the + # held tail is safe to flush as content. + streamer = ToolCallStreamer(HermesToolCallParser()) + mid = streamer.extract_streaming("before <") + assert mid is not None and mid.content == "before " + final = streamer.finalize() + assert final is not None and final.content == "<" + + def test_held_tail_flushes_when_disambiguated(self): + # `{"name": "get_weather", "arguments": {"city": "Paris"}}') + _, deltas = self._feed(chunks) + + tool_deltas = [tc for d in deltas for tc in d.tool_calls] + # First tool delta: name + id, no arguments. + assert tool_deltas[0].function is not None + assert tool_deltas[0].function.name == "get_weather" + assert tool_deltas[0].id is not None + assert tool_deltas[0].function.arguments is None + # Subsequent deltas carry arguments fragments only (no name). + arg_deltas = tool_deltas[1:] + assert all(d.function and d.function.name is None for d in arg_deltas) + # Concatenated arguments form valid JSON. + joined_args = "".join(d.function.arguments or "" for d in arg_deltas) + assert json.loads(joined_args) == {"city": "Paris"} + + def test_arguments_stream_incrementally(self): + # Feed the args char-by-char; each char-after-name should generate + # an args delta of length 1 (or close to it). + prefix = '{"name": "ping", "arguments": ' + suffix = '{"x": 42}}' + _, deltas = self._feed([prefix, *list(suffix)]) + + arg_deltas = [tc for d in deltas for tc in d.tool_calls if tc.function and tc.function.arguments is not None] + assert len(arg_deltas) >= 3 # incremental, not one big shot + joined = "".join(d.function.arguments or "" for d in arg_deltas) + assert json.loads(joined) == {"x": 42} + + def test_multiple_tool_calls_get_distinct_indices(self): + text = ( + '{"name": "a", "arguments": {"x": 1}}' + '{"name": "b", "arguments": {"y": 2}}' + ) + _, deltas = self._feed([text]) + + tool_deltas = [tc for d in deltas for tc in d.tool_calls] + indices = {d.index for d in tool_deltas} + assert indices == {0, 1} + names = [d.function.name for d in tool_deltas if d.function and d.function.name] + assert names == ["a", "b"] + + def test_content_after_tool_call_resumes_streaming(self): + text = '{"name": "p", "arguments": {}} ok' + _, deltas = self._feed([text]) + joined_content = "".join(d.content or "" for d in deltas) + assert "ok" in joined_content + + def test_partial_name_held_until_closing_quote(self): + # While the model is mid-name (``"name": "get_wea`` so far), the + # streamer must NOT send a partial name — wait for the closing quote. + streamer = ToolCallStreamer(HermesToolCallParser()) + partial = '{"name": "get_wea' + d = streamer.extract_streaming(partial) + # No name yet → no tool delta. + assert d is None or all(not tc.function or not tc.function.name for tc in d.tool_calls) + + def test_unterminated_block_is_finalized_without_crash(self): + # Model stops mid tool-call. The streamer should not raise on finalize + # and should not claim a finalized ToolCall (the call wasn't closed). + text = '{"name": "incomplete", "arguments": {"a": 1' + streamer, _ = self._feed([text]) + assert streamer.result.tool_calls == [] + + def test_skipped_malformed_block_does_not_block_subsequent_finalization(self): + # If the first tool-call block is malformed (missing name), it should be + # skipped, but the SECOND block (if valid) should still be finalized. + text = ( + '{"arguments": {"skipped": true}}' + '{"name": "valid", "arguments": {"x": 1}}' + ) + streamer, _ = self._feed([text]) + assert [tc.function.name for tc in streamer.result.tool_calls] == ["valid"] + + +class TestResolveToolsForRequest: + tools: ClassVar = [ + {"type": "function", "function": {"name": "alpha"}}, + {"type": "function", "function": {"name": "beta"}}, + ] + + def test_no_tools_returns_none(self): + assert resolve_tools_for_request(None, "auto") is None + assert resolve_tools_for_request([], "auto") is None + + def test_auto_passes_through(self): + assert resolve_tools_for_request(self.tools, "auto") == self.tools + + def test_unset_tool_choice_passes_through(self): + assert resolve_tools_for_request(self.tools, None) == self.tools + + def test_none_suppresses_tools(self): + assert resolve_tools_for_request(self.tools, "none") is None + + def test_required_passes_through(self): + # We cannot strictly enforce a tool call without constrained decoding, + # so "required" downgrades to "auto" semantics (with a logged warning). + assert resolve_tools_for_request(self.tools, "required") == self.tools + + def test_specific_function_filters_to_that_tool(self): + result = resolve_tools_for_request(self.tools, {"type": "function", "function": {"name": "beta"}}) + assert result is not None + assert len(result) == 1 + assert result[0]["function"]["name"] == "beta" + + def test_unknown_function_falls_back_to_all(self): + # If the named function isn't in the tools list, fall back to passing + # them all through rather than emitting an empty list. + result = resolve_tools_for_request(self.tools, {"type": "function", "function": {"name": "missing"}}) + assert result == self.tools + + def test_unrecognized_choice_falls_back_to_all(self): + result = resolve_tools_for_request(self.tools, "weird-mode") + assert result == self.tools diff --git a/tests/test_transformers_chat_tools.py b/tests/test_transformers_chat_tools.py new file mode 100644 index 0000000..003a909 --- /dev/null +++ b/tests/test_transformers_chat_tools.py @@ -0,0 +1,163 @@ +"""Tests for the Transformers chat path's tool-call handling. + +These tests bypass the real HF pipeline by injecting a callable that returns +a canned generation, so they run offline and do not touch any model weights. +""" + +from __future__ import annotations + +import json +from typing import Any + +import pytest + +from modelship.infer.infer_config import RawRequestProxy, TransformersConfig +from modelship.infer.transformers.capabilities import TransformersCapabilities +from modelship.infer.transformers.openai.serving_chat import OpenAIServingChat +from modelship.openai.protocol import ChatCompletionRequest, ChatCompletionResponse + + +class _FakeTokenizer: + def encode(self, text: str, add_special_tokens: bool = False) -> list[int]: + return [0] * len(text.split()) + + def apply_chat_template(self, messages: list[dict], **kwargs: Any) -> Any: + prompt = "\n".join(f"{m['role']}: {m.get('content', '')}" for m in messages) + if "tools" in kwargs: + prompt = f"[TOOLS:{len(kwargs['tools'])}]\n" + prompt + if kwargs.get("tokenize"): + return [0] * len(prompt.split()) + return prompt + + +class _FakePipeline: + """Stand-in for ``transformers.Pipeline`` that records calls and replays canned output.""" + + def __init__(self, generated_text: str): + self.tokenizer = _FakeTokenizer() + self.task = "text-generation" + self.generated_text = generated_text + self.last_input: Any = None + self.last_kwargs: dict[str, Any] = {} + + def __call__(self, inputs: Any, **kwargs: Any) -> list[dict]: + self.last_input = inputs + self.last_kwargs = kwargs + return [{"generated_text": self.generated_text}] + + +def _make_serving(generated: str) -> tuple[OpenAIServingChat, _FakePipeline]: + pipe = _FakePipeline(generated) + serving = OpenAIServingChat( + pipeline=pipe, # type: ignore[arg-type] + model_name="test-model", + config=TransformersConfig(), + capabilities=TransformersCapabilities(supports_image=False, supports_audio=False), + ) + return serving, pipe + + +def _raw_request() -> RawRequestProxy: + return RawRequestProxy(None, {}) + + +@pytest.mark.asyncio +async def test_response_without_tools_carries_content_only(): + serving, _ = _make_serving("hello there") + req = ChatCompletionRequest(messages=[{"role": "user", "content": "hi"}], stream=False) + resp = await serving.create_chat_completion(req, _raw_request()) + + assert isinstance(resp, ChatCompletionResponse) + msg = resp.choices[0].message + assert msg.content == "hello there" + assert msg.tool_calls == [] + assert resp.choices[0].finish_reason == "stop" + + +@pytest.mark.asyncio +async def test_tools_only_render_when_requested(): + serving, pipe = _make_serving("hello") + req = ChatCompletionRequest(messages=[{"role": "user", "content": "hi"}], stream=False) + await serving.create_chat_completion(req, _raw_request()) + + # Without `tools` in the request, the pipeline receives the message list + # directly — no pre-rendered prompt. + assert isinstance(pipe.last_input, list) + + +@pytest.mark.asyncio +async def test_tools_in_request_pre_renders_prompt_and_parses_tool_call(): + raw = '{"name": "get_weather", "arguments": {"city": "Paris"}}' + serving, pipe = _make_serving(raw) + req = ChatCompletionRequest( + messages=[{"role": "user", "content": "weather in paris?"}], + tools=[ + { + "type": "function", + "function": {"name": "get_weather", "parameters": {"type": "object"}}, + } + ], + tool_choice="auto", + stream=False, + ) + resp = await serving.create_chat_completion(req, _raw_request()) + + assert isinstance(resp, ChatCompletionResponse) + # Pre-rendered prompt is a string carrying the tool marker injected by our fake template. + assert isinstance(pipe.last_input, str) + assert pipe.last_input.startswith("[TOOLS:1]") + + msg = resp.choices[0].message + assert msg.content is None + assert len(msg.tool_calls) == 1 + assert msg.tool_calls[0].function.name == "get_weather" + assert json.loads(msg.tool_calls[0].function.arguments) == {"city": "Paris"} + assert resp.choices[0].finish_reason == "tool_calls" + + +@pytest.mark.asyncio +async def test_tool_choice_none_skips_tool_rendering(): + serving, pipe = _make_serving("regular reply") + req = ChatCompletionRequest( + messages=[{"role": "user", "content": "hi"}], + tools=[{"type": "function", "function": {"name": "noop"}}], + tool_choice="none", + stream=False, + ) + resp = await serving.create_chat_completion(req, _raw_request()) + + # tool_choice="none" — pipeline should receive messages, not a rendered prompt. + assert isinstance(pipe.last_input, list) + assert isinstance(resp, ChatCompletionResponse) + assert resp.choices[0].message.content == "regular reply" + assert resp.choices[0].message.tool_calls == [] + assert resp.choices[0].finish_reason == "stop" + + +@pytest.mark.asyncio +async def test_tool_call_with_trailing_text_preserves_content(): + raw = 'Calling now.\n{"name": "ping", "arguments": {}}' + serving, _ = _make_serving(raw) + req = ChatCompletionRequest( + messages=[{"role": "user", "content": "ping?"}], + tools=[{"type": "function", "function": {"name": "ping"}}], + stream=False, + ) + resp = await serving.create_chat_completion(req, _raw_request()) + + assert isinstance(resp, ChatCompletionResponse) + msg = resp.choices[0].message + assert msg.content == "Calling now." + assert len(msg.tool_calls) == 1 + + +@pytest.mark.asyncio +async def test_unknown_parser_at_init_raises(): + pipe = _FakePipeline("anything") + with pytest.raises(ValueError, match="unknown tool_call_parser"): + OpenAIServingChat( + pipeline=pipe, # type: ignore[arg-type] + model_name="test-model", + config=TransformersConfig(tool_call_parser="not-a-real-parser"), + capabilities=TransformersCapabilities(supports_image=False, supports_audio=False), + )