Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions docs/model-configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ The `transformers` loader uses PyTorch with HuggingFace Transformers. Supports c
| `trust_remote_code` | bool | `false` | Allow remote code execution |
| `model_kwargs` | object | `{}` | Extra keyword arguments passed to the model constructor |
| `pipeline_kwargs` | object | `{}` | Extra keyword arguments passed to the pipeline at inference time |
| `tool_call_parser` | string | `hermes` | Parser used to turn raw model output into OpenAI `tool_calls`. Currently supported: `hermes` (Hermes-2-Pro / Qwen2.5-Instruct / many community fine-tunes that emit `<tool_call>{...}</tool_call>` markers). |

### Chat / Text Generation (CPU)

Expand All @@ -233,6 +234,25 @@ models:
device: "cpu"
```

### Chat with Tool Calling (CPU)

The transformers loader renders `tools` into the prompt via the model's chat
template and parses the output back into OpenAI `tool_calls`. The model must
have been trained on a Hermes-style tool format (Qwen2.5-Instruct, Hermes-2,
many community fine-tunes); the parser is selected via `tool_call_parser`.

```yaml
models:
- name: qwen-tools
model: Qwen/Qwen2.5-0.5B-Instruct
usecase: generate
loader: transformers
num_cpus: 2
transformers_config:
device: "cpu"
tool_call_parser: hermes # this is the default; shown for clarity
```

### Speech-to-Text (CPU)

Audio is automatically decoded and resampled to the model's expected sample rate (e.g. 16kHz for Whisper).
Expand Down
152 changes: 109 additions & 43 deletions modelship/infer/transformers/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
from collections.abc import AsyncGenerator
from threading import Thread
from typing import Any

from transformers import Pipeline, PreTrainedTokenizerBase, TextIteratorStreamer

Expand All @@ -23,6 +24,7 @@
UsageInfo,
create_error_response,
)
from modelship.openai.tool_calling import ParsedToolCalls, ToolCallStreamer, get_parser, resolve_tools_for_request
from modelship.utils import base_request_id

logger = get_logger("infer.transformers.chat")
Expand All @@ -45,12 +47,16 @@ def __init__(
assert pipeline.tokenizer is not None, "text-generation pipeline must have a tokenizer"
self.tokenizer: PreTrainedTokenizerBase = pipeline.tokenizer
self._lock = asyncio.Lock()
# Validate the configured parser at startup so misconfiguration surfaces
# before the first request rather than mid-generation.
get_parser(self.config.tool_call_parser)

async def warmup(self) -> None:
logger.info("Warming up chat model: %s", self.model_name)
await self.run_in_executor(
self._run,
[{"role": "user", "content": "warmup"}],
None,
1,
)
logger.info("Warmup chat done for %s", self.model_name)
Expand All @@ -74,34 +80,41 @@ async def create_chat_completion(
logger.warning("chat request %s rejected: %s", request_id, e)
return create_error_response(e)

tools = resolve_tools_for_request(request.tools, request.tool_choice)

max_tokens = request.max_tokens
if max_tokens is None and request.max_completion_tokens is not None:
max_tokens = request.max_completion_tokens

if request.stream:
include_usage = bool(request.stream_options and request.stream_options.include_usage)
return self._locked_stream(request_id, messages, max_tokens, include_usage=include_usage)
return self._locked_stream(request_id, messages, tools, max_tokens, include_usage=include_usage)

async with self._lock:
try:
result = await self.run_in_executor(self._run, messages, max_tokens)
result = await self.run_in_executor(self._run, messages, tools, max_tokens)
except Exception:
logger.exception("chat completion inference failed for %s", request_id)
return create_error_response("chat completion inference failed")

prompt_tokens = self._count_prompt_tokens(messages)
generated = result[0]["generated_text"]
completion_text = generated[-1]["content"] if isinstance(generated, list) else generated
prompt_tokens = self._count_prompt_tokens(messages, tools)
completion_text = self._extract_completion_text(result)
completion_tokens = len(self.tokenizer.encode(completion_text, add_special_tokens=False))
finish_reason = "length" if (max_tokens is not None and completion_tokens >= max_tokens) else "stop"

parsed = self._parse_tool_calls(completion_text) if tools else ParsedToolCalls(completion_text, [])
finish_reason = self._finish_reason(parsed, completion_tokens, max_tokens)

response = ChatCompletionResponse(
id=request_id,
model=self.model_name,
choices=[
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=completion_text),
message=ChatMessage(
role="assistant",
content=parsed.content,
tool_calls=parsed.tool_calls,
),
finish_reason=finish_reason,
)
],
Expand All @@ -114,54 +127,102 @@ async def create_chat_completion(
)
logger.log(
TRACE,
"chat response %s: text=%r, prompt_tokens=%d, completion_tokens=%d",
"chat response %s: text=%r, tool_calls=%d, prompt_tokens=%d, completion_tokens=%d",
request_id,
completion_text,
len(parsed.tool_calls),
prompt_tokens,
completion_tokens,
)
return response

def _count_prompt_tokens(self, messages: list[dict]) -> int:
def _parse_tool_calls(self, text: str) -> ParsedToolCalls:
return get_parser(self.config.tool_call_parser).parse(text)

@staticmethod
def _finish_reason(parsed: ParsedToolCalls, completion_tokens: int, max_tokens: int | None) -> str:
if parsed.has_tool_calls:
return "tool_calls"
if max_tokens is not None and completion_tokens >= max_tokens:
return "length"
return "stop"

@staticmethod
def _extract_completion_text(result: list) -> str:
generated = result[0]["generated_text"]
if isinstance(generated, list):
return generated[-1]["content"]
return generated

def _count_prompt_tokens(self, messages: list[dict], tools: list[dict[str, Any]] | None) -> int:
# apply_chat_template returns a string by default (character count!) — force tokenize=True.
token_ids = self.tokenizer.apply_chat_template(messages, tokenize=True)
kwargs: dict[str, Any] = {"tokenize": True}
if tools:
kwargs["tools"] = tools
token_ids = self.tokenizer.apply_chat_template(messages, **kwargs)
return len(token_ids)

def _run(self, messages: list[dict], max_tokens: int | None) -> list:
def _render_prompt(self, messages: list[dict], tools: list[dict[str, Any]]) -> str:
rendered = self.tokenizer.apply_chat_template(
messages,
tools=tools, # type: ignore[arg-type]
tokenize=False,
add_generation_prompt=True,
)
assert isinstance(rendered, str), "apply_chat_template(tokenize=False) must return str"
return rendered

def _run(self, messages: list[dict], tools: list[dict[str, Any]] | None, max_tokens: int | None) -> list:
kwargs = {**self.config.pipeline_kwargs}
if max_tokens is not None:
kwargs["max_new_tokens"] = max_tokens
if tools:
# The standard text-generation pipeline does not forward `tools` to
# `apply_chat_template`, so we render the prompt ourselves and feed
# it as a plain string.
prompt = self._render_prompt(messages, tools)
return self.pipeline(prompt, return_full_text=False, **kwargs) # type: ignore[return-value]
return self.pipeline(messages, return_full_text=False, **kwargs) # type: ignore[return-value]

async def _locked_stream(
self,
request_id: str,
messages: list[dict],
tools: list[dict[str, Any]] | None,
max_tokens: int | None,
*,
include_usage: bool,
) -> AsyncGenerator[str, None]:
async with self._lock:
async for chunk in self._stream(request_id, messages, max_tokens, include_usage=include_usage):
async for chunk in self._stream(request_id, messages, tools, max_tokens, include_usage=include_usage):
yield chunk

async def _stream(
self,
request_id: str,
messages: list[dict],
tools: list[dict[str, Any]] | None,
max_tokens: int | None,
*,
include_usage: bool,
) -> AsyncGenerator[str, None]:
created_time = int(time.time())
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) # type: ignore[arg-type]

kwargs = {**self.config.pipeline_kwargs}
if max_tokens is not None:
kwargs["max_new_tokens"] = max_tokens

if tools:
prompt: Any = self._render_prompt(messages, tools)
tool_call_streamer: ToolCallStreamer | None = ToolCallStreamer(get_parser(self.config.tool_call_parser))
else:
prompt = messages
tool_call_streamer = None

thread = Thread(
target=self.pipeline,
args=(messages,),
args=(prompt,),
kwargs={
"streamer": streamer,
"return_full_text": False,
Expand All @@ -173,42 +234,35 @@ async def _stream(
accumulated: list[str] = []
try:
# Per OpenAI spec, the first delta carries `role` only.
yield self._encode_chunk(
ChatCompletionStreamResponse(
id=request_id,
model=self.model_name,
choices=[
ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
)
],
created=int(time.time()),
)
)
yield self._delta_chunk(request_id, DeltaMessage(role="assistant"), created_time)

for text_chunk in streamer:
if not text_chunk:
continue
accumulated.append(text_chunk)
yield self._encode_chunk(
ChatCompletionStreamResponse(
id=request_id,
model=self.model_name,
choices=[
ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(content=text_chunk),
)
],
created=int(time.time()),
)
)
if tool_call_streamer is None:
yield self._delta_chunk(request_id, DeltaMessage(content=text_chunk), created_time)
await asyncio.sleep(0)
continue
# Re-parse the cumulative text and emit any new content / tool-call
# fragments. Bounded held-back tail (length of the start marker) so
# the client never sees a half-formed `<tool_call>` opening tag.
delta = tool_call_streamer.extract_streaming("".join(accumulated))
Comment thread
alez007 marked this conversation as resolved.
Outdated
if delta is not None:
yield self._delta_chunk(request_id, delta, created_time)
await asyncio.sleep(0)

if tool_call_streamer is not None:
final = tool_call_streamer.finalize()
if final is not None:
yield self._delta_chunk(request_id, final, created_time)
parsed = tool_call_streamer.result
else:
parsed = ParsedToolCalls("".join(accumulated), [])

completion_text = "".join(accumulated)
completion_tokens = len(self.tokenizer.encode(completion_text, add_special_tokens=False))
finish_reason = "length" if (max_tokens is not None and completion_tokens >= max_tokens) else "stop"
finish_reason = self._finish_reason(parsed, completion_tokens, max_tokens)

yield self._encode_chunk(
ChatCompletionStreamResponse(
Expand All @@ -221,12 +275,12 @@ async def _stream(
finish_reason=finish_reason,
)
],
created=int(time.time()),
created=created_time,
)
)

if include_usage:
prompt_tokens = self._count_prompt_tokens(messages)
prompt_tokens = self._count_prompt_tokens(messages, tools)
yield self._encode_chunk(
ChatCompletionStreamResponse(
id=request_id,
Expand All @@ -237,14 +291,26 @@ async def _stream(
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
created=int(time.time()),
created=created_time,
)
)

yield "data: [DONE]\n\n"
finally:
thread.join()

def _delta_chunk(self, request_id: str, delta: DeltaMessage, created_time: int) -> str:
return self._encode_chunk(
ChatCompletionStreamResponse(
id=request_id,
model=self.model_name,
choices=[
ChatCompletionResponseStreamChoice(index=0, delta=delta),
],
created=created_time,
)
)

@staticmethod
def _encode_chunk(chunk: ChatCompletionStreamResponse) -> str:
return f"data: {json.dumps(chunk.model_dump(mode='json'))}\n\n"
2 changes: 1 addition & 1 deletion modelship/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
temperature: float | None = None
top_p: float | None = None
tools: list[dict[str, Any]] | None = None
tool_choice: str | dict[str, Any] | None = "none"
tool_choice: str | dict[str, Any] | None = None
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None
parallel_tool_calls: bool | None = True
user: str | None = None
Expand Down
22 changes: 22 additions & 0 deletions modelship/openai/tool_calling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Cross-loader tool-calling toolkit.

Loaders without native tool-call support (Transformers today, plugin-wrapped
raw-text engines tomorrow) use the parsers and helpers in this package to
turn raw model output into OpenAI-shape ``tool_calls``. Loaders whose engines
already emit structured tool calls (vLLM, llama.cpp via a function-calling
chat handler) bypass it.
"""

from modelship.openai.tool_calling.input import resolve_tools_for_request
from modelship.openai.tool_calling.parsers import ParsedToolCalls, ToolCallParser, ToolCallStreamer
from modelship.openai.tool_calling.registry import available_parsers, get_parser, register_parser

__all__ = [
"ParsedToolCalls",
"ToolCallParser",
"ToolCallStreamer",
"available_parsers",
"get_parser",
"register_parser",
"resolve_tools_for_request",
]
Loading
Loading