diff --git a/Cargo.lock b/Cargo.lock index 3acda9a73e6a..663e2c024271 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2433,6 +2433,7 @@ dependencies = [ "dynamo-mocker", "dynamo-parsers", "dynamo-protocols", + "dynamo-rl", "dynamo-runtime", "dynamo-tokenizers", "dynamo-tokens", @@ -2606,6 +2607,20 @@ dependencies = [ "uuid", ] +[[package]] +name = "dynamo-rl" +version = "1.2.0" +dependencies = [ + "anyhow", + "axum 0.8.4", + "dynamo-runtime", + "futures", + "serde", + "serde_json", + "tokio", + "tracing", +] + [[package]] name = "dynamo-runtime" version = "1.2.0" diff --git a/Cargo.toml b/Cargo.toml index d7d2f9fd02f6..6cc5e3d7a270 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ members = [ "lib/backend-common/examples/mocker", "lib/bindings/c", "lib/bindings/python/codegen", + "lib/rl", ] resolver = "3" @@ -41,6 +42,7 @@ keywords = ["llm", "genai", "inference", "nvidia", "distributed"] # Local crates dynamo-runtime = { path = "lib/runtime", version = "1.2.0" } dynamo-llm = { path = "lib/llm", version = "1.2.0" } +dynamo-rl = { path = "lib/rl", version = "1.2.0" } dynamo-config = { path = "lib/config", version = "1.2.0" } dynamo-tokenizers = { path = "lib/tokenizers", version = "1.2.0" } dynamo-tokens = { path = "lib/tokens", version = "1.2.0" } diff --git a/components/src/dynamo/frontend/frontend_args.py b/components/src/dynamo/frontend/frontend_args.py index 2040a19ec874..5517386e1eb3 100644 --- a/components/src/dynamo/frontend/frontend_args.py +++ b/components/src/dynamo/frontend/frontend_args.py @@ -56,6 +56,7 @@ class FrontendConfig(RouterConfigBase, KvRouterConfigBase, AicPerfConfigBase): kv_cache_block_size: Optional[int] http_host: str http_port: int + rl_port: int tls_cert_path: Optional[pathlib.Path] tls_key_path: Optional[pathlib.Path] @@ -97,6 +98,8 @@ def validate(self) -> None: raise ValueError( f"--migration-limit must be between 0 and {_U32_MAX} (0=disabled)" ) + if self.rl_port < 0 or self.rl_port > 65535: + raise ValueError("--rl-port must be between 0 and 65535") if self.migration_max_seq_len is not None and ( self.migration_max_seq_len < 1 or self.migration_max_seq_len > _U32_MAX ): @@ -208,6 +211,14 @@ def add_arguments(self, parser) -> None: help="HTTP port for the engine (u16).", arg_type=int, ) + add_argument( + g, + flag_name="--rl-port", + env_var="DYN_RL_PORT", + default=8002, + help="Dedicated HTTP port for RL admin endpoints (u16).", + arg_type=int, + ) add_negatable_bool_argument( g, flag_name="--serve-indexer", diff --git a/components/src/dynamo/frontend/main.py b/components/src/dynamo/frontend/main.py index 361d16ffd25c..a54845d4a558 100644 --- a/components/src/dynamo/frontend/main.py +++ b/components/src/dynamo/frontend/main.py @@ -237,6 +237,7 @@ def signal_handler(): kwargs: dict[str, Any] = { "http_host": config.http_host, "http_port": config.http_port, + "rl_port": config.rl_port, "kv_cache_block_size": config.kv_cache_block_size, "router_config": router_config, "migration_limit": config.migration_limit, diff --git a/components/src/dynamo/frontend/vllm_processor.py b/components/src/dynamo/frontend/vllm_processor.py index 8962bfd02ec5..df29f1bd6dd6 100644 --- a/components/src/dynamo/frontend/vllm_processor.py +++ b/components/src/dynamo/frontend/vllm_processor.py @@ -633,6 +633,40 @@ async def _generate_and_stream( break choice = post.process_output(output) if choice: + # ── RL logprobs injection ────────────────────── + # The vLLM worker sends log_probs/top_logprobs in + # the engine_response dict. Since we can't easily + # construct LogprobsLists for EngineCoreOutput, we + # inject them directly into the choice here. + worker_log_probs = engine_response.get("log_probs") + worker_top_logprobs = engine_response.get("top_logprobs") + if worker_log_probs is not None and choice.get("logprobs") is None: + oai_logprobs_content = [] + new_tids = engine_response.get("token_ids", []) + for i, lp in enumerate(worker_log_probs): + # Always populate token/bytes so consumers never see a + # missing key. If top_logprobs is absent or the token + # string cannot be resolved we fall back to the numeric + # ID as a string — better than a KeyError / silent None. + tid_str = str(new_tids[i]) if i < len(new_tids) else "" + entry: dict = { + "logprob": lp, + "token": tid_str, + "bytes": None, + } + # Resolve the human-readable token string and top_logprobs + # from the engine's top_logprobs table when available. + if worker_top_logprobs and i < len(worker_top_logprobs): + tops = worker_top_logprobs[i] + entry["top_logprobs"] = tops + if i < len(new_tids): + for tp in tops: + if tp.get("token_id") == new_tids[i]: + entry["token"] = tp.get("token", tid_str) + break + oai_logprobs_content.append(entry) + choice["logprobs"] = {"content": oai_logprobs_content} + choices.append(choice) if choices: @@ -646,6 +680,11 @@ async def _generate_and_stream( if usage := engine_response.get("completion_usage"): dynamo_out["usage"] = usage + # ── RL: pass output token IDs for nvext.completion_token_ids ── + new_token_ids = engine_response.get("token_ids", []) + if new_token_ids: + dynamo_out["_completion_token_ids"] = new_token_ids + yield dynamo_out _nvtx.end_range(rng_stream) except Exception as e: diff --git a/components/src/dynamo/vllm/handlers.py b/components/src/dynamo/vllm/handlers.py index 46964076b8c8..37c8a42c7706 100644 --- a/components/src/dynamo/vllm/handlers.py +++ b/components/src/dynamo/vllm/handlers.py @@ -13,7 +13,7 @@ from abc import ABC, abstractmethod from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import Any, AsyncIterator, Dict, Final, Generic, Optional, TypeVar +from typing import Any, AsyncIterator, Dict, Final, Generic, NoReturn, Optional, TypeVar import torch from vllm.config import ModelConfig, VllmConfig @@ -568,6 +568,15 @@ def __init__( # Store shutdown event for graceful shutdown monitoring self.shutdown_event = shutdown_event + def _shutdown_on_engine_dead(self, e: EngineDeadError) -> NoReturn: + """Common handler for `EngineDeadError`: log, shut down the runtime, + hard-exit. Called from RL admin handler `except` clauses so a dead + engine surfaces as a worker restart instead of silent failure.""" + logger.error(f"vLLM EngineDeadError: {e}") + logger.warning("Initiating Dynamo Runtime shutdown.") + self.runtime.shutdown() + os._exit(1) + def init_embedding_loader( self, config: Config, encode_worker_client: Optional[Client] = None ) -> Optional[MultiModalEmbeddingLoader]: @@ -825,6 +834,678 @@ async def stop_profile(self, body: dict) -> dict: logger.error(f"Failed to stop profiling: {e}") return {"status": "error", "message": str(e)} + # ── RL weight lifecycle engine routes ────────────────────────────── + + async def pause_generation(self, body: dict) -> dict: + """Pause the engine: drain in-flight requests, keep model loaded. + + Called by the RL admin coordinator before weight updates. + Uses engine_client.pause_generation() directly -- does NOT sleep + (no GPU memory release) and does NOT unregister from discovery. + + Body (all optional): + - mode: "keep" | "wait" | "abort" (default "keep" — drain in-flight) + - clear_cache: bool (default False) + """ + body = body or {} + mode = body.get("mode", "keep") + clear_cache = bool(body.get("clear_cache", False)) + if mode not in ("keep", "wait", "abort"): + return { + "status": "error", + "message": f"Invalid mode '{mode}'; expected one of keep|wait|abort", + } + try: + await self.engine_client.pause_generation() + # mode=abort → also abort in-flight requests via vLLM's request abort + if mode == "abort": + try: + # Best-effort abort of all in-flight requests. + # vLLM exposes per-request abort; we don't track ids here so + # rely on engine internals to drain the rest under pause. + await self.engine_client.collective_rpc( + "abort_all_requests", kwargs={} + ) + except Exception as abort_err: + logger.warning( + f"[RL] mode=abort: collective_rpc(abort_all_requests) " + f"unavailable on this engine version: {abort_err}; " + f"in-flight requests will drain naturally" + ) + if clear_cache: + try: + await self.engine_client.reset_prefix_cache() + logger.debug("[RL] pause: prefix cache cleared") + except Exception as flush_err: + logger.warning( + f"[RL] pause: clear_cache requested but reset_prefix_cache failed: {flush_err}" + ) + self._paused = True + logger.debug( + f"[RL] Engine paused (generation quiesced, mode={mode}, clear_cache={clear_cache})" + ) + return { + "status": "ok", + "message": "Engine paused", + "mode": mode, + "clear_cache": clear_cache, + } + except EngineDeadError as e: + self._shutdown_on_engine_dead(e) + except Exception as e: + logger.error(f"[RL] Failed to pause: {e}") + return {"status": "error", "message": str(e)} + + async def resume_generation(self, body: dict) -> dict: + """Resume the engine after a weight update.""" + body = body or {} + try: + await self.engine_client.resume_generation() + self._paused = False + logger.info("[RL] Engine resumed") + return {"status": "ok", "message": "Engine resumed"} + except EngineDeadError as e: + self._shutdown_on_engine_dead(e) + except Exception as e: + logger.error(f"[RL] Failed to resume: {e}") + return {"status": "error", "message": str(e)} + + async def liveness_probe(self, body: dict) -> dict: + """Engine event-loop probe — confirms the engine is responsive. + + Used by ``GET /v1/rl/liveness``. The Rust frontend fans this out with a + short timeout (default 5s). Returning ``alive: True`` requires the + engine_client IPC roundtrip to complete: a hung event loop, deadlocked + worker, or wedged engine will time out at the frontend instead of + returning a stale ``OK`` (which is what the legacy ``/v1/rl/health`` + does — that endpoint is just a frontend-process check). + """ + body = body or {} + try: + # vLLM's AsyncLLM/AsyncEngineClient exposes check_health() as the + # canonical liveness probe. It does a lightweight collective RPC + # to all engine workers and raises if any are unresponsive. + if hasattr(self.engine_client, "check_health"): + await self.engine_client.check_health() + return {"status": "ok", "alive": True} + # Fallback for engines without check_health: a no-op collective_rpc. + # The RPC round-trip itself is the liveness signal — if the engine + # event loop is wedged the frontend's 5s timeout fires. + await self.engine_client.collective_rpc("get_weight_version", kwargs={}) + return {"status": "ok", "alive": True} + except EngineDeadError as e: + self._shutdown_on_engine_dead(e) + except Exception as e: + logger.warning(f"[RL] liveness_probe failed: {e}") + return {"status": "error", "alive": False, "message": str(e)} + + async def get_state(self, body: dict) -> dict: + """Composite per-worker state snapshot for ``GET /v1/rl/state``. + + The Rust frontend aggregates these per-worker payloads into the + fleet-wide ``RlStateResponse`` — a single composite that replaces the + separate ``/v1/rl/health`` + ``/v1/rl/ready`` + ``/v1/rl/weight_version`` + endpoints with one RL-scoped readiness call. + """ + body = body or {} + try: + engine_alive = True + try: + if hasattr(self.engine_client, "check_health"): + await self.engine_client.check_health() + else: + # Same fallback as liveness_probe: a no-op collective_rpc + # round-trip is the liveness signal when check_health is + # absent; otherwise older engines would always look alive. + await self.engine_client.collective_rpc( + "get_weight_version", kwargs={} + ) + except Exception as health_err: + engine_alive = False + logger.warning(f"[RL] get_state: engine_alive=false ({health_err})") + return { + "status": "ok", + "engine_alive": engine_alive, + "pause_state": "paused" if getattr(self, "_paused", False) else "running", + "applied_weight_version": getattr(self, "_weight_version", "initial"), + "loras": [ + {"name": name, "id": info.id, "path": info.path} + for name, info in getattr(self, "loaded_loras", {}).items() + ], + } + except EngineDeadError as e: + self._shutdown_on_engine_dead(e) + except Exception as e: + logger.error(f"[RL] get_state failed: {e}") + return {"status": "error", "message": str(e)} + + async def flush_cache(self, body: dict) -> dict: + """Invalidate prefix/KV cache. Called after weight updates.""" + body = body or {} + try: + await self.engine_client.reset_prefix_cache() + logger.debug("[RL] Prefix cache flushed") + return {"status": "ok", "message": "Cache flushed"} + except EngineDeadError as e: + self._shutdown_on_engine_dead(e) + except Exception as e: + logger.error(f"[RL] Failed to flush cache: {e}") + return {"status": "error", "message": str(e)} + + async def update_weights_from_path(self, body: dict) -> dict: + """Load weights from a filesystem path (safetensors/torch checkpoint). + + Expects body: {"path": "/path/to/weights", "version": "step_N"} + The caller is responsible for pausing/resuming around this call. + """ + body = body or {} + path = body.get("path") + version = body.get("version", "unknown") + if not path: + return {"status": "error", "message": "Missing 'path' in body"} + try: + # Use vLLM's built-in reload_weights via collective RPC. + # This calls Worker.reload_weights() -> GPUModelRunner.reload_weights() + # which handles loading safetensors from a directory using vLLM's + # model loader with proper layerwise reload. + await self.engine_client.collective_rpc( + "reload_weights", + kwargs={"weights_path": path}, + ) + self._weight_version = version + logger.info(f"[RL] Weights loaded from {path} (version={version})") + return { + "status": "ok", + "message": f"Weights loaded from {path}", + "version": version, + } + except EngineDeadError as e: + self._shutdown_on_engine_dead(e) + except Exception as e: + logger.error(f"[RL] Failed to load weights from {path}: {e}") + return {"status": "error", "message": str(e)} + + async def get_weight_version(self, body: dict) -> dict: + """Return the current weight version tag.""" + return {"version": getattr(self, "_weight_version", "initial")} + + async def load_lora_adapter(self, body: dict) -> dict: + """Load (or hot-swap) a LoRA adapter from a filesystem path. + + Expects body: {"lora_name": str, "lora_path": "/path/to/adapter_dir"} + + The adapter directory must contain ``adapter_model.safetensors`` and + ``adapter_config.json`` -- the standard PEFT output layout that Prime-RL + writes each training step. + + Unlike :meth:`load_lora` (which downloads from a URI via ``LoRAManager`` + streaming a gRPC response), this method is the RL admin equivalent used + for training-loop weight updates: + + * Reads the adapter directly from the given filesystem path (no URI / + no network fetch, no LoRAManager needed). + * Hot-swaps if ``lora_name`` is already loaded (remove old id then + re-add) so every training step replaces the same logical adapter. + * Resets the prefix cache after a hot-swap so stale KV entries keyed + to the previous adapter weights do not poison subsequent rollouts. + * Publishes a ModelDeploymentCard the first time a new ``lora_name`` is + loaded. Prime-RL switches its request ``model`` field to the LoRA + name after load (``scheduler.py``: ``self.model_name = self.lora_name``) + so the frontend needs an MDC entry to route ``r16-a32`` → this worker. + On subsequent hot-swaps the MDC is already published and we skip + re-registration. + """ + body = body or {} + lora_name = body.get("lora_name") + lora_path = body.get("lora_path") + if not lora_name: + return {"status": "error", "message": "Missing 'lora_name' in body"} + if not lora_path: + return {"status": "error", "message": "Missing 'lora_path' in body"} + try: + lock = self._get_lora_lock(lora_name) + async with lock: + lora_id = lora_name_to_id(lora_name) + is_hot_swap = lora_name in self.loaded_loras + + # Hot-swap: vLLM's add_lora is a no-op when the lora_int_id is + # already registered, so we must remove the previous adapter + # first. remove_lora is best-effort on a fresh add. + if is_hot_swap: + old_id = self.loaded_loras[lora_name].id + try: + await self.engine_client.remove_lora(old_id) + # Invalidate the cache entry immediately after remove succeeds. + # If add_lora below fails, this prevents a stale entry pointing + # at an adapter the engine no longer holds from poisoning future + # rollouts with wrong importance ratios. + self.loaded_loras.pop(lora_name, None) + except Exception as e: + # remove_lora failure during hot-swap is non-recoverable + # for this request: add_lora below would no-op against + # the still-registered ID. Surface as error so the + # caller doesn't think the swap succeeded. + logger.error( + f"[RL] remove_lora({lora_name}, id={old_id}) failed during hot-swap: {e}" + ) + return { + "status": "error", + "message": ( + f"Failed to remove existing LoRA '{lora_name}' " + f"before hot-swap: {e}" + ), + "lora_name": lora_name, + } + + await self.engine_client.add_lora( + LoRARequest( + lora_name=lora_name, + lora_int_id=lora_id, + lora_path=lora_path, + ) + ) + self.loaded_loras[lora_name] = LoRAInfo(id=lora_id, path=lora_path) + + # Invalidate KV cache on hot-swap so stale prefix entries keyed + # to the previous LoRA weights can't contaminate new rollouts. + if is_hot_swap: + try: + await self.engine_client.reset_prefix_cache() + except Exception as e: + # A failed cache reset means subsequent requests sharing + # a prefix with an old rollout can reuse KV state + # computed under the previous adapter — silent logprobs + # mismatch. Surface as an error so the caller doesn't + # treat the swap as safe to serve. + logger.error( + f"[RL] reset_prefix_cache after LoRA swap failed: {e}" + ) + return { + "status": "error", + "message": ( + f"LoRA '{lora_name}' was loaded but prefix cache " + f"reset failed; worker is not safe to serve until " + f"the next successful swap." + ), + "lora_name": lora_name, + "lora_id": lora_id, + } + + # Publish an MDC for the LoRA on first load so Dynamo's frontend + # can route requests with model= to this worker. + # Mirror the logic in load_lora() (URI variant). Skip on hot-swap + # since the MDC was already published on the first load. + if not is_hot_swap and self.generate_endpoint is not None: + try: + runtime_config = ModelRuntimeConfig() + runtime_config.tool_call_parser = self.config.dyn_tool_call_parser + runtime_config.reasoning_parser = self.config.dyn_reasoning_parser + await register_model( + model_input=ModelInput.Tokens, + model_type=ModelType.Chat | ModelType.Completions, + endpoint=self.generate_endpoint, + model_path=self.config.model, + kv_cache_block_size=self.config.engine_args.block_size, + runtime_config=runtime_config, + user_data={"lora_adapter": True, "lora_id": lora_id}, + lora_name=lora_name, + base_model_path=self.config.model, + ) + logger.debug( + f"[RL] Published LoRA '{lora_name}' ModelDeploymentCard" + ) + except Exception as e: + # Rollback: remove the LoRA from the engine to keep state consistent. + logger.exception( + f"[RL] Failed to publish LoRA '{lora_name}' MDC: {e}; rolling back add_lora" + ) + try: + await self.engine_client.remove_lora(lora_id) + except Exception as rollback_err: + # The adapter is now leaked in the engine: it is registered but + # unreachable via loaded_loras (we pop it below). Log at ERROR + # so this doesn't go unnoticed in production. + logger.error( + f"[RL] Rollback remove_lora({lora_name}, id={lora_id}) failed " + f"— adapter is leaked in the engine: {rollback_err}" + ) + self.loaded_loras.pop(lora_name, None) + return { + "status": "error", + "message": f"Failed to register LoRA '{lora_name}' in discovery registry: {e}", + "lora_name": lora_name, + } + + logger.info( + f"[RL] LoRA adapter {'hot-swapped' if is_hot_swap else 'loaded'}: " + f"name={lora_name} id={lora_id} path={lora_path}" + ) + return { + "status": "ok", + "message": f"LoRA adapter '{lora_name}' loaded from {lora_path}", + "lora_name": lora_name, + "lora_id": lora_id, + "hot_swap": is_hot_swap, + } + except EngineDeadError as e: + self._shutdown_on_engine_dead(e) + except Exception as e: + logger.exception( + f"[RL] Failed to load LoRA adapter '{lora_name}' from {lora_path}: {e}" + ) + return {"status": "error", "message": str(e)} + + async def unload_lora_adapter(self, body: dict) -> dict: + """Unload a LoRA adapter previously loaded via :meth:`load_lora_adapter`. + + Expects body: {"lora_name": str} + + Idempotent: unloading an already-absent LoRA returns status=ok so + callers can safely retry without special-casing the not-found path. + """ + body = body or {} + lora_name = body.get("lora_name") + if not lora_name: + return {"status": "error", "message": "Missing 'lora_name' in body"} + try: + lock = self._get_lora_lock(lora_name) + async with lock: + lora = self.loaded_loras.get(lora_name) + if lora is None: + return { + "status": "ok", + "message": f"LoRA adapter '{lora_name}' not loaded (no-op)", + "lora_name": lora_name, + } + lora_id = lora.id + await self.engine_client.remove_lora(lora_id) + del self.loaded_loras[lora_name] + + # Unregister the MDC published on load so the frontend stops + # routing `model=` requests to this worker. If this + # fails the engine no longer has the adapter but the frontend + # still routes to us — `_resolve_lora_request` then falls back + # to the base model, silently changing semantics. Surface as + # an error so the caller can retry / drain explicitly. + if self.generate_endpoint is not None: + try: + await unregister_model( + endpoint=self.generate_endpoint, + lora_name=lora_name, + ) + except Exception as e: + logger.error( + f"[RL] Failed to unregister LoRA '{lora_name}' MDC after engine removal: {e}" + ) + return { + "status": "error", + "message": ( + f"LoRA '{lora_name}' removed from engine but " + f"discovery unregister failed; frontend may " + f"still route to this worker until retried: {e}" + ), + "lora_name": lora_name, + "lora_id": lora_id, + } + + logger.debug( + f"[RL] LoRA adapter unloaded: name={lora_name} id={lora_id}" + ) + return { + "status": "ok", + "message": f"LoRA adapter '{lora_name}' unloaded", + "lora_name": lora_name, + "lora_id": lora_id, + } + except EngineDeadError as e: + self._shutdown_on_engine_dead(e) + except Exception as e: + logger.exception( + f"[RL] Failed to unload LoRA adapter '{lora_name}': {e}" + ) + return {"status": "error", "message": str(e)} + + # ── WeightTransferConfig API (Phase 1+4) ─────────────────────────── + # + # New unified surface paired with the Rust frontend's + # ``/v1/rl/init_transport`` and the discriminated ``/v1/rl/update_weights`` + # body. Backwards-compatible: legacy ``update_weights_from_path`` / + # ``load_lora_adapter`` / ``unload_lora_adapter`` engine routes stay live + # for callers that haven't migrated yet. + + def _ensure_weight_transports(self): + """Lazy-init transport registry + vLLM engine adapter.""" + if getattr(self, "_weight_transports", None) is not None: + return + from .weight_transports import VllmEngineAdapter + + adapter = VllmEngineAdapter(self.engine_client) + adapter.bind_lora_helpers( + loader=self._lora_load_via_admin, + unloader=self._lora_unload_via_admin, + ) + self._weight_engine_adapter = adapter + self._weight_transports: dict = {} + + async def _lora_load_via_admin(self, *, name: str, path: str) -> dict: + """Re-use the existing :meth:`load_lora_adapter` path so MDC publish, + hot-swap detection, and prefix-cache reset all stay consistent.""" + return await self.load_lora_adapter( + {"lora_name": name, "lora_path": path} + ) + + async def _lora_unload_via_admin(self, *, name: str) -> dict: + return await self.unload_lora_adapter({"lora_name": name}) + + async def weight_transport_init(self, body: dict) -> dict: + """Idempotent transport setup. Backs ``POST /v1/rl/init_transport``. + + Body: + - transport_id: str (caller-chosen) + - backend: "filesystem" | "nccl" + - : {…} (backend-specific block) + """ + body = body or {} + backend = body.get("backend") + if backend not in ("filesystem", "nccl"): + return { + "status": "error", + "message": ( + f"Unsupported backend '{backend}'. In scope this iteration: " + "filesystem, nccl. Future (deferred): nixl, model_express, ipc." + ), + } + transport_id = body.get("transport_id", backend) + cfg = dict(body.get(backend) or {}) + cfg.setdefault("transport_id", transport_id) + + try: + self._ensure_weight_transports() + from .weight_transports import build_transport, InitCtx + + existing = self._weight_transports.get(transport_id) + if existing is not None and existing.backend_id == backend: + logger.info( + f"[RL] init_transport: '{transport_id}' already configured " + f"(backend={backend}); idempotent re-init" + ) + # Re-run init for idempotency (eg. NCCL group bootstrap). + ctx = InitCtx(rank=0, world_size=1, served_model_name="") + result = await existing.init(ctx, cfg) + return result.to_dict() + + transport = build_transport(backend, self._weight_engine_adapter, cfg) + ctx = InitCtx(rank=0, world_size=1, served_model_name="") + result = await transport.init(ctx, cfg) + self._weight_transports[transport_id] = transport + logger.info( + f"[RL] init_transport: backend={backend} transport_id={transport_id} " + f"ready={result.ready}" + ) + return result.to_dict() + except EngineDeadError as e: + self._shutdown_on_engine_dead(e) + except Exception as e: + logger.exception(f"[RL] init_transport failed: {e}") + return {"status": "error", "message": str(e)} + + async def weight_transport_update(self, body: dict) -> dict: + """Backs the new-shape ``POST /v1/rl/update_weights``. + + Body: + - version: str + - target: {"kind": "base"} | {"kind": "lora", "name": str, "op": …} + - transport: {"backend": "filesystem"|"nccl", : {…}} + """ + try: + from .weight_transports import UpdateWeightsRequest, build_transport, InitCtx + + req = UpdateWeightsRequest.from_dict(body or {}) + self._ensure_weight_transports() + backend = (req.transport or {}).get("backend") + + # LoRA unload may omit the transport block — synthesize filesystem. + if ( + req.target.kind == "lora" + and req.target.op == "unload" + and backend is None + ): + backend = "filesystem" + req.transport = {"backend": "filesystem", "filesystem": {}} + + if backend not in ("filesystem", "nccl"): + return { + "status": "error", + "message": ( + f"Unsupported backend '{backend}'. In scope this iteration: " + "filesystem, nccl." + ), + } + + # Resolve transport instance: prefer one bound by init_transport; + # for filesystem we lazily build per-call (no setup needed). + transport_id = (req.transport.get(backend) or {}).get( + "transport_id", backend + ) + transport = self._weight_transports.get(transport_id) + if transport is None: + if backend == "filesystem": + transport = build_transport( + backend, self._weight_engine_adapter, {"transport_id": transport_id} + ) + await transport.init( + InitCtx(rank=0, world_size=1, served_model_name=""), {} + ) + self._weight_transports[transport_id] = transport + else: + return { + "status": "error", + "message": ( + f"Transport '{transport_id}' (backend={backend}) is not " + f"initialized. Call POST /v1/rl/init_transport first." + ), + } + elif transport.backend_id != backend: + return { + "status": "error", + "message": ( + f"Transport '{transport_id}' is bound to backend " + f"'{transport.backend_id}', not '{backend}'." + ), + } + + result = await transport.update_weights(req) + self._weight_version = req.version + payload = result.to_dict() + payload.setdefault("version", req.version) + payload.setdefault("backend", backend) + payload.setdefault("transport_id", transport_id) + return payload + except EngineDeadError as e: + self._shutdown_on_engine_dead(e) + except (ValueError, FileNotFoundError, NotImplementedError) as e: + return {"status": "error", "message": str(e)} + except Exception as e: + logger.exception(f"[RL] weight_transport_update failed: {e}") + return {"status": "error", "message": str(e)} + + async def describe_rl(self, body: dict | None = None) -> dict: + """Return lightweight RL worker metadata for SDK topology probes.""" + mode = getattr(self.config, "disaggregation_mode", None) + if hasattr(mode, "value"): + mode = mode.value + return { + "status": "ok", + "namespace": getattr(self.config, "namespace", None), + "component": getattr(self.config, "component", None), + "endpoint": "rl", + "worker_role": mode, + "details": { + "model": getattr(self.config, "model", None), + "served_model_name": ( + getattr(self.config, "served_model_name", None) + or getattr(self.config, "model", None) + ), + "weight_version": getattr(self, "_weight_version", "initial"), + "lora_count": len(self.loaded_loras), + }, + } + + # ── PR B: unified `rl` request-plane endpoint ───────────────────── + # + # Worker registers ``dyn://..rl`` and serves this + # dispatcher. The frontend (dynamo-rl crate) discovers live `rl` + # instances via the standard discovery plane and dispatches via strict + # request-plane direct calls over NATS / shared TCP — no system-port HTTP + # fan-out, no static `DYN_RL_WORKER_SYSTEM_URLS` list. + # + # Wire shape: ``{"op": str, "body": dict}`` where `op` is one of + # ``describe | pause | resume | init_transport | update_weights``. The dispatcher + # routes to the existing per-op handlers and yields a single response + # dict (matching the serve_endpoint async-generator contract used by + # ``generate``, ``load_lora``, etc.). + # + # Legacy ``register_engine_route`` HTTP-on-system-port routes stay + # live during PR B / PR C overlap so unmigrated callers don't break. + async def rl_dispatch(self, request=None): + """Single-endpoint RL admin dispatcher (PR B). + + Async generator yielding exactly one response dict per call. + """ + if request is None: + yield {"status": "error", "message": "rl_dispatch: request required"} + return + op = request.get("op") + body = request.get("body") or {} + if not isinstance(op, str) or not op: + yield { + "status": "error", + "message": "rl_dispatch: missing 'op' (str)", + } + return + try: + if op == "describe": + yield await self.describe_rl(body) + elif op == "pause": + yield await self.pause_generation(body) + elif op == "resume": + yield await self.resume_generation(body) + elif op == "init_transport": + yield await self.weight_transport_init(body) + elif op == "update_weights": + yield await self.weight_transport_update(body) + else: + yield { + "status": "error", + "message": ( + f"rl_dispatch: unknown op {op!r}; expected one of " + "describe|pause|resume|init_transport|update_weights" + ), + } + except Exception as e: + logger.exception(f"[RL] rl_dispatch op={op!r} failed: {e}") + yield {"status": "error", "op": op, "message": str(e)} + @abstractmethod def generate(self, request: RequestT, context: Context) -> AsyncIterator[ResponseT]: raise NotImplementedError @@ -2012,10 +2693,7 @@ async def generate_tokens( num_output_tokens_so_far[output_idx] = next_total_toks except EngineDeadError as e: - logger.error(f"vLLM EngineDeadError: {e}") - logger.warning("Initiating Dynamo Runtime shutdown.") - self.runtime.shutdown() - os._exit(1) + self._shutdown_on_engine_dead(e) class DecodeWorkerHandler(BaseWorkerHandler): @@ -2257,10 +2935,7 @@ async def _generate_token_mode(self, request, context, request_id): ] = prefill_prompt_tokens_details yield tok except EngineDeadError as e: - logger.error(f"vLLM EngineDeadError: {e}") - logger.warning("Initiating Dynamo Runtime shutdown.") - self.runtime.shutdown() - os._exit(1) + self._shutdown_on_engine_dead(e) async def _generate_text_mode(self, request, context, request_id): """Generate text using OpenAI-compatible format (text-in-text-out).""" diff --git a/components/src/dynamo/vllm/publisher.py b/components/src/dynamo/vllm/publisher.py index d3a8619ad9d9..d233c55f636d 100644 --- a/components/src/dynamo/vllm/publisher.py +++ b/components/src/dynamo/vllm/publisher.py @@ -58,6 +58,7 @@ def record( *args: object, **kwargs: object, ) -> None: + # scheduler_stats can be None right after a weight reload / cache reset. if scheduler_stats is None: return diff --git a/components/src/dynamo/vllm/weight_transports/__init__.py b/components/src/dynamo/vllm/weight_transports/__init__.py new file mode 100644 index 000000000000..391a60562cb9 --- /dev/null +++ b/components/src/dynamo/vllm/weight_transports/__init__.py @@ -0,0 +1,57 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Weight-transport plug-ins for ``dynamo.vllm`` (Phase 1+4 of the +WeightTransferConfig design). + +In scope this iteration: + +* :class:`FilesystemTransport` — current default, safetensors via shared FS. +* :class:`NcclTransport` — collective broadcast on a pre-formed group + (vLLM ``collective_rpc("update_weights_from_distributed", …)``). + +Future (deferred): ``NixlTransport``, ``ModelExpressTransport``, +``IpcTransport``, plus an ``SglangEngineAdapter`` for the second engine +flavor. +""" + +from .base import ( + EngineAdapter, + InitCtx, + InitResult, + TransportState, + UpdateResult, + UpdateWeightsRequest, + WeightTarget, + WeightTransport, +) +from .engine_adapter import VllmEngineAdapter +from .filesystem import FilesystemTransport +from .nccl import NcclTransport + +__all__ = [ + "EngineAdapter", + "FilesystemTransport", + "InitCtx", + "InitResult", + "NcclTransport", + "TransportState", + "UpdateResult", + "UpdateWeightsRequest", + "VllmEngineAdapter", + "WeightTarget", + "WeightTransport", + "build_transport", +] + + +def build_transport(backend: str, engine_adapter, cfg: dict): + """Factory: instantiate the right transport for the given backend id.""" + if backend == "filesystem": + return FilesystemTransport(engine_adapter, cfg) + if backend == "nccl": + return NcclTransport(engine_adapter, cfg) + raise ValueError( + f"Unsupported weight-transport backend '{backend}'. " + "In-scope this iteration: filesystem, nccl. " + "Future (deferred): nixl, model_express, ipc." + ) diff --git a/components/src/dynamo/vllm/weight_transports/base.py b/components/src/dynamo/vllm/weight_transports/base.py new file mode 100644 index 000000000000..f8c6d9e8f791 --- /dev/null +++ b/components/src/dynamo/vllm/weight_transports/base.py @@ -0,0 +1,191 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Trait + types for the WeightTransferConfig API (vLLM-scoped, Phase 1).""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Literal, Optional, Protocol + + +PauseMode = Literal["keep", "wait", "abort"] +TransportState = Literal["configured", "ready", "receiving", "failed"] +TargetKind = Literal["base", "lora"] +LoraOp = Literal["load", "swap", "unload"] + + +@dataclass(frozen=True) +class WeightTarget: + """What is being updated. + + * ``kind="base"``: the base model itself (full-FT reload). + * ``kind="lora"``: a LoRA adapter; ``name`` is required and ``op`` selects + between load/swap/unload. + """ + + kind: TargetKind + name: Optional[str] = None + op: Optional[LoraOp] = None + + @classmethod + def from_dict(cls, body: dict) -> "WeightTarget": + kind = body.get("kind") + if kind not in ("base", "lora"): + raise ValueError( + f"WeightTarget.kind must be 'base' or 'lora', got {kind!r}" + ) + if kind == "lora": + name = body.get("name") + if not isinstance(name, str) or not name: + raise ValueError( + "WeightTarget.name is required when kind='lora'" + ) + op = body.get("op") + if op not in ("load", "swap", "unload"): + raise ValueError( + f"WeightTarget.op must be 'load'|'swap'|'unload' when " + f"kind='lora', got {op!r}" + ) + return cls(kind="lora", name=name, op=op) + return cls(kind="base") + + +@dataclass +class UpdateWeightsRequest: + """Single discriminated body for ``POST /v1/rl/update_weights``.""" + + version: str + target: WeightTarget + transport: dict # backend-specific block, validated by the transport impl + pause_mode: PauseMode = "keep" + clear_cache: bool = True + + @classmethod + def from_dict(cls, body: dict) -> "UpdateWeightsRequest": + version = body.get("version") + if not isinstance(version, str) or not version: + raise ValueError("update_weights: 'version' is required") + target = WeightTarget.from_dict(body.get("target", {}) or {}) + transport = body.get("transport") or {} + if target.kind == "base" or target.op != "unload": + if not isinstance(transport, dict) or "backend" not in transport: + raise ValueError( + "update_weights: 'transport.backend' is required " + "(except for lora unload)" + ) + pause_mode = body.get("pause_mode", "keep") + if pause_mode not in ("keep", "wait", "abort"): + raise ValueError( + f"update_weights: pause_mode must be 'keep'|'wait'|'abort', " + f"got {pause_mode!r}" + ) + clear_cache = bool(body.get("clear_cache", True)) + return cls( + version=version, + target=target, + transport=transport, + pause_mode=pause_mode, + clear_cache=clear_cache, + ) + + +@dataclass +class InitCtx: + """Constant context passed to every transport ``init`` call.""" + + rank: int + world_size: int + served_model_name: str + + +@dataclass +class InitResult: + status: str + transport_id: str + ready: bool + message: Optional[str] = None + extra: dict = field(default_factory=dict) + + def to_dict(self) -> dict: + out = { + "status": self.status, + "transport_id": self.transport_id, + "ready": self.ready, + } + if self.message: + out["message"] = self.message + if self.extra: + out.update(self.extra) + return out + + +@dataclass +class UpdateResult: + status: str + message: str = "" + version: Optional[str] = None + extra: dict = field(default_factory=dict) + + def to_dict(self) -> dict: + out = {"status": self.status, "message": self.message} + if self.version is not None: + out["version"] = self.version + if self.extra: + out.update(self.extra) + return out + + +class WeightTransport(Protocol): + """One implementation per backend. + + Phase 1: ``FilesystemTransport``. + Phase 4: ``NcclTransport``. + """ + + backend_id: str + + async def init(self, ctx: InitCtx, cfg: dict) -> InitResult: ... + + async def update_weights( + self, req: UpdateWeightsRequest + ) -> UpdateResult: ... + + async def teardown(self) -> None: ... + + @property + def state(self) -> TransportState: ... + + +class EngineAdapter(Protocol): + """Engine-flavor shim. One implementation per engine. + + Phase 1+4 ships :class:`VllmEngineAdapter` only. Future: + ``SglangEngineAdapter`` drops in as one extra subclass without touching + any :class:`WeightTransport` impl. + """ + + async def update_weights_from_disk( + self, *, path: str, version: str, target: WeightTarget + ) -> UpdateResult: ... + + async def update_weights_from_distributed( + self, + *, + group: str, + dtype: str, + version: str, + target: WeightTarget, + weight_names: Optional[list[str]] = None, + ) -> UpdateResult: ... + + async def update_weights_from_tensor( + self, *, tensors: Any, version: str, target: WeightTarget + ) -> UpdateResult: ... + + async def update_weights_from_ipc( + self, *, handle: Any, version: str, target: WeightTarget + ) -> UpdateResult: ... + + async def add_lora(self, *, name: str, source: str) -> UpdateResult: ... + + async def remove_lora(self, *, name: str) -> UpdateResult: ... diff --git a/components/src/dynamo/vllm/weight_transports/engine_adapter.py b/components/src/dynamo/vllm/weight_transports/engine_adapter.py new file mode 100644 index 000000000000..2c3f9e914a34 --- /dev/null +++ b/components/src/dynamo/vllm/weight_transports/engine_adapter.py @@ -0,0 +1,138 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""vLLM-flavor engine adapter. + +Wraps ``engine_client.collective_rpc(...)`` so each :class:`WeightTransport` +implementation can call a stable, engine-agnostic API. Future: +``SglangEngineAdapter`` will wrap ``tokenizer_manager.update_weights_from_*`` +following the same Protocol. +""" + +from __future__ import annotations + +import logging +from typing import Any, Optional + +from .base import EngineAdapter, UpdateResult, WeightTarget + +logger = logging.getLogger(__name__) + + +class VllmEngineAdapter(EngineAdapter): + """vLLM-flavor :class:`EngineAdapter` backed by an ``engine_client``. + + All four ``update_weights_from_*`` paths route through ``collective_rpc`` + against the in-process worker(s); LoRA ops route through the engine's + ``add_lora`` / ``remove_lora`` (or equivalent collective) calls. + """ + + backend_id = "vllm" + + def __init__(self, engine_client, *, lora_loader=None): + self.engine_client = engine_client + self._lora_loader = lora_loader # optional callable for LoRA add path + + # ---- four canonical update paths --------------------------------------- + + async def update_weights_from_disk( + self, *, path: str, version: str, target: WeightTarget + ) -> UpdateResult: + await self.engine_client.collective_rpc( + "reload_weights", + kwargs={"weights_path": path}, + ) + return UpdateResult( + status="ok", + message=f"Weights loaded from {path}", + version=version, + ) + + async def update_weights_from_distributed( + self, + *, + group: str, + dtype: str, + version: str, + target: WeightTarget, + weight_names: Optional[list[str]] = None, + ) -> UpdateResult: + # vLLM exposes per-name distributed update via the worker's + # `update_weight_from_tensor` / `update_weight` collective. We loop + # over weight_names so the trainer can drive the broadcast iteration. + if not weight_names: + raise ValueError( + "update_weights_from_distributed: weight_names is required so " + "the worker knows which named parameters to receive on the " + "NCCL group." + ) + for name in weight_names: + await self.engine_client.collective_rpc( + "update_weight", + kwargs={"name": name, "dtype": dtype, "shape": None}, + ) + return UpdateResult( + status="ok", + message=f"Updated {len(weight_names)} weights via group '{group}'", + version=version, + extra={"weights_received": len(weight_names)}, + ) + + async def update_weights_from_tensor( + self, *, tensors: Any, version: str, target: WeightTarget + ) -> UpdateResult: + # Future hook for NIXL/MX paths (deferred). + raise NotImplementedError( + "update_weights_from_tensor is reserved for NIXL/ModelExpress " + "transports; not implemented in Phase 1+4." + ) + + async def update_weights_from_ipc( + self, *, handle: Any, version: str, target: WeightTarget + ) -> UpdateResult: + raise NotImplementedError( + "update_weights_from_ipc is reserved for the colocated-trainer " + "path; not implemented in Phase 1+4." + ) + + # ---- LoRA ops ---------------------------------------------------------- + + async def add_lora(self, *, name: str, source: str) -> UpdateResult: + if self._lora_loader is None: + raise RuntimeError( + "VllmEngineAdapter.add_lora called but no lora_loader was " + "supplied at construction. Wire it from the handler." + ) + result = await self._lora_loader(name=name, path=source) + return UpdateResult( + status=result.get("status", "ok"), + message=result.get("message", ""), + extra={k: v for k, v in result.items() if k not in ("status", "message")}, + ) + + async def remove_lora(self, *, name: str) -> UpdateResult: + if self._lora_loader is None: + raise RuntimeError( + "VllmEngineAdapter.remove_lora called but no lora_loader was " + "supplied at construction. Wire it from the handler." + ) + # The handler exposes both load and unload via the same `lora_loader` + # callable, dispatched on a sentinel ``op`` field. We use the same + # convention: invoke the unload helper if available; otherwise fall + # through and let the caller handle. + unloader = getattr(self, "_lora_unloader", None) + if unloader is None: + raise RuntimeError( + "VllmEngineAdapter.remove_lora called but no lora_unloader " + "was supplied. Wire it from the handler." + ) + result = await unloader(name=name) + return UpdateResult( + status=result.get("status", "ok"), + message=result.get("message", ""), + extra={k: v for k, v in result.items() if k not in ("status", "message")}, + ) + + # Convenience: handler wires both helpers in one shot. + def bind_lora_helpers(self, *, loader, unloader): + self._lora_loader = loader + self._lora_unloader = unloader diff --git a/components/src/dynamo/vllm/weight_transports/filesystem.py b/components/src/dynamo/vllm/weight_transports/filesystem.py new file mode 100644 index 000000000000..8ec0d7626acc --- /dev/null +++ b/components/src/dynamo/vllm/weight_transports/filesystem.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Filesystem weight transport (Phase 1). + +Equivalent of the existing ``update_weights_from_path`` route, but reachable +through the unified :class:`WeightTransport` Protocol so the same wire shape +covers full-FT and LoRA, and so future backends slot in alongside. +""" + +from __future__ import annotations + +import logging +import os +from typing import Optional + +from .base import ( + EngineAdapter, + InitCtx, + InitResult, + TransportState, + UpdateResult, + UpdateWeightsRequest, + WeightTransport, +) + +logger = logging.getLogger(__name__) + + +class FilesystemTransport(WeightTransport): + """Filesystem path → engine reload. + + Config (the ``"filesystem"`` block of an ``init_transport`` body or a + ``transport.filesystem`` block of an ``update_weights`` body): + + path: str (required for base / lora-load / lora-swap) + require_marker: str (optional, default 'STABLE') + """ + + backend_id = "filesystem" + + def __init__(self, engine_adapter: EngineAdapter, cfg: dict): + self._engine = engine_adapter + self._cfg = cfg or {} + self._state: TransportState = "configured" + self._transport_id: str = self._cfg.get("transport_id", "filesystem") + + @property + def state(self) -> TransportState: + return self._state + + async def init(self, ctx: InitCtx, cfg: dict) -> InitResult: + # No setup needed for filesystem — degenerate one-shot. + self._cfg = {**self._cfg, **(cfg or {})} + self._transport_id = self._cfg.get("transport_id", self._transport_id) + self._state = "ready" + return InitResult( + status="ok", + transport_id=self._transport_id, + ready=True, + message="filesystem transport ready (no setup required)", + ) + + async def teardown(self) -> None: + self._state = "configured" + + async def update_weights( + self, req: UpdateWeightsRequest + ) -> UpdateResult: + fs = req.transport.get("filesystem") or {} + path: Optional[str] = fs.get("path") + require_marker: Optional[str] = fs.get( + "require_marker", self._cfg.get("require_marker", "STABLE") + ) + + # ---- LoRA unload: no transport, no path ---------------------------- + if req.target.kind == "lora" and req.target.op == "unload": + return await self._engine.remove_lora(name=req.target.name) + + if not path: + raise ValueError( + "filesystem.update_weights: 'transport.filesystem.path' is " + "required (except for lora unload)" + ) + + if require_marker: + marker = os.path.join(path, require_marker) + if not os.path.exists(marker): + raise FileNotFoundError( + f"filesystem transport: require_marker '{require_marker}' " + f"not found under {path!r}" + ) + + if req.target.kind == "base": + self._state = "receiving" + try: + result = await self._engine.update_weights_from_disk( + path=path, version=req.version, target=req.target + ) + finally: + self._state = "ready" + logger.info( + f"[RL] filesystem.update_weights: base reload from {path} " + f"(version={req.version})" + ) + return result + + # target.kind == "lora", op in {load, swap} + result = await self._engine.add_lora(name=req.target.name, source=path) + logger.info( + f"[RL] filesystem.update_weights: lora {req.target.op} " + f"name={req.target.name} from {path}" + ) + return result diff --git a/components/src/dynamo/vllm/weight_transports/nccl.py b/components/src/dynamo/vllm/weight_transports/nccl.py new file mode 100644 index 000000000000..0a3636f5bf61 --- /dev/null +++ b/components/src/dynamo/vllm/weight_transports/nccl.py @@ -0,0 +1,167 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""NCCL weight transport (Phase 4). + +Trainer + dynamo.vllm worker(s) form a NCCL process group at +``init_transport`` time; per-step ``update_weights`` triggers receive via +``collective_rpc("update_weight", ...)`` for each named parameter. + +Phase 4 scope: vLLM only. The trainer side is responsible for driving the +broadcast itself; dynamo just exposes the receiver hook. +""" + +from __future__ import annotations + +import logging +from typing import Optional + +from .base import ( + EngineAdapter, + InitCtx, + InitResult, + TransportState, + UpdateResult, + UpdateWeightsRequest, + WeightTransport, +) + +logger = logging.getLogger(__name__) + + +class NcclTransport(WeightTransport): + """NCCL collective broadcast → engine receive. + + Config (the ``"nccl"`` block of an ``init_transport`` body or a + ``transport.nccl`` block of an ``update_weights`` body): + + group_name: str (required) + init_method: str (e.g. "tcp://trainer:29500", required at init) + trainer_world_size: int (required at init) + inference_world_size: int (required at init; usually == # workers) + dtype: str (e.g. "bf16") + + For ``update_weights``: + + weight_names: list[str] (the iteration order of named params + the trainer is broadcasting; required) + """ + + backend_id = "nccl" + + def __init__(self, engine_adapter: EngineAdapter, cfg: dict): + self._engine = engine_adapter + self._cfg = cfg or {} + self._state: TransportState = "configured" + self._transport_id: str = self._cfg.get("transport_id", "nccl") + + @property + def state(self) -> TransportState: + return self._state + + async def init(self, ctx: InitCtx, cfg: dict) -> InitResult: + cfg = cfg or {} + merged = {**self._cfg, **cfg} + # vLLM's init_weight_transfer_engine takes: + # master_address, master_port, rank_offset, world_size + # The trainer is rank 0; inference workers are rank_offset..world_size-1. + for required in ("master_address", "master_port", "world_size"): + if required not in merged: + raise ValueError( + f"nccl transport: '{required}' is required in init_transport" + ) + + self._cfg = merged + self._transport_id = merged.get("transport_id", self._transport_id) + + # Drive the worker-side bootstrap via vLLM's + # `init_weight_transfer_engine` collective. + try: + init_info = { + "master_address": str(merged["master_address"]), + "master_port": int(merged["master_port"]), + "rank_offset": int(merged.get("rank_offset", 1)), + "world_size": int(merged["world_size"]), + } + await self._engine.engine_client.collective_rpc( + "init_weight_transfer_engine", + kwargs={"init_info": init_info}, + ) + self._state = "ready" + return InitResult( + status="ok", + transport_id=self._transport_id, + ready=True, + message=( + f"nccl init_weight_transfer_engine ok " + f"(master={init_info['master_address']}:{init_info['master_port']}, " + f"world_size={init_info['world_size']})" + ), + extra={"init_info": init_info}, + ) + except Exception as exc: + self._state = "failed" + logger.error(f"[RL] nccl.init failed: {exc}") + raise + + async def teardown(self) -> None: + # vLLM doesn't expose an explicit destroy hook; engine teardown handles it. + self._state = "configured" + + async def update_weights( + self, req: UpdateWeightsRequest + ) -> UpdateResult: + # NCCL transport does not own LoRA hot-swap in this iteration; LoRA + # adapters are tiny enough that filesystem stays the better path. + if req.target.kind == "lora": + raise NotImplementedError( + "nccl transport: LoRA adapter transfer is deferred. Use " + "transport.backend='filesystem' for LoRA in this iteration." + ) + + nccl = req.transport.get("nccl") or {} + # The trainer must supply (names, dtype_names, shapes) so the worker + # knows how big each `torch.empty(...)` receive buffer should be. + names: Optional[list[str]] = nccl.get("names") or nccl.get("weight_names") + dtype_names: Optional[list[str]] = nccl.get("dtype_names") + shapes: Optional[list[list[int]]] = nccl.get("shapes") + if not names: + raise ValueError( + "nccl.update_weights: 'transport.nccl.names' is required" + ) + if not dtype_names or not shapes: + raise ValueError( + "nccl.update_weights: 'transport.nccl.dtype_names' and " + "'transport.nccl.shapes' are required" + ) + if len(dtype_names) != len(names) or len(shapes) != len(names): + raise ValueError( + f"nccl.update_weights: names/dtype_names/shapes length mismatch " + f"({len(names)} / {len(dtype_names)} / {len(shapes)})" + ) + + update_info = { + "names": names, + "dtype_names": dtype_names, + "shapes": shapes, + "is_checkpoint_format": bool(nccl.get("is_checkpoint_format", True)), + "packed": bool(nccl.get("packed", False)), + } + + self._state = "receiving" + try: + await self._engine.engine_client.collective_rpc( + "update_weights", + kwargs={"update_info": update_info}, + ) + finally: + self._state = "ready" + logger.info( + f"[RL] nccl.update_weights: {len(names)} weights received " + f"(version={req.version})" + ) + return UpdateResult( + status="ok", + message=f"Updated {len(names)} weights via nccl", + version=req.version, + extra={"weights_received": len(names)}, + ) diff --git a/components/src/dynamo/vllm/worker_factory.py b/components/src/dynamo/vllm/worker_factory.py index ce3d473fdb24..9018f22a9fd1 100644 --- a/components/src/dynamo/vllm/worker_factory.py +++ b/components/src/dynamo/vllm/worker_factory.py @@ -241,9 +241,19 @@ async def _create_decode_worker( f"{config.namespace}.{config.component}.clear_kv_blocks" ) + # PR B: unified RL admin endpoint on the request plane. Discoverable + # via etcd as ``..rl``; the dynamo-rl frontend crate + # uses Discovery::list(NamespacedEndpoints) + PushRouter::direct to + # fan out admin ops here, replacing the legacy HTTP-on-system-port + # ``register_engine_route("pause_generation", …)`` etc. mechanism. + rl_endpoint = runtime.endpoint( + f"{config.namespace}.{config.component}.rl" + ) + shutdown_endpoints[:] = [ generate_endpoint, clear_endpoint, + rl_endpoint, ] lora_enabled = config.engine_args.enable_lora @@ -366,7 +376,7 @@ async def _create_decode_worker( component_name=config.component, ) - # Register engine routes + # Register engine routes (sleep/wake_up + RL weight-lifecycle + RL LoRA) self.register_engine_routes(runtime, handler) # Parse endpoint types from --endpoint-types flag @@ -442,6 +452,12 @@ async def _create_decode_worker( handler.get_perf_metrics, metrics_labels=model_metrics_labels, ), + # PR B: unified RL admin endpoint (rl_dispatch dispatches + # by op name to pause/resume/init_transport/update_weights). + rl_endpoint.serve_endpoint( + handler.rl_dispatch, + metrics_labels=model_metrics_labels, + ), ] if lora_enabled: @@ -576,7 +592,7 @@ async def _create_prefill_worker( component_name=config.component, ) - # Register engine routes + # Register engine routes (sleep/wake_up + RL weight-lifecycle + RL LoRA) self.register_engine_routes(runtime, handler) await self._maybe_wait_for_failover_lock(handler, runtime, config) @@ -676,6 +692,46 @@ def register_engine_routes( runtime.register_engine_route("wake_up", handler.wake_up) runtime.register_engine_route("scale_elastic_ep", handler.scale_elastic_ep) + # RL weight-lifecycle routes — driven by the + # /v1/rl/{pause,resume,update_weights} bracket in the Rust frontend. + # Names line up with the SGLang RL admin routes so a single admin + # coordinator can talk to either backend. + runtime.register_engine_route("pause_generation", handler.pause_generation) + runtime.register_engine_route("resume_generation", handler.resume_generation) + runtime.register_engine_route("flush_cache", handler.flush_cache) + runtime.register_engine_route( + "update_weights_from_path", handler.update_weights_from_path + ) + runtime.register_engine_route("get_weight_version", handler.get_weight_version) + + # RL state + liveness — drive /v1/rl/state and /v1/rl/liveness in the + # Rust frontend. /v1/rl/state aggregates these per-worker snapshots + # into the composite RlStateResponse. + runtime.register_engine_route("get_state", handler.get_state) + runtime.register_engine_route("liveness_probe", handler.liveness_probe) + + # RL LoRA adapter routes: filesystem-native hot-swap used by RL + # trainers every step to broadcast new adapter weights into the engine. + runtime.register_engine_route("load_lora_adapter", handler.load_lora_adapter) + runtime.register_engine_route( + "unload_lora_adapter", handler.unload_lora_adapter + ) + + # RL WeightTransferConfig API (Phase 1+4): unified transport surface + # for filesystem + nccl backends. Coexists with the legacy routes + # above; legacy callers continue to work unchanged. + runtime.register_engine_route( + "weight_transport_init", handler.weight_transport_init + ) + runtime.register_engine_route( + "weight_transport_update", handler.weight_transport_update + ) + logger.info( - "Registered engine routes: /engine/sleep, /engine/wake_up, /engine/scale_elastic_ep, /engine/start_profile, /engine/stop_profile" + "Registered engine routes: sleep, wake_up, scale_elastic_ep, " + "start_profile, stop_profile, pause_generation, resume_generation, " + "flush_cache, update_weights_from_path, get_weight_version, " + "get_state, liveness_probe, " + "load_lora_adapter, unload_lora_adapter, " + "weight_transport_init, weight_transport_update" ) diff --git a/container/deps/vllm/install_vllm.sh b/container/deps/vllm/install_vllm.sh index 5ab0aa0e6896..ffbceff6bb1e 100755 --- a/container/deps/vllm/install_vllm.sh +++ b/container/deps/vllm/install_vllm.sh @@ -13,7 +13,7 @@ set -euo pipefail -VLLM_VER="0.20.0" +VLLM_VER="0.19.1" VLLM_REF="v${VLLM_VER}" DEVICE="cuda" @@ -300,4 +300,39 @@ if [ "$DEVICE" = "cuda" ]; then # TODO we will be able to specify which pplx and deepep commit we want in future TORCH_CUDA_ARCH_LIST="$TORCH_CUDA_ARCH_LIST" bash install_python_libraries.sh fi + +# --------------------------------------------------------------------------- +# prime-rl inference-side vLLM plugin (pinned tag). +# +# Registers the ``vllm.general_plugins`` entry-point that applies prime-rl's +# monkey patches (LoRA adapter load, DP engine pause/resume deadlock, Qwen 3.5 +# LoRA, etc.) automatically in every vLLM worker process -- including spawned +# subprocesses. Required for prime-rl / Dynamo RL training integration. +# +# Pinned to an immutable commit SHA (not a tag) for reproducibility; tags can +# be re-pointed upstream. PRIME_RL_REF is kept for human-readable build logs. +# Override at build time: --build-arg PRIME_RL_COMMIT= +# --no-deps: prime-rl's full dep tree includes trainer + wandb; Dynamo only +# needs the inference-side plugin and worker-extension classes. +# Python version: prime-rl pins requires-python = "~=3.12.0"; Dynamo containers +# are Python 3.12, so no version override is needed. For 3.11 local +# dev venvs use the regular pip (not uv) with --ignore-requires-python. +# --------------------------------------------------------------------------- +PRIME_RL_REF="${PRIME_RL_REF:-v0.5.1.dev101}" +PRIME_RL_COMMIT="${PRIME_RL_COMMIT:-d49f3939e7dca29bceb9ed515cc1782497b67e81}" +printf '\n=== Installing prime-rl vLLM plugin (ref=%s commit=%s) ===\n' \ + "$PRIME_RL_REF" "$PRIME_RL_COMMIT" +uv pip install --no-deps \ + "prime-rl @ git+https://github.com/PrimeIntellect-ai/prime-rl@${PRIME_RL_COMMIT}" + +# Sanity-check: confirm vllm.general_plugins entry-point is registered. +python3 - <<'PY_SANITY' +from importlib.metadata import entry_points +names = [ep.name for ep in entry_points(group="vllm.general_plugins")] +assert "prime_rl" in names, ( + f"prime-rl plugin NOT registered; vllm.general_plugins={names}" +) +print(f"✓ prime-rl plugin registered (vllm.general_plugins={names})") +PY_SANITY + echo "\n✅ All installations completed successfully!" diff --git a/docs/RL.md b/docs/RL.md new file mode 100644 index 000000000000..3956acccaafd --- /dev/null +++ b/docs/RL.md @@ -0,0 +1,667 @@ +--- +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +title: RL +--- + +# Dynamo RL + +Dynamo RL support has two separate surfaces: + +1. The inference surface on the normal OpenAI listener, usually + `:8000 /v1/chat/completions`. This carries rollout-time extensions such as + token-in/token-out, cache salt, and weight version metadata. +2. The admin surface for pause, resume, transport setup, and weight updates. + The canonical implementation is the `dynamo-rl` SDK. The optional HTTP + facade exposes the same operations at `:8002 /v1/rl/*` by default. + +The admin surface does not fan out through worker system ports. Workers are +ephemeral, so the SDK snapshots the discovery plane, finds live worker +endpoints named `rl`, and dispatches strict direct calls through Dynamo's +request plane. The request plane may be TCP, NATS, or HTTP depending on the +deployment. + +## Architecture + +```mermaid +flowchart LR + subgraph ClientSide["RL clients"] + Trainer["Trainer or orchestrator"] + Slime["Slime / in-process client"] + Prime["prime-rl HTTP client"] + end + + subgraph Frontend["Dynamo frontend"] + OpenAI[":8000 /v1/chat/completions"] + RlHttp[":8002 /v1/rl/*"] + RlClient["dynamo-rl RlClient"] + end + + subgraph Runtime["Dynamo runtime"] + Discovery["Discovery plane"] + RequestPlane["Request plane\nTCP / NATS / HTTP"] + end + + subgraph Workers["Inference workers"] + W1["namespace.component.rl\ninstance 1"] + W2["namespace.component.rl\ninstance 2"] + Wn["namespace.component.rl\ninstance N"] + Adapter["vLLM RL adapter"] + Transport["WeightTransport\nfilesystem / nccl"] + end + + Trainer --> OpenAI + Trainer --> RlClient + Slime --> RlClient + Prime --> RlHttp + RlHttp --> RlClient + RlClient --> Discovery + RlClient --> RequestPlane + RequestPlane --> W1 + RequestPlane --> W2 + RequestPlane --> Wn + W1 --> Adapter + W2 --> Adapter + Wn --> Adapter + Adapter --> Transport +``` + +System ports remain useful for process health, metrics, and debugging, but +they are not the RL worker fan-out contract. There is no +`DYN_RL_WORKER_SYSTEM_URLS` static worker list. + +## Enablement + +Frontend configuration: + +| Setting | Default | Purpose | +|---|---:|---| +| `DYN_ENABLE_RL` | `false` | Enables inference-plane RL extensions on `/v1/chat/completions`, including automatic token-id return on unary chat responses. | +| `DYN_ENABLE_RL_ENDPOINTS` | `false` | Enables the optional admin HTTP facade. | +| `DYN_RL_PORT` or `--rl-port` | `8002` | Dedicated listener for `/v1/rl/*`; routes are not mounted on the main `:8000` listener. | +| `DYN_NAMESPACE` | `dynamo` | Namespace scanned by the RL SDK. | +| `DYN_RL_COMPONENT` | unset | Optional component filter. When unset, all live endpoints named `rl` in the namespace are targeted. | +| `DYN_REQUEST_PLANE` | deployment-specific | Selects the request-plane transport, for example TCP, NATS, or HTTP. | +| `DYN_DISCOVERY_BACKEND` | deployment-specific | Selects the discovery backend, for example etcd, Kubernetes, file, or memory. | + +Example trainer endpoints: + +```toml +base_url = "http://dynamo-frontend:8000/v1" +admin_base_url = "http://dynamo-frontend:8002/v1/rl" +``` + +Worker requirements: + +- Workers serving RL workloads register a request-plane endpoint named `rl`. +- The endpoint receives a single envelope, `{"op": "...", "body": {...}}`. +- Supported operations are `describe`, `pause`, `resume`, `init_transport`, + and `update_weights`. + +## Discovery And Fan-Out + +The `dynamo-rl` SDK owns membership and dispatch: + +```rust +pub struct RlClient { + runtime: Arc, + namespace: String, + rl_endpoint: String, // default: "rl" + policy: FanoutPolicy, +} + +pub struct MembershipSnapshot { + pub epoch: u64, + pub targets: Vec, +} + +pub struct FanoutPolicy { + pub min_workers: usize, + pub membership_timeout: Duration, + pub request_timeout: Duration, + pub strict_direct: bool, + pub abort_on_membership_change: bool, + pub component_filter: Option>, +} +``` + +For each admin operation, the SDK: + +1. Lists live namespaced endpoints through discovery. +2. Filters to endpoint name `rl` and the optional component filter. +3. Builds a `MembershipSnapshot` with an epoch fingerprint. +4. Groups targets by `(namespace, component, endpoint)`. +5. Sends a strict direct request to each discovered `instance_id`. +6. Optionally snapshots membership again and fails with + `membership_changed` if the epoch changed during fan-out. + +`strict_direct` matters for RL admin calls. A pause or weight update addressed +to worker instance `A` must not silently fall back to instance `B` if `A` +disappears. If the target is gone, the call fails and the caller receives a +per-worker error. + +```mermaid +sequenceDiagram + autonumber + participant Caller as Trainer / SDK user + participant Client as RlClient + participant Discovery as Discovery plane + participant RP as Request plane + participant Worker as Worker rl endpoint + + Caller->>Client: update_weights(request) + Client->>Discovery: list namespaced endpoints + Discovery-->>Client: live endpoints + Client->>Client: filter endpoint == rl and compute epoch + loop Each worker instance + Client->>RP: strict direct op=update_weights, instance_id + RP->>Worker: op=update_weights, body=request + Worker-->>RP: status payload + RP-->>Client: worker result + end + Client->>Discovery: list namespaced endpoints + Discovery-->>Client: live endpoints + alt epoch unchanged + Client-->>Caller: FanoutReport with membership_epoch and workers + else membership changed + Client-->>Caller: error membership_changed + end +``` + +The snapshot is a consistency guard, not a distributed lock. In a deployment +where workers are added or removed frequently, callers should treat a +`membership_changed` response as a retryable orchestration event. If membership +stays stable but a worker rejects or times out, the HTTP facade returns `502` +with per-worker status so the orchestrator can retry, drain, or rebuild the +fleet. + +The SDK does not poll worker system-port health. It snapshots discovery for +each fan-out operation and waits briefly for the request-plane client to see +the target instance IDs before dispatching. + +## Inference Surface + +RL rollout traffic uses the standard chat-completions route: + +```http +POST /v1/chat/completions +``` + +When `DYN_ENABLE_RL=true`, unary chat responses promote token metadata for RL +clients: + +- `response.prompt_token_ids` is populated from the original messages or from + pre-tokenized input. +- `choices[].token_ids` is populated from completion token IDs. +- `return_token_ids` is auto-enabled for unary chat responses. + +Callers can also request token IDs explicitly with `return_token_ids: true`. +When token IDs are requested, `n > 1` is rejected because the current +aggregation path cannot safely assign one shared completion-token vector back +to multiple choices. + +Supported request extensions include: + +| Field | Direction | Purpose | +|---|---|---| +| `prompt_token_ids` | request | Token-in/token-out path. Send pre-tokenized prompt IDs instead of messages. | +| `tokens` | request | Legacy pre-tokenized prompt path mapped into `nvext.token_data`. | +| `return_token_ids` | request | Requests completion token IDs in the response. | +| `cache_salt` | request | Salts prefix-cache identity for rollout isolation. | +| `weight_version` | request | Routes or annotates requests against a caller-selected weight version. | +| `stop_token_ids` | request | Stop generation when any listed token ID is produced. | +| `allowed_token_ids` | request | Sampling constraint passthrough. | +| `bad_words_token_ids` | request | Sampling constraint passthrough. | +| `truncate_prompt_tokens` | request | Prompt truncation passthrough. | +| `return_prompt_logprobs` | request | Requests prompt logprobs where supported by the backend. | +| `return_routed_experts` | request | Requests routed expert metadata where supported by the backend. | + +TITO callers should send `prompt_token_ids` on `/v1/chat/completions`. The +separate `/v1/chat/completions/tokens` route is not part of the current +surface. + +## Admin HTTP Facade + +The HTTP facade is optional. It exists for clients that cannot embed the SDK +but still need the same fan-out semantics. The facade is mounted only when +`DYN_ENABLE_RL_ENDPOINTS=true` or the service configuration enables RL +endpoints. + +Routes: + +| Method | Path | Description | +|---|---|---| +| `POST` | `/v1/rl/pause` | Fan out `pause` to every discovered worker. | +| `POST` | `/v1/rl/resume` | Fan out `resume` to every discovered worker. | +| `POST` | `/v1/rl/init_transport` | Initialize a weight-transfer backend on every worker. | +| `POST` | `/v1/rl/update_weights` | Apply a base-model or LoRA weight update on every worker. | + +Read-side RL routes are not part of the current HTTP surface: +`/v1/rl/state`, `/v1/rl/health`, `/v1/rl/ready`, `/v1/rl/liveness`, and +`/v1/rl/weight_version` are dropped. Use the frontend's existing `/live` and +`/health` process checks for Kubernetes probes. SDK callers can use +`describe` for topology and worker metadata probes. + +### Pause + +```http +POST /v1/rl/pause?mode=keep&clear_cache=false +``` + +Query parameters: + +| Parameter | Values | Default | +|---|---|---| +| `mode` | `keep`, `wait`, `abort` | `keep` | +| `clear_cache` | `true`, `false` | `false` | + +Successful response: + +```json +{ + "status": "ok", + "mode": "keep", + "clear_cache": false, + "membership_epoch": 129837465, + "workers": [ + { + "status": "ok", + "version": "initial" + } + ] +} +``` + +### Resume + +```http +POST /v1/rl/resume +``` + +Successful response: + +```json +{ + "status": "ok", + "membership_epoch": 129837465, + "workers": [ + { + "status": "ok" + } + ] +} +``` + +### Init Transport + +`init_transport` is idempotent setup for a weight-transfer backend. Filesystem +is a no-op that marks the transport ready. NCCL initializes the worker-side +vLLM weight-transfer engine. + +Filesystem: + +```http +POST /v1/rl/init_transport +``` + +```json +{ + "transport_id": "fs-step", + "backend": "filesystem", + "filesystem": { + "require_marker": "STABLE" + } +} +``` + +NCCL: + +```http +POST /v1/rl/init_transport +``` + +```json +{ + "transport_id": "rl-nccl", + "backend": "nccl", + "nccl": { + "master_address": "trainer-0.trainer", + "master_port": 29500, + "world_size": 9, + "rank_offset": 1 + } +} +``` + +Successful response: + +```json +{ + "status": "ok", + "transport_id": "rl-nccl", + "backend": "nccl", + "ready": true, + "membership_epoch": 129837465, + "workers": [ + { + "status": "ok", + "transport_id": "rl-nccl", + "ready": true + } + ] +} +``` + +### Update Weights + +All weight updates use one discriminated body: + +```json +{ + "version": "step_42", + "target": { + "kind": "base" + }, + "transport": { + "backend": "filesystem" + }, + "pause_mode": "keep", + "clear_cache": true +} +``` + +Required fields: + +| Field | Description | +|---|---| +| `version` | Caller-assigned version string applied to the update. | +| `target.kind` | `base` or `lora`. | +| `transport.backend` | `filesystem` or `nccl` for the current vLLM implementation. Not required for LoRA unload. | + +Optional fields: + +| Field | Default | Description | +|---|---|---| +| `pause_mode` | `keep` | Worker-side pause behavior: `keep`, `wait`, or `abort`. | +| `clear_cache` | `true` | Whether the worker should clear prefix/KV cache where supported. | + +Successful response: + +```json +{ + "status": "ok", + "applied_weight_version": "step_42", + "backend": "filesystem", + "membership_epoch": 129837465, + "workers": [ + { + "status": "ok", + "message": "Updated weights from filesystem", + "version": "step_42" + } + ] +} +``` + +#### Base Model From Filesystem + +The trainer writes a checkpoint to shared storage, creates the marker file +after the checkpoint is complete, then calls `update_weights`. + +```json +{ + "version": "step_42", + "target": { + "kind": "base" + }, + "transport": { + "backend": "filesystem", + "filesystem": { + "path": "/share/broadcasts/step_42", + "require_marker": "STABLE" + } + }, + "pause_mode": "keep", + "clear_cache": true +} +``` + +#### Base Model From NCCL + +The trainer and inference workers form a group during `init_transport`. On +each update, the trainer broadcasts the named tensors and the workers receive +through vLLM's weight-update collective. + +```json +{ + "version": "step_42", + "target": { + "kind": "base" + }, + "transport": { + "backend": "nccl", + "nccl": { + "transport_id": "rl-nccl", + "names": [ + "model.layers.0.self_attn.q_proj.weight" + ], + "dtype_names": [ + "bfloat16" + ], + "shapes": [ + [4096, 4096] + ], + "is_checkpoint_format": true, + "packed": false + } + }, + "pause_mode": "keep", + "clear_cache": true +} +``` + +#### LoRA Load, Swap, And Unload + +LoRA uses the same `update_weights` route. In the current vLLM implementation, +LoRA transfer uses the filesystem backend. NCCL LoRA transfer is deferred. + +Load: + +```json +{ + "version": "step_42", + "target": { + "kind": "lora", + "name": "qwen3-06b-gsm8k", + "op": "load" + }, + "transport": { + "backend": "filesystem", + "filesystem": { + "path": "/share/lora/qwen3-06b-gsm8k/step_42", + "require_marker": "STABLE" + } + }, + "pause_mode": "wait", + "clear_cache": false +} +``` + +Swap: + +```json +{ + "version": "step_43", + "target": { + "kind": "lora", + "name": "qwen3-06b-gsm8k", + "op": "swap" + }, + "transport": { + "backend": "filesystem", + "filesystem": { + "path": "/share/lora/qwen3-06b-gsm8k/step_43", + "require_marker": "STABLE" + } + }, + "pause_mode": "wait", + "clear_cache": false +} +``` + +Unload: + +```json +{ + "version": "step_44", + "target": { + "kind": "lora", + "name": "qwen3-06b-gsm8k", + "op": "unload" + } +} +``` + +Dedicated `load_lora_adapter` and `unload_lora_adapter` RL routes are not part +of the current surface. + +## Weight-Update Sequence + +```mermaid +sequenceDiagram + autonumber + participant Trainer + participant Admin as RlClient or HTTP facade + participant Discovery as Discovery plane + participant RP as Request plane + participant Worker as Worker rl endpoint + participant Adapter as vLLM adapter + participant Engine as vLLM engine + + Trainer->>Admin: pause(mode=keep) + Admin->>Discovery: snapshot rl workers + Admin->>RP: strict direct op=pause to each instance + RP->>Worker: pause + Worker->>Adapter: pause_generation + Adapter-->>Worker: ok + Worker-->>Admin: worker result + + alt filesystem backend + Trainer->>Trainer: write checkpoint and marker to shared storage + Trainer->>Admin: update_weights(version, filesystem path) + Admin->>Discovery: snapshot rl workers + Admin->>RP: strict direct op=update_weights + RP->>Worker: update_weights + Worker->>Adapter: FilesystemTransport.update_weights + Adapter->>Adapter: verify require_marker + Adapter->>Engine: reload weights from path + Engine-->>Adapter: ok + Adapter-->>Worker: ok + else nccl backend + Trainer->>Admin: init_transport(nccl) + Admin->>RP: strict direct op=init_transport + RP->>Worker: init_transport + Worker->>Engine: init_weight_transfer_engine + Engine-->>Worker: ready + Trainer->>Trainer: prepare named tensor broadcast + Trainer->>Admin: update_weights(version, tensor metadata) + Admin->>RP: strict direct op=update_weights + RP->>Worker: update_weights + Worker->>Engine: update_weights receive collective + Engine-->>Worker: ok + end + + Admin-->>Trainer: FanoutReport + Trainer->>Admin: resume() + Admin->>RP: strict direct op=resume to each instance + RP->>Worker: resume + Worker->>Adapter: resume_generation + Adapter-->>Worker: ok + Admin-->>Trainer: FanoutReport +``` + +Weight updates are not atomic across workers. If some workers update and one +worker fails, the fleet can be left at mixed versions. The response includes +per-worker results so the orchestrator can decide whether to retry, drain the +failed worker, or rebuild the serving group. + +## Kubernetes + +Kubernetes deployments should expose two frontend ports: + +- Main inference port, usually `8000`, for OpenAI-compatible traffic and + standard `/health` and `/live` checks. +- RL admin port, usually `8002`, for `/v1/rl/*`. Keep this port + cluster-internal and protect it with service policy or network policy. + +Workers do not need their system ports exposed for RL admin fan-out. They must +be discoverable through the configured Dynamo discovery backend and reachable +through the configured request plane. + +Transport-specific Kubernetes notes: + +- Filesystem transfer requires shared storage mounted at the same path on the + trainer and every inference worker, or a path mapping layer in the + orchestrator. +- NCCL transfer requires the trainer and workers to resolve the NCCL + `master_address` and connect to `master_port`. This rendezvous is separate + from Dynamo's request plane. +- NATS request-plane deployments need the worker and frontend pods connected + to the same NATS deployment. +- TCP request-plane deployments need pod-to-pod connectivity for the Dynamo + request-plane endpoints. + +## Error Responses + +The HTTP facade maps SDK errors to stable status codes: + +| Status | `error_type` | Meaning | +|---:|---|---| +| `503` | `no_workers` | Discovery found fewer than `min_workers` live `rl` endpoints. | +| `409` | `membership_changed` | Membership changed during fan-out and the policy requires a stable epoch. | +| `502` | `fanout_failed` | Request-plane setup failed, worker dispatch failed, or one or more workers returned an error. | + +Per-worker failures also return `502` with a `workers` array: + +```json +{ + "status": "error", + "stage": "weight_transport_update", + "backend": "filesystem", + "membership_epoch": 129837465, + "workers": [ + { + "status": "ok", + "version": "step_42" + }, + { + "status": "error", + "message": "filesystem transport: require_marker 'STABLE' not found" + } + ] +} +``` + +## Backend Status + +Current implementation scope: + +- `dynamo-rl` Rust SDK and HTTP facade. +- Discovery-backed membership snapshots. +- Request-plane strict direct fan-out. +- vLLM worker `rl` dispatcher. +- vLLM filesystem base-model and LoRA updates. +- vLLM NCCL base-model updates. + +Deferred or backend-specific: + +- SGLang weight-transfer adapter parity. +- NCCL LoRA transfer. +- NIXL, Model Express, CUDA IPC, and tensor-handle transports. +- Public read-side RL state endpoints. +- Auth and RBAC inside the RL facade. Deploy the admin port as an internal + control-plane surface. + +`call_tokenizer_manager` is SGLang-specific tokenizer-manager passthrough. It +is not the generic Dynamo RL admin fan-out path. The portable RL admin contract +is the discoverable worker endpoint named `rl` plus the SDK fan-out policy. diff --git a/docs/index.yml b/docs/index.yml index 14b48b8e012d..a2b1aa8fbd2f 100644 --- a/docs/index.yml +++ b/docs/index.yml @@ -346,6 +346,9 @@ navigation: - section: Additional Resources hidden: true contents: + # -- RL -- + - page: RL + path: RL.md # -- Development -- - page: Runtime Guide path: development/runtime-guide.md diff --git a/lib/bindings/python/Cargo.lock b/lib/bindings/python/Cargo.lock index e9bb77d0e246..c8b32e0b194d 100644 --- a/lib/bindings/python/Cargo.lock +++ b/lib/bindings/python/Cargo.lock @@ -2071,6 +2071,7 @@ dependencies = [ "dynamo-mocker", "dynamo-parsers", "dynamo-protocols", + "dynamo-rl", "dynamo-runtime", "dynamo-tokenizers", "dynamo-tokens", @@ -2241,6 +2242,20 @@ dependencies = [ "uuid", ] +[[package]] +name = "dynamo-rl" +version = "1.2.0" +dependencies = [ + "anyhow", + "axum", + "dynamo-runtime", + "futures", + "serde", + "serde_json", + "tokio", + "tracing", +] + [[package]] name = "dynamo-runtime" version = "1.2.0" diff --git a/lib/bindings/python/rust/lib.rs b/lib/bindings/python/rust/lib.rs index c2ff76c0b509..2fb9ffd1898b 100644 --- a/lib/bindings/python/rust/lib.rs +++ b/lib/bindings/python/rust/lib.rs @@ -1159,6 +1159,49 @@ impl Client { Ok(AsyncResponseStream::new(rx, annotated)) }) } + + /// Directly send a request to a specific endpoint without fallback re-selection. + #[pyo3(signature = (request, instance_id, annotated=DEFAULT_ANNOTATED_SETTING, context=None))] + fn direct_strict<'p>( + &self, + py: Python<'p>, + request: PyObject, + instance_id: u64, + annotated: Option, + context: Option, + ) -> PyResult> { + let request: serde_json::Value = pythonize::depythonize(&request.into_bound(py))?; + let request_ctx = create_request_context(request, &context); + let annotated = annotated.unwrap_or(false); + + let (tx, rx) = tokio::sync::mpsc::channel(32); + let client = self.router.clone(); + + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let stream = match context { + Some(context) => { + let span = get_span_for_direct_context( + &context, + "direct_strict", + &instance_id.to_string(), + ); + client + .direct_strict(request_ctx, instance_id) + .instrument(span) + .await + .map_err(to_pyerr)? + } + _ => client + .direct_strict(request_ctx, instance_id) + .await + .map_err(to_pyerr)?, + }; + + tokio::spawn(process_stream(stream, tx)); + + Ok(AsyncResponseStream::new(rx, annotated)) + }) + } } async fn process_stream( diff --git a/lib/bindings/python/rust/llm/entrypoint.rs b/lib/bindings/python/rust/llm/entrypoint.rs index 87663b5e4af5..808b561b5770 100644 --- a/lib/bindings/python/rust/llm/entrypoint.rs +++ b/lib/bindings/python/rust/llm/entrypoint.rs @@ -320,6 +320,7 @@ pub(crate) struct EntrypointArgs { kv_cache_block_size: Option, http_host: Option, http_port: u16, + rl_port: Option, http_metrics_port: Option, tls_cert_path: Option, tls_key_path: Option, @@ -339,7 +340,7 @@ pub(crate) struct EntrypointArgs { impl EntrypointArgs { #[allow(clippy::too_many_arguments)] #[new] - #[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, mocker_engine_args=None, runtime_config=None, namespace=None, namespace_prefix=None, is_prefill=false, migration_limit=0, migration_max_seq_len=None, chat_engine_factory=None, aic_perf_config=None))] + #[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, rl_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, mocker_engine_args=None, runtime_config=None, namespace=None, namespace_prefix=None, is_prefill=false, migration_limit=0, migration_max_seq_len=None, chat_engine_factory=None, aic_perf_config=None))] pub fn new( py: Python<'_>, engine_type: EngineType, @@ -352,6 +353,7 @@ impl EntrypointArgs { kv_cache_block_size: Option, http_host: Option, http_port: Option, + rl_port: Option, http_metrics_port: Option, tls_cert_path: Option, tls_key_path: Option, @@ -402,6 +404,7 @@ impl EntrypointArgs { kv_cache_block_size, http_host, http_port: http_port.unwrap_or(DEFAULT_HTTP_PORT), + rl_port, http_metrics_port, tls_cert_path, tls_key_path, @@ -450,6 +453,7 @@ pub fn make_engine<'p>( .migration_max_seq_len(args.migration_max_seq_len) .http_host(args.http_host.clone()) .http_port(args.http_port) + .rl_port(args.rl_port) .http_metrics_port(args.http_metrics_port) .tls_cert_path(args.tls_cert_path.clone()) .tls_key_path(args.tls_key_path.clone()) diff --git a/lib/bindings/python/src/dynamo/_core.pyi b/lib/bindings/python/src/dynamo/_core.pyi index dba5466d0532..eb3a847abeca 100644 --- a/lib/bindings/python/src/dynamo/_core.pyi +++ b/lib/bindings/python/src/dynamo/_core.pyi @@ -246,6 +246,18 @@ class Client: """ ... + async def direct_strict( + self, + request: JsonLike, + instance_id: int, + annotated: bool | None = True, + context: Context | None = None, + ) -> AsyncIterator[JsonLike]: + """ + Pick a specific instance of the endpoint without fallback re-selection. + """ + ... + async def generate( self, request: JsonLike, @@ -2114,6 +2126,7 @@ class EntrypointArgs: kv_cache_block_size: Optional[int] = None, http_host: Optional[str] = None, http_port: Optional[int] = None, + rl_port: Optional[int] = None, http_metrics_port: Optional[int] = None, tls_cert_path: Optional[str] = None, tls_key_path: Optional[str] = None, @@ -2141,6 +2154,7 @@ class EntrypointArgs: kv_cache_block_size: Optional KV cache block size http_host: HTTP host to bind to http_port: HTTP port to bind to + rl_port: Dedicated RL admin HTTP port to bind to http_metrics_port: HTTP metrics port (for gRPC service) tls_cert_path: TLS certificate path (PEM format) tls_key_path: TLS key path (PEM format) diff --git a/lib/llm/Cargo.toml b/lib/llm/Cargo.toml index 31559be92ed4..d384bf1d5510 100644 --- a/lib/llm/Cargo.toml +++ b/lib/llm/Cargo.toml @@ -56,6 +56,7 @@ dynamo-config = { workspace = true } dynamo-kv-router = { workspace = true, features = ["metrics", "runtime-protocols"] } dynamo-memory = { workspace = true } dynamo-mocker = { workspace = true } +dynamo-rl = { workspace = true } dynamo-runtime = { workspace = true } dynamo-tokenizers = { workspace = true } dynamo-tokens = { workspace = true } diff --git a/lib/llm/src/audit/stream.rs b/lib/llm/src/audit/stream.rs index 2663cdd7eeb0..4985139d37c8 100644 --- a/lib/llm/src/audit/stream.rs +++ b/lib/llm/src/audit/stream.rs @@ -101,6 +101,7 @@ where service_tier: None, }, nvext: None, + prompt_token_ids: None, } }) }), @@ -138,6 +139,7 @@ where service_tier: None, }, nvext: None, + prompt_token_ids: None, }; let _ = tx.send(fallback.clone()); final_response_to_one_chunk_stream(fallback) @@ -160,6 +162,7 @@ where service_tier: None, }, nvext: None, + prompt_token_ids: None, } }) }); diff --git a/lib/llm/src/entrypoint/input/http.rs b/lib/llm/src/entrypoint/input/http.rs index a9bdf1c6c09a..b54a0f799fd5 100644 --- a/lib/llm/src/entrypoint/input/http.rs +++ b/lib/llm/src/entrypoint/input/http.rs @@ -49,6 +49,9 @@ pub async fn run( if let Some(http_host) = local_model.http_host() { http_service_builder = http_service_builder.host(http_host); } + if let Some(rl_port) = local_model.rl_port() { + http_service_builder = http_service_builder.rl_port(rl_port); + } http_service_builder = http_service_builder.cancel_token(Some(distributed_runtime.primary_token())); http_service_builder = @@ -63,6 +66,12 @@ pub async fn run( http_service_builder = http_service_builder.drt_discovery(Some(distributed_runtime.discovery())); + // Wire the full DRT so the RL admin router (when DYN_ENABLE_RL_ENDPOINTS=true) + // can use the discovery + request planes to fan out to live `..rl` + // worker endpoints. + http_service_builder = + http_service_builder.runtime(Some(Arc::new(distributed_runtime.clone()))); + let http_service = match engine_config { EngineConfig::Dynamic { ref model, diff --git a/lib/llm/src/entrypoint/input/text.rs b/lib/llm/src/entrypoint/input/text.rs index 1c0138fd34b3..6c7046602d7e 100644 --- a/lib/llm/src/entrypoint/input/text.rs +++ b/lib/llm/src/entrypoint/input/text.rs @@ -116,6 +116,8 @@ async fn main_loop( chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + tokens: None, + return_token_ids: None, }; // Call the model diff --git a/lib/llm/src/http/service/openai.rs b/lib/llm/src/http/service/openai.rs index a7733428105e..c6a87a43638d 100644 --- a/lib/llm/src/http/service/openai.rs +++ b/lib/llm/src/http/service/openai.rs @@ -200,9 +200,23 @@ impl ErrorMessage { ) } - /// Not Implemented Error - /// Return this error when the client requests a feature that is not yet implemented. - /// This should be used for features that are planned but not available. + /// Bad Request Error. + /// Return this error when the client sends an invalid request — malformed + /// JSON, schema mismatch, or fields that fail `validate.rs` gating. + #[allow(dead_code)] // exposed for downstream crates; not directly called in lib/llm + pub fn bad_request(msg: &str) -> ErrorResponse { + let code = StatusCode::BAD_REQUEST; + let error_type = map_error_code_to_error_type(code); + ( + code, + Json(ErrorMessage { + message: msg.to_string(), + error_type, + code: code.as_u16(), + }), + ) + } + pub fn not_implemented_error(msg: T) -> ErrorResponse { tracing::error!("Not Implemented error: {msg}"); let code = StatusCode::NOT_IMPLEMENTED; @@ -910,6 +924,82 @@ async fn handler_chat_completions( request.nvext = apply_header_routing_overrides(request.nvext.take(), &headers); + // RL field promotion: wire `tokens` and `return_token_ids` when provided on the standard + // chat completions endpoint. This eliminates the need for the rl_admin Python proxy to + // intercept and rewrite these fields. + // + // If `return_token_ids` is true, request completion_token_ids in the response. + // Auto-enable when DYN_ENABLE_RL is set -- ensures token IDs flow even if the + // client forgets to request them. + let rl_want_token_ids = request + .return_token_ids + .take() + .unwrap_or_else(|| dynamo_runtime::config::env_is_truthy("DYN_ENABLE_RL")); + if rl_want_token_ids { + // Reject n > 1 for the RL token-id path: the streaming aggregator + // accumulates `completion_token_ids` into a single `Vec` shared + // across all choices, so the per-choice promotion downstream cannot + // recover which tokens belong to which choice. A keyed-by-index + // accumulator is the long-term fix. + if request.inner.n.unwrap_or(1) > 1 { + return Err(bad_request( + "n > 1 is not supported when RL token IDs are requested. \ + Send separate requests instead.", + )); + } + tracing::debug!("RL: want_token_ids=true, will promote nvext.extra_fields"); + } + { + // If `tokens` is provided, inject into nvext.token_data (pre-tokenized prompt path). + let token_data = request.tokens.take(); + + if token_data.is_some() || rl_want_token_ids { + let mut nvext = request.nvext.take().unwrap_or_default(); + + if let Some(ids) = token_data { + if !ids.is_empty() { + nvext.token_data = Some(ids); + // Ensure messages is non-empty for model lookup / chat template + if request.inner.messages.is_empty() { + use dynamo_protocols::types::{ + ChatCompletionRequestMessage, ChatCompletionRequestUserMessage, + ChatCompletionRequestUserMessageContent, + }; + request + .inner + .messages + .push(ChatCompletionRequestMessage::User( + ChatCompletionRequestUserMessage { + content: ChatCompletionRequestUserMessageContent::Text( + "(token-in mode)".to_string(), + ), + name: None, + }, + )); + } + } + } + + if rl_want_token_ids { + let mut extra_fields = nvext.extra_fields.take().unwrap_or_default(); + for field in &["token_ids", "completion_token_ids"] { + if !extra_fields.contains(&field.to_string()) { + extra_fields.push(field.to_string()); + } + } + nvext.extra_fields = Some(extra_fields); + // RL token-id extraction depends on logprobs being enabled at + // the engine. Override unconditionally — an explicit + // logprobs=false would otherwise drop completion_token_ids + // from the response while we silently still claim to return + // them. + request.inner.logprobs = Some(true); + } + + request.nvext = Some(nvext); + } + } + // create the context for the request let request_id = get_or_create_request_id(&headers); let streaming = request.inner.stream.unwrap_or(false); @@ -1195,6 +1285,26 @@ async fn chat_completions( // todo - decide on default let streaming = request.inner.stream.unwrap_or(false); + // RL: save messages for post-response prompt tokenization (needed for prompt_token_ids). + let rl_saved_messages = if dynamo_runtime::config::env_is_truthy("DYN_ENABLE_RL") && !streaming + { + Some(request.inner.messages.clone()) + } else { + None + }; + + // RL: for TITO requests the caller (handler_chat_completions_tokens) injects a + // placeholder message so Dynamo can select a chat template, but then saves the + // real token IDs in nvext.token_data. Capture them now — before the request is + // consumed by engine.generate() — so the post-processing step can use them + // directly as prompt_token_ids instead of re-tokenizing the placeholder. + let rl_tito_token_ids: Option> = + if dynamo_runtime::config::env_is_truthy("DYN_ENABLE_RL") && !streaming { + request.nvext.as_ref().and_then(|nv| nv.token_data.clone()) + } else { + None + }; + // Apply template values first to resolve the model before creating metrics guards if let Some(template) = template { if request.inner.model.is_empty() { @@ -1406,6 +1516,41 @@ async fn chat_completions( if ctx.is_killed() { inflight_guard.mark_error(ErrorType::Cancelled); } + + // RL post-processing: when DYN_ENABLE_RL is active, promote + // token IDs to the top-level locations that Prime-RL / verifiers expects: + // response.prompt_token_ids (from tokenizing the prompt) + // response.choices[i].token_ids (from nvext.completion_token_ids) + let response = if let Some(ref messages) = rl_saved_messages { + let mut response = response; + // For TITO requests, nvext.token_data IS the prompt — use those IDs + // directly. Falling back to rl_tokenize_prompt would re-tokenize the + // placeholder message injected by handler_chat_completions_tokens and + // return the wrong IDs. + response.prompt_token_ids = + rl_tito_token_ids.or_else(|| rl_tokenize_prompt(&state, &model, messages)); + match serde_json::to_value(&response) { + Ok(mut json_val) => { + rl_promote_token_ids_in_response(&mut json_val); + return Ok(Json(json_val).into_response()); + } + Err(e) => { + // This path means choice.token_ids will NOT be promoted — Prime-RL + // will see None for completion token IDs and may silently drop the + // rollout or crash. Log at error so data-loss does not go unnoticed. + tracing::error!( + request_id, + "rl_promote_token_ids: serde_json serialization failed — \ + choice.token_ids will NOT be promoted to top-level; \ + Prime-RL rollout may be dropped or corrupt: {e}" + ); + } + } + response + } else { + response + }; + Ok(Json(response).into_response()) } } @@ -1441,7 +1586,17 @@ pub fn validate_chat_completion_required_fields( ) -> Result<(), ErrorResponse> { let inner = &request.inner; - if inner.messages.is_empty() { + // RL renderer / TITO callers send `prompt_token_ids` (or legacy + // `nvext.token_data`) in place of `messages`. Treat either pre-tokenized + // input as satisfying the "non-empty input" requirement. + let has_pretokenized_input = request.unsupported_fields.contains_key("prompt_token_ids") + || request + .nvext + .as_ref() + .and_then(|ext| ext.token_data.as_ref()) + .is_some(); + + if inner.messages.is_empty() && !has_pretokenized_input { return Err(ErrorMessage::from_http_error(HttpError { code: 400, message: VALIDATION_PREFIX.to_string() @@ -1909,6 +2064,60 @@ pub(crate) fn check_ready(_state: &Arc) -> Result<(), ErrorRe Ok(()) } +// ── Tokenize / Detokenize ──────────────────────────────────────────── + +fn bad_request>(message: T) -> ErrorResponse { + let code = StatusCode::BAD_REQUEST; + ( + code, + Json(ErrorMessage { + message: message.into(), + error_type: map_error_code_to_error_type(code), + code: code.as_u16(), + }), + ) +} + +fn resolve_tokenizer_model_name( + state: &Arc, + requested_model: Option<&str>, +) -> Result { + if let Some(model) = requested_model { + if state.manager().has_model_any(model) { + return Ok(model.to_string()); + } + return Err(ErrorMessage::model_not_found()); + } + let mut served_models = state.manager().model_display_names(); + if served_models.len() == 1 { + if let Some(only) = served_models.drain().next() { + return Ok(only); + } + } + Err(bad_request( + "Model must be specified when more than one model is served.", + )) +} + +fn resolve_model_card( + state: &Arc, + requested_model: Option<&str>, +) -> Result<(String, crate::model_card::ModelDeploymentCard), ErrorResponse> { + let model = resolve_tokenizer_model_name(state, requested_model)?; + let card = state + .manager() + .get_model_cards() + .into_iter() + .find(|card| card.display_name == model) + .ok_or_else(|| { + ErrorMessage::internal_server_error(&format!( + "Tokenizer metadata is not available for model '{}'", + model + )) + })?; + Ok((model, card)) +} + /// openai compatible format /// Example: /// { @@ -2020,6 +2229,32 @@ pub fn chat_completions_router( (vec![doc], router) } +/// Create an Axum [`Router`] for the RL TITO (Token-In / Token-Out) endpoint. +/// +/// This endpoint accepts Prime-RL's `tokens` field (pre-tokenized prompt), +/// translates it to `nvext.token_data`, forces logprobs on, and delegates +/// to the standard chat_completions handler -- all in Rust, eliminating the +/// Python rl-admin proxy from the hot inference path. +/// +/// If no path is provided, the default path is `/v1/chat/completions/tokens`. +/// +/// Dropped from the v2 surface (see `service_v2.rs`). TITO callers retarget +/// to `/v1/chat/completions` with the `prompt_token_ids` extension — vLLM +/// 0.20+ skips chat templating when that field is present, identical +/// behavior. The handler is kept as `#[allow(dead_code)]` for downstream +/// code that still references it; deletion is a follow-up cleanup. + +/// Handler for TITO (Token-In / Token-Out) chat completions. +/// +/// Accepts Prime-RL's request format which includes a `tokens` field containing +/// pre-tokenized prompt token IDs. The handler: +/// 1. Extracts the `tokens` field +/// 2. Injects them as `nvext.token_data` (Dynamo's native pre-tokenized input) +/// 3. Requests `token_ids` and `completion_token_ids` in the response via `nvext.extra_fields` +/// 4. Forces `logprobs = true` (RL always needs logprobs) +/// 5. Ensures `messages` is non-empty (Dynamo requires it for chat template selection) +/// 6. Delegates to the standard `chat_completions()` internal function (zero HTTP proxy) + /// Create an Axum [`Router`] for the OpenAI API Embeddings endpoint /// If not path is provided, the default path is `/v1/embeddings` pub fn embeddings_router( @@ -2639,6 +2874,136 @@ pub fn audios_router( (vec![doc], router) } +// ────────────────────────────────────────────────────────────────────────── +// RL Admin router: /v1/rl/* +// ────────────────────────────────────────────────────────────────────────── +/// Tokenize chat messages using the model's tokenizer and return prompt token IDs. +/// Used by the RL post-processing path to populate `response.prompt_token_ids`. +fn rl_tokenize_prompt( + state: &Arc, + model: &str, + messages: &[dynamo_protocols::types::ChatCompletionRequestMessage], +) -> Option> { + if messages.is_empty() { + return None; + } + let (_, card) = resolve_model_card(state, Some(model)).ok()?; + let tokenizer = card.tokenizer().ok()?; + let formatter = crate::preprocessor::prompt::PromptFormatter::from_mdc(&card).ok()?; + let inner_request = dynamo_protocols::types::CreateChatCompletionRequest { + model: model.to_string(), + messages: messages.to_vec(), + ..Default::default() + }; + let wrapped = crate::protocols::openai::chat_completions::NvCreateChatCompletionRequest { + inner: inner_request, + common: Default::default(), + nvext: None, + chat_template_args: None, + media_io_kwargs: None, + tokens: None, + return_token_ids: None, + unsupported_fields: Default::default(), + }; + let prompt = match formatter { + crate::preprocessor::prompt::PromptFormatter::OAI(f) => f.render(&wrapped), + } + .ok()?; + let encoding = tokenizer.encode_with_special_tokens(&prompt, true).ok()?; + Some(encoding.token_ids().to_vec()) +} + +/// Promote token IDs from the Dynamo `nvext` response object to the top-level +/// locations that Prime-RL / verifiers expects: +/// +/// response.nvext.completion_token_ids → response.choices[i].token_ids +/// +/// This lets RL clients read `choice.token_ids` without knowing about the +/// `nvext` extension structure. Called on non-streaming responses when RL +/// token ID mode is active. +fn rl_promote_token_ids_in_response(json_val: &mut serde_json::Value) { + // Move completion_token_ids from response-level nvext to each choice, + // because some RL clients expect: + // response.choices[i].token_ids (not response.nvext.completion_token_ids) + let has_nvext = json_val.get("nvext").is_some(); + let has_completion_ids = json_val + .get("nvext") + .and_then(|nv| nv.get("completion_token_ids")) + .is_some(); + + tracing::debug!( + has_nvext, + has_completion_ids, + "rl_promote_token_ids_in_response: inspecting response" + ); + + if let Some(nvext) = json_val.get("nvext") { + if let Some(completion_ids) = nvext.get("completion_token_ids").cloned() { + let n = completion_ids.as_array().map(|a| a.len()).unwrap_or(0); + tracing::info!( + n_completion_ids = n, + "rl_promote: copying completion_token_ids to choices[].token_ids" + ); + if let Some(choices) = json_val.get_mut("choices").and_then(|c| c.as_array_mut()) { + for choice in choices.iter_mut() { + if let Some(obj) = choice.as_object_mut() { + obj.insert("token_ids".to_string(), completion_ids.clone()); + } + } + } + } + } +} + +/// `GET /v1/rl/health` — lightweight health check for Prime-RL admin client. +/// +/// RL admin clients that POST `GET /health` against the admin client land +/// here when `admin_base_url = ["http://dynamo:8000/v1/rl"]`. Returns 200 OK +/// if the frontend process is running (no deep probe needed — the frontend's +/// own `/health` endpoint handles that separately). +/// +/// **Deprecated in favor of `/v1/rl/state.ingress_alive`.** Kept for +/// back-compat until existing clients migrate to `/v1/rl/state`; will be +/// removed in a follow-up. + +// ── RL admin router ──────────────────────────────────────────────────── +// All `/v1/rl/*` handlers, `RlState`, body types, and fan-out logic now +// live in the `dynamo-rl` crate (see `plans/rl-crate.md`). This shim +// delegates and wraps the result into dynamo-llm's `RouteDoc` plus the +// shared `smart_json_error_middleware` that all OpenAI-side routes use. + +/// Build the `/v1/rl/*` router. Delegates to `dynamo_rl::rl_router()` and +/// wraps the documentation tuples into `RouteDoc`. Wraps the router with +/// `smart_json_error_middleware` so 422s are coerced to 400s consistently +/// with the OpenAI-compat surface. +/// +/// Exposed only on the dedicated RL listener when +/// `DYN_ENABLE_RL_ENDPOINTS=true` or `HttpServiceConfig.enable_rl` is set. +pub fn rl_router( + drt: std::sync::Arc, +) -> anyhow::Result<(Vec, Router)> { + let namespace = std::env::var("DYN_NAMESPACE").unwrap_or_else(|_| "dynamo".into()); + let mut policy = dynamo_rl::FanoutPolicy::default_admin(); + if let Ok(component) = std::env::var("DYN_RL_COMPONENT") { + policy = policy.with_component_filter(vec![component]); + } + + let client = dynamo_rl::RlClient::new(dynamo_rl::RlClientConfig { + runtime: drt, + namespace, + rl_endpoint: dynamo_rl::DEFAULT_RL_ENDPOINT.to_string(), + policy, + })?; + + let (rl_docs, router) = dynamo_rl::rl_router(dynamo_rl::RlHttpDeps { client })?; + let docs = rl_docs + .into_iter() + .map(|d| RouteDoc::new(d.method, d.path)) + .collect(); + let router = router.layer(middleware::from_fn(smart_json_error_middleware)); + Ok((docs, router)) +} + #[cfg(test)] mod tests { @@ -2880,6 +3245,9 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, }; let result = validate_chat_completion_required_fields(&request); assert!(result.is_err()); @@ -2912,6 +3280,9 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, }; let result = validate_chat_completion_required_fields(&request); assert!(result.is_ok()); @@ -3128,6 +3499,9 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, }; let result = validate_chat_completion_fields_generic(&request); @@ -3158,6 +3532,9 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, }; let result = validate_chat_completion_fields_generic(&request); assert!(result.is_err()); @@ -3187,6 +3564,9 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, }; let result = validate_chat_completion_fields_generic(&request); assert!(result.is_err()); @@ -3216,6 +3596,9 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, }; let result = validate_chat_completion_fields_generic(&request); assert!(result.is_err()); @@ -3247,6 +3630,9 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, }; let result = validate_chat_completion_fields_generic(&request); assert!(result.is_err()); @@ -3276,6 +3662,9 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, }; let result = validate_chat_completion_fields_generic(&request); assert!(result.is_err()); @@ -4303,4 +4692,63 @@ mod tests { let json = extract_sse_data_json(events[0].as_ref().unwrap()); assert_eq!(json["reasoning_content"], "让我想想 🤔 分析完成 ✅"); } + + // ── RL admin types ────────────────────────────────────────────────── + + #[test] + fn test_pause_mode_serde_roundtrip() { + for (mode, lower) in [ + (PauseMode::Keep, "keep"), + (PauseMode::Wait, "wait"), + (PauseMode::Abort, "abort"), + ] { + let json = serde_json::to_string(&mode).unwrap(); + assert_eq!(json, format!("\"{lower}\"")); + assert_eq!(mode.as_str(), lower); + let parsed: PauseMode = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed, mode); + } + assert_eq!(PauseMode::default(), PauseMode::Keep); + } + + #[test] + fn test_pause_mode_rejects_unknown_value() { + // Axum returns 400 on this deserialize failure before the handler + // runs — that's the whole point of the typed enum vs the prior + // string match. + let err = + serde_json::from_str::("\"foo\"").expect_err("foo is not a valid PauseMode"); + assert!(err.to_string().to_lowercase().contains("foo")); + } + + #[test] + fn test_rl_update_weights_body_defaults() { + let body: RlUpdateWeightsBody = serde_json::from_str(r#"{}"#).unwrap(); + assert!(body.weight_dir.is_none()); + assert!(body.weight_version.is_none()); + assert!(body.reset_prefix_cache); + + let body: RlUpdateWeightsBody = serde_json::from_str(r#"{"weight_dir":null}"#).unwrap(); + assert!(body.weight_dir.is_none()); + assert!(body.reset_prefix_cache); + + let body: RlUpdateWeightsBody = + serde_json::from_str(r#"{"weight_dir":"/path","reset_prefix_cache":false}"#).unwrap(); + assert_eq!(body.weight_dir.as_deref(), Some("/path")); + assert!(!body.reset_prefix_cache); + } + + #[test] + fn test_rl_state_new_constructs_without_env() { + // Sanity check the testability constructor — needed so future + // route-level tests can build an `RlState` without env vars or a + // real network client. + let state = RlState::new( + vec!["http://w0:9090".to_string(), "http://w1:9090".to_string()], + reqwest::Client::new(), + std::time::Duration::from_millis(100), + ); + assert_eq!(state.worker_system_urls.len(), 2); + assert_eq!(state.probe_timeout, std::time::Duration::from_millis(100)); + } } diff --git a/lib/llm/src/http/service/service_v2.rs b/lib/llm/src/http/service/service_v2.rs index 1eeaefdff912..e87edbf87c21 100644 --- a/lib/llm/src/http/service/service_v2.rs +++ b/lib/llm/src/http/service/service_v2.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use std::env::var; +use std::io::ErrorKind; use std::path::PathBuf; use std::sync::Arc; use std::sync::atomic::AtomicBool; @@ -204,6 +205,8 @@ pub struct HttpService { router: axum::Router, port: u16, + rl_router: Option, + rl_port: u16, host: String, enable_tls: bool, tls_cert_path: Option, @@ -217,6 +220,9 @@ pub struct HttpServiceConfig { #[builder(default = "8787")] port: u16, + #[builder(default = "default_rl_port()")] + rl_port: u16, + #[builder(setter(into), default = "String::from(\"0.0.0.0\")")] host: String, @@ -246,6 +252,12 @@ pub struct HttpServiceConfig { #[builder(default = "false")] enable_anthropic_endpoints: bool, + /// When true, expose the RL admin routes at `/v1/rl/*` on the dedicated + /// `rl_port` listener. Fan-out uses dynamo-rl over the discovery and + /// request planes; worker system ports are not part of this contract. + #[builder(default = "false")] + enable_rl: bool, + #[builder(default = "None")] request_template: Option, @@ -264,6 +276,12 @@ pub struct HttpServiceConfig { /// are registered using discovery.instance_id() and exposed on /metrics. #[builder(default = "None")] drt_discovery: Option>, + + /// Required when `enable_rl` (or `DYN_ENABLE_RL_ENDPOINTS=true`): the + /// dynamo-rl crate uses this runtime's discovery + request planes to + /// fan out admin calls to live `..rl` endpoint instances. + #[builder(default = "None")] + runtime: Option>, } impl HttpService { @@ -289,103 +307,57 @@ impl HttpService { } pub async fn run(&self, cancel_token: CancellationToken) -> Result<()> { - let address = format!("{}:{}", self.host, self.port); - let protocol = if self.enable_tls { "HTTPS" } else { "HTTP" }; - tracing::info!(protocol, address, "Starting HTTP(S) service"); - - let router = self.router.clone(); - let observer = cancel_token.child_token(); - - let state_cancel = self.state.cancel_token().clone(); - - let addr: SocketAddr = address - .parse() - .map_err(|e| anyhow::anyhow!("Invalid address '{}': {}", address, e))?; - - if self.enable_tls { - let cert_path = self - .tls_cert_path - .as_ref() - .ok_or_else(|| anyhow::anyhow!("TLS certificate path not provided"))?; - let key_path = self - .tls_key_path - .as_ref() - .ok_or_else(|| anyhow::anyhow!("TLS private key path not provided"))?; - - // aws_lc_rs is the default but other crates pull in `ring` also, - // so rustls doesn't know which one to use. Tell it. - if let Err(e) = rustls::crypto::aws_lc_rs::default_provider().install_default() { - tracing::debug!("TLS crypto provider already installed: {e:?}"); - } + let mut handles = vec![spawn_http_listener(HttpListenerConfig { + name: "openai", + router: self.router.clone(), + host: self.host.clone(), + port: self.port, + port_arg: "--http-port", + enable_tls: self.enable_tls, + tls_cert_path: self.tls_cert_path.clone(), + tls_key_path: self.tls_key_path.clone(), + cancel_token: cancel_token.clone(), + state_cancel: self.state.cancel_token().clone(), + })]; + + if let Some(router) = self.rl_router.clone() { + handles.push(spawn_http_listener(HttpListenerConfig { + name: "rl", + router, + host: self.host.clone(), + port: self.rl_port, + port_arg: "--rl-port", + enable_tls: self.enable_tls, + tls_cert_path: self.tls_cert_path.clone(), + tls_key_path: self.tls_key_path.clone(), + cancel_token: cancel_token.clone(), + state_cancel: self.state.cancel_token().clone(), + })); + } - let config = RustlsConfig::from_pem_file(cert_path, key_path) - .await - .map_err(|e| anyhow::anyhow!("Failed to create TLS config: {}", e))?; + tokio::spawn(tokio_metrics_and_canary_loop(cancel_token.clone())); - let handle = axum_server::Handle::new(); - let server = axum_server::bind_rustls(addr, config) - .handle(handle.clone()) - .serve(router.into_make_service()); + let (first_result, _idx, remaining) = futures::future::select_all(handles).await; + cancel_token.cancel(); - // Spawn canary after all fallible startup so it won't leak on early errors - tokio::spawn(tokio_metrics_and_canary_loop(cancel_token.clone())); + let mut result = match first_result { + Ok(result) => result, + Err(err) => Err(anyhow::anyhow!("HTTP listener task failed: {err}")), + }; - tokio::select! { - result = server => { - let result = result.map_err(|e| anyhow::anyhow!("HTTPS server error: {}", e)); - cancel_token.cancel(); - result?; - } - _ = observer.cancelled() => { - state_cancel.cancel(); - tracing::info!("HTTPS server shutdown requested"); - // accepting requests for 5 more seconds, to allow incorrectly routed requests to arrive - handle.graceful_shutdown(Some(Duration::from_secs(get_graceful_shutdown_timeout() as u64))); - // no longer accepting requests, draining all existing connections + for handle in remaining { + match handle.await { + Ok(Ok(())) => {} + Ok(Err(err)) if result.is_ok() => result = Err(err), + Ok(Err(_)) => {} + Err(err) if result.is_ok() => { + result = Err(anyhow::anyhow!("HTTP listener task failed: {err}")); } + Err(_) => {} } - } else { - let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| { - tracing::error!( - protocol = %protocol, - address = %address, - error = %e, - "Failed to bind server to address" - ); - match e.kind() { - std::io::ErrorKind::AddrInUse => anyhow::anyhow!( - "Failed to start {} server: port {} already in use. Use --http-port to specify a different port.", - protocol, - self.port - ), - _ => anyhow::anyhow!( - "Failed to start {} server on {}: {}", - protocol, - address, - e - ), - } - })?; - - // Spawn canary after all fallible startup so it won't leak on early errors - tokio::spawn(tokio_metrics_and_canary_loop(cancel_token.clone())); - - axum::serve(listener, router) - .with_graceful_shutdown(async move { - observer.cancelled_owned().await; - state_cancel.cancel(); - tracing::info!("HTTP server shutdown requested"); - // accepting requests for 5 more seconds, to allow incorrectly routed requests to arrive - tokio::time::sleep(Duration::from_secs(get_graceful_shutdown_timeout() as u64)) - .await; - // no longer accepting requests, draining all existing connections - }) - .await - .inspect_err(|_| cancel_token.cancel())?; - cancel_token.cancel(); } - Ok(()) + result } /// Documentation of exposed HTTP endpoints @@ -403,6 +375,150 @@ impl HttpService { } } +struct HttpListenerConfig { + name: &'static str, + router: axum::Router, + host: String, + port: u16, + port_arg: &'static str, + enable_tls: bool, + tls_cert_path: Option, + tls_key_path: Option, + cancel_token: CancellationToken, + state_cancel: CancellationToken, +} + +fn spawn_http_listener(config: HttpListenerConfig) -> JoinHandle> { + tokio::spawn(run_http_listener(config)) +} + +async fn run_http_listener(config: HttpListenerConfig) -> Result<()> { + let address = format!("{}:{}", config.host, config.port); + let protocol = if config.enable_tls { "HTTPS" } else { "HTTP" }; + tracing::info!( + listener = config.name, + protocol, + address, + "Starting HTTP listener" + ); + + let addr: SocketAddr = address + .parse() + .map_err(|e| anyhow::anyhow!("Invalid address '{}': {}", address, e))?; + + if config.enable_tls { + run_tls_listener(config, addr, protocol, address).await + } else { + run_plain_listener(config, addr, protocol, address).await + } +} + +async fn run_tls_listener( + config: HttpListenerConfig, + addr: SocketAddr, + protocol: &'static str, + address: String, +) -> Result<()> { + let cert_path = config + .tls_cert_path + .as_ref() + .ok_or_else(|| anyhow::anyhow!("TLS certificate path not provided"))?; + let key_path = config + .tls_key_path + .as_ref() + .ok_or_else(|| anyhow::anyhow!("TLS private key path not provided"))?; + + // aws_lc_rs is the default but other crates pull in `ring` also, + // so rustls doesn't know which one to use. Tell it. + if let Err(e) = rustls::crypto::aws_lc_rs::default_provider().install_default() { + tracing::debug!("TLS crypto provider already installed: {e:?}"); + } + + let tls_config = RustlsConfig::from_pem_file(cert_path, key_path) + .await + .map_err(|e| anyhow::anyhow!("Failed to create TLS config: {}", e))?; + + let handle = axum_server::Handle::new(); + let observer = config.cancel_token.child_token(); + let state_cancel = config.state_cancel.clone(); + let listener_name = config.name; + let server = axum_server::bind_rustls(addr, tls_config) + .handle(handle.clone()) + .serve(config.router.into_make_service()); + + tokio::select! { + result = server => { + result.map_err(|e| { + tracing::error!( + listener = listener_name, + protocol = %protocol, + address = %address, + error = %e, + "HTTP listener failed" + ); + anyhow::anyhow!("{} listener '{}' error: {}", protocol, listener_name, e) + })?; + } + _ = observer.cancelled_owned() => { + state_cancel.cancel(); + tracing::info!(listener = listener_name, "HTTP listener shutdown requested"); + // accepting requests for a short window allows incorrectly routed + // requests already in flight to arrive before draining connections. + handle.graceful_shutdown(Some(Duration::from_secs(get_graceful_shutdown_timeout() as u64))); + } + } + + Ok(()) +} + +async fn run_plain_listener( + config: HttpListenerConfig, + addr: SocketAddr, + protocol: &'static str, + address: String, +) -> Result<()> { + let listener = tokio::net::TcpListener::bind(addr).await.map_err(|e| { + tracing::error!( + listener = config.name, + protocol = %protocol, + address = %address, + error = %e, + "Failed to bind HTTP listener to address" + ); + match e.kind() { + ErrorKind::AddrInUse => anyhow::anyhow!( + "Failed to start {} listener '{}': port {} already in use. Use {} to specify a different port.", + protocol, + config.name, + config.port, + config.port_arg + ), + _ => anyhow::anyhow!( + "Failed to start {} listener '{}' on {}: {}", + protocol, + config.name, + address, + e + ), + } + })?; + + let observer = config.cancel_token.child_token(); + let state_cancel = config.state_cancel.clone(); + let listener_name = config.name; + + axum::serve(listener, config.router) + .with_graceful_shutdown(async move { + observer.cancelled_owned().await; + state_cancel.cancel(); + tracing::info!(listener = listener_name, "HTTP listener shutdown requested"); + tokio::time::sleep(Duration::from_secs(get_graceful_shutdown_timeout() as u64)).await; + }) + .await?; + + Ok(()) +} + fn get_graceful_shutdown_timeout() -> usize { std::env::var(env_llm::DYN_HTTP_GRACEFUL_SHUTDOWN_TIMEOUT_SECS) .ok() @@ -410,6 +526,16 @@ fn get_graceful_shutdown_timeout() -> usize { .unwrap_or(5) } +const DEFAULT_RL_PORT: u16 = 8002; +const DYN_RL_PORT_ENV: &str = "DYN_RL_PORT"; + +fn default_rl_port() -> u16 { + std::env::var(DYN_RL_PORT_ENV) + .ok() + .and_then(|s| s.parse::().ok()) + .unwrap_or(DEFAULT_RL_PORT) +} + /// Environment variable to set the metrics endpoint path (default: `/metrics`) static HTTP_SVC_METRICS_PATH_ENV: &str = "DYN_HTTP_SVC_METRICS_PATH"; /// Environment variable to set the models endpoint path (default: `/v1/models`) @@ -536,6 +662,42 @@ impl HttpServiceConfigBuilder { super::health::live_check_router(state.clone(), var(HTTP_SVC_LIVE_PATH_ENV).ok()), super::busy_threshold::busy_threshold_router(state.clone(), None), ]; + // RL admin routes: gated by `DYN_ENABLE_RL_ENDPOINTS` (frontend-only) + // and served on a separate listener (`DYN_RL_PORT`, default 8002). + // `DYN_ENABLE_RL` remains the inference-plane flag and no longer + // mounts admin routes on the OpenAI listener. + let rl_router = if config.enable_rl || env_is_truthy("DYN_ENABLE_RL_ENDPOINTS") { + match config.runtime.as_ref() { + Some(drt) => { + tracing::info!( + rl_port = config.rl_port, + "RL admin routes enabled at /v1/rl/* on dedicated listener" + ); + let (rl_docs, router) = super::openai::rl_router(drt.clone())?; + let (_openapi_docs, openapi_route) = + super::openapi_docs::openapi_router(rl_docs, None); + let router = router + .merge(openapi_route) + .layer( + TraceLayer::new_for_http() + .make_span_with(make_system_request_span) + .on_response(on_response), + ) + .layer(axum::middleware::from_fn(echo_request_id_header)); + Some(router) + } + None => { + tracing::warn!( + "RL admin routes requested (DYN_ENABLE_RL_ENDPOINTS=true) but \ + HttpServiceConfigBuilder.runtime is None — skipping mount. \ + The frontend caller must supply the DistributedRuntime." + ); + None + } + } + } else { + None + }; let mut system_router = axum::Router::new(); for (route_docs, route) in system_routes { system_router = system_router.merge(route); @@ -576,6 +738,8 @@ impl HttpServiceConfigBuilder { state, router, port: config.port, + rl_router, + rl_port: config.rl_port, host: config.host, enable_tls: config.enable_tls, tls_cert_path: config.tls_cert_path, @@ -600,6 +764,11 @@ impl HttpServiceConfigBuilder { request_template.clone(), var(HTTP_SVC_CHAT_PATH_ENV).ok(), ); + // The legacy `/v1/chat/completions/tokens` TITO fork URI is dropped. + // TITO callers send `prompt_token_ids` as a top-level extension on + // `/v1/chat/completions` (allowlisted by `validate.rs::PASSTHROUGH_EXTRA_FIELDS`); + // vLLM 0.20+ skips chat templating when that field is present. + let (cmpl_docs, cmpl_route) = super::openai::completions_router(state.clone(), var(HTTP_SVC_CMP_PATH_ENV).ok()); let (embed_docs, embed_route) = @@ -612,6 +781,8 @@ impl HttpServiceConfigBuilder { request_template.clone(), var(HTTP_SVC_RESPONSES_PATH_ENV).ok(), ); + // Phase 5: TITO fork URI dropped — chat route stands alone now. + let mut endpoint_routes = HashMap::new(); endpoint_routes.insert(EndpointType::Chat, (chat_docs, chat_route)); endpoint_routes.insert(EndpointType::Completion, (cmpl_docs, cmpl_route)); diff --git a/lib/llm/src/local_model.rs b/lib/llm/src/local_model.rs index 000f87a300b3..3c88dd64e254 100644 --- a/lib/llm/src/local_model.rs +++ b/lib/llm/src/local_model.rs @@ -57,6 +57,7 @@ pub struct LocalModelBuilder { kv_cache_block_size: u32, http_host: Option, http_port: u16, + rl_port: Option, http_metrics_port: Option, tls_cert_path: Option, tls_key_path: Option, @@ -80,6 +81,7 @@ impl Default for LocalModelBuilder { kv_cache_block_size: DEFAULT_KV_CACHE_BLOCK_SIZE, http_host: Default::default(), http_port: DEFAULT_HTTP_PORT, + rl_port: None, http_metrics_port: None, tls_cert_path: Default::default(), tls_key_path: Default::default(), @@ -152,6 +154,11 @@ impl LocalModelBuilder { self } + pub fn rl_port(&mut self, port: Option) -> &mut Self { + self.rl_port = port; + self + } + pub fn http_metrics_port(&mut self, port: Option) -> &mut Self { self.http_metrics_port = port; self @@ -282,6 +289,7 @@ impl LocalModelBuilder { template, http_host: self.http_host.take(), http_port: self.http_port, + rl_port: self.rl_port, http_metrics_port: self.http_metrics_port, tls_cert_path: self.tls_cert_path.take(), tls_key_path: self.tls_key_path.take(), @@ -339,6 +347,7 @@ impl LocalModelBuilder { template, http_host: self.http_host.take(), http_port: self.http_port, + rl_port: self.rl_port, http_metrics_port: self.http_metrics_port, tls_cert_path: self.tls_cert_path.take(), tls_key_path: self.tls_key_path.take(), @@ -362,6 +371,7 @@ pub struct LocalModel { template: Option, http_host: Option, http_port: u16, + rl_port: Option, http_metrics_port: Option, tls_cert_path: Option, tls_key_path: Option, @@ -418,6 +428,10 @@ impl LocalModel { self.http_port } + pub fn rl_port(&self) -> Option { + self.rl_port + } + pub fn http_metrics_port(&self) -> Option { self.http_metrics_port } diff --git a/lib/llm/src/preprocessor.rs b/lib/llm/src/preprocessor.rs index 53756bb90c49..3fa355a98053 100644 --- a/lib/llm/src/preprocessor.rs +++ b/lib/llm/src/preprocessor.rs @@ -379,6 +379,16 @@ impl OpenAIPreprocessor { &self, request: &R, ) -> Result> { + // Renderer / TITO callers post `prompt_token_ids` (or legacy + // `nvext.token_data`); chat templating is bypassed entirely. + // `gather_tokens` reads the same channel via `get_pretokenized_input` + // and feeds the engine directly, so we must not attempt to render + // a chat template here (would fail with "undefined value" when + // `messages` is empty). + if request.get_pretokenized_input().is_some() { + return Ok(None); + } + if let PromptInput::Text(_) = request.prompt_input_type() && let Some(TextInput::Single(_)) = request.extract_text() { @@ -593,18 +603,31 @@ impl OpenAIPreprocessor { .and_then(|ext| ext.backend_instance_id) .is_some(); - let token_data = - request.nvext().and_then(|ext| ext.token_data.as_ref()); - + // get_pretokenized_input() consults both + // `nvext.token_data` (legacy GAIE/EPP/TITO path) AND + // top-level `prompt_token_ids` extension (renderer / TITO + // canonical path now that `/v1/chat/completions/tokens` is + // dropped). Either channel produces the same engine input. + let token_data = request.get_pretokenized_input(); + + // Use token_data when provided (TITO / EPP / RL), + // regardless of backend_instance_id. + // + // skip_token_annotation = has_backend_instance_id: GAIE EPP-style + // callers (which set backend_instance_id) pre-tokenize and don't + // want the annotation echoed back; RL / TITO callers (no + // backend_instance_id) DO want the token_ids annotation in the + // response so the trainer can validate. let (tokens_vec, skip_token_annotation) = if let Some(tokens) = token_data { tracing::info!( token_count = tokens.len(), first_tokens = ?&tokens[..std::cmp::min(5, tokens.len())], - "[SIDECAR-SKIP-TOKENIZE] Found nvext.token_data — using pre-computed tokens, SKIPPING tokenization" + backend_instance_id = has_backend_instance_id, + "[SIDECAR-SKIP-TOKENIZE] Found pre-tokenized input (nvext.token_data or prompt_token_ids extension) — using pre-computed tokens, SKIPPING tokenization" ); - (tokens.clone(), true) + (tokens, has_backend_instance_id) } else if has_backend_instance_id { tracing::warn!( "backend_instance_id provided but no token_data; tokenizing prompt" diff --git a/lib/llm/src/protocols/anthropic/types.rs b/lib/llm/src/protocols/anthropic/types.rs index 5214ee10b0a7..db2ef8406375 100644 --- a/lib/llm/src/protocols/anthropic/types.rs +++ b/lib/llm/src/protocols/anthropic/types.rs @@ -141,6 +141,8 @@ impl TryFrom for NvCreateChatCompletionRequest { }, media_io_kwargs: None, unsupported_fields: Default::default(), + tokens: None, + return_token_ids: None, }) } } @@ -821,6 +823,8 @@ mod tests { }), }, nvext: None, + + prompt_token_ids: None, }; let response = chat_completion_to_anthropic_response(chat_resp, "test-model", None); diff --git a/lib/llm/src/protocols/openai.rs b/lib/llm/src/protocols/openai.rs index 42ef621f8797..e22d4ae12f80 100644 --- a/lib/llm/src/protocols/openai.rs +++ b/lib/llm/src/protocols/openai.rs @@ -80,6 +80,19 @@ pub(crate) trait OpenAIStopConditionsProvider { fn get_max_thinking_tokens(&self) -> Option { self.nvext().and_then(|nv| nv.max_thinking_tokens) } + + /// Get token-id-based stop conditions (renderer / TITO parity with vLLM + /// `/inference/v1/generate`'s `sampling_params.stop_token_ids`). + /// + /// Default returns `Ok(None)`; chat-completions / completions impls + /// override to read the field from `unsupported_fields["stop_token_ids"]` + /// after the `validate.rs` PASSTHROUGH allowlist accepts it. Plumbed + /// through into `common::StopConditions::stop_token_ids_hidden` by + /// `extract_stop_conditions`. Returns `Result` so malformed payloads + /// surface as a typed 400 instead of silently dropping the field. + fn get_stop_token_ids(&self) -> anyhow::Result>> { + Ok(None) + } } pub(crate) trait OpenAIOutputOptionsProvider { @@ -90,6 +103,10 @@ pub(crate) trait OpenAIOutputOptionsProvider { fn get_skip_special_tokens(&self) -> Option; fn get_formatted_prompt(&self) -> Option; + + fn get_return_tokens_as_token_ids(&self) -> Option { + None + } } impl SamplingOptionsProvider for T { @@ -176,6 +193,13 @@ impl StopConditionsProvider for T { let min_tokens = self.get_min_tokens(); let stop = self.get_stop(); let max_thinking_tokens = self.get_max_thinking_tokens(); + // Token-id stop conditions ride through PASSTHROUGH_EXTRA_FIELDS on the + // chat-completions surface; impls of this trait read it from the + // request's `unsupported_fields` map. Engine receives it as + // `stop_token_ids_hidden` (already wired in `common::StopConditions`). + // The `?` propagates a typed 400 on malformed payloads (e.g. + // `stop_token_ids: "not-an-array"`). + let stop_token_ids_hidden = self.get_stop_token_ids()?; if let Some(stop) = &stop && stop.len() > 4 @@ -190,7 +214,7 @@ impl StopConditionsProvider for T { max_tokens, min_tokens, stop, - stop_token_ids_hidden: None, + stop_token_ids_hidden, ignore_eos, max_thinking_tokens, }) @@ -203,7 +227,6 @@ impl OutputOptionsProvider for T { let prompt_logprobs = self.get_prompt_logprobs(); let skip_special_tokens = self.get_skip_special_tokens(); let formatted_prompt = self.get_formatted_prompt(); - Ok(common::OutputOptions { logprobs, prompt_logprobs, diff --git a/lib/llm/src/protocols/openai/chat_completions.rs b/lib/llm/src/protocols/openai/chat_completions.rs index 8a77038d5834..6d1f071d5a18 100644 --- a/lib/llm/src/protocols/openai/chat_completions.rs +++ b/lib/llm/src/protocols/openai/chat_completions.rs @@ -59,6 +59,19 @@ pub struct NvCreateChatCompletionRequest { #[serde(default, skip_serializing_if = "Option::is_none")] pub media_io_kwargs: Option, + /// Legacy RL field (pre-tokenized prompt). Accepted but ignored on + /// `/v1/chat/completions` — the canonical TITO channel is now the + /// top-level `prompt_token_ids` extension on the same endpoint + /// (allowlisted in `validate.rs::PASSTHROUGH_EXTRA_FIELDS`). Kept here + /// so older clients still sending `tokens` don't 400. + #[serde(default, skip_serializing)] + pub tokens: Option>, + + /// Legacy RL field. Accepted but ignored on standard chat completions — + /// use `nvext.extra_fields = ["completion_token_ids"]` instead. + #[serde(default, skip_serializing)] + pub return_token_ids: Option, + /// Catch-all for unsupported fields - checked during validation #[serde(flatten, default, skip_serializing)] pub unsupported_fields: std::collections::HashMap, @@ -72,6 +85,10 @@ pub struct NvCreateChatCompletionResponse { pub inner: dynamo_protocols::types::CreateChatCompletionResponse, #[serde(skip_serializing_if = "Option::is_none")] pub nvext: Option, + /// RL: Prompt token IDs for Prime-RL/verifiers alignment. + /// Populated when `DYN_ENABLE_RL=true` or `return_token_ids=true`. + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_token_ids: Option>, } /// A response structure for streamed chat completions, embedding OpenAI's @@ -96,6 +113,25 @@ impl NvExtProvider for NvCreateChatCompletionRequest { fn raw_prompt(&self) -> Option { None } + + /// Pre-tokenized input — checks `nvext.token_data` first (legacy path), + /// falls back to top-level `prompt_token_ids` extension that + /// `validate.rs` PASSTHROUGH_EXTRA_FIELDS allowlists. The two channels + /// are equivalent at the engine level; the top-level extension is the + /// canonical home now that `/v1/chat/completions/tokens` is dropped. + fn get_pretokenized_input(&self) -> Option> { + // 1. Prefer nvext.token_data when present (existing GAIE/EPP path). + if let Some(token_data) = self.nvext.as_ref().and_then(|ext| ext.token_data.as_ref()) { + return Some(token_data.clone()); + } + // 2. Fall back to top-level `prompt_token_ids` extension. Renderer + // and TITO callers post here directly — the field rides through + // PASSTHROUGH_EXTRA_FIELDS without 400, then we promote it to + // the engine path here. + self.unsupported_fields + .get("prompt_token_ids") + .and_then(|v| serde_json::from_value::>(v.clone()).ok()) + } } /// Implements `AnnotationsProvider` for `NvCreateChatCompletionRequest`, @@ -305,6 +341,26 @@ impl OpenAIStopConditionsProvider for NvCreateChatCompletionRequest { fn get_ignore_eos(&self) -> Option { self.common.ignore_eos } + + /// Read `stop_token_ids` from `unsupported_fields` (allowlisted via + /// `validate.rs` PASSTHROUGH_EXTRA_FIELDS). RL renderer / TITO callers + /// rely on this for stop-on-token-id conditions that don't tokenize + /// cleanly as strings (custom EOS, model-specific control tokens). + /// Malformed values surface as a typed `anyhow::Error` so the caller + /// gets a 400 with a useful diagnostic rather than a silent drop. + fn get_stop_token_ids(&self) -> anyhow::Result>> { + let Some(value) = self.unsupported_fields.get("stop_token_ids") else { + return Ok(None); + }; + if value.is_null() { + return Ok(None); + } + serde_json::from_value(value.clone()) + .map(Some) + .map_err(|err| { + anyhow::anyhow!("stop_token_ids must be an array of unsigned token IDs: {err}") + }) + } } impl OpenAIOutputOptionsProvider for NvCreateChatCompletionRequest { @@ -337,7 +393,31 @@ impl OpenAIOutputOptionsProvider for NvCreateChatCompletionRequest { impl ValidateRequest for NvCreateChatCompletionRequest { fn validate(&self) -> Result<(), anyhow::Error> { validate::validate_no_unsupported_fields(&self.unsupported_fields)?; - validate::validate_messages(&self.inner.messages)?; + // Mutual-exclusivity applies ONLY to the canonical top-level + // `prompt_token_ids` extension (the new vLLM-0.20-aligned channel). + // The legacy `nvext.token_data` channel is intentionally allowed to + // coexist with non-empty messages — that's how the renderer transport + // ships pre-tokenized inputs alongside placeholder messages + // (PrimeIntellect-ai/verifiers PR #1287's `dynamo_chat_nvext` mode). + // Empty messages are accepted when EITHER channel carries tokens. + let has_top_level_prompt_token_ids = + self.unsupported_fields.contains_key("prompt_token_ids"); + let has_nvext_token_data = self + .nvext + .as_ref() + .and_then(|ext| ext.token_data.as_ref()) + .is_some(); + + if has_top_level_prompt_token_ids && !self.inner.messages.is_empty() { + anyhow::bail!( + "messages and prompt_token_ids are mutually exclusive; \ + send one (use prompt_token_ids for renderer / TITO mode, \ + messages for MITO mode)" + ); + } + if !has_top_level_prompt_token_ids && !has_nvext_token_data { + validate::validate_messages(&self.inner.messages)?; + } validate::validate_model(&self.inner.model)?; // none for store validate::validate_reasoning_effort(&self.inner.reasoning_effort)?; diff --git a/lib/llm/src/protocols/openai/chat_completions/aggregator.rs b/lib/llm/src/protocols/openai/chat_completions/aggregator.rs index 5def709448d8..2b5fc16e1414 100644 --- a/lib/llm/src/protocols/openai/chat_completions/aggregator.rs +++ b/lib/llm/src/protocols/openai/chat_completions/aggregator.rs @@ -419,6 +419,7 @@ impl DeltaAggregator { service_tier: aggregator.service_tier, }, nvext: aggregator.nvext, + prompt_token_ids: None, }; Ok(response) diff --git a/lib/llm/src/protocols/openai/chat_completions/delta.rs b/lib/llm/src/protocols/openai/chat_completions/delta.rs index 4a3dfae2d405..5ff1ed5522c7 100644 --- a/lib/llm/src/protocols/openai/chat_completions/delta.rs +++ b/lib/llm/src/protocols/openai/chat_completions/delta.rs @@ -10,7 +10,7 @@ use crate::{ common::{self, timing::RequestTracker}, openai::{ convert_backend_top_logprobs, - nvext::{NvExtProvider, NvExtResponseFieldSelection}, + nvext::{NvExtProvider, NvExtResponse, NvExtResponseFieldSelection}, token_to_utf8_bytes, }, }, @@ -51,6 +51,7 @@ impl NvCreateChatCompletionRequest { /// # Returns /// * [`DeltaGenerator`] configured with model name and response options. pub fn response_generator(&self, request_id: String) -> DeltaGenerator { + // `completion_token_ids` is parsed by from_nvext into response_fields. let response_fields = NvExtResponseFieldSelection::from_nvext(self.nvext()); let options = DeltaGeneratorOptions { @@ -86,6 +87,7 @@ pub struct DeltaGeneratorOptions { /// Determines whether log probabilities should be included in the response. pub enable_logprobs: bool, /// Determines which nvext response fields may be emitted for this request. + /// (Includes `completion_token_ids` for the RL inference path.) pub response_fields: NvExtResponseFieldSelection, pub runtime_config: ModelRuntimeConfig, @@ -112,6 +114,10 @@ pub struct DeltaGenerator { options: DeltaGeneratorOptions, /// Optional request tracker for per-request metrics (shared with PreprocessedRequest). tracker: Option>, + /// Accumulated output token IDs across chunks. Only used when + /// `options.response_fields.completion_token_ids` is true. Emitted in `nvext.completion_token_ids` + /// on the final (finish_reason-bearing) chunk. + accumulated_completion_token_ids: Vec, } impl DeltaGenerator { @@ -160,6 +166,7 @@ impl DeltaGenerator { msg_counter: 0, options, tracker, + accumulated_completion_token_ids: Vec::new(), } } @@ -353,6 +360,12 @@ impl crate::protocols::openai::DeltaGeneratorExt { + stream_response.nvext = Some(nvext_json); + if let Some(ref info) = nvext_response.worker_id { + tracing::debug!( + "Injected worker_id into chat completion nvext: prefill={:?}, decode={:?}", + info.prefill_worker_id, + info.decode_worker_id + ); + } + if let Some(ref tokens) = nvext_response.token_ids { + tracing::debug!( + "Injected token_ids into chat completion nvext: {} tokens", + tokens.len() + ); + } + if let Some(ref tokens) = nvext_response.completion_token_ids { + tracing::debug!( + "Injected completion_token_ids into chat completion nvext: {} tokens", + tokens.len() + ); + } + } + Err(err) => { + tracing::warn!( + error = %err, + "chat completion nvext: serde_json::to_value failed, dropping nvext payload \ + (RL trainer will not receive token_ids / weight_version this chunk)", + ); + } } } @@ -496,6 +560,9 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, } } @@ -589,6 +656,9 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, } } diff --git a/lib/llm/src/protocols/openai/completions/delta.rs b/lib/llm/src/protocols/openai/completions/delta.rs index 3f039399e1cd..a47d9215a6ef 100644 --- a/lib/llm/src/protocols/openai/completions/delta.rs +++ b/lib/llm/src/protocols/openai/completions/delta.rs @@ -313,21 +313,33 @@ impl crate::protocols::openai::DeltaGeneratorExt for delta.disaggregated_params.as_ref(), finish_reason.is_some(), delta.engine_data, - ) && let Ok(nvext_json) = serde_json::to_value(&nvext_response) - { - response.nvext = Some(nvext_json); - if let Some(ref info) = nvext_response.worker_id { - tracing::debug!( - "Injected worker_id into completions nvext: prefill={:?}, decode={:?}", - info.prefill_worker_id, - info.decode_worker_id - ); - } - if let Some(ref tokens) = nvext_response.token_ids { - tracing::debug!( - "Injected token_ids into completions nvext: {} tokens", - tokens.len() - ); + ) { + // Log a warning if serialization fails instead of silently + // dropping the nvext payload (would mean promoted fields never + // reach the client). + match serde_json::to_value(&nvext_response) { + Ok(nvext_json) => { + response.nvext = Some(nvext_json); + if let Some(ref info) = nvext_response.worker_id { + tracing::debug!( + "Injected worker_id into completions nvext: prefill={:?}, decode={:?}", + info.prefill_worker_id, + info.decode_worker_id + ); + } + if let Some(ref tokens) = nvext_response.token_ids { + tracing::debug!( + "Injected token_ids into completions nvext: {} tokens", + tokens.len() + ); + } + } + Err(err) => { + tracing::warn!( + error = %err, + "completions nvext: serde_json::to_value failed, dropping nvext payload", + ); + } } } diff --git a/lib/llm/src/protocols/openai/nvext.rs b/lib/llm/src/protocols/openai/nvext.rs index 90919836fc29..837cd29b72e7 100644 --- a/lib/llm/src/protocols/openai/nvext.rs +++ b/lib/llm/src/protocols/openai/nvext.rs @@ -73,6 +73,26 @@ pub fn apply_header_routing_overrides(nvext: Option, headers: &HeaderMap) pub trait NvExtProvider { fn nvext(&self) -> Option<&NvExt>; fn raw_prompt(&self) -> Option; + + /// Pre-tokenized input that bypasses chat templating. + /// + /// Two callers populate this today: + /// - GAIE EPP / TITO via `nvext.token_data` (existing path). + /// - Renderer / TITO via top-level `prompt_token_ids` extension on + /// `/v1/chat/completions` (allowlisted by `validate.rs` + /// PASSTHROUGH_EXTRA_FIELDS). This is the canonical home now that + /// `/v1/chat/completions/tokens` is dropped. + /// + /// Default reads only `nvext.token_data`; the chat-completions impl + /// also falls back to `unsupported_fields["prompt_token_ids"]` so the + /// preprocessor sees one effective value regardless of which channel + /// the client used. Returns owned Vec because the top-level field + /// arrives as a JSON value that has to be deserialized fresh. + fn get_pretokenized_input(&self) -> Option> { + self.nvext() + .and_then(|ext| ext.token_data.as_ref()) + .cloned() + } } /// Worker ID information for disaggregated serving @@ -120,6 +140,13 @@ pub struct NvExtResponse { /// Dynamo does not inspect this; it is forwarded as-is to the client. #[serde(skip_serializing_if = "Option::is_none")] pub engine_data: Option, + + /// Output token IDs generated by the engine (RL inference path). + /// Populated when client requests `extra_fields: ["completion_token_ids"]` + /// or auto-enabled under `DYN_ENABLE_RL=true` for the chat-completions path. + /// For RL: `len(completion_token_ids) == len(logprobs.content)` is a hard invariant. + #[serde(skip_serializing_if = "Option::is_none")] + pub completion_token_ids: Option>, } /// Response nvext fields requested for a given request. @@ -137,6 +164,7 @@ pub struct NvExtResponseFieldSelection { pub token_ids: bool, pub routed_experts: bool, pub engine_data: bool, + pub completion_token_ids: bool, } impl NvExtResponseFieldSelection { @@ -153,6 +181,8 @@ impl NvExtResponseFieldSelection { "timing" => selection.timing = true, "routed_experts" => selection.routed_experts = true, "engine_data" => selection.engine_data = true, + "completion_token_ids" => selection.completion_token_ids = true, + "token_ids" => selection.token_ids = true, _ => {} } } @@ -181,6 +211,8 @@ impl NvExtResponseFieldSelection { /// - `worker_id` requires the selection flag **and** `tracker.get_worker_info()` to return `Some`. /// - `token_ids` requires the selection flag **and** a `"token_ids"` key on `disaggregated_params` /// that deserializes into `Vec`; malformed values silently fall back to `None`. + /// - `completion_token_ids` requires the selection flag **and** a `"completion_token_ids"` key on + /// `disaggregated_params` that deserializes into `Vec`; malformed values silently fall back to `None`. /// - `routed_experts` requires the selection flag **and** a `"routed_experts"` key on /// `disaggregated_params` (cloned as-is, no validation). /// - `timing` requires the selection flag, `finish_reason_present == true`, **and** a tracker. @@ -206,6 +238,14 @@ impl NvExtResponseFieldSelection { None }; + let completion_token_ids = if self.completion_token_ids { + disaggregated_params + .and_then(|params| params.get("completion_token_ids")) + .and_then(|v| serde_json::from_value::>(v.clone()).ok()) + } else { + None + }; + let routed_experts = if self.routed_experts { disaggregated_params .and_then(|params| params.get("routed_experts")) @@ -228,6 +268,7 @@ impl NvExtResponseFieldSelection { if worker_id.is_none() && token_ids.is_none() + && completion_token_ids.is_none() && routed_experts.is_none() && timing.is_none() && engine_data.is_none() @@ -241,6 +282,7 @@ impl NvExtResponseFieldSelection { token_ids, routed_experts, engine_data, + completion_token_ids, }) } } @@ -868,6 +910,8 @@ mod tests { token_ids: true, routed_experts: true, engine_data: false, + + completion_token_ids: false, }; let tracker = tracker_with_prefill_worker(); let params = disagg_params_full(); @@ -904,6 +948,8 @@ mod tests { token_ids: false, // only enabled via query_instance_id routed_experts: true, engine_data: false, + + completion_token_ids: false, } ); } diff --git a/lib/llm/src/protocols/openai/responses/mod.rs b/lib/llm/src/protocols/openai/responses/mod.rs index 2dcc530573c6..d6c008a537c8 100644 --- a/lib/llm/src/protocols/openai/responses/mod.rs +++ b/lib/llm/src/protocols/openai/responses/mod.rs @@ -732,6 +732,8 @@ impl TryFrom for NvCreateChatCompletionRequest { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + tokens: None, + return_token_ids: None, }) } } @@ -2113,6 +2115,8 @@ mod tests { usage: None, }, nvext: None, + + prompt_token_ids: None, }; let wrapped = @@ -2174,6 +2178,8 @@ mod tests { usage: None, }, nvext: None, + + prompt_token_ids: None, }; let wrapped = @@ -2379,6 +2385,8 @@ thinking usage: None, }, nvext: None, + + prompt_token_ids: None, }; let resp = chat_completion_to_response(chat_resp, ¶ms, None).unwrap(); @@ -2412,6 +2420,8 @@ thinking usage: None, }, nvext: None, + + prompt_token_ids: None, }; let resp = chat_completion_to_response(chat_resp, ¶ms, None).unwrap(); @@ -2440,6 +2450,8 @@ thinking usage: None, }, nvext: None, + + prompt_token_ids: None, }; let resp = chat_completion_to_response(chat_resp, ¶ms, None).unwrap(); @@ -2465,6 +2477,8 @@ thinking usage: None, }, nvext: None, + + prompt_token_ids: None, }; let resp = chat_completion_to_response(chat_resp, ¶ms, None).unwrap(); @@ -2583,6 +2597,8 @@ thinking usage: None, }, nvext: None, + + prompt_token_ids: None, } } diff --git a/lib/llm/src/protocols/openai/validate.rs b/lib/llm/src/protocols/openai/validate.rs index 237e84bc75be..1e694a0c6af2 100644 --- a/lib/llm/src/protocols/openai/validate.rs +++ b/lib/llm/src/protocols/openai/validate.rs @@ -97,16 +97,69 @@ pub const MAX_REPETITION_PENALTY: f32 = 2.0; // Shared Fields // +/// Fields that RL clients may send as extra_body hints which Dynamo does not +/// implement but should not reject with a 400. They are silently accepted +/// (the chat-completions handler reads what it understands and ignores the +/// rest) so the RL client stack is forward-compatible with new extension +/// fields without churning Dynamo. +/// +/// This is the canonical home for typed RL extension fields; the prior +/// `nvext.extra_fields = ["completion_token_ids", ...]` opt-in mechanism +/// still works alongside it but the named fields here are the recommended +/// path. +const PASSTHROUGH_EXTRA_FIELDS: &[&str] = &[ + // KV prefix-cache isolation hint. The equivalent `X-Tenant-Id` request + // header is also accepted; the header takes precedence when both are + // present. + "cache_salt", + // Pre-tokenized prompt for the RL TITO path. Mutually exclusive with + // `messages`; when present, vLLM 0.20+ skips chat templating. The + // "tokens variant of /v1/chat/completions" collapses into the same URI + // with this extension field instead of a forked + // /v1/chat/completions/tokens. Today RL clients can also pre-tokenize + // and pass via `nvext.token_data` (handled in preprocessor.rs); the + // typed top-level field shipped here is the long-term canonical entry + // for clients written against the vLLM 0.20 schema. + "prompt_token_ids", + // RL routing filter — only dispatch to workers reporting this applied + // weight version. Used by IS-correction strict-version mode and by + // eval-on-subset workflows. Today accepted-and-ignored at the request + // level; the routing-side filter lands in a follow-up. + "weight_version", + // Per-request gate for MoE Routing Replay capture. Honored by + // `nvext.extra_fields = ["routed_experts"]` already (see + // `NvExtResponseFieldSelection`); accepted here as a typed alias. + "return_routed_experts", + // Per-request opt-in for `nvext.completion_token_ids` on the response. + // Today the `extra_fields = ["completion_token_ids"]` mechanism is the + // canonical; this typed alias is the long-term form. + "return_token_ids", + // Opt-in for `nvext.prompt_logprobs` on the response. Aliased through + // to vLLM's `sampling_params.prompt_logprobs` in a follow-up. + "return_prompt_logprobs", + // Token-level sampling controls. Without these, callers in renderer / + // TITO mode can't express stop-on-token-id, constrained sampling, or + // bad-word filtering — which is the whole reason vLLM 0.20's + // `/inference/v1/generate` exists. Promoting them to the chat-completions + // surface as PASSTHROUGH extras keeps `/v1/chat/completions` as the + // single canonical RL data path with full SamplingParams parity. + "stop_token_ids", + "bad_words_token_ids", + "allowed_token_ids", + "truncate_prompt_tokens", +]; + /// Validates that no unsupported fields are present in the request pub fn validate_no_unsupported_fields( unsupported_fields: &std::collections::HashMap, ) -> Result<(), anyhow::Error> { - if !unsupported_fields.is_empty() { - let fields: Vec<_> = unsupported_fields - .keys() - .map(|s| format!("`{}`", s)) - .collect(); - anyhow::bail!("Unsupported parameter(s): {}", fields.join(", ")); + let unknown: Vec<_> = unsupported_fields + .keys() + .filter(|k| !PASSTHROUGH_EXTRA_FIELDS.contains(&k.as_str())) + .map(|s| format!("`{}`", s)) + .collect(); + if !unknown.is_empty() { + anyhow::bail!("Unsupported parameter(s): {}", unknown.join(", ")); } Ok(()) } diff --git a/lib/llm/src/protocols/unified.rs b/lib/llm/src/protocols/unified.rs index 6ce62744e7f3..d5f1fc55b3de 100644 --- a/lib/llm/src/protocols/unified.rs +++ b/lib/llm/src/protocols/unified.rs @@ -535,6 +535,9 @@ mod tests { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, }; let unified = UnifiedRequest::from(req.clone()); diff --git a/lib/llm/tests/parallel_tool_call_integration.rs b/lib/llm/tests/parallel_tool_call_integration.rs index 2827239d4754..c81915c5f724 100644 --- a/lib/llm/tests/parallel_tool_call_integration.rs +++ b/lib/llm/tests/parallel_tool_call_integration.rs @@ -93,6 +93,9 @@ fn create_mock_chat_completion_request() -> NvCreateChatCompletionRequest { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, } } diff --git a/lib/llm/tests/preprocessor.rs b/lib/llm/tests/preprocessor.rs index 2896e01c427f..29dbfb8b0297 100644 --- a/lib/llm/tests/preprocessor.rs +++ b/lib/llm/tests/preprocessor.rs @@ -261,6 +261,9 @@ impl Request { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, } } } @@ -701,6 +704,9 @@ mod context_length_validation { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, } } diff --git a/lib/llm/tests/test_common_ext.rs b/lib/llm/tests/test_common_ext.rs index 8e49c7377b09..f8bfd3ed6232 100644 --- a/lib/llm/tests/test_common_ext.rs +++ b/lib/llm/tests/test_common_ext.rs @@ -1,13 +1,16 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -use dynamo_llm::protocols::{ - common::StopConditionsProvider, - openai::{ - chat_completions::NvCreateChatCompletionRequest, - common_ext::{CommonExt, CommonExtProvider}, - completions::NvCreateCompletionRequest, - nvext::NvExt, +use dynamo_llm::{ + engines::ValidateRequest, + protocols::{ + common::StopConditionsProvider, + openai::{ + chat_completions::NvCreateChatCompletionRequest, + common_ext::{CommonExt, CommonExtProvider}, + completions::NvCreateCompletionRequest, + nvext::NvExt, + }, }, }; @@ -70,6 +73,8 @@ fn test_sampling_parameters_include_stop_str_in_output_extraction() { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + tokens: None, + return_token_ids: None, }; let sampling = request.extract_sampling_options().unwrap(); @@ -213,6 +218,54 @@ fn test_max_thinking_tokens_extraction() { assert_eq!(stop_conditions_none.max_thinking_tokens, None); } +#[test] +fn test_chat_completions_stop_token_ids_extraction() { + // Renderer / TITO callers send `stop_token_ids` as a top-level field + // alongside `nvext.token_data`. Both ride PASSTHROUGH_EXTRA_FIELDS; + // `extract_stop_conditions` plumbs the IDs into + // `common::StopConditions::stop_token_ids_hidden` so the engine layer + // honors them. (Lifted from PR #9141.) + let json_str = r#"{ + "model": "test-model", + "messages": [{"role": "user", "content": "(token-in mode)"}], + "nvext": { + "token_data": [1, 2, 3] + }, + "stop_token_ids": [151645, 151643] + }"#; + + let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap(); + + request.validate().unwrap(); + let stop_conditions = request.extract_stop_conditions().unwrap(); + assert_eq!( + stop_conditions.stop_token_ids_hidden, + Some(vec![151645, 151643]) + ); +} + +#[test] +fn test_chat_completions_stop_token_ids_malformed_returns_400() { + // Malformed stop_token_ids must NOT silently fall back to None — it + // surfaces as a typed anyhow::Error so the HTTP layer returns 400 with + // a useful diagnostic. (PR #9141 contract.) + let json_str = r#"{ + "model": "test-model", + "messages": [{"role": "user", "content": "x"}], + "stop_token_ids": "not-an-array" + }"#; + + let request: NvCreateChatCompletionRequest = serde_json::from_str(json_str).unwrap(); + let err = request + .extract_stop_conditions() + .expect_err("malformed stop_token_ids should error"); + assert!( + err.to_string() + .contains("stop_token_ids must be an array of unsigned token IDs"), + "got: {err}" + ); +} + #[test] fn test_chat_completions_no_common_values() { // Test that when no common values are set, we get None @@ -300,6 +353,8 @@ fn test_serialization_preserves_structure() { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + tokens: None, + return_token_ids: None, }; let json = serde_json::to_value(&request).unwrap(); @@ -352,6 +407,8 @@ fn test_sampling_parameters_extraction() { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + tokens: None, + return_token_ids: None, }; let sampling_options = request.extract_sampling_options().unwrap(); diff --git a/lib/llm/tests/test_streaming_usage.rs b/lib/llm/tests/test_streaming_usage.rs index 0a6fd3178bf6..5357fbb2ee31 100644 --- a/lib/llm/tests/test_streaming_usage.rs +++ b/lib/llm/tests/test_streaming_usage.rs @@ -195,6 +195,9 @@ fn create_chat_request( chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, } } @@ -529,6 +532,9 @@ fn create_nonstreaming_chat_request() -> NvCreateChatCompletionRequest { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, } } diff --git a/lib/llm/tests/tool_choice.rs b/lib/llm/tests/tool_choice.rs index bbff1bd38508..0be68c5015d0 100644 --- a/lib/llm/tests/tool_choice.rs +++ b/lib/llm/tests/tool_choice.rs @@ -41,6 +41,9 @@ fn create_test_request() -> NvCreateChatCompletionRequest { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, } } diff --git a/lib/llm/tests/tool_choice_finish_reasons.rs b/lib/llm/tests/tool_choice_finish_reasons.rs index d3d190c3953c..9141556cae21 100644 --- a/lib/llm/tests/tool_choice_finish_reasons.rs +++ b/lib/llm/tests/tool_choice_finish_reasons.rs @@ -34,6 +34,9 @@ fn create_test_request() -> NvCreateChatCompletionRequest { chat_template_args: None, media_io_kwargs: None, unsupported_fields: Default::default(), + + return_token_ids: None, + tokens: None, } } diff --git a/lib/rl/Cargo.toml b/lib/rl/Cargo.toml new file mode 100644 index 000000000000..99ed788636e7 --- /dev/null +++ b/lib/rl/Cargo.toml @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +[package] +name = "dynamo-rl" +description = "RL admin control plane — handlers, state, fan-out, and HTTP facade for /v1/rl/*" +version.workspace = true +edition.workspace = true +authors.workspace = true +license.workspace = true +homepage.workspace = true +repository.workspace = true + +# Dependency direction: dynamo-llm -> dynamo-rl -> dynamo-runtime. +# This crate must NOT depend on dynamo-llm. + +[dependencies] +dynamo-runtime = { workspace = true } + +axum = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tokio = { workspace = true } +tracing = { workspace = true } +anyhow = { workspace = true } +futures = { workspace = true } diff --git a/lib/rl/src/lib.rs b/lib/rl/src/lib.rs new file mode 100644 index 000000000000..dfd453afc330 --- /dev/null +++ b/lib/rl/src/lib.rs @@ -0,0 +1,1072 @@ +// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Dynamo RL admin control plane — handlers, state, fan-out for `/v1/rl/*`. +//! +//! See `plans/rl-crate.md` and `plans/weight-transfer-config.md`. +//! +//! **PR B status:** request-plane fan-out via the dynamo discovery plane. +//! Workers register one endpoint `dyn://..rl` (see +//! `worker_factory.py::rl_endpoint.serve_endpoint(handler.rl_dispatch, …)`) +//! and the frontend dispatches by snapshotting live `rl` instances and calling +//! each via strict request-plane direct routing. The legacy `register_engine_route` +//! HTTP-on-system-port mechanism + `DYN_RL_WORKER_SYSTEM_URLS` static URL +//! list are gone. + +use std::{ + collections::{HashMap, hash_map::DefaultHasher}, + hash::{Hash, Hasher}, + sync::Arc, + time::Duration, +}; + +use axum::{ + Json, Router, + extract::State, + http::{Method, StatusCode}, + response::IntoResponse, + routing::post, +}; +use dynamo_runtime::{ + DistributedRuntime, + component::Client, + discovery::{DiscoveryInstance, DiscoveryQuery}, + pipeline::{ + SingleIn, + network::egress::push_router::{PushRouter, RouterMode}, + }, + protocols::annotated::Annotated, +}; +use futures::{FutureExt, StreamExt}; + +pub const DEFAULT_RL_ENDPOINT: &str = "rl"; + +#[derive(Debug, Clone)] +pub enum RlError { + NoWorkers { + namespace: String, + rl_endpoint: String, + }, + MembershipChanged { + before_epoch: u64, + after_epoch: u64, + }, +} + +impl std::fmt::Display for RlError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + RlError::NoWorkers { + namespace, + rl_endpoint, + } => write!( + f, + "no live RL workers found in namespace '{namespace}' for endpoint '{rl_endpoint}'" + ), + RlError::MembershipChanged { + before_epoch, + after_epoch, + } => write!( + f, + "RL worker membership changed during fan-out (before={before_epoch}, after={after_epoch})" + ), + } + } +} + +impl std::error::Error for RlError {} + +#[derive(Debug, Clone)] +pub struct RlClientConfig { + pub runtime: Arc, + pub namespace: String, + pub rl_endpoint: String, + pub policy: FanoutPolicy, +} + +#[derive(Debug, Clone)] +pub struct FanoutPolicy { + pub min_workers: usize, + pub membership_timeout: Duration, + pub request_timeout: Duration, + pub strict_direct: bool, + pub abort_on_membership_change: bool, + pub component_filter: Option>, +} + +impl FanoutPolicy { + pub fn default_admin() -> Self { + Self { + min_workers: 1, + membership_timeout: Duration::from_secs(5), + request_timeout: Duration::from_secs(30), + strict_direct: true, + abort_on_membership_change: true, + component_filter: None, + } + } + + pub fn with_component_filter(mut self, components: Vec) -> Self { + let components: Vec = components + .into_iter() + .map(|c| c.trim().to_string()) + .filter(|c| !c.is_empty()) + .collect(); + self.component_filter = if components.is_empty() { + None + } else { + Some(components) + }; + self + } +} + +impl Default for FanoutPolicy { + fn default() -> Self { + Self::default_admin() + } +} + +#[derive( + Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq, Hash, PartialOrd, Ord, +)] +pub struct WorkerTarget { + pub namespace: String, + pub component: String, + pub endpoint: String, + pub instance_id: u64, +} + +impl WorkerTarget { + fn endpoint_key(&self) -> (String, String, String) { + ( + self.namespace.clone(), + self.component.clone(), + self.endpoint.clone(), + ) + } +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct MembershipSnapshot { + pub epoch: u64, + pub targets: Vec, +} + +impl MembershipSnapshot { + fn new(mut targets: Vec) -> Self { + targets.sort(); + targets.dedup(); + + let mut hasher = DefaultHasher::new(); + targets.hash(&mut hasher); + let epoch = hasher.finish(); + + Self { epoch, targets } + } + + pub fn is_empty(&self) -> bool { + self.targets.is_empty() + } +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct RlRequest { + pub op: String, + #[serde(default)] + pub body: serde_json::Value, +} + +impl RlRequest { + pub fn new(op: impl Into, body: serde_json::Value) -> Self { + Self { + op: op.into(), + body, + } + } + + pub fn describe(_req: DescribeRequest) -> Self { + Self::new("describe", serde_json::json!({})) + } + + pub fn pause(req: PauseRequest) -> Self { + Self::new("pause", serde_json::to_value(req).unwrap_or_default()) + } + + pub fn resume(_req: ResumeRequest) -> Self { + Self::new("resume", serde_json::json!({})) + } + + pub fn init_transport(req: InitTransportRequest) -> Self { + Self::new("init_transport", req.0) + } + + pub fn update_weights(req: UpdateWeightsRequest) -> Self { + Self::new("update_weights", req.into_body()) + } +} + +#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)] +pub struct DescribeRequest {} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct PauseRequest { + pub mode: String, + pub clear_cache: bool, +} + +impl Default for PauseRequest { + fn default() -> Self { + Self { + mode: "keep".to_string(), + clear_cache: false, + } + } +} + +#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)] +pub struct ResumeRequest {} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(transparent)] +pub struct InitTransportRequest(pub serde_json::Value); + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct UpdateWeightsRequest { + pub version: String, + pub target: serde_json::Value, + #[serde(default)] + pub transport: serde_json::Value, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub pause_mode: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub clear_cache: Option, +} + +impl UpdateWeightsRequest { + fn into_body(self) -> serde_json::Value { + let mut body = serde_json::json!({ + "version": self.version, + "target": self.target, + "transport": self.transport, + }); + if let Some(pause_mode) = self.pause_mode { + body["pause_mode"] = serde_json::Value::String(pause_mode); + } + if let Some(clear_cache) = self.clear_cache { + body["clear_cache"] = serde_json::Value::Bool(clear_cache); + } + body + } +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct WorkerResult { + pub target: WorkerTarget, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub response: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +impl WorkerResult { + fn ok(target: WorkerTarget, response: serde_json::Value) -> Self { + Self { + target, + response: Some(response), + error: None, + } + } + + fn error(target: WorkerTarget, error: impl Into) -> Self { + Self { + target, + response: None, + error: Some(error.into()), + } + } + + pub fn is_ok(&self) -> bool { + self.error.is_none() + && self + .response + .as_ref() + .and_then(|r| r.get("status")) + .and_then(|s| s.as_str()) + == Some("ok") + } + + pub fn payload(&self) -> serde_json::Value { + match (&self.response, &self.error) { + (Some(response), None) => response.clone(), + (_, Some(error)) => serde_json::json!({ + "status": "error", + "namespace": self.target.namespace, + "component": self.target.component, + "endpoint": self.target.endpoint, + "instance_id": self.target.instance_id, + "message": error, + }), + _ => serde_json::json!({ + "status": "error", + "namespace": self.target.namespace, + "component": self.target.component, + "endpoint": self.target.endpoint, + "instance_id": self.target.instance_id, + "message": "missing worker response", + }), + } + } +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct FanoutReport { + pub snapshot: MembershipSnapshot, + pub workers: Vec, +} + +impl FanoutReport { + pub fn all_ok(&self) -> bool { + !self.workers.is_empty() && self.workers.iter().all(WorkerResult::is_ok) + } + + pub fn worker_payloads(&self) -> Vec { + self.workers.iter().map(WorkerResult::payload).collect() + } +} + +#[derive(Clone)] +pub struct RlClient { + runtime: Arc, + namespace: String, + rl_endpoint: String, + policy: FanoutPolicy, +} + +impl RlClient { + pub fn new(config: RlClientConfig) -> anyhow::Result { + if config.namespace.trim().is_empty() { + anyhow::bail!("RlClientConfig.namespace must not be empty"); + } + if config.rl_endpoint.trim().is_empty() { + anyhow::bail!("RlClientConfig.rl_endpoint must not be empty"); + } + + Ok(Self { + runtime: config.runtime, + namespace: config.namespace, + rl_endpoint: config.rl_endpoint, + policy: config.policy, + }) + } + + pub async fn snapshot(&self) -> anyhow::Result { + let instances = self + .runtime + .discovery() + .list(DiscoveryQuery::NamespacedEndpoints { + namespace: self.namespace.clone(), + }) + .await?; + + let targets = instances + .into_iter() + .filter_map(|instance| match instance { + DiscoveryInstance::Endpoint(instance) if instance.endpoint == self.rl_endpoint => { + Some(instance) + } + _ => None, + }) + .filter(|instance| { + self.policy + .component_filter + .as_ref() + .map(|components| components.iter().any(|c| c == &instance.component)) + .unwrap_or(true) + }) + .map(|instance| WorkerTarget { + namespace: instance.namespace, + component: instance.component, + endpoint: instance.endpoint, + instance_id: instance.instance_id, + }) + .collect(); + + Ok(MembershipSnapshot::new(targets)) + } + + pub async fn describe(&self) -> anyhow::Result { + self.fanout(RlRequest::describe(DescribeRequest::default())) + .await + } + + pub async fn pause(&self, req: PauseRequest) -> anyhow::Result { + self.fanout(RlRequest::pause(req)).await + } + + pub async fn resume(&self, req: ResumeRequest) -> anyhow::Result { + self.fanout(RlRequest::resume(req)).await + } + + pub async fn init_transport(&self, req: InitTransportRequest) -> anyhow::Result { + self.fanout(RlRequest::init_transport(req)).await + } + + pub async fn update_weights(&self, req: UpdateWeightsRequest) -> anyhow::Result { + self.fanout(RlRequest::update_weights(req)).await + } + + pub async fn fanout(&self, request: RlRequest) -> anyhow::Result { + let snapshot = self.snapshot().await?; + self.fanout_snapshot(snapshot, request).await + } + + pub async fn fanout_snapshot( + &self, + snapshot: MembershipSnapshot, + request: RlRequest, + ) -> anyhow::Result { + if snapshot.targets.len() < self.policy.min_workers { + return Err(RlError::NoWorkers { + namespace: self.namespace.clone(), + rl_endpoint: self.rl_endpoint.clone(), + } + .into()); + } + + let mut grouped: HashMap<(String, String, String), Vec> = HashMap::new(); + for target in &snapshot.targets { + grouped + .entry(target.endpoint_key()) + .or_default() + .push(target.clone()); + } + + let mut calls: Vec> = Vec::new(); + for ((namespace, component, endpoint_name), targets) in grouped { + let endpoint = match self + .runtime + .namespace(&namespace) + .and_then(|ns| ns.component(&component)) + { + Ok(component) => component.endpoint(endpoint_name), + Err(err) => { + for target in targets { + calls.push( + futures::future::ready(WorkerResult::error( + target, + format!("endpoint build failed: {err}"), + )) + .boxed(), + ); + } + continue; + } + }; + + let client = match endpoint.client().await { + Ok(client) => client, + Err(err) => { + for target in targets { + calls.push( + futures::future::ready(WorkerResult::error( + target, + format!("client create failed: {err}"), + )) + .boxed(), + ); + } + continue; + } + }; + + let target_ids: Vec = targets.iter().map(|target| target.instance_id).collect(); + wait_for_client_targets(&client, &target_ids, self.policy.membership_timeout).await; + + let router = + match PushRouter::>::from_client( + client, + RouterMode::Direct, + ) + .await + { + Ok(router) => router, + Err(err) => { + for target in targets { + calls.push( + futures::future::ready(WorkerResult::error( + target, + format!("PushRouter build failed: {err}"), + )) + .boxed(), + ); + } + continue; + } + }; + + for target in targets { + calls.push( + call_worker( + router.clone(), + target, + request.clone(), + self.policy.request_timeout, + self.policy.strict_direct, + ) + .boxed(), + ); + } + } + + let workers = futures::future::join_all(calls).await; + + if self.policy.abort_on_membership_change { + let after = self.snapshot().await?; + if after.epoch != snapshot.epoch { + return Err(RlError::MembershipChanged { + before_epoch: snapshot.epoch, + after_epoch: after.epoch, + } + .into()); + } + } + + Ok(FanoutReport { snapshot, workers }) + } +} + +async fn wait_for_client_targets(client: &Client, target_ids: &[u64], timeout: Duration) { + let wait = async { + loop { + let instance_ids = client.instance_ids(); + if target_ids.iter().all(|id| instance_ids.contains(id)) { + break; + } + tokio::time::sleep(Duration::from_millis(25)).await; + } + }; + + let _ = tokio::time::timeout(timeout, wait).await; +} + +async fn call_worker( + router: PushRouter>, + target: WorkerTarget, + request: RlRequest, + timeout: Duration, + strict_direct: bool, +) -> WorkerResult { + let request_value = match serde_json::to_value(request) { + Ok(value) => value, + Err(err) => return WorkerResult::error(target, format!("request encode failed: {err}")), + }; + + let instance_id = target.instance_id; + let dispatch = async { + let req = SingleIn::new(request_value); + let mut stream = if strict_direct { + router.direct_strict(req, instance_id).await? + } else { + router.direct(req, instance_id).await? + }; + + while let Some(chunk) = stream.next().await { + if let Some(data) = chunk.data { + return anyhow::Ok(data); + } + if let Some(err) = chunk.error { + anyhow::bail!(err.to_string()); + } + } + + anyhow::bail!("empty response stream"); + }; + + match tokio::time::timeout(timeout, dispatch).await { + Ok(Ok(response)) => WorkerResult::ok(target, response), + Ok(Err(err)) => WorkerResult::error(target, format!("dispatch failed: {err}")), + Err(_) => WorkerResult::error( + target, + format!("dispatch timed out after {}s", timeout.as_secs()), + ), + } +} + +/// Documentation tuple for an RL admin route. The dynamo-llm caller wraps +/// each tuple into its own `RouteDoc` for `/openapi.json` aggregation. +#[derive(Debug, Clone)] +pub struct RlRouteDoc { + pub method: Method, + pub path: String, +} + +impl RlRouteDoc { + fn new(method: Method, path: impl Into) -> Self { + Self { + method, + path: path.into(), + } + } +} + +/// Shared state for the RL admin HTTP facade. +#[derive(Clone)] +struct RlState { + client: RlClient, +} + +impl RlState { + fn new(client: RlClient) -> Self { + Self { client } + } + + async fn fan_out(&self, route: &str, body: serde_json::Value) -> anyhow::Result { + self.client + .fanout(RlRequest::new(route_to_op(route), body)) + .await + } +} + +#[derive(Clone)] +pub struct RlHttpDeps { + pub client: RlClient, +} + +/// Map a legacy engine-route name to the corresponding `rl_dispatch` op. +fn route_to_op(route: &str) -> &str { + match route { + "pause_generation" => "pause", + "resume_generation" => "resume", + "weight_transport_init" => "init_transport", + "weight_transport_update" => "update_weights", + // Anything else — pass through verbatim so `rl_dispatch` can return + // a meaningful "unknown op" error instead of us silently rewriting. + other => other, + } +} + +/// `POST /v1/rl/pause` — fan out `pause_generation` to all workers. +/// +/// Query params (both optional): +/// - `mode`: `keep` | `wait` | `abort` (default `keep`) +/// - `clear_cache`: `true` | `false` (default `false`) +/// +/// Three-mode pause matches what vLLM exposes (abort / wait / keep). The +/// default `mode=keep&clear_cache=false` preserves the original single-mode +/// pause behavior so existing callers keep working without changes. +#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Deserialize, serde::Serialize)] +#[serde(rename_all = "lowercase")] +enum PauseMode { + Keep, + Wait, + Abort, +} + +impl PauseMode { + fn as_str(self) -> &'static str { + match self { + PauseMode::Keep => "keep", + PauseMode::Wait => "wait", + PauseMode::Abort => "abort", + } + } +} + +impl Default for PauseMode { + fn default() -> Self { + PauseMode::Keep + } +} + +#[derive(Debug, serde::Deserialize)] +struct RlPauseQuery { + /// Axum returns 400 automatically if this fails to deserialize as a + /// `PauseMode` (i.e. on `mode=invalid`), so we don't need a runtime check. + #[serde(default)] + mode: Option, + #[serde(default)] + clear_cache: Option, +} + +fn rl_error_response(err: anyhow::Error) -> (StatusCode, Json) { + let (status, error_type) = match err.downcast_ref::() { + Some(RlError::NoWorkers { .. }) => (StatusCode::SERVICE_UNAVAILABLE, "no_workers"), + Some(RlError::MembershipChanged { .. }) => (StatusCode::CONFLICT, "membership_changed"), + None => (StatusCode::BAD_GATEWAY, "fanout_failed"), + }; + + ( + status, + Json(serde_json::json!({ + "status": "error", + "error_type": error_type, + "message": err.to_string(), + })), + ) +} + +async fn rl_pause( + State(state): State>, + axum::extract::Query(q): axum::extract::Query, +) -> impl IntoResponse { + let mode = q.mode.unwrap_or_default(); + let clear_cache = q.clear_cache.unwrap_or(false); + let report = match state + .fan_out( + "pause_generation", + serde_json::json!({"mode": mode.as_str(), "clear_cache": clear_cache}), + ) + .await + { + Ok(report) => report, + Err(err) => return rl_error_response(err), + }; + + let workers = report.worker_payloads(); + if report.all_ok() { + tracing::info!( + worker_count = workers.len(), + membership_epoch = report.snapshot.epoch, + mode = %mode.as_str(), + clear_cache, + "RL pause: all workers paused" + ); + ( + StatusCode::OK, + Json(serde_json::json!({ + "status": "ok", + "mode": mode.as_str(), + "clear_cache": clear_cache, + "membership_epoch": report.snapshot.epoch, + "workers": workers, + })), + ) + } else { + tracing::warn!(?workers, "RL pause: some workers failed"); + ( + StatusCode::BAD_GATEWAY, + Json(serde_json::json!({ + "status": "error", + "membership_epoch": report.snapshot.epoch, + "workers": workers, + })), + ) + } +} + +/// `POST /v1/rl/resume` — fan out `resume_generation` to all workers. +async fn rl_resume(State(state): State>) -> impl IntoResponse { + let report = match state + .fan_out("resume_generation", serde_json::json!({})) + .await + { + Ok(report) => report, + Err(err) => return rl_error_response(err), + }; + + let workers = report.worker_payloads(); + if report.all_ok() { + tracing::info!( + worker_count = workers.len(), + membership_epoch = report.snapshot.epoch, + "RL resume: all workers resumed" + ); + ( + StatusCode::OK, + Json(serde_json::json!({ + "status": "ok", + "membership_epoch": report.snapshot.epoch, + "workers": workers, + })), + ) + } else { + tracing::warn!(?workers, "RL resume: some workers failed"); + ( + StatusCode::BAD_GATEWAY, + Json(serde_json::json!({ + "status": "error", + "membership_epoch": report.snapshot.epoch, + "workers": workers, + })), + ) + } +} + +/// `POST /v1/rl/update_weights` — fan out `flush_cache → update_weights_from_path` to all workers. +/// +/// **Not atomic.** If `update_weights_from_path` succeeds on workers `0..N-1` +/// and fails on worker `N`, the fleet is left in a mixed-version state: the +/// successful workers serve the new version while worker `N` still runs the +/// previous one. The response carries per-worker status so callers can +/// retry / drain manually; a true rollback layer is a follow-up. +/// +/// Two body shapes are accepted: +/// +/// **Legacy** (Phase 1 backward-compat): +/// ```json +/// { +/// "weight_dir": "/path/to/checkpoint" | null, // null → NCCL mode no-op +/// "weight_version": "step_42", // optional; derived from +/// // weight_dir basename if missing +/// "reset_prefix_cache": true +/// } +/// ``` +/// +/// **WeightTransferConfig** (new, single shape across backends): +/// ```json +/// { +/// "version": "step_42", +/// "target": {"kind": "base"} | {"kind": "lora", "name": "...", "op": "load|swap|unload"}, +/// "transport": { +/// "backend": "filesystem" | "nccl", +/// "filesystem": {"path": "...", "require_marker": "STABLE"}, +/// "nccl": {"transport_id": "...", "weight_names": [...], "dtype": "bf16"} +/// } +/// } +/// ``` +/// +/// Returns `{ "status": "ok", "applied_weight_version": "step_42", "workers": [...] }` on success. +/// +/// The pause/resume envelope is left to the caller; full-FT updates MUST +/// bracket this call with `/v1/rl/pause` and `/v1/rl/resume`. +/// +/// **Phase 3 (PR C):** the legacy `{weight_dir, weight_version, reset_prefix_cache}` +/// body is gone. Every caller now provides `version`, `target`, and +/// `transport`. LoRA load/swap/unload also go through this same body via +/// `target.kind = "lora"` — see `weight-transfer-config.md` § 2. +#[derive(Debug, serde::Deserialize)] +struct RlUpdateWeightsBody { + version: String, + target: serde_json::Value, + #[serde(default)] + transport: serde_json::Value, + #[serde(default)] + pause_mode: Option, + #[serde(default)] + clear_cache: Option, +} + +async fn rl_update_weights( + State(state): State>, + body: axum::extract::Json, +) -> impl IntoResponse { + let RlUpdateWeightsBody { + version, + target, + mut transport, + pause_mode, + clear_cache, + } = body.0; + if is_lora_unload(&target) && transport.get("backend").is_none() { + transport = serde_json::json!({ + "backend": "filesystem", + "filesystem": {}, + }); + } + rl_update_weights_inner(state, version, target, transport, pause_mode, clear_cache).await +} + +fn is_lora_unload(target: &serde_json::Value) -> bool { + target.get("kind").and_then(|v| v.as_str()) == Some("lora") + && target.get("op").and_then(|v| v.as_str()) == Some("unload") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn update_weights_body_accepts_lora_unload_without_transport() { + let body: RlUpdateWeightsBody = serde_json::from_value(serde_json::json!({ + "version": "step_44", + "target": { + "kind": "lora", + "name": "adapter", + "op": "unload" + } + })) + .unwrap(); + + assert!(is_lora_unload(&body.target)); + assert!(body.transport.is_null()); + } +} + +/// WeightTransferConfig path — fans out to ``weight_transport_update``. +async fn rl_update_weights_inner( + state: Arc, + version: String, + target: serde_json::Value, + transport: serde_json::Value, + pause_mode: Option, + clear_cache: Option, +) -> (StatusCode, Json) { + let backend = transport + .get("backend") + .and_then(|v| v.as_str()) + .unwrap_or(""); + tracing::info!( + version = %version, + backend = %backend, + ?target, + "RL update_weights" + ); + let mut body = serde_json::json!({ + "version": version, + "target": target, + "transport": transport, + }); + if let Some(pm) = pause_mode { + body["pause_mode"] = serde_json::Value::String(pm); + } + if let Some(cc) = clear_cache { + body["clear_cache"] = serde_json::Value::Bool(cc); + } + let report = match state.fan_out("weight_transport_update", body).await { + Ok(report) => report, + Err(err) => return rl_error_response(err), + }; + + let workers = report.worker_payloads(); + if report.all_ok() { + tracing::info!( + worker_count = workers.len(), + membership_epoch = report.snapshot.epoch, + backend = %backend, + version = %version, + "RL update_weights: all workers updated" + ); + ( + StatusCode::OK, + Json(serde_json::json!({ + "status": "ok", + "applied_weight_version": version, + "backend": backend, + "membership_epoch": report.snapshot.epoch, + "workers": workers, + })), + ) + } else { + tracing::warn!(?workers, backend = %backend, "RL update_weights: some workers failed"); + ( + StatusCode::BAD_GATEWAY, + Json(serde_json::json!({ + "status": "error", + "stage": "weight_transport_update", + "backend": backend, + "membership_epoch": report.snapshot.epoch, + "workers": workers, + })), + ) + } +} + +/// `POST /v1/rl/init_transport` — idempotent one-time setup for a weight +/// transport (filesystem / nccl). Replaces backend-specific bring-up +/// endpoints with a single discriminated body. +/// +/// Body: +/// ```json +/// { +/// "transport_id": "rl-weights-step", +/// "backend": "filesystem" | "nccl", +/// "filesystem": { … } | "nccl": { … } +/// } +/// ``` +/// +/// `filesystem` is a no-op (transport state goes ``ready`` immediately). +/// `nccl` triggers the worker-side group bootstrap. +async fn rl_init_transport( + State(state): State>, + body: axum::extract::Json, +) -> impl IntoResponse { + let body = body.0; + let backend = body + .get("backend") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let transport_id = body + .get("transport_id") + .and_then(|v| v.as_str()) + .unwrap_or(&backend) + .to_string(); + tracing::info!(%backend, %transport_id, "RL init_transport"); + + let report = match state.fan_out("weight_transport_init", body).await { + Ok(report) => report, + Err(err) => return rl_error_response(err), + }; + + let workers = report.worker_payloads(); + if report.all_ok() { + tracing::info!( + worker_count = workers.len(), + membership_epoch = report.snapshot.epoch, + %backend, + %transport_id, + "RL init_transport: all workers ready" + ); + ( + StatusCode::OK, + Json(serde_json::json!({ + "status": "ok", + "transport_id": transport_id, + "backend": backend, + "ready": true, + "membership_epoch": report.snapshot.epoch, + "workers": workers, + })), + ) + } else { + tracing::warn!(?workers, %backend, "RL init_transport: some workers failed"); + ( + StatusCode::BAD_GATEWAY, + Json(serde_json::json!({ + "status": "error", + "transport_id": transport_id, + "backend": backend, + "membership_epoch": report.snapshot.epoch, + "workers": workers, + })), + ) + } +} + +/// Create an Axum [`Router`] for the RL admin endpoints at `/v1/rl/*`. +/// +/// Fan-out goes through [`RlClient`], which snapshots the discovery plane, +/// groups live `..rl` workers, and dispatches with +/// request-plane strict direct calls over NATS / TCP / HTTP. +/// +/// **Surface:** four POST routes after Phase 3. +/// `pause`, `resume`, `init_transport`, `update_weights`. Read-side +/// endpoints (`state`, `health`, `ready`, `liveness`, `weight_version`) +/// and the dedicated LoRA routes (`load_lora_adapter`, `unload_lora_adapter`) +/// are dropped — replacements piggyback on the frontend's existing `/live` +/// and `/health`, and LoRA flows through `update_weights {target.kind="lora"}`. +/// See `weight-transfer-config.md` § "Constraints from existing surface". +/// +/// Mounted on the dedicated `/v1/rl/*` listener when +/// `DYN_ENABLE_RL_ENDPOINTS=true`. prime-rl usage: +/// `admin_base_url = "http://dynamo-frontend:8002/v1/rl"`. +pub fn rl_router(deps: RlHttpDeps) -> anyhow::Result<(Vec, Router)> { + let rl_state_arc = Arc::new(RlState::new(deps.client)); + let docs = vec![ + // Pause / resume bracket. + RlRouteDoc::new(axum::http::Method::POST, "/v1/rl/pause"), + RlRouteDoc::new(axum::http::Method::POST, "/v1/rl/resume"), + // WeightTransferConfig API: init + discriminated update_weights body + // covering both base-model reload and LoRA load/swap/unload. + RlRouteDoc::new(axum::http::Method::POST, "/v1/rl/init_transport"), + RlRouteDoc::new(axum::http::Method::POST, "/v1/rl/update_weights"), + ]; + let router = Router::new() + .route("/v1/rl/pause", post(rl_pause)) + .route("/v1/rl/resume", post(rl_resume)) + .route("/v1/rl/init_transport", post(rl_init_transport)) + .route("/v1/rl/update_weights", post(rl_update_weights)) + .with_state(rl_state_arc); + Ok((docs, router)) +} diff --git a/lib/runtime/src/pipeline/network/egress/push_router.rs b/lib/runtime/src/pipeline/network/egress/push_router.rs index 41fa024caced..001865c95dca 100644 --- a/lib/runtime/src/pipeline/network/egress/push_router.rs +++ b/lib/runtime/src/pipeline/network/egress/push_router.rs @@ -452,6 +452,27 @@ where .await } + /// Issue a request to a specific endpoint without fallback re-selection. + /// + /// This is intended for admin/control-plane operations where the caller has + /// already selected a concrete membership snapshot and routing the request + /// to any other instance would be incorrect. + pub async fn direct_strict( + &self, + request: SingleIn, + instance_id: u64, + ) -> anyhow::Result> { + if !self.client.instance_ids().contains(&instance_id) { + return Err(anyhow::anyhow!( + "instance_id={instance_id} not found for endpoint {}", + self.client.endpoint.id() + )); + } + + self.generate_with_fault_detection_options(instance_id, request, false) + .await + } + /// Issue a request using device-aware weighted routing. /// /// Instances are partitioned by device type (CPU vs non-CPU), then the router @@ -650,9 +671,19 @@ where */ async fn generate_with_fault_detection( + &self, + instance_id: u64, + request: SingleIn, + ) -> anyhow::Result> { + self.generate_with_fault_detection_options(instance_id, request, true) + .await + } + + async fn generate_with_fault_detection_options( &self, mut instance_id: u64, request: SingleIn, + allow_fallback: bool, ) -> anyhow::Result> { let route_start = Instant::now(); let request_id = request.id().to_string(); @@ -734,6 +765,12 @@ where if let Some(result) = resolve_transport(instance_id) { result + } else if !allow_fallback { + return Err(anyhow::anyhow!( + "Instance {} not found for endpoint {}", + instance_id, + self.client.endpoint.id() + )); } else { // Instance vanished — pick a different one from the current // availability list and retry the lookup once. diff --git a/lib/tokenizers/src/fastokens.rs b/lib/tokenizers/src/fastokens.rs index 93e855cc5c58..4295e69fcb91 100644 --- a/lib/tokenizers/src/fastokens.rs +++ b/lib/tokenizers/src/fastokens.rs @@ -39,16 +39,28 @@ impl FastTokenizer { impl Encoder for FastTokenizer { fn encode(&self, input: &str) -> Result { + self.encode_with_special_tokens(input, false) + } + + fn encode_batch(&self, inputs: &[&str]) -> Result> { + inputs.par_iter().map(|input| self.encode(input)).collect() + } + + fn encode_with_special_tokens( + &self, + input: &str, + add_special_tokens: bool, + ) -> Result { + if add_special_tokens { + return self.hf_decoder.encode_with_special_tokens(input, true); + } + let ids = self .fast_encoder .encode(input) .map_err(|e| Error::msg(format!("Fastokens encode error: {e}")))?; Ok(Encoding::Sp(ids)) } - - fn encode_batch(&self, inputs: &[&str]) -> Result> { - inputs.par_iter().map(|input| self.encode(input)).collect() - } } impl Decoder for FastTokenizer { @@ -57,7 +69,11 @@ impl Decoder for FastTokenizer { } } -impl Tokenizer for FastTokenizer {} +impl Tokenizer for FastTokenizer { + fn convert_ids_to_tokens(&self, token_ids: &[TokenIdType]) -> Result> { + self.hf_decoder.convert_ids_to_tokens(token_ids) + } +} #[cfg(test)] mod tests { diff --git a/lib/tokenizers/src/hf.rs b/lib/tokenizers/src/hf.rs index 080a775719fe..68c720c3ea8c 100644 --- a/lib/tokenizers/src/hf.rs +++ b/lib/tokenizers/src/hf.rs @@ -27,19 +27,18 @@ impl HuggingFaceTokenizer { impl Encoder for HuggingFaceTokenizer { fn encode(&self, input: &str) -> Result { - // This self.tokenizer is the library - let encoding = self - .tokenizer - .encode(input, false) - .map_err(|err| Error::msg(format!("Error tokenizing input: {err}")))?; - - Ok(Encoding::Hf(Box::new(encoding))) + // Use add_special_tokens=true to match TikTokenTokenizer::encode() behaviour. + // Both backends must agree on whether BOS/EOS are included so that callers + // (e.g. /v1/tokenize, rl_tokenize_prompt) get consistent token counts + // regardless of which backend is active. Callers that explicitly need no + // special tokens should call encode_with_special_tokens(input, false) directly. + self.encode_with_special_tokens(input, true) } fn encode_batch(&self, inputs: &[&str]) -> Result> { let hf_encodings = self .tokenizer - .encode_batch(inputs.to_vec(), false) + .encode_batch(inputs.to_vec(), true) // true to match encode() above .map_err(|err| Error::msg(format!("Error batch tokenizing input: {err}")))?; let encodings = hf_encodings @@ -49,6 +48,20 @@ impl Encoder for HuggingFaceTokenizer { Ok(encodings) } + + fn encode_with_special_tokens( + &self, + input: &str, + add_special_tokens: bool, + ) -> Result { + // This self.tokenizer is the library + let encoding = self + .tokenizer + .encode(input, add_special_tokens) + .map_err(|err| Error::msg(format!("Error tokenizing input: {err}")))?; + + Ok(Encoding::Hf(Box::new(encoding))) + } } impl Decoder for HuggingFaceTokenizer { @@ -63,7 +76,14 @@ impl Decoder for HuggingFaceTokenizer { } } -impl Tokenizer for HuggingFaceTokenizer {} +impl Tokenizer for HuggingFaceTokenizer { + fn convert_ids_to_tokens(&self, token_ids: &[TokenIdType]) -> Result> { + Ok(token_ids + .iter() + .map(|&id| self.tokenizer.id_to_token(id).unwrap_or_default()) + .collect()) + } +} impl From for HuggingFaceTokenizer { fn from(tokenizer: HfTokenizer) -> Self { diff --git a/lib/tokenizers/src/lib.rs b/lib/tokenizers/src/lib.rs index 95494b4f73f0..dbf00955b0bc 100644 --- a/lib/tokenizers/src/lib.rs +++ b/lib/tokenizers/src/lib.rs @@ -63,6 +63,14 @@ pub mod traits { pub trait Encoder: Send + Sync { fn encode(&self, input: &str) -> Result; fn encode_batch(&self, inputs: &[&str]) -> Result>; + + fn encode_with_special_tokens( + &self, + input: &str, + _add_special_tokens: bool, + ) -> Result { + self.encode(input) + } } /// Result of decoding token IDs to text. @@ -128,8 +136,17 @@ pub mod traits { } pub trait Tokenizer: Encoder + Decoder { - // fn get_vocab_size(&self) -> usize; - // fn make_unique_clone(&self) -> Box; + fn convert_ids_to_tokens(&self, token_ids: &[TokenIdType]) -> Result> { + // Decoder::decode returns DecodeResult (Complete/Partial); the existing + // `impl From for String` unwraps to the inner string. + token_ids + .iter() + .map(|id| { + self.decode(std::slice::from_ref(id), false) + .map(String::from) + }) + .collect() + } } } @@ -224,6 +241,18 @@ impl Tokenizer { Ok(Tokenizer(create_tokenizer_from_file(file_path)?)) } + pub fn encode_with_special_tokens( + &self, + input: &str, + add_special_tokens: bool, + ) -> Result { + self.0.encode_with_special_tokens(input, add_special_tokens) + } + + pub fn convert_ids_to_tokens(&self, token_ids: &[TokenIdType]) -> Result> { + self.0.convert_ids_to_tokens(token_ids) + } + /// Create a stateful sequence object for decoding token_ids into text pub fn decode_stream( &self, diff --git a/lib/tokenizers/src/tiktoken.rs b/lib/tokenizers/src/tiktoken.rs index 7082acb0f6a6..cc049c13001d 100644 --- a/lib/tokenizers/src/tiktoken.rs +++ b/lib/tokenizers/src/tiktoken.rs @@ -24,6 +24,8 @@ const KIMI_PATTERN: &str = r#"[\p{Han}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p pub struct TikTokenTokenizer { bpe: CoreBPE, special_token_ids: HashSet, + decoder_tokens: FxHashMap>, + special_tokens_decoder: FxHashMap>, } impl TikTokenTokenizer { @@ -39,6 +41,14 @@ impl TikTokenTokenizer { special_tokens: FxHashMap, ) -> Result { let encoder = parse_tiktoken_file(path)?; + let decoder_tokens: FxHashMap> = encoder + .iter() + .map(|(bytes, &id)| (id, bytes.clone())) + .collect(); + let special_tokens_decoder: FxHashMap> = special_tokens + .iter() + .map(|(token, &id)| (id, token.as_bytes().to_vec())) + .collect(); let special_token_ids: HashSet = special_tokens.values().copied().collect(); let bpe = CoreBPE::new(encoder, special_tokens, pattern) @@ -47,6 +57,8 @@ impl TikTokenTokenizer { Ok(Self { bpe, special_token_ids, + decoder_tokens, + special_tokens_decoder, }) } @@ -62,9 +74,17 @@ impl TikTokenTokenizer { let pattern = detect_bpe_pattern(directory)?; let encoder = parse_tiktoken_file(path)?; + let decoder_tokens: FxHashMap> = encoder + .iter() + .map(|(bytes, &id)| (id, bytes.clone())) + .collect(); // Use max rank + 1 (not len) to avoid ID collisions with sparse/non-contiguous ranks let num_base_tokens = encoder.values().max().map_or(0, |&m| m + 1) as usize; let special_tokens = load_special_tokens(directory, num_base_tokens)?; + let special_tokens_decoder: FxHashMap> = special_tokens + .iter() + .map(|(token, &id)| (id, token.as_bytes().to_vec())) + .collect(); let special_token_ids: HashSet = special_tokens.values().copied().collect(); let bpe = CoreBPE::new(encoder, special_tokens, pattern) @@ -73,19 +93,33 @@ impl TikTokenTokenizer { Ok(Self { bpe, special_token_ids, + decoder_tokens, + special_tokens_decoder, }) } } impl Encoder for TikTokenTokenizer { fn encode(&self, input: &str) -> Result { - let token_ids: Vec = self.bpe.encode_with_special_tokens(input); - Ok(Encoding::Sp(token_ids)) + self.encode_with_special_tokens(input, true) } fn encode_batch(&self, inputs: &[&str]) -> Result> { inputs.par_iter().map(|input| self.encode(input)).collect() } + + fn encode_with_special_tokens( + &self, + input: &str, + add_special_tokens: bool, + ) -> Result { + let token_ids: Vec = if add_special_tokens { + self.bpe.encode_with_special_tokens(input) + } else { + self.bpe.encode_ordinary(input) + }; + Ok(Encoding::Sp(token_ids)) + } } impl Decoder for TikTokenTokenizer { @@ -119,7 +153,20 @@ impl Decoder for TikTokenTokenizer { } } -impl Tokenizer for TikTokenTokenizer {} +impl Tokenizer for TikTokenTokenizer { + fn convert_ids_to_tokens(&self, token_ids: &[TokenIdType]) -> Result> { + Ok(token_ids + .iter() + .map(|id| { + self.decoder_tokens + .get(id) + .or_else(|| self.special_tokens_decoder.get(id)) + .map(|bytes| String::from_utf8_lossy(bytes).into_owned()) + .unwrap_or_default() + }) + .collect()) + } +} /// Parse a tiktoken model file (base64-encoded token + rank per line). fn parse_tiktoken_file(path: &str) -> Result, u32>> {