Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions lmdeploy/serve/core/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,7 @@ async def stop_all_session(self):
logger.info(f'stop all sessions, epoch {self.epoch} -> {self.epoch + 1}')
self.epoch += 1
await self.session_mgr.async_abort_all()

def prepare_sleep(self):
"""Reject new inference requests before backend sleep starts."""
self.sleeping_tags = {'weights', 'kv_cache'}
self.is_sleeping = True
logger.info('stopped all sessions')

async def sleep(self, level: int = 1):
"""Sleep the model.
Expand All @@ -249,9 +245,10 @@ async def sleep(self, level: int = 1):
weights and discard the kv cache. Level 2 sleep will
discard both the model weights and the kv cache.
"""
await self.engine.sleep(level)
self.sleeping_tags = {'weights', 'kv_cache'}
self.is_sleeping = True
self.sleeping_tags = {'weights', 'kv_cache'}
await self.stop_all_session()
await self.engine.sleep(level)

def wakeup(self, tags: list[str] | None = None):
"""Wake up the model.
Expand Down
38 changes: 33 additions & 5 deletions lmdeploy/serve/managers/session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,34 @@ def __init__(self):
"""Initialize the session manager."""

self.sessions = {}
self.session_id_generator = itertools.count(1)
self.session_id_generator = itertools.count(0)
self.request_handle_pool = None
Comment thread
lvhan028 marked this conversation as resolved.
self.loop = None
# user_session_id->session_id. If user specifies
# a session_id when visiting the api_server's endpoint,
# we map the user_session_id to the session_id to keep
# session's id globally identical across different requests.
self.user_session_id_map = {}
# session_id->user_session_id map.
self.session_id_map = {}

def map_user_session_id(self, user_session_id: int) -> int:
"""Map a user_session_id to a session_id."""
if user_session_id in self.user_session_id_map:
raise ValueError(f'User session id {user_session_id} already exists')
session_id = next(self.session_id_generator)
self.user_session_id_map[user_session_id] = session_id
self.session_id_map[session_id] = user_session_id
return session_id
Comment on lines +210 to +217
Copy link

Copilot AI Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

map_user_session_id() currently raises if user_session_id is already present. Given the API server expects the same user-provided session_id to be reusable across requests (and create_session() calls this on every request), this should be idempotent: return the existing mapped session_id when present, and only allocate a new internal session_id when absent. Otherwise, repeated calls with the same user_session_id will crash the request flow.

Copilot uses AI. Check for mistakes.

def get(self, session_id: int | None = None, create_if_not_exists: bool = True, **kwargs) -> Session | None:
"""Get or create a session."""
if not create_if_not_exists:
return self.sessions.get(session_id, None)
Comment thread
lvhan028 marked this conversation as resolved.

if session_id is None:
session_id = next(self.session_id_generator)

def get(self, session_id: int | None = None, **kwargs) -> Session:
"""Create a new session."""
session_id = session_id or next(self.session_id_generator)
if session_id in self.sessions:
logger.debug(f'[SessionManager] session {session_id} already exists. Updating...')
session = self.sessions[session_id]
Expand All @@ -224,17 +245,24 @@ async def async_abort_all(self):
# "abort all" is designed for async RL. The aborted sessions will be no longer used,
# so we clear the sessions here.
self.sessions.clear()
self.user_session_id_map.clear()
self.session_id_map.clear()

def has(self, session_id):
return session_id in self.sessions

def remove(self, session: Session):
self.sessions.pop(session.session_id, None)
user_session_id = self.session_id_map.pop(session.session_id, None)
if user_session_id is not None:
self.user_session_id_map.pop(user_session_id, None)

def clear(self):
self.sessions.clear()
self.user_session_id_map.clear()
self.session_id_map.clear()
# reset the session id generator
self.session_id_generator = itertools.count(1)
self.session_id_generator = itertools.count(0)

def attach_event_loop(self, loop):
self.loop = loop
Expand Down
55 changes: 37 additions & 18 deletions lmdeploy/serve/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,24 +104,41 @@ class VariableInterface:
allow_terminate_by_client: bool = False
enable_abort_handling: bool = False

@staticmethod
def get_session(session_id: int) -> Session:
session_mgr = VariableInterface.get_session_manager()
if session_id == -1:
@classmethod
def create_session(cls, user_session_id: int | None = None) -> Session:
session_mgr = cls.get_session_manager()
if user_session_id is None or user_session_id == -1:
# user doesn't input session_id, so we need to generate a new one
session = session_mgr.get()
else:
# find the inside session_id by user_session_id, create a new one
# if it doesn't exist and update the user_session_id_map
session_id = session_mgr.map_user_session_id(user_session_id)
logger.info(f'created session {session_id} for user_session_id {user_session_id}')
Comment thread
lvhan028 marked this conversation as resolved.
Outdated
Copy link

Copilot AI Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create_session() calls session_mgr.map_user_session_id(user_session_id) unconditionally. As implemented, map_user_session_id raises if the user_session_id already exists, so repeated requests (or interactive continuation) with the same session_id will fail with an unhandled ValueError (500) instead of reusing the existing session mapping. Make the mapping lookup idempotent (return the existing mapped session_id when present) or check user_session_id_map first and only create a new mapping when it doesn't exist.

Suggested change
session_id = session_mgr.map_user_session_id(user_session_id)
logger.info(f'created session {session_id} for user_session_id {user_session_id}')
session_id = session_mgr.user_session_id_map.get(user_session_id)
if session_id is None:
session_id = session_mgr.map_user_session_id(user_session_id)
logger.info(f'created session {session_id} for user_session_id {user_session_id}')

Copilot uses AI. Check for mistakes.
session = session_mgr.get(session_id)
# Stamp epoch for ``stop_all_session`` / ``abort_all`` coordination in ``AsyncEngine.generate``.
session.epoch = VariableInterface.async_engine.epoch
session.epoch = cls.async_engine.epoch
return session

@staticmethod
def get_session_manager():
return VariableInterface.async_engine.session_mgr
@classmethod
def find_session(cls, user_session_id: int) -> Session | None:
"""Find the session by user_session_id.

Users cannot access inner session_id directly.
"""
session_mgr = cls.get_session_manager()
session_id = session_mgr.user_session_id_map.get(user_session_id, None)
if session_id is None:
return None
return session_mgr.get(session_id, create_if_not_exists=False)

@staticmethod
def get_engine_config():
return VariableInterface.async_engine.backend_config
@classmethod
def get_session_manager(cls):
return cls.async_engine.session_mgr

@classmethod
def get_engine_config(cls):
return cls.async_engine.backend_config


router = APIRouter()
Expand Down Expand Up @@ -419,7 +436,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque
return error_check_ret
if VariableInterface.tool_parser is not None:
request = VariableInterface.tool_parser.adjust_request(request)
session = VariableInterface.get_session(request.session_id)
session = VariableInterface.create_session(request.session_id)
Copy link

Copilot AI Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chat_completions_v1 now passes a mapped/internal session via create_session(). Later in this handler the response id is derived from session.session_id, which will now be the internal id (not the user-provided session_id) and can effectively expose internal ids to clients. Consider switching response ids to use the user-provided session id (or a separate request UUID) so internal ids remain an implementation detail, and keep behavior consistent with /v1/completions.

Copilot uses AI. Check for mistakes.

json_request = await raw_request.json()
migration_request = json_request.pop('migration_request', None)
Expand Down Expand Up @@ -793,10 +810,10 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None
sessions = []
if isinstance(request.prompt, str):
request.prompt = [request.prompt]
sessions.append(VariableInterface.get_session(request.session_id))
sessions.append(VariableInterface.create_session(request.session_id))
elif isinstance(request.prompt, list):
for i in range(len(request.prompt)):
sessions.append(VariableInterface.get_session(i + 1))
sessions.append(VariableInterface.create_session(i + 1))
if isinstance(request.stop, str):
request.stop = [request.stop]
random_seed = request.seed if request.seed else None
Expand Down Expand Up @@ -971,7 +988,7 @@ async def generate(request: GenerateReqInput, raw_request: Request = None):
if error_check_ret is not None:
return error_check_ret

session = VariableInterface.get_session(request.session_id)
session = VariableInterface.create_session(request.session_id)

prompt = request.prompt
input_ids = request.input_ids
Expand Down Expand Up @@ -1190,8 +1207,6 @@ async def sleep(raw_request: Request = None):
if level not in (1, 2):
return create_error_response(HTTPStatus.BAD_REQUEST, 'The "level" query parameter must be 1 or 2.')
async_engine = VariableInterface.async_engine
async_engine.prepare_sleep()
await async_engine.stop_all_session()
await async_engine.sleep(level)
return Response(status_code=200)

Expand Down Expand Up @@ -1265,8 +1280,12 @@ async def abort_request(request: AbortRequest, raw_request: Request = None):
if request.abort_all:
await VariableInterface.async_engine.stop_all_session()
else:
session = VariableInterface.get_session(request.session_id)
session = VariableInterface.find_session(request.session_id)
if session is None:
return create_error_response(HTTPStatus.BAD_REQUEST, f'Session {request.session_id} not found.')
await session.async_abort()
session_mgr = VariableInterface.get_session_manager()
session_mgr.remove(session)
Comment thread
lvhan028 marked this conversation as resolved.
return Response(status_code=200)


Expand Down
Loading