Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lmdeploy/pytorch/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def create_instance(self, cuda_stream_id=0):

class EngineInstanceBase:

async def async_start_session(self, session_id: int):
"""Ensure session exists before streaming starts."""
raise NotImplementedError('This method is not implemented.')

async def async_end(self, session_id: int):
"""End the given session."""
raise NotImplementedError('This method is not implemented.')
Expand Down
6 changes: 4 additions & 2 deletions lmdeploy/pytorch/engine/engine_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,10 @@ def _try_add_session(self, session_id: int):
"""
return try_add_session(self.req_sender, session_id)

async def async_start_session(self, session_id: int):
"""Ensure the session exists before request streaming."""
await self._async_try_add_session(session_id)

async def async_stream_infer(self,
session_id: int,
input_ids: list[int],
Expand All @@ -188,8 +192,6 @@ async def async_stream_infer(self,
return
gen_config = gen_config or GenerationConfig()
sampling_param = SamplingParam.from_gen_config(gen_config=gen_config)
logger.debug(f'session[{session_id}] try add session.')
self.req_sender.send_async(RequestType.ADD_SESSION, dict(session_id=session_id, response=False))
msg = dict(
token_ids=input_ids,
session_id=session_id,
Expand Down
12 changes: 9 additions & 3 deletions lmdeploy/pytorch/engine/mp_engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@ def __init__(self, engine: MPEngine):
self.engine = engine
self.session_states = engine.session_states

async def async_start_session(self, session_id: int):
"""Ensure session exists on all workers before streaming."""
# Lazily create local session state at startup.
# Before API split, this side effect happened in async_stream_infer.
state = self.session_states[session_id]
ret = await self.engine._collective_rpc_async('instance_async_start_session', session_id)
state.is_exists.set()
return ret

async def async_end(self, session_id: int):
"""End the given session."""
if session_id not in self.session_states:
Expand All @@ -117,12 +126,9 @@ async def async_cancel(self, session_id: int):

async def async_stream_infer(self, session_id: int, *args, **kwargs):
"""Send stream inference request."""
state = self.session_states[session_id]
kwargs['session_id'] = session_id
kwargs['notify_add_msg'] = True
generator = self.engine._collective_rpc_streaming_async('instance_async_stream_infer', *args, **kwargs)
# session should have been added
state.is_exists.set()

async for result in generator:
yield result
9 changes: 9 additions & 0 deletions lmdeploy/pytorch/engine/mp_engine/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ async def async_end(self, session_id: int):
async with self.instance() as instance:
return await instance.async_end(session_id)

async def async_start_session(self, session_id: int):
"""Ensure session exists before streaming."""
async with self.instance() as instance:
return await instance.async_start_session(session_id)

async def async_cancel(self, session_id: int):
"""Stop current streaming inference."""
async with self.instance() as instance:
Expand Down Expand Up @@ -120,6 +125,10 @@ async def instance_async_end(self, session_id: int):
"""End the given session."""
return await self.instance_pool.async_end(session_id)

async def instance_async_start_session(self, session_id: int):
"""Ensure session exists before streaming."""
return await self.instance_pool.async_start_session(session_id)

async def instance_async_cancel(self, session_id: int):
"""Stop current streaming inference."""
return await self.instance_pool.async_cancel(session_id)
Expand Down
36 changes: 19 additions & 17 deletions lmdeploy/serve/core/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,17 +458,21 @@ def is_error(status):
self.session_mgr.remove(session)
return
async with session.request_handle() as handle:
if session.epoch is not None and session.epoch != self.epoch:
logger.info(f'[generate] session {session_id} got aborted before starting inference, '
f'session.epoch={session.epoch}, async_engine.epoch={self.epoch}')
metrics_processor.increase_failed_requests('abort')
yield GenOut(response='',
history_token_len=0,
input_token_len=len(input_ids),
generate_token_len=0,
finish_reason='abort',
token_ids=[])
return
# Serialize same-session lifecycle operations during startup only.
# Once request startup is complete, decode streaming remains lock-free.
async with session._lifecycle_lock:
if session.epoch is not None and session.epoch != self.epoch:
Comment on lines 460 to +464
Copy link

Copilot AI Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential deadlock due to lock ordering: generate() acquires session.request_handle() (which sets session._handle/_active) before taking session._lifecycle_lock. If another task calls session.async_close() between those steps, it can take _lifecycle_lock and then await _active, while generate() is blocked waiting for _lifecycle_lock, leaving _active unset forever. To avoid this, take _lifecycle_lock before entering request_handle(), or ensure async_close() never waits on _active while holding _lifecycle_lock.

Copilot uses AI. Check for mistakes.
logger.info(f'[generate] session {session_id} got aborted before starting inference, '
f'session.epoch={session.epoch}, async_engine.epoch={self.epoch}')
metrics_processor.increase_failed_requests('abort')
yield GenOut(response='',
history_token_len=0,
input_token_len=len(input_ids),
generate_token_len=0,
finish_reason='abort',
token_ids=[])
return
await handle.async_start_session(session.session_id)
token_ids = input_ids.copy()
Comment thread
lvhan028 marked this conversation as resolved.
history_len = session.step
input_len = len(input_ids)
Expand Down Expand Up @@ -600,12 +604,9 @@ def is_error(status):
# until the session is finished, i.e., session.request_handle() context exits.
await handle.async_end(session.session_id)
self.session_mgr.remove(session)
# if sequence_end:
# if self.backend == 'pytorch':
# # manually end pytorch session. session cannot be ended until session.request_handle()
# # context exits
# await session.async_close()
# self.session_mgr.remove(session)
# We cannot call end session after with session.request_handle(), because session.async_close()
# will try to request another free handle, which might be blocked if other sessions are waiting
# for the same request_handle_pool.

def start_loop(self, loop, use_async_api=False):
"""Start engine loop.
Expand Down Expand Up @@ -680,6 +681,7 @@ async def async_get_logits(self,

async def _proc(session, i):
async with session.request_handle() as handle:
await handle.async_start_session(session.session_id)
input_len = len(input_ids[i])
# TODO(lvhan): Fix the ugly code later on
max_new_tokens = 1 if self.backend == 'turbomind' else 0
Expand Down
39 changes: 25 additions & 14 deletions lmdeploy/serve/managers/session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def __init__(self, session_id: int, session_mgr: SessionManager, **kwargs):
# Set by api_server to AsyncEngine.epoch when a request binds a session;
# generate() drops work if stop_all_session() bumped epoch after bind.
self.epoch: int | None = None
# Serialize per-session lifecycle operations (generate startup, abort, close)
# without affecting cross-session concurrency.
self._lifecycle_lock = asyncio.Lock()
# event to wait for the session to be active
self._active: asyncio.Event | None = None
self._handle = None # inference instance
Expand Down Expand Up @@ -105,23 +108,31 @@ async def request_handle(self):

async def async_abort(self):
"""Abort the session."""
logger.debug(f'[session] Aborting session {self.session_id}, epoch={self.epoch}')
if self._handle is not None:
await self._handle.async_cancel(self.session_id)
async with self._lifecycle_lock:
logger.debug(f'[session] Aborting session {self.session_id}, epoch={self.epoch}')
# Session already closed/reset; treat as a benign no-op.
if self._session_mgr is None:
return
if self._handle is not None:
await self._handle.async_cancel(self.session_id)

async def async_close(self):
"""End the session."""
logger.info(f'[session] Ending session {self.session_id}')
if self._handle is None and self.step == 0:
return
if self._handle is not None:
await self._active.wait()
async with self.request_handle() as handle:
try:
await handle.async_end(self.session_id)
except (Exception, asyncio.CancelledError, GeneratorExit) as e:
logger.exception(f'[async_close] exception caught: {e}')
self.reset()
async with self._lifecycle_lock:
logger.info(f'[session] Ending session {self.session_id}')
# Already closed/reset; keep end idempotent.
if self._session_mgr is None:
return
if self._handle is None and self.step == 0:
return
if self._handle is not None:
await self._active.wait()
async with self.request_handle() as handle:
try:
await handle.async_end(self.session_id)
except (Exception, asyncio.CancelledError, GeneratorExit) as e:
logger.exception(f'[async_close] exception caught: {e}')
self.reset()
Comment on lines +121 to +135
Copy link

Copilot AI Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Holding _lifecycle_lock while awaiting _active can deadlock with AsyncEngine.generate(): generate() enters request_handle() (sets _handle/_active) and then waits for _lifecycle_lock, while async_close() can acquire _lifecycle_lock first and then wait on _active, preventing generate() from ever progressing to release the handle. Consider releasing _lifecycle_lock before await self._active.wait() (e.g., capture the event under the lock, then await outside), or adjust lock acquisition order so generate() takes _lifecycle_lock before request_handle().

Suggested change
async with self._lifecycle_lock:
logger.info(f'[session] Ending session {self.session_id}')
# Already closed/reset; keep end idempotent.
if self._session_mgr is None:
return
if self._handle is None and self.step == 0:
return
if self._handle is not None:
await self._active.wait()
async with self.request_handle() as handle:
try:
await handle.async_end(self.session_id)
except (Exception, asyncio.CancelledError, GeneratorExit) as e:
logger.exception(f'[async_close] exception caught: {e}')
self.reset()
while True:
active = None
async with self._lifecycle_lock:
logger.info(f'[session] Ending session {self.session_id}')
# Already closed/reset; keep end idempotent.
if self._session_mgr is None:
return
if self._handle is None and self.step == 0:
return
if self._handle is not None:
active = self._active
else:
async with self.request_handle() as handle:
try:
await handle.async_end(self.session_id)
except (Exception, asyncio.CancelledError, GeneratorExit) as e:
logger.exception(f'[async_close] exception caught: {e}')
self.reset()
return
await active.wait()

Copilot uses AI. Check for mistakes.

def abort(self):
"""Abort the session in sync mode."""
Expand Down
6 changes: 5 additions & 1 deletion lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,10 @@ def prepare_inputs(self,
async def async_cancel(self, session_id: int = None):
self.model_inst.cancel()

async def async_start_session(self, session_id: int):
"""TurboMind does not require an explicit add-session request."""
return None

def async_end_cb(self, fut: asyncio.Future, status: int):
"""Executing on engine's signaling thread."""
logger.info(f'[async_end_cb] session ended, status = {status}')
Expand Down Expand Up @@ -828,5 +832,5 @@ def _get_generation_config(self, cfg: GenerationConfig):
c.output_logprobs = cfg.logprobs
if cfg.random_seed is not None:
c.random_seed = cfg.random_seed
# print (c)

return c
Loading