From 8856f3667a9c90cd2bd569f194f9122b9628822c Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Wed, 22 Apr 2026 05:28:17 +0000 Subject: [PATCH 1/3] refactor: split session startup from streaming to enforce add-cancel-end ordering --- lmdeploy/pytorch/engine/base.py | 4 ++ lmdeploy/pytorch/engine/engine_instance.py | 6 ++- lmdeploy/pytorch/engine/mp_engine/base.py | 12 ++++-- .../pytorch/engine/mp_engine/base_worker.py | 9 +++++ lmdeploy/serve/core/async_engine.py | 35 +++++++++-------- lmdeploy/serve/managers/session_manager.py | 39 ++++++++++++------- lmdeploy/turbomind/turbomind.py | 31 ++++++++++++++- 7 files changed, 98 insertions(+), 38 deletions(-) diff --git a/lmdeploy/pytorch/engine/base.py b/lmdeploy/pytorch/engine/base.py index 590a1ec82a..46aa0ed3b6 100644 --- a/lmdeploy/pytorch/engine/base.py +++ b/lmdeploy/pytorch/engine/base.py @@ -42,6 +42,10 @@ def create_instance(self, cuda_stream_id=0): class EngineInstanceBase: + async def async_start_session(self, session_id: int): + """Ensure session exists before streaming starts.""" + raise NotImplementedError('This method is not implemented.') + async def async_end(self, session_id: int): """End the given session.""" raise NotImplementedError('This method is not implemented.') diff --git a/lmdeploy/pytorch/engine/engine_instance.py b/lmdeploy/pytorch/engine/engine_instance.py index 217e1d4609..cf08a49125 100644 --- a/lmdeploy/pytorch/engine/engine_instance.py +++ b/lmdeploy/pytorch/engine/engine_instance.py @@ -163,6 +163,10 @@ def _try_add_session(self, session_id: int): """ return try_add_session(self.req_sender, session_id) + async def async_start_session(self, session_id: int): + """Ensure the session exists before request streaming.""" + await self._async_try_add_session(session_id) + async def async_stream_infer(self, session_id: int, input_ids: list[int], @@ -188,8 +192,6 @@ async def async_stream_infer(self, return gen_config = gen_config or GenerationConfig() sampling_param = SamplingParam.from_gen_config(gen_config=gen_config) - logger.debug(f'session[{session_id}] try add session.') - self.req_sender.send_async(RequestType.ADD_SESSION, dict(session_id=session_id, response=False)) msg = dict( token_ids=input_ids, session_id=session_id, diff --git a/lmdeploy/pytorch/engine/mp_engine/base.py b/lmdeploy/pytorch/engine/mp_engine/base.py index 2dfe423ed7..7b7ece6fff 100644 --- a/lmdeploy/pytorch/engine/mp_engine/base.py +++ b/lmdeploy/pytorch/engine/mp_engine/base.py @@ -97,6 +97,15 @@ def __init__(self, engine: MPEngine): self.engine = engine self.session_states = engine.session_states + async def async_start_session(self, session_id: int): + """Ensure session exists on all workers before streaming.""" + # Lazily create local session state at startup. + # Before API split, this side effect happened in async_stream_infer. + state = self.session_states[session_id] + ret = await self.engine._collective_rpc_async('instance_async_start_session', session_id) + state.is_exists.set() + return ret + async def async_end(self, session_id: int): """End the given session.""" if session_id not in self.session_states: @@ -117,12 +126,9 @@ async def async_cancel(self, session_id: int): async def async_stream_infer(self, session_id: int, *args, **kwargs): """Send stream inference request.""" - state = self.session_states[session_id] kwargs['session_id'] = session_id kwargs['notify_add_msg'] = True generator = self.engine._collective_rpc_streaming_async('instance_async_stream_infer', *args, **kwargs) - # session should have been added - state.is_exists.set() async for result in generator: yield result diff --git a/lmdeploy/pytorch/engine/mp_engine/base_worker.py b/lmdeploy/pytorch/engine/mp_engine/base_worker.py index bc2076863a..0a4da5d381 100644 --- a/lmdeploy/pytorch/engine/mp_engine/base_worker.py +++ b/lmdeploy/pytorch/engine/mp_engine/base_worker.py @@ -52,6 +52,11 @@ async def async_end(self, session_id: int): async with self.instance() as instance: return await instance.async_end(session_id) + async def async_start_session(self, session_id: int): + """Ensure session exists before streaming.""" + async with self.instance() as instance: + return await instance.async_start_session(session_id) + async def async_cancel(self, session_id: int): """Stop current streaming inference.""" async with self.instance() as instance: @@ -120,6 +125,10 @@ async def instance_async_end(self, session_id: int): """End the given session.""" return await self.instance_pool.async_end(session_id) + async def instance_async_start_session(self, session_id: int): + """Ensure session exists before streaming.""" + return await self.instance_pool.async_start_session(session_id) + async def instance_async_cancel(self, session_id: int): """Stop current streaming inference.""" return await self.instance_pool.async_cancel(session_id) diff --git a/lmdeploy/serve/core/async_engine.py b/lmdeploy/serve/core/async_engine.py index a259fbdd90..d9b22f399e 100644 --- a/lmdeploy/serve/core/async_engine.py +++ b/lmdeploy/serve/core/async_engine.py @@ -458,17 +458,21 @@ def is_error(status): self.session_mgr.remove(session) return async with session.request_handle() as handle: - if session.epoch is not None and session.epoch != self.epoch: - logger.info(f'[generate] session {session_id} got aborted before starting inference, ' - f'session.epoch={session.epoch}, async_engine.epoch={self.epoch}') - metrics_processor.increase_failed_requests('abort') - yield GenOut(response='', - history_token_len=0, - input_token_len=len(input_ids), - generate_token_len=0, - finish_reason='abort', - token_ids=[]) - return + # Serialize same-session lifecycle operations during startup only. + # Once request startup is complete, decode streaming remains lock-free. + async with session._lifecycle_lock: + if session.epoch is not None and session.epoch != self.epoch: + logger.info(f'[generate] session {session_id} got aborted before starting inference, ' + f'session.epoch={session.epoch}, async_engine.epoch={self.epoch}') + metrics_processor.increase_failed_requests('abort') + yield GenOut(response='', + history_token_len=0, + input_token_len=len(input_ids), + generate_token_len=0, + finish_reason='abort', + token_ids=[]) + return + await handle.async_start_session(session.session_id) token_ids = input_ids.copy() history_len = session.step input_len = len(input_ids) @@ -600,12 +604,9 @@ def is_error(status): # until the session is finished, i.e., session.request_handle() context exits. await handle.async_end(session.session_id) self.session_mgr.remove(session) - # if sequence_end: - # if self.backend == 'pytorch': - # # manually end pytorch session. session cannot be ended until session.request_handle() - # # context exits - # await session.async_close() - # self.session_mgr.remove(session) + # We cannot call end session after with session.request_handle(), because session.async_close() + # will try to request another free handle, which might be blocked if other sessions are waiting + # for the same request_handle_pool. def start_loop(self, loop, use_async_api=False): """Start engine loop. diff --git a/lmdeploy/serve/managers/session_manager.py b/lmdeploy/serve/managers/session_manager.py index 685631091f..d7edb37a3f 100644 --- a/lmdeploy/serve/managers/session_manager.py +++ b/lmdeploy/serve/managers/session_manager.py @@ -27,6 +27,9 @@ def __init__(self, session_id: int, session_mgr: SessionManager, **kwargs): # Set by api_server to AsyncEngine.epoch when a request binds a session; # generate() drops work if stop_all_session() bumped epoch after bind. self.epoch: int | None = None + # Serialize per-session lifecycle operations (generate startup, abort, close) + # without affecting cross-session concurrency. + self._lifecycle_lock = asyncio.Lock() # event to wait for the session to be active self._active: asyncio.Event | None = None self._handle = None # inference instance @@ -105,23 +108,31 @@ async def request_handle(self): async def async_abort(self): """Abort the session.""" - logger.debug(f'[session] Aborting session {self.session_id}, epoch={self.epoch}') - if self._handle is not None: - await self._handle.async_cancel(self.session_id) + async with self._lifecycle_lock: + logger.debug(f'[session] Aborting session {self.session_id}, epoch={self.epoch}') + # Session already closed/reset; treat as a benign no-op. + if self._session_mgr is None: + return + if self._handle is not None: + await self._handle.async_cancel(self.session_id) async def async_close(self): """End the session.""" - logger.info(f'[session] Ending session {self.session_id}') - if self._handle is None and self.step == 0: - return - if self._handle is not None: - await self._active.wait() - async with self.request_handle() as handle: - try: - await handle.async_end(self.session_id) - except (Exception, asyncio.CancelledError, GeneratorExit) as e: - logger.exception(f'[async_close] exception caught: {e}') - self.reset() + async with self._lifecycle_lock: + logger.info(f'[session] Ending session {self.session_id}') + # Already closed/reset; keep end idempotent. + if self._session_mgr is None: + return + if self._handle is None and self.step == 0: + return + if self._handle is not None: + await self._active.wait() + async with self.request_handle() as handle: + try: + await handle.async_end(self.session_id) + except (Exception, asyncio.CancelledError, GeneratorExit) as e: + logger.exception(f'[async_close] exception caught: {e}') + self.reset() def abort(self): """Abort the session in sync mode.""" diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index 5271262f7b..0353e4d7cd 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -642,6 +642,10 @@ def prepare_inputs(self, async def async_cancel(self, session_id: int = None): self.model_inst.cancel() + async def async_start_session(self, session_id: int): + """TurboMind does not require an explicit add-session request.""" + return None + def async_end_cb(self, fut: asyncio.Future, status: int): """Executing on engine's signaling thread.""" logger.info(f'[async_end_cb] session ended, status = {status}') @@ -735,8 +739,26 @@ async def async_stream_infer(self, sem = StreamingSemaphore() signal_cb = partial(self.async_signal_cb, sem) + logits_cb = None + ppl_state = None + if gen_config.output_ppl: + ppl_state = [0.0, 0] # [accumulated_loss, accumulated_count] + + def logits_cb(logits_chunk, vocab_size, begin, count): + targets = torch.tensor( + input_ids[begin + 1:begin + count + 1], + device=logits_chunk.device) + valid = min(count, len(targets)) + if valid > 0: + loss = torch.nn.functional.cross_entropy( + logits_chunk[:valid].float(), targets[:valid], + reduction='sum') + ppl_state[0] += loss.item() + ppl_state[1] += valid + outputs, shared_state, metrics = self.model_inst.forward(inputs, session, gen_cfg, stream_output, - self.tm_model.engine_config.enable_metrics, signal_cb) + self.tm_model.engine_config.enable_metrics, signal_cb, + logits_cb=logits_cb) outputs = _tm_dict_to_torch_dict(outputs) @@ -774,6 +796,10 @@ async def async_stream_infer(self, for f in extra_fs: f(output, seq_len) + if finish and ppl_state is not None: + output.ppl_loss = ppl_state[0] + output.ppl_count = ppl_state[1] + prev_len = seq_len yield output @@ -828,5 +854,6 @@ def _get_generation_config(self, cfg: GenerationConfig): c.output_logprobs = cfg.logprobs if cfg.random_seed is not None: c.random_seed = cfg.random_seed - # print (c) + if cfg.output_ppl: + c.compute_ppl = True return c From a19a45237684a4d92348896865c517164df75f6e Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Wed, 22 Apr 2026 05:45:49 +0000 Subject: [PATCH 2/3] rollback ppl feature in turbomind --- lmdeploy/turbomind/turbomind.py | 27 ++------------------------- 1 file changed, 2 insertions(+), 25 deletions(-) diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index 0353e4d7cd..418857cf61 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -739,26 +739,8 @@ async def async_stream_infer(self, sem = StreamingSemaphore() signal_cb = partial(self.async_signal_cb, sem) - logits_cb = None - ppl_state = None - if gen_config.output_ppl: - ppl_state = [0.0, 0] # [accumulated_loss, accumulated_count] - - def logits_cb(logits_chunk, vocab_size, begin, count): - targets = torch.tensor( - input_ids[begin + 1:begin + count + 1], - device=logits_chunk.device) - valid = min(count, len(targets)) - if valid > 0: - loss = torch.nn.functional.cross_entropy( - logits_chunk[:valid].float(), targets[:valid], - reduction='sum') - ppl_state[0] += loss.item() - ppl_state[1] += valid - outputs, shared_state, metrics = self.model_inst.forward(inputs, session, gen_cfg, stream_output, - self.tm_model.engine_config.enable_metrics, signal_cb, - logits_cb=logits_cb) + self.tm_model.engine_config.enable_metrics, signal_cb) outputs = _tm_dict_to_torch_dict(outputs) @@ -796,10 +778,6 @@ def logits_cb(logits_chunk, vocab_size, begin, count): for f in extra_fs: f(output, seq_len) - if finish and ppl_state is not None: - output.ppl_loss = ppl_state[0] - output.ppl_count = ppl_state[1] - prev_len = seq_len yield output @@ -854,6 +832,5 @@ def _get_generation_config(self, cfg: GenerationConfig): c.output_logprobs = cfg.logprobs if cfg.random_seed is not None: c.random_seed = cfg.random_seed - if cfg.output_ppl: - c.compute_ppl = True + return c From f175f1acecd07bb328949a6d106703d4c3814814 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Wed, 22 Apr 2026 07:19:57 +0000 Subject: [PATCH 3/3] fix --- lmdeploy/serve/core/async_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lmdeploy/serve/core/async_engine.py b/lmdeploy/serve/core/async_engine.py index d9b22f399e..713d635b14 100644 --- a/lmdeploy/serve/core/async_engine.py +++ b/lmdeploy/serve/core/async_engine.py @@ -681,6 +681,7 @@ async def async_get_logits(self, async def _proc(session, i): async with session.request_handle() as handle: + await handle.async_start_session(session.session_id) input_len = len(input_ids[i]) # TODO(lvhan): Fix the ugly code later on max_new_tokens = 1 if self.backend == 'turbomind' else 0