From 82e71bf1b56104efa9fcae7d73630872357676ca Mon Sep 17 00:00:00 2001 From: andreinknv Date: Sat, 9 May 2026 15:38:26 -0400 Subject: [PATCH] feat(server): add /v1/embeddings route via mlx_embeddings Adds an optional POST /v1/embeddings route to mlx_lm.server backed by the mlx_embeddings package. Enables a single mlx_lm.server process to serve both OpenAI-compatible chat AND embeddings. When --embedding-model is omitted, server behavior is unchanged. Chat code path is not modified, so chat throughput is identical before/after. Embedding model is lazy-loaded on first request. Inference is serialised under a class-level lock because MLX inference is not thread-safe under the default stream (see ml-explore/mlx#3078). 162 lines added, 0 deletions, all in mlx_lm/server.py. --- mlx_lm/server.py | 162 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 162 insertions(+) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index ce8d95817..f2dcdffa4 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -11,6 +11,7 @@ import warnings from collections import deque from dataclasses import dataclass, replace +import threading from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from pathlib import Path from queue import Empty as QueueEmpty @@ -1098,10 +1099,25 @@ def do_OPTIONS(self): self._set_completion_headers(204) self.end_headers() + # Embedding route (when --embedding-model was passed at startup). + # Set once in main() before the server starts; per-request handler + # threads read the resolved model + tokenizer, serialised by + # _embed_lock since MLX inference is not thread-safe (see + # ml-explore/mlx#3078). + embedding_model = None + embedding_tokenizer = None + embedding_model_id = None + _embed_model_path = None + _embed_lock = threading.Lock() + _embed_load_lock = threading.Lock() + def do_POST(self): """ Respond to a POST request from a client. """ + if self.path == "/v1/embeddings": + return self._handle_embeddings() + request_factories = { "/v1/completions": self.handle_text_completions, "/v1/chat/completions": self.handle_chat_completions, @@ -1617,6 +1633,129 @@ def handle_text_completions(self) -> CompletionRequest: None, ) + def _handle_embeddings(self): + """Handle POST /v1/embeddings via mlx_embeddings. + + Uses a per-class lock around the forward pass because MLX + inference is not thread-safe under the default stream + (ml-explore/mlx#3078). The lock cost is amortised by clients + sending batched `input` arrays. + + Lazy-loads on first request so the ~1-3s ONNX/MLX session + init is only paid by users who actually call the route. + """ + # Lazy load: first request triggers the actual model load. + if self.embedding_model is None: + if self._embed_model_path is None: + self._set_completion_headers(404) + self.end_headers() + self.wfile.write( + b'{"error": "No embedding model loaded - pass --embedding-model at server start."}' + ) + return + with self._embed_load_lock: + if APIHandler.embedding_model is None: + try: + from mlx_embeddings import load as embed_load + except ImportError: + self._set_completion_headers(500) + self.end_headers() + self.wfile.write( + b'{"error": "Embedding support requires `pip install mlx-embeddings`."}' + ) + return + logging.info( + f"Lazy-loading embedding model: {self._embed_model_path}" + ) + m, t = embed_load(self._embed_model_path) + APIHandler.embedding_model = m + APIHandler.embedding_tokenizer = t + APIHandler.embedding_model_id = self._embed_model_path + logging.info("Embedding model loaded.") + + # Parse + validate body. + content_length = int(self.headers.get("Content-Length", "0")) + if content_length <= 0: + self._set_completion_headers(411) + self.end_headers() + self.wfile.write(b'{"error": "Content-Length required."}') + return + try: + body = json.loads(self.rfile.read(content_length).decode()) + except json.JSONDecodeError as err: + self._set_completion_headers(400) + self.end_headers() + self.wfile.write(json.dumps({"error": f"Invalid JSON: {err}"}).encode()) + return + if not isinstance(body, dict): + self._set_completion_headers(400) + self.end_headers() + self.wfile.write(b'{"error": "Request body must be a JSON object."}') + return + + raw = body.get("input") + if isinstance(raw, str): + texts = [raw] + elif isinstance(raw, list) and all(isinstance(x, str) for x in raw): + texts = list(raw) + else: + self._set_completion_headers(400) + self.end_headers() + self.wfile.write(b'{"error": "input must be a string or list of strings."}') + return + if not texts: + self._set_completion_headers(400) + self.end_headers() + self.wfile.write(b'{"error": "input must be non-empty."}') + return + + # Inference. Loop per-text and serialise via the lock. + # Pooled-vector priority: text_embeds (bi-encoder MLX models) → + # pooler_output (BERT-family [CLS]) → mean-pooled + # last_hidden_state. Explicit `is None` checks because + # mx.array doesn't implement scalar truthiness. + try: + vectors = [] + with self._embed_lock: + for text in texts: + inputs = self.embedding_tokenizer.encode( + text, return_tensors="mlx" + ) + output = self.embedding_model(inputs) + vec = getattr(output, "text_embeds", None) + if vec is None: + vec = getattr(output, "pooler_output", None) + if vec is None: + vec = output.last_hidden_state.mean(axis=1) + mx.eval(vec) + vectors.append(vec.flatten().tolist()) + except Exception as err: # noqa: BLE001 + logging.exception("embedding inference failed") + self._set_completion_headers(500) + self.end_headers() + self.wfile.write(json.dumps({"error": str(err)}).encode()) + return + + # OpenAI-compat usage stats. chars/4 is the existing + # convention used elsewhere in this server for cheap accounting. + prompt_tokens = sum(max(1, len(t) // 4) for t in texts) + response = { + "object": "list", + "data": [ + {"object": "embedding", "embedding": v, "index": i} + for i, v in enumerate(vectors) + ], + # Always echo the loaded model name; ignore any client- + # supplied `model` field (mirrors OpenAI behavior and + # avoids reflecting arbitrary client strings). + "model": self.embedding_model_id or "embedding-model", + "usage": {"prompt_tokens": prompt_tokens, "total_tokens": prompt_tokens}, + } + self._set_completion_headers(200) + self.end_headers() + self.wfile.write(json.dumps(response).encode()) + self.wfile.flush() + def do_GET(self): """ Respond to a GET request from a client. @@ -1884,6 +2023,18 @@ def main(): action="store_true", help="Use pipelining instead of tensor parallelism", ) + parser.add_argument( + "--embedding-model", + type=str, + default=None, + help=( + "Optional embedding model HF repo or local path. When set, " + "the server exposes a POST /v1/embeddings endpoint backed by " + "the `mlx_embeddings` package (must be installed separately). " + "Lazy-loaded on first request; not loaded when the flag is " + "omitted, so chat-only deployments pay no cost." + ), + ) args = parser.parse_args() if mx.metal.is_available(): wired_limit = mx.device_info()["max_recommended_working_set_size"] @@ -1893,6 +2044,17 @@ def main(): level=getattr(logging, args.log_level.upper(), None), format="%(asctime)s - %(levelname)s - %(message)s", ) + + if args.embedding_model: + # Defer the actual load to the first /v1/embeddings request + # (see APIHandler._handle_embeddings) so chat-only sessions + # don't pay the load cost and don't contend GPU memory until + # an embed request actually arrives. + APIHandler._embed_model_path = args.embedding_model + logging.info( + f"Embedding model registered for lazy load: {args.embedding_model}" + ) + run(args.host, args.port, ModelProvider(args))