diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 27848de026..bf8b7e88bb 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -443,13 +443,13 @@ def update_params(self, request: Any): """Update params.""" self.executor.update_params(request) - def sleep(self, level: int = 1): + async def sleep(self, level: int = 1): """Sleep.""" - self.executor.sleep(level) + await self.executor.sleep(level) - def wakeup(self, tags: list[str] | None = None): + async def wakeup(self, tags: list[str] | None = None): """Wakeup.""" - self.executor.wakeup(tags) + await self.executor.wakeup(tags) async def async_loop(self): engine_loop = None diff --git a/lmdeploy/pytorch/engine/engine_loop.py b/lmdeploy/pytorch/engine/engine_loop.py index 2584b18e00..caf44e65c5 100644 --- a/lmdeploy/pytorch/engine/engine_loop.py +++ b/lmdeploy/pytorch/engine/engine_loop.py @@ -258,6 +258,16 @@ def __get_logprobs(batched_outputs: 'BatchedOutputs'): seq.append_logits(logits) return dict() + engine_error_msg = getattr(batched_outputs, 'engine_error_msg', None) + if engine_error_msg: + for msg in running: + if msg.status != MessageStatus.RUNNING: + continue + response_reqs(self.req_manager, msg.resp, ResponseType.INTERNAL_ENGINE_ERROR, + data=dict(token_ids=[]), err_msg=engine_error_msg) + msg.state.finish() + return dict() + new_token_timestamp = batched_outputs.new_token_timestamp logprobs = batched_outputs.logprobs diff --git a/lmdeploy/pytorch/engine/executor/base.py b/lmdeploy/pytorch/engine/executor/base.py index 9e17a87ba2..a01da86f41 100644 --- a/lmdeploy/pytorch/engine/executor/base.py +++ b/lmdeploy/pytorch/engine/executor/base.py @@ -78,7 +78,7 @@ async def sleep(self, level: int = 1): """Sleep.""" raise NotImplementedError('Not Implemented.') - def wakeup(self, tags: list[str] | None = None): + async def wakeup(self, tags: list[str] | None = None): """Wakeup.""" raise NotImplementedError('Not Implemented.') diff --git a/lmdeploy/pytorch/engine/executor/base_worker.py b/lmdeploy/pytorch/engine/executor/base_worker.py index d78ab9867f..0812ec9302 100644 --- a/lmdeploy/pytorch/engine/executor/base_worker.py +++ b/lmdeploy/pytorch/engine/executor/base_worker.py @@ -123,7 +123,7 @@ async def sleep(self, level: int = 1): """Sleep.""" await self.model_agent.sleep(level) - def wakeup(self, tags: list[str] | None = None): + async def wakeup(self, tags: list[str] | None = None): """Wakeup.""" self.model_agent.wakeup(tags) diff --git a/lmdeploy/pytorch/engine/executor/mp_executor.py b/lmdeploy/pytorch/engine/executor/mp_executor.py index 9f457eec3c..e41d99483f 100644 --- a/lmdeploy/pytorch/engine/executor/mp_executor.py +++ b/lmdeploy/pytorch/engine/executor/mp_executor.py @@ -373,6 +373,14 @@ def warmup(self): """Build cache engine.""" self.collective_rpc('warmup') + async def sleep(self, level: int = 1): + """Sleep.""" + await self.collective_rpc_async('sleep', args=(level, )) + + async def wakeup(self, tags: list[str] | None = None): + """Wakeup.""" + await self.collective_rpc_async('wakeup', args=(tags, )) + async def _prefetch_outputs(self): while True: out = (await self.collective_rpc_async('get_outputs', receiver_mask=1, return_mask=1))[0] diff --git a/lmdeploy/pytorch/engine/executor/ray_executor.py b/lmdeploy/pytorch/engine/executor/ray_executor.py index d585e78c15..63f1cf5f20 100644 --- a/lmdeploy/pytorch/engine/executor/ray_executor.py +++ b/lmdeploy/pytorch/engine/executor/ray_executor.py @@ -353,12 +353,18 @@ def warmup(self): """Build cache engine.""" self.collective_rpc('warmup') - def sleep(self, level: int = 1): + async def sleep(self, level: int = 1): """Sleep.""" + await asyncio.to_thread(self._sleep_collective_rpc, level) + + def _sleep_collective_rpc(self, level: int): self.collective_rpc('sleep', (level, )) - def wakeup(self, tags: list[str] | None = None): + async def wakeup(self, tags: list[str] | None = None): """Wakeup.""" + await asyncio.to_thread(self._wakeup_collective_rpc, tags) + + def _wakeup_collective_rpc(self, tags: list[str] | None): if tags is None or 'kv_cache' in tags: self.update_configs() self.collective_rpc('wakeup', (tags, )) diff --git a/lmdeploy/pytorch/engine/executor/uni_executor.py b/lmdeploy/pytorch/engine/executor/uni_executor.py index 34c7412ee6..84d5e50cf3 100644 --- a/lmdeploy/pytorch/engine/executor/uni_executor.py +++ b/lmdeploy/pytorch/engine/executor/uni_executor.py @@ -108,6 +108,14 @@ async def get_output_async(self, dp_rank: int = 0): assert dp_rank == 0 return await self.model_agent.get_output_async() + async def sleep(self, level: int = 1): + """Sleep.""" + await self.model_agent.sleep(level) + + async def wakeup(self, tags: list[str] | None = None): + """Wakeup on the event-loop thread (CUDA-safe; may block the loop).""" + self.model_agent.wakeup(tags) + def get_input_processor(self): """Get input processor.""" return self.model_agent.get_input_processor() diff --git a/lmdeploy/pytorch/engine/model_agent/__init__.py b/lmdeploy/pytorch/engine/model_agent/__init__.py index 083e6a1fe4..85bc568606 100644 --- a/lmdeploy/pytorch/engine/model_agent/__init__.py +++ b/lmdeploy/pytorch/engine/model_agent/__init__.py @@ -4,7 +4,7 @@ from lmdeploy.pytorch.devices import DeviceContext, get_device_manager from lmdeploy.pytorch.distributed import DistContext, get_dist_manager -from .agent import BaseModelAgent, BatchedOutputs # noqa: F401 +from .agent import BaseModelAgent, BatchedOutputs, CacheNotReadyError # noqa: F401 def build_model_agent( diff --git a/lmdeploy/pytorch/engine/model_agent/agent.py b/lmdeploy/pytorch/engine/model_agent/agent.py index 07353c1dcc..91fc059fc7 100644 --- a/lmdeploy/pytorch/engine/model_agent/agent.py +++ b/lmdeploy/pytorch/engine/model_agent/agent.py @@ -38,6 +38,12 @@ logger = get_logger('lmdeploy') +class CacheNotReadyError(RuntimeError): + """Raised when a forward runs while KV/state cache engines are missing.""" + + pass + + @dataclass class SleepWakeupState: to_sleep: asyncio.Event = field(default_factory=asyncio.Event) @@ -82,6 +88,7 @@ class BatchedOutputs: new_token_timestamp: int = 0 extra_outputs: ExtraOutputs | None = None all_routed_experts: torch.Tensor | None = None + engine_error_msg: str | None = None def to_cpu(self): """To cpu.""" @@ -128,11 +135,15 @@ def msg_with_rank(rank: int, msg: str): return f'rank[{rank}] - {msg}' -def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, swap_out_map: dict): +def cache_swapping(cache_engine: CacheEngine | None, swap_in_map: dict, swap_out_map: dict): """Perform cache swapping.""" issued_cache_op = False swap_in_map = swap_in_map or dict() swap_out_map = swap_out_map or dict() + if cache_engine is None and (len(swap_in_map) > 0 or len(swap_out_map) > 0): + raise CacheNotReadyError( + 'KV cache is not available; cannot swap blocks. ' + "Restore cache via wakeup with the 'kv_cache' tag before inference.") if len(swap_in_map) > 0: cache_engine.swap_in(swap_in_map) issued_cache_op = True @@ -568,6 +579,27 @@ def _push_output(self, output: BatchedOutputs): event.record() self._out_que.put_nowait((output, event)) + def _batched_outputs_for_cache_error(self, forward_inputs: dict, err_msg: str) -> BatchedOutputs: + """Build a minimal batch for get_output_async pairing.""" + inputs = forward_inputs.get('inputs') + delta = forward_inputs.get('delta') + device = self.device + if inputs is not None: + batch_size = int(inputs.seq_length.size(0)) + elif delta is not None: + batch_size = int(delta.block_offsets.size(0)) + else: + batch_size = 1 + next_token_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=device) + stopped = torch.ones(batch_size, dtype=torch.bool, device=device) + stop_pos = torch.zeros(batch_size, dtype=torch.long, device=device) + return BatchedOutputs( + next_token_ids=next_token_ids, + stopped=stopped, + stop_pos=stop_pos, + engine_error_msg=err_msg, + ) + @contextmanager def _broadcast_next_token(self, next_token_ids: torch.Tensor, extra_inputs: ExtraInputs, enable: bool = True): if not enable: @@ -644,6 +676,10 @@ def _get_inputs_from_delta( sampling_inputs: SamplingInputs, ): """Get inputs from delta.""" + if self.step_inputs.model_inputs is None: + raise CacheNotReadyError( + 'Decode step has no cached ModelInputs (e.g. after a KV-cache error or reset). ' + "Call wakeup with the 'kv_cache' tag before continuing inference.") self.step_inputs.update_delta(delta, self) inputs = self.step_inputs.model_inputs extra_inputs = self.step_inputs.extra_inputs @@ -661,6 +697,10 @@ def _prepare_inputs_prefill( if delta is not None: # update decoding inputs with delta # for second round chat + if self.step_inputs.model_inputs is None: + raise CacheNotReadyError( + 'Prefill with delta requires prior ModelInputs (e.g. after a KV-cache error). ' + "Call wakeup with the 'kv_cache' tag before continuing inference.") self.step_inputs.update_delta(delta, self) if inputs.is_first_chunk: @@ -836,6 +876,11 @@ def __update_inputs( await asyncio.sleep(0.01) return + if self.cache_engine is None or self.state_cache_engine is None: + raise CacheNotReadyError( + 'KV or state cache engine is not built (e.g. after sleep or partial wakeup). ' + "Call wakeup with the 'kv_cache' tag before running inference.") + # swap caches cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map) @@ -936,7 +981,12 @@ async def _async_loop_background(self, forward_event: asyncio.Event = None): while True: forward_inputs = await input_maker.get() - await self._async_step(**forward_inputs, ) + try: + await self._async_step(**forward_inputs, ) + except CacheNotReadyError as err: + logger.warning('Forward skipped: %s', err) + self.step_inputs = StepInputs() + self._push_output(self._batched_outputs_for_cache_error(forward_inputs, str(err))) if forward_event is not None: forward_event.set() diff --git a/lmdeploy/pytorch/engine/mp_engine/base.py b/lmdeploy/pytorch/engine/mp_engine/base.py index a5c16dd967..d98131b280 100644 --- a/lmdeploy/pytorch/engine/mp_engine/base.py +++ b/lmdeploy/pytorch/engine/mp_engine/base.py @@ -53,13 +53,13 @@ def end_session(self, session_id: int): """End session.""" return self._collective_rpc('end_session', session_id) - def sleep(self, level: int): + async def sleep(self, level: int): """sleep.""" - return self._collective_rpc('sleep', level) + return await self._collective_rpc_async('sleep', level) - def wakeup(self, tags: list[str] | None = None): + async def wakeup(self, tags: list[str] | None = None): """Wakeup.""" - return self._collective_rpc('wakeup', tags) + return await self._collective_rpc_async('wakeup', tags) def update_params(self, request: Any): """Update params.""" diff --git a/lmdeploy/pytorch/engine/mp_engine/base_worker.py b/lmdeploy/pytorch/engine/mp_engine/base_worker.py index 0e0fa0fa82..cd2816f92c 100644 --- a/lmdeploy/pytorch/engine/mp_engine/base_worker.py +++ b/lmdeploy/pytorch/engine/mp_engine/base_worker.py @@ -100,13 +100,13 @@ def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest): """ return self.engine.p2p_drop_connect(drop_conn_request) - def sleep(self, level: int = 1): + async def sleep(self, level: int = 1): """sleep.""" - return self.engine.sleep(level) + return await self.engine.sleep(level) - def wakeup(self, tags: list[str] | None = None): + async def wakeup(self, tags: list[str] | None = None): """Wakeup.""" - return self.engine.wakeup(tags) + return await self.engine.wakeup(tags) def update_params(self, request: Any): """Update params.""" diff --git a/lmdeploy/serve/core/async_engine.py b/lmdeploy/serve/core/async_engine.py index 60807bcc77..13d0cc9477 100644 --- a/lmdeploy/serve/core/async_engine.py +++ b/lmdeploy/serve/core/async_engine.py @@ -46,7 +46,7 @@ class GenOut: history_token_len: int input_token_len: int generate_token_len: int - finish_reason: Literal['stop', 'length', 'error'] | None = None + finish_reason: Literal['stop', 'length', 'error', 'abort'] | None = None token_ids: list[int] | None = None logprobs: list[dict[int, float]] | None = None logits: Any = None @@ -201,6 +201,23 @@ def _build_stat_loggers(self): # set stats loggers of metrics processor metrics_processor.stat_loggers = self.stat_loggers + def _if_session_stale(self, session: Session, + input_token_len: int) -> GenOut | None: + """If ``session.epoch`` was stamped by api_server and + ``stop_all_session`` ran since then (the engine epoch changed), drop + the session.""" + epoch = session.epoch + if epoch is None or epoch == self.epoch: + return None + logger.info( + f'[generate] session {session.session_id} dropped (session.epoch={epoch}, epoch={self.epoch})') + return GenOut(response='', + history_token_len=session.step, + input_token_len=input_token_len, + generate_token_len=0, + finish_reason='abort', + token_ids=[]) + def get_schedule_metrics(self): return self.engine.get_schedule_metrics() @@ -212,11 +229,11 @@ async def do_log_stats(self): async def stop_all_session(self): """Stop all running sessions.""" - logger.info('stop all sessions') + logger.info(f'stop all sessions, epoch {self.epoch} -> {self.epoch + 1}') self.epoch += 1 await self.session_mgr.async_abort_all() - def sleep(self, level: int = 1): + async def sleep(self, level: int = 1): """Sleep the model. Args: @@ -224,11 +241,13 @@ def sleep(self, level: int = 1): weights and discard the kv cache. Level 2 sleep will discard both the model weights and the kv cache. """ - self.engine.sleep(level) + logger.info(f'[async_engine]sleep, level={level}') + await self.engine.sleep(level) self.sleeping_tags = {'weights', 'kv_cache'} self.is_sleeping = True + logger.info('[async_engine] sleep, done') - def wakeup(self, tags: list[str] | None = None): + async def wakeup(self, tags: list[str] | None = None): """Wake up the model. Args: @@ -242,7 +261,7 @@ def wakeup(self, tags: list[str] | None = None): if any(tag not in self.sleeping_tags for tag in tags): logger.warning(f'some tag in {tags} not in sleeping tags {self.sleeping_tags}') return - self.engine.wakeup(tags) + await self.engine.wakeup(tags) # for TM backend, sleep/wakeup will reset gateway, therefore we need to rebuild instances if self.backend == 'turbomind' and 'kv_cache' in tags: self.session_mgr.build_request_handle_pool(self.engine, self.backend_config.max_batch_size) @@ -339,7 +358,8 @@ async def generate( do_preprocess (bool): whether pre-process the messages. Default to True, which means chat_template will be applied. """ - epoch = self.epoch + metrics_processor.increase_total_requests() + if (messages is not None) ^ (input_ids is None): raise ValueError('You must specify exactly one of messages or input_ids') if isinstance(session_id, Session): @@ -386,6 +406,7 @@ async def generate( if gen_config.max_new_tokens == 0: logger.info(f'run out of tokens. session={session_id}.') + metrics_processor.increase_failed_requests('error') yield GenOut(response='', history_token_len=session.step, input_token_len=len(input_ids), @@ -400,6 +421,7 @@ async def generate( or gen_config.output_logits == 'all'): errmsg = ('lmdeploy does not support outputting all token\'s logits or last_hidden_state ' 'when prefix caching is ON') + metrics_processor.increase_failed_requests('error') yield GenOut(response=errmsg, history_token_len=session.step, input_token_len=len(input_ids), @@ -421,10 +443,18 @@ def is_error(status): if not gen_config.ignore_eos: stop_ids = gen_config.stop_token_ids or [] - metrics_processor.increase_total_requests() + + stale = self._if_session_stale(session, len(prompt_input['input_ids'])) + if stale is not None: + metrics_processor.increase_failed_requests('abort') + yield stale + if sequence_end: + self.session_mgr.remove(session) + return async with session.request_handle() as handle: - if epoch != self.epoch: - logger.info(f'[generate] session {session_id} got aborted before starting inference') + if session.epoch is not None and session.epoch != self.epoch: + logger.info(f'[generate] session {session_id} got aborted before starting inference, ' + f'session.epoch={session.epoch}, epoch={self.epoch}') metrics_processor.increase_failed_requests('abort') yield GenOut(response='', history_token_len=0, diff --git a/lmdeploy/serve/managers/session_manager.py b/lmdeploy/serve/managers/session_manager.py index 0ac7e1465f..c3aa0f66ef 100644 --- a/lmdeploy/serve/managers/session_manager.py +++ b/lmdeploy/serve/managers/session_manager.py @@ -24,6 +24,9 @@ def __init__(self, session_id: int, session_mgr: SessionManager, **kwargs): self.history: list[tuple[Any, str]] = [] self.gen_config: GenerationConfig | None = None self.step: int = 0 + # Set by api_server to AsyncEngine.epoch when a request binds a session; + # generate() drops work if stop_all_session() bumped epoch after bind. + self.epoch: int | None = None # event to wait for the session to be active self._active: asyncio.Event | None = None self._handle = None # inference instance @@ -64,6 +67,7 @@ def reset(self): self.history = [] self.gen_config = None self.step = 0 + self.epoch = None self._active = None self._handle = None self._session_mgr = None @@ -101,7 +105,7 @@ async def request_handle(self): async def async_abort(self): """Abort the session.""" - logger.info(f'[session] Aborting session {self.session_id}') + logger.info(f'[session] Aborting session {self.session_id}, epoch={self.epoch}') if self._handle is not None: await self._handle.async_cancel(self.session_id) @@ -205,7 +209,7 @@ def get(self, session_id: int | None = None, **kwargs) -> Session: session.update(**kwargs) return session else: - logger.info(f'[SessionManager] session {session_id} not found. Creating...') + logger.debug(f'[SessionManager] session {session_id} not found. Creating...') session = Session(session_id, self, **kwargs) self.sessions[session_id] = session return session diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 5e84ee4221..09ac5d2324 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from __future__ import annotations + # yapf: disable import asyncio import copy @@ -10,7 +12,7 @@ from contextlib import asynccontextmanager from functools import partial from http import HTTPStatus -from typing import Literal +from typing import TYPE_CHECKING, Literal import uvicorn from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, status @@ -76,10 +78,13 @@ ) from lmdeploy.serve.openai.reasoning_parser.reasoning_parser import ReasoningParser, ReasoningParserManager from lmdeploy.serve.openai.tool_parser.tool_parser import ToolParser, ToolParserManager -from lmdeploy.serve.utils.server_utils import validate_json_request +from lmdeploy.serve.utils.server_utils import AuthenticationMiddleware, EngineSleepingMiddleware, validate_json_request from lmdeploy.tokenizer import DetokenizeState, Tokenizer from lmdeploy.utils import get_logger +if TYPE_CHECKING: + from lmdeploy.serve.managers import Session + # yapf: enable logger = get_logger('lmdeploy') @@ -100,12 +105,15 @@ class VariableInterface: enable_abort_handling: bool = False @staticmethod - def get_session(session_id: int) -> int: + def get_session(session_id: int) -> Session: session_mgr = VariableInterface.get_session_manager() if session_id == -1: - return session_mgr.get() + session = session_mgr.get() else: - return session_mgr.get(session_id) + session = session_mgr.get(session_id) + # Stamp epoch for ``stop_all_session`` / ``abort_all`` coordination in ``AsyncEngine.generate``. + session.epoch = VariableInterface.async_engine.epoch + return session @staticmethod def get_session_manager(): @@ -152,6 +160,19 @@ def create_error_response(status: HTTPStatus, message: str, error_type='invalid_ status_code=status.value) +def reject_if_engine_sleeping() -> JSONResponse | None: + """Return an error response when the engine is in sleep mode (see POST + /sleep, /wakeup).""" + eng = VariableInterface.async_engine + if eng is None or not eng.is_sleeping: + return None + return create_error_response( + HTTPStatus.SERVICE_UNAVAILABLE, + 'Engine is sleeping; call POST /wakeup before inference (e.g. tags=weights&tags=kv_cache).', + error_type='engine_sleeping', + ) + + def check_request(request) -> JSONResponse | None: """Check if a request is valid.""" if hasattr(request, 'model') and request.model not in get_model_list(): @@ -409,6 +430,9 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque error_check_ret = check_request(request) if error_check_ret is not None: return error_check_ret + sleeping_ret = reject_if_engine_sleeping() + if sleeping_ret is not None: + return sleeping_ret if VariableInterface.tool_parser is not None: request = VariableInterface.tool_parser.adjust_request(request) session = VariableInterface.get_session(request.session_id) @@ -769,7 +793,6 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None error_check_ret = check_request(request) if error_check_ret is not None: return error_check_ret - json_request = await raw_request.json() migration_request = json_request.pop('migration_request', None) with_cache = json_request.pop('with_cache', False) @@ -1175,7 +1198,9 @@ def update_params(request: UpdateParamsRequest, raw_request: Request = None): @router.post('/sleep', dependencies=[Depends(validate_json_request)]) async def sleep(raw_request: Request = None): level = raw_request.query_params.get('level', '1') - VariableInterface.async_engine.sleep(int(level)) + async_engine = VariableInterface.async_engine + await async_engine.stop_all_session() + await async_engine.sleep(int(level)) return Response(status_code=200) @@ -1183,7 +1208,7 @@ async def sleep(raw_request: Request = None): async def wakeup(raw_request: Request = None): tags = raw_request.query_params.getlist('tags') tags = tags or None - VariableInterface.async_engine.wakeup(tags) + await VariableInterface.async_engine.wakeup(tags) return Response(status_code=200) @@ -1526,10 +1551,13 @@ def serve(model_path: str, ) if api_keys is not None and (tokens := [key for key in api_keys if key]): - from lmdeploy.serve.utils.server_utils import AuthenticationMiddleware - app.add_middleware(AuthenticationMiddleware, tokens=tokens) + def is_engine_sleeping() -> bool: + eng = VariableInterface.async_engine + return eng is not None and eng.is_sleeping + app.add_middleware(EngineSleepingMiddleware, is_sleeping=is_engine_sleeping) + # set the maximum number of concurrent requests if max_concurrent_requests is not None: app.add_middleware(ConcurrencyLimitMiddleware, max_concurrent_requests=max_concurrent_requests) diff --git a/lmdeploy/serve/utils/server_utils.py b/lmdeploy/serve/utils/server_utils.py index f7bcbdfa49..f032b2bc14 100644 --- a/lmdeploy/serve/utils/server_utils.py +++ b/lmdeploy/serve/utils/server_utils.py @@ -2,7 +2,8 @@ # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/server_utils.py import hashlib import secrets -from collections.abc import Awaitable +from collections.abc import Awaitable, Callable +from http import HTTPStatus from fastapi import Request from fastapi.exceptions import RequestValidationError @@ -18,6 +19,58 @@ def validate_json_request(raw_request: Request): raise RequestValidationError(errors=["Unsupported Media Type: Only 'application/json' is allowed"]) +class EngineSleepingMiddleware: + """Pure ASGI middleware that returns 503 for configured inference routes + when ``is_sleeping()`` is true (after ``POST /sleep``, until ``POST + /wakeup``). + + Notes + ----- + - Skips non-http scopes (except ``http``/``websocket`` are passed through + to the app; only ``http`` requests are gated). + - HTTP ``OPTIONS`` is passed through so CORS preflight is unaffected. + """ + + # POST routes rejected while sleeping (see POST /sleep, /wakeup). + DEFAULT_PROTECTED_INFERENCE_ROUTES = frozenset({ + ('POST', '/v1/chat/completions'), + ('POST', '/v1/completions'), + ('POST', '/generate'), + }) + + def __init__( + self, + app: ASGIApp, + is_sleeping: Callable[[], bool], + protected_routes: frozenset[tuple[str, str]] | None = None, + ) -> None: + self.app = app + self.is_sleeping = is_sleeping + self.protected_routes = protected_routes or type(self).DEFAULT_PROTECTED_INFERENCE_ROUTES + + def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]: + if scope['type'] not in ('http', 'websocket'): + return self.app(scope, receive, send) + if scope['type'] == 'http' and scope['method'] == 'OPTIONS': + return self.app(scope, receive, send) + if scope['type'] == 'http': + root_path = scope.get('root_path', '') + url_path = URL(scope=scope).path.removeprefix(root_path) + key = (scope['method'], url_path) + if key in self.protected_routes and self.is_sleeping(): + response = JSONResponse( + content={ + 'error': ( + 'Engine is sleeping; call POST /wakeup before inference ' + '(e.g. tags=weights&tags=kv_cache).' + ), + }, + status_code=HTTPStatus.SERVICE_UNAVAILABLE, + ) + return response(scope, receive, send) + return self.app(scope, receive, send) + + class AuthenticationMiddleware: """Pure ASGI middleware that authenticates each request by checking if the Authorization Bearer token exists and equals anyof "{api_key}". diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index f95b2b93ca..527444e79f 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -285,20 +285,28 @@ def _from_hf(self, model_path: str, engine_config: TurbomindEngineConfig): self._tm_model = tm_model return model_comm - def sleep(self, level: int = 1): - """Sleep the model.""" + def _sleep_sync(self, level: int = 1): + """Synchronous sleep implementation (runs in worker thread).""" with ThreadPoolExecutor(max_workers=self.gpu_count) as e: for _ in e.map(self.model_comm.sleep, range(self.gpu_count), [level] * self.gpu_count): pass - def wakeup(self, tags: list[str] | None = None): - """Wakeup the model.""" + async def sleep(self, level: int = 1): + """Sleep the model.""" + await asyncio.to_thread(self._sleep_sync, level) + + def _wakeup_sync(self, tags: list[str] | None = None): + """Synchronous wakeup implementation (runs in worker thread).""" if tags is None: tags = ['weights', 'kv_cache'] with ThreadPoolExecutor(max_workers=self.gpu_count) as e: for _ in e.map(self.model_comm.wakeup, range(self.gpu_count), [tags] * self.gpu_count): pass + async def wakeup(self, tags: list[str] | None = None): + """Wakeup the model.""" + await asyncio.to_thread(self._wakeup_sync, tags) + def update_params(self, request: UpdateParamsRequest): """Update params.