Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,13 +443,13 @@ def update_params(self, request: Any):
"""Update params."""
self.executor.update_params(request)

def sleep(self, level: int = 1):
async def sleep(self, level: int = 1):
"""Sleep."""
self.executor.sleep(level)
await self.executor.sleep(level)

def wakeup(self, tags: list[str] | None = None):
async def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
self.executor.wakeup(tags)
await self.executor.wakeup(tags)

async def async_loop(self):
engine_loop = None
Expand Down
10 changes: 10 additions & 0 deletions lmdeploy/pytorch/engine/engine_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,16 @@ def __get_logprobs(batched_outputs: 'BatchedOutputs'):
seq.append_logits(logits)
return dict()

engine_error_msg = getattr(batched_outputs, 'engine_error_msg', None)
if engine_error_msg:
for msg in running:
if msg.status != MessageStatus.RUNNING:
continue
response_reqs(self.req_manager, msg.resp, ResponseType.INTERNAL_ENGINE_ERROR,
data=dict(token_ids=[]), err_msg=engine_error_msg)
msg.state.finish()
return dict()

new_token_timestamp = batched_outputs.new_token_timestamp
logprobs = batched_outputs.logprobs

Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/engine/executor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async def sleep(self, level: int = 1):
"""Sleep."""
raise NotImplementedError('Not Implemented.')

def wakeup(self, tags: list[str] | None = None):
async def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
raise NotImplementedError('Not Implemented.')

Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/engine/executor/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async def sleep(self, level: int = 1):
"""Sleep."""
await self.model_agent.sleep(level)

def wakeup(self, tags: list[str] | None = None):
async def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
self.model_agent.wakeup(tags)

Expand Down
8 changes: 8 additions & 0 deletions lmdeploy/pytorch/engine/executor/mp_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,14 @@ def warmup(self):
"""Build cache engine."""
self.collective_rpc('warmup')

async def sleep(self, level: int = 1):
"""Sleep."""
await self.collective_rpc_async('sleep', args=(level, ))

async def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
await self.collective_rpc_async('wakeup', args=(tags, ))

async def _prefetch_outputs(self):
while True:
out = (await self.collective_rpc_async('get_outputs', receiver_mask=1, return_mask=1))[0]
Expand Down
10 changes: 8 additions & 2 deletions lmdeploy/pytorch/engine/executor/ray_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,12 +353,18 @@ def warmup(self):
"""Build cache engine."""
self.collective_rpc('warmup')

def sleep(self, level: int = 1):
async def sleep(self, level: int = 1):
"""Sleep."""
await asyncio.to_thread(self._sleep_collective_rpc, level)

def _sleep_collective_rpc(self, level: int):
self.collective_rpc('sleep', (level, ))

def wakeup(self, tags: list[str] | None = None):
async def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
await asyncio.to_thread(self._wakeup_collective_rpc, tags)

def _wakeup_collective_rpc(self, tags: list[str] | None):
if tags is None or 'kv_cache' in tags:
self.update_configs()
self.collective_rpc('wakeup', (tags, ))
Expand Down
8 changes: 8 additions & 0 deletions lmdeploy/pytorch/engine/executor/uni_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ async def get_output_async(self, dp_rank: int = 0):
assert dp_rank == 0
return await self.model_agent.get_output_async()

async def sleep(self, level: int = 1):
"""Sleep."""
await self.model_agent.sleep(level)

async def wakeup(self, tags: list[str] | None = None):
"""Wakeup on the event-loop thread (CUDA-safe; may block the loop)."""
self.model_agent.wakeup(tags)

def get_input_processor(self):
"""Get input processor."""
return self.model_agent.get_input_processor()
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/engine/model_agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from lmdeploy.pytorch.devices import DeviceContext, get_device_manager
from lmdeploy.pytorch.distributed import DistContext, get_dist_manager

from .agent import BaseModelAgent, BatchedOutputs # noqa: F401
from .agent import BaseModelAgent, BatchedOutputs, CacheNotReadyError # noqa: F401


def build_model_agent(
Expand Down
54 changes: 52 additions & 2 deletions lmdeploy/pytorch/engine/model_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@
logger = get_logger('lmdeploy')


class CacheNotReadyError(RuntimeError):
"""Raised when a forward runs while KV/state cache engines are missing."""

pass


@dataclass
class SleepWakeupState:
to_sleep: asyncio.Event = field(default_factory=asyncio.Event)
Expand Down Expand Up @@ -82,6 +88,7 @@ class BatchedOutputs:
new_token_timestamp: int = 0
extra_outputs: ExtraOutputs | None = None
all_routed_experts: torch.Tensor | None = None
engine_error_msg: str | None = None

def to_cpu(self):
"""To cpu."""
Expand Down Expand Up @@ -128,11 +135,15 @@ def msg_with_rank(rank: int, msg: str):
return f'rank[{rank}] - {msg}'


def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict, swap_out_map: dict):
def cache_swapping(cache_engine: CacheEngine | None, swap_in_map: dict, swap_out_map: dict):
"""Perform cache swapping."""
issued_cache_op = False
swap_in_map = swap_in_map or dict()
swap_out_map = swap_out_map or dict()
if cache_engine is None and (len(swap_in_map) > 0 or len(swap_out_map) > 0):
raise CacheNotReadyError(
'KV cache is not available; cannot swap blocks. '
"Restore cache via wakeup with the 'kv_cache' tag before inference.")
if len(swap_in_map) > 0:
cache_engine.swap_in(swap_in_map)
issued_cache_op = True
Expand Down Expand Up @@ -568,6 +579,27 @@ def _push_output(self, output: BatchedOutputs):
event.record()
self._out_que.put_nowait((output, event))

def _batched_outputs_for_cache_error(self, forward_inputs: dict, err_msg: str) -> BatchedOutputs:
"""Build a minimal batch for get_output_async pairing."""
inputs = forward_inputs.get('inputs')
delta = forward_inputs.get('delta')
device = self.device
if inputs is not None:
batch_size = int(inputs.seq_length.size(0))
elif delta is not None:
batch_size = int(delta.block_offsets.size(0))
else:
batch_size = 1
next_token_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=device)
stopped = torch.ones(batch_size, dtype=torch.bool, device=device)
stop_pos = torch.zeros(batch_size, dtype=torch.long, device=device)
return BatchedOutputs(
next_token_ids=next_token_ids,
stopped=stopped,
stop_pos=stop_pos,
engine_error_msg=err_msg,
)

@contextmanager
def _broadcast_next_token(self, next_token_ids: torch.Tensor, extra_inputs: ExtraInputs, enable: bool = True):
if not enable:
Expand Down Expand Up @@ -644,6 +676,10 @@ def _get_inputs_from_delta(
sampling_inputs: SamplingInputs,
):
"""Get inputs from delta."""
if self.step_inputs.model_inputs is None:
raise CacheNotReadyError(
'Decode step has no cached ModelInputs (e.g. after a KV-cache error or reset). '
"Call wakeup with the 'kv_cache' tag before continuing inference.")
self.step_inputs.update_delta(delta, self)
inputs = self.step_inputs.model_inputs
extra_inputs = self.step_inputs.extra_inputs
Expand All @@ -661,6 +697,10 @@ def _prepare_inputs_prefill(
if delta is not None:
# update decoding inputs with delta
# for second round chat
if self.step_inputs.model_inputs is None:
raise CacheNotReadyError(
'Prefill with delta requires prior ModelInputs (e.g. after a KV-cache error). '
"Call wakeup with the 'kv_cache' tag before continuing inference.")
self.step_inputs.update_delta(delta, self)

if inputs.is_first_chunk:
Expand Down Expand Up @@ -836,6 +876,11 @@ def __update_inputs(
await asyncio.sleep(0.01)
return

if self.cache_engine is None or self.state_cache_engine is None:
raise CacheNotReadyError(
'KV or state cache engine is not built (e.g. after sleep or partial wakeup). '
"Call wakeup with the 'kv_cache' tag before running inference.")

# swap caches
cache_swapping(self.cache_engine, swap_in_map=swap_in_map, swap_out_map=swap_out_map)

Expand Down Expand Up @@ -936,7 +981,12 @@ async def _async_loop_background(self, forward_event: asyncio.Event = None):
while True:
forward_inputs = await input_maker.get()

await self._async_step(**forward_inputs, )
try:
await self._async_step(**forward_inputs, )
except CacheNotReadyError as err:
logger.warning('Forward skipped: %s', err)
self.step_inputs = StepInputs()
self._push_output(self._batched_outputs_for_cache_error(forward_inputs, str(err)))
if forward_event is not None:
forward_event.set()

Expand Down
8 changes: 4 additions & 4 deletions lmdeploy/pytorch/engine/mp_engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ def end_session(self, session_id: int):
"""End session."""
return self._collective_rpc('end_session', session_id)

def sleep(self, level: int):
async def sleep(self, level: int):
"""sleep."""
return self._collective_rpc('sleep', level)
return await self._collective_rpc_async('sleep', level)

def wakeup(self, tags: list[str] | None = None):
async def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
return self._collective_rpc('wakeup', tags)
return await self._collective_rpc_async('wakeup', tags)

def update_params(self, request: Any):
"""Update params."""
Expand Down
8 changes: 4 additions & 4 deletions lmdeploy/pytorch/engine/mp_engine/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,13 @@ def p2p_drop_connect(self, drop_conn_request: DistServeDropConnectionRequest):
"""
return self.engine.p2p_drop_connect(drop_conn_request)

def sleep(self, level: int = 1):
async def sleep(self, level: int = 1):
"""sleep."""
return self.engine.sleep(level)
return await self.engine.sleep(level)

def wakeup(self, tags: list[str] | None = None):
async def wakeup(self, tags: list[str] | None = None):
"""Wakeup."""
return self.engine.wakeup(tags)
return await self.engine.wakeup(tags)

def update_params(self, request: Any):
"""Update params."""
Expand Down
48 changes: 39 additions & 9 deletions lmdeploy/serve/core/async_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,23 @@ def _build_stat_loggers(self):
# set stats loggers of metrics processor
metrics_processor.stat_loggers = self.stat_loggers

def _if_session_stale(self, session: Session,
input_token_len: int) -> GenOut | None:
"""If session is stamped ``http_bind_epoch`` by api_server and
Copy link

Copilot AI Apr 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring refers to http_bind_epoch, but the stamped attribute is session.epoch (set in api_server.VariableInterface.get_session). Consider updating the wording to avoid suggesting a non-existent field/name.

Suggested change
"""If session is stamped ``http_bind_epoch`` by api_server and
"""If api_server stamped the session's ``epoch`` and

Copilot uses AI. Check for mistakes.
``stop_all_session`` ran since then (epoch changed), drop the
session."""
epoch = session.epoch
if epoch is None or epoch == self.epoch:
return None
logger.info(
f'[generate] session {session.session_id} dropped (session.epoch={epoch}, epoch={self.epoch})')
return GenOut(response='',
history_token_len=session.step,
input_token_len=input_token_len,
generate_token_len=0,
finish_reason='abort',
token_ids=[])

def get_schedule_metrics(self):
return self.engine.get_schedule_metrics()

Expand All @@ -212,23 +229,25 @@ async def do_log_stats(self):

async def stop_all_session(self):
"""Stop all running sessions."""
logger.info('stop all sessions')
logger.info(f'stop all sessions, epoch {self.epoch} -> {self.epoch + 1}')
self.epoch += 1
await self.session_mgr.async_abort_all()

def sleep(self, level: int = 1):
async def sleep(self, level: int = 1):
"""Sleep the model.

Args:
level (int): The sleep level. Level 1 sleep will offload the model
weights and discard the kv cache. Level 2 sleep will
discard both the model weights and the kv cache.
"""
self.engine.sleep(level)
logger.info(f'[async_engine]sleep, level={level}')
await self.engine.sleep(level)
self.sleeping_tags = {'weights', 'kv_cache'}
self.is_sleeping = True
logger.info('[async_engine] sleep, done')

def wakeup(self, tags: list[str] | None = None):
async def wakeup(self, tags: list[str] | None = None):
"""Wake up the model.

Args:
Expand All @@ -242,7 +261,7 @@ def wakeup(self, tags: list[str] | None = None):
if any(tag not in self.sleeping_tags for tag in tags):
logger.warning(f'some tag in {tags} not in sleeping tags {self.sleeping_tags}')
return
self.engine.wakeup(tags)
await self.engine.wakeup(tags)
# for TM backend, sleep/wakeup will reset gateway, therefore we need to rebuild instances
if self.backend == 'turbomind' and 'kv_cache' in tags:
self.session_mgr.build_request_handle_pool(self.engine, self.backend_config.max_batch_size)
Expand Down Expand Up @@ -339,7 +358,8 @@ async def generate(
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
"""
epoch = self.epoch
metrics_processor.increase_total_requests()

if (messages is not None) ^ (input_ids is None):
raise ValueError('You must specify exactly one of messages or input_ids')
if isinstance(session_id, Session):
Expand Down Expand Up @@ -386,6 +406,7 @@ async def generate(

if gen_config.max_new_tokens == 0:
logger.info(f'run out of tokens. session={session_id}.')
metrics_processor.increase_failed_requests('error')
yield GenOut(response='',
history_token_len=session.step,
input_token_len=len(input_ids),
Expand All @@ -400,6 +421,7 @@ async def generate(
or gen_config.output_logits == 'all'):
errmsg = ('lmdeploy does not support outputting all token\'s logits or last_hidden_state '
'when prefix caching is ON')
metrics_processor.increase_failed_requests('error')
yield GenOut(response=errmsg,
history_token_len=session.step,
input_token_len=len(input_ids),
Expand All @@ -421,10 +443,18 @@ def is_error(status):
if not gen_config.ignore_eos:
stop_ids = gen_config.stop_token_ids or []

metrics_processor.increase_total_requests()

stale = self._if_session_stale(session, len(prompt_input['input_ids']))
if stale is not None:
metrics_processor.increase_failed_requests('abort')
yield stale
if sequence_end:
self.session_mgr.remove(session)
return
async with session.request_handle() as handle:
if epoch != self.epoch:
logger.info(f'[generate] session {session_id} got aborted before starting inference')
if session.epoch is not None and session.epoch != self.epoch:
logger.warning(f'[generate] session {session_id} got aborted before starting inference, '
f'session.epoch={session.epoch}, epoch={self.epoch}')
metrics_processor.increase_failed_requests('abort')
yield GenOut(response='',
history_token_len=0,
Expand Down
Loading
Loading