Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 162 additions & 0 deletions mlx_lm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import warnings
from collections import deque
from dataclasses import dataclass, replace
import threading
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from queue import Empty as QueueEmpty
Expand Down Expand Up @@ -1098,10 +1099,25 @@ def do_OPTIONS(self):
self._set_completion_headers(204)
self.end_headers()

# Embedding route (when --embedding-model was passed at startup).
# Set once in main() before the server starts; per-request handler
# threads read the resolved model + tokenizer, serialised by
# _embed_lock since MLX inference is not thread-safe (see
# ml-explore/mlx#3078).
embedding_model = None
embedding_tokenizer = None
embedding_model_id = None
_embed_model_path = None
_embed_lock = threading.Lock()
_embed_load_lock = threading.Lock()

def do_POST(self):
"""
Respond to a POST request from a client.
"""
if self.path == "/v1/embeddings":
return self._handle_embeddings()

request_factories = {
"/v1/completions": self.handle_text_completions,
"/v1/chat/completions": self.handle_chat_completions,
Expand Down Expand Up @@ -1617,6 +1633,129 @@ def handle_text_completions(self) -> CompletionRequest:
None,
)

def _handle_embeddings(self):
"""Handle POST /v1/embeddings via mlx_embeddings.

Uses a per-class lock around the forward pass because MLX
inference is not thread-safe under the default stream
(ml-explore/mlx#3078). The lock cost is amortised by clients
sending batched `input` arrays.

Lazy-loads on first request so the ~1-3s ONNX/MLX session
init is only paid by users who actually call the route.
"""
# Lazy load: first request triggers the actual model load.
if self.embedding_model is None:
if self._embed_model_path is None:
self._set_completion_headers(404)
self.end_headers()
self.wfile.write(
b'{"error": "No embedding model loaded - pass --embedding-model at server start."}'
)
return
with self._embed_load_lock:
if APIHandler.embedding_model is None:
try:
from mlx_embeddings import load as embed_load
except ImportError:
self._set_completion_headers(500)
self.end_headers()
self.wfile.write(
b'{"error": "Embedding support requires `pip install mlx-embeddings`."}'
)
return
logging.info(
f"Lazy-loading embedding model: {self._embed_model_path}"
)
m, t = embed_load(self._embed_model_path)
APIHandler.embedding_model = m
APIHandler.embedding_tokenizer = t
APIHandler.embedding_model_id = self._embed_model_path
logging.info("Embedding model loaded.")

# Parse + validate body.
content_length = int(self.headers.get("Content-Length", "0"))
if content_length <= 0:
self._set_completion_headers(411)
self.end_headers()
self.wfile.write(b'{"error": "Content-Length required."}')
return
try:
body = json.loads(self.rfile.read(content_length).decode())
except json.JSONDecodeError as err:
self._set_completion_headers(400)
self.end_headers()
self.wfile.write(json.dumps({"error": f"Invalid JSON: {err}"}).encode())
return
if not isinstance(body, dict):
self._set_completion_headers(400)
self.end_headers()
self.wfile.write(b'{"error": "Request body must be a JSON object."}')
return

raw = body.get("input")
if isinstance(raw, str):
texts = [raw]
elif isinstance(raw, list) and all(isinstance(x, str) for x in raw):
texts = list(raw)
else:
self._set_completion_headers(400)
self.end_headers()
self.wfile.write(b'{"error": "input must be a string or list of strings."}')
return
if not texts:
self._set_completion_headers(400)
self.end_headers()
self.wfile.write(b'{"error": "input must be non-empty."}')
return

# Inference. Loop per-text and serialise via the lock.
# Pooled-vector priority: text_embeds (bi-encoder MLX models) →
# pooler_output (BERT-family [CLS]) → mean-pooled
# last_hidden_state. Explicit `is None` checks because
# mx.array doesn't implement scalar truthiness.
try:
vectors = []
with self._embed_lock:
for text in texts:
inputs = self.embedding_tokenizer.encode(
text, return_tensors="mlx"
)
output = self.embedding_model(inputs)
vec = getattr(output, "text_embeds", None)
if vec is None:
vec = getattr(output, "pooler_output", None)
if vec is None:
vec = output.last_hidden_state.mean(axis=1)
mx.eval(vec)
vectors.append(vec.flatten().tolist())
except Exception as err: # noqa: BLE001
logging.exception("embedding inference failed")
self._set_completion_headers(500)
self.end_headers()
self.wfile.write(json.dumps({"error": str(err)}).encode())
return

# OpenAI-compat usage stats. chars/4 is the existing
# convention used elsewhere in this server for cheap accounting.
prompt_tokens = sum(max(1, len(t) // 4) for t in texts)
response = {
"object": "list",
"data": [
{"object": "embedding", "embedding": v, "index": i}
for i, v in enumerate(vectors)
],
# Always echo the loaded model name; ignore any client-
# supplied `model` field (mirrors OpenAI behavior and
# avoids reflecting arbitrary client strings).
"model": self.embedding_model_id or "embedding-model",
"usage": {"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens},
}
self._set_completion_headers(200)
self.end_headers()
self.wfile.write(json.dumps(response).encode())
self.wfile.flush()

def do_GET(self):
"""
Respond to a GET request from a client.
Expand Down Expand Up @@ -1884,6 +2023,18 @@ def main():
action="store_true",
help="Use pipelining instead of tensor parallelism",
)
parser.add_argument(
"--embedding-model",
type=str,
default=None,
help=(
"Optional embedding model HF repo or local path. When set, "
"the server exposes a POST /v1/embeddings endpoint backed by "
"the `mlx_embeddings` package (must be installed separately). "
"Lazy-loaded on first request; not loaded when the flag is "
"omitted, so chat-only deployments pay no cost."
),
)
args = parser.parse_args()
if mx.metal.is_available():
wired_limit = mx.device_info()["max_recommended_working_set_size"]
Expand All @@ -1893,6 +2044,17 @@ def main():
level=getattr(logging, args.log_level.upper(), None),
format="%(asctime)s - %(levelname)s - %(message)s",
)

if args.embedding_model:
# Defer the actual load to the first /v1/embeddings request
# (see APIHandler._handle_embeddings) so chat-only sessions
# don't pay the load cost and don't contend GPU memory until
# an embed request actually arrives.
APIHandler._embed_model_path = args.embedding_model
logging.info(
f"Embedding model registered for lazy load: {args.embedding_model}"
)

run(args.host, args.port, ModelProvider(args))


Expand Down