diff --git a/lmdeploy/serve/core/async_engine.py b/lmdeploy/serve/core/async_engine.py index a259fbdd90..c64abb9d39 100644 --- a/lmdeploy/serve/core/async_engine.py +++ b/lmdeploy/serve/core/async_engine.py @@ -235,11 +235,7 @@ async def stop_all_session(self): logger.info(f'stop all sessions, epoch {self.epoch} -> {self.epoch + 1}') self.epoch += 1 await self.session_mgr.async_abort_all() - - def prepare_sleep(self): - """Reject new inference requests before backend sleep starts.""" - self.sleeping_tags = {'weights', 'kv_cache'} - self.is_sleeping = True + logger.info('stopped all sessions') async def sleep(self, level: int = 1): """Sleep the model. @@ -249,9 +245,10 @@ async def sleep(self, level: int = 1): weights and discard the kv cache. Level 2 sleep will discard both the model weights and the kv cache. """ - await self.engine.sleep(level) - self.sleeping_tags = {'weights', 'kv_cache'} self.is_sleeping = True + self.sleeping_tags = {'weights', 'kv_cache'} + await self.stop_all_session() + await self.engine.sleep(level) def wakeup(self, tags: list[str] | None = None): """Wake up the model. diff --git a/lmdeploy/serve/managers/session_manager.py b/lmdeploy/serve/managers/session_manager.py index 685631091f..0dfa20e5aa 100644 --- a/lmdeploy/serve/managers/session_manager.py +++ b/lmdeploy/serve/managers/session_manager.py @@ -196,13 +196,34 @@ def __init__(self): """Initialize the session manager.""" self.sessions = {} - self.session_id_generator = itertools.count(1) + self.session_id_generator = itertools.count(0) self.request_handle_pool = None self.loop = None + # user_session_id->session_id. If user specifies + # a session_id when visiting the api_server's endpoint, + # we map the user_session_id to the session_id to keep + # session's id globally identical across different requests. + self.user_session_id_map = {} + # session_id->user_session_id map. + self.session_id_map = {} + + def map_user_session_id(self, user_session_id: int) -> int: + """Map a user_session_id to a session_id.""" + if user_session_id in self.user_session_id_map: + raise ValueError(f'User session id {user_session_id} already exists') + session_id = next(self.session_id_generator) + self.user_session_id_map[user_session_id] = session_id + self.session_id_map[session_id] = user_session_id + return session_id + + def get(self, session_id: int | None = None, create_if_not_exists: bool = True, **kwargs) -> Session | None: + """Get or create a session.""" + if not create_if_not_exists: + return self.sessions.get(session_id, None) + + if session_id is None: + session_id = next(self.session_id_generator) - def get(self, session_id: int | None = None, **kwargs) -> Session: - """Create a new session.""" - session_id = session_id or next(self.session_id_generator) if session_id in self.sessions: logger.debug(f'[SessionManager] session {session_id} already exists. Updating...') session = self.sessions[session_id] @@ -224,17 +245,24 @@ async def async_abort_all(self): # "abort all" is designed for async RL. The aborted sessions will be no longer used, # so we clear the sessions here. self.sessions.clear() + self.user_session_id_map.clear() + self.session_id_map.clear() def has(self, session_id): return session_id in self.sessions def remove(self, session: Session): self.sessions.pop(session.session_id, None) + user_session_id = self.session_id_map.pop(session.session_id, None) + if user_session_id is not None: + self.user_session_id_map.pop(user_session_id, None) def clear(self): self.sessions.clear() + self.user_session_id_map.clear() + self.session_id_map.clear() # reset the session id generator - self.session_id_generator = itertools.count(1) + self.session_id_generator = itertools.count(0) def attach_event_loop(self, loop): self.loop = loop diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 4e93acc9f3..5f4b83ffcd 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -104,24 +104,40 @@ class VariableInterface: allow_terminate_by_client: bool = False enable_abort_handling: bool = False - @staticmethod - def get_session(session_id: int) -> Session: - session_mgr = VariableInterface.get_session_manager() - if session_id == -1: + @classmethod + def create_session(cls, user_session_id: int | None = None) -> Session: + session_mgr = cls.get_session_manager() + if user_session_id is None or user_session_id == -1: + # user doesn't input session_id, so we need to generate a new one session = session_mgr.get() else: + # find the inside session_id by user_session_id, create a new one + # if it doesn't exist and update the user_session_id_map + session_id = session_mgr.map_user_session_id(user_session_id) session = session_mgr.get(session_id) # Stamp epoch for ``stop_all_session`` / ``abort_all`` coordination in ``AsyncEngine.generate``. - session.epoch = VariableInterface.async_engine.epoch + session.epoch = cls.async_engine.epoch return session - @staticmethod - def get_session_manager(): - return VariableInterface.async_engine.session_mgr + @classmethod + def find_session(cls, user_session_id: int) -> Session | None: + """Find the session by user_session_id. + + Users cannot access inner session_id directly. + """ + session_mgr = cls.get_session_manager() + session_id = session_mgr.user_session_id_map.get(user_session_id, None) + if session_id is None: + return None + return session_mgr.get(session_id, create_if_not_exists=False) - @staticmethod - def get_engine_config(): - return VariableInterface.async_engine.backend_config + @classmethod + def get_session_manager(cls): + return cls.async_engine.session_mgr + + @classmethod + def get_engine_config(cls): + return cls.async_engine.backend_config router = APIRouter() @@ -419,7 +435,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque return error_check_ret if VariableInterface.tool_parser is not None: request = VariableInterface.tool_parser.adjust_request(request) - session = VariableInterface.get_session(request.session_id) + session = VariableInterface.create_session(request.session_id) json_request = await raw_request.json() migration_request = json_request.pop('migration_request', None) @@ -793,10 +809,10 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None sessions = [] if isinstance(request.prompt, str): request.prompt = [request.prompt] - sessions.append(VariableInterface.get_session(request.session_id)) + sessions.append(VariableInterface.create_session(request.session_id)) elif isinstance(request.prompt, list): for i in range(len(request.prompt)): - sessions.append(VariableInterface.get_session(i + 1)) + sessions.append(VariableInterface.create_session(i + 1)) if isinstance(request.stop, str): request.stop = [request.stop] random_seed = request.seed if request.seed else None @@ -971,7 +987,7 @@ async def generate(request: GenerateReqInput, raw_request: Request = None): if error_check_ret is not None: return error_check_ret - session = VariableInterface.get_session(request.session_id) + session = VariableInterface.create_session(request.session_id) prompt = request.prompt input_ids = request.input_ids @@ -1190,8 +1206,6 @@ async def sleep(raw_request: Request = None): if level not in (1, 2): return create_error_response(HTTPStatus.BAD_REQUEST, 'The "level" query parameter must be 1 or 2.') async_engine = VariableInterface.async_engine - async_engine.prepare_sleep() - await async_engine.stop_all_session() await async_engine.sleep(level) return Response(status_code=200) @@ -1265,8 +1279,12 @@ async def abort_request(request: AbortRequest, raw_request: Request = None): if request.abort_all: await VariableInterface.async_engine.stop_all_session() else: - session = VariableInterface.get_session(request.session_id) + session = VariableInterface.find_session(request.session_id) + if session is None: + return create_error_response(HTTPStatus.BAD_REQUEST, f'Session {request.session_id} not found.') await session.async_abort() + session_mgr = VariableInterface.get_session_manager() + session_mgr.remove(session) return Response(status_code=200)