diff --git a/lmdeploy/pytorch/engine/engine_instance.py b/lmdeploy/pytorch/engine/engine_instance.py index 217e1d4609..c08595deaf 100644 --- a/lmdeploy/pytorch/engine/engine_instance.py +++ b/lmdeploy/pytorch/engine/engine_instance.py @@ -106,11 +106,14 @@ def _lazy_create_ray_store(): try: _SHARED_STORE = ray.get_actor(name, namespace='lmdeploy') except ValueError: - _SHARED_STORE = ray.remote(num_cpus=0,)(SharedStore).options( - name=name, - namespace='lmdeploy', - lifetime='detached', - ).remote() + try: + _SHARED_STORE = ray.remote(num_cpus=0,)(SharedStore).options( + name=name, + namespace='lmdeploy', + lifetime='detached', + ).remote() + except ray.exceptions.ActorAlreadyExistsError: + _SHARED_STORE = ray.get_actor(name, namespace='lmdeploy') class EngineInstance(EngineInstanceBase): @@ -134,13 +137,19 @@ def __del__(self): """Destructor.""" self.engine.req_manager.senders.pop(self.req_sender.sender_id) - def _get_extra_outputs(self, resp: Response): + def _get_extra_outputs(self, resp: Response, num_all_ids: int): """Get extra outputs.""" outputs = dict(routed_experts=None) routed_experts = resp.data.get('routed_experts', None) if resp.data else None if routed_experts is not None and resp.type in [ResponseType.FINISH, ResponseType.CANCEL]: if self._enable_transfer_obj_ref: import ray + # validate experts + num_expected_experts = num_all_ids - 1 + if routed_experts.shape[0] != num_expected_experts: + logger.warning(f'Expected number of routed_experts: {num_expected_experts}, ' + f'but got {routed_experts.shape[0]}') + routed_experts = routed_experts[:num_expected_experts] key = ray.get(_SHARED_STORE.put.remote(routed_experts)) outputs['routed_experts'] = key else: @@ -169,6 +178,7 @@ async def async_stream_infer(self, gen_config: GenerationConfig = None, multimodal: InputMultiModalType = None, adapter_name: str = None, + notify_add_msg_func = None, **kwargs): """Send stream inference request. @@ -202,7 +212,12 @@ async def async_stream_infer(self, ) logger.debug(f'session[{session_id}] add message: num_input_ids={len(input_ids)}.') resp = self.req_sender.send_async(RequestType.ADD_MESSAGE, msg) + # notify add msg + if notify_add_msg_func is not None: + notify_add_msg_func() + output_offset = 0 + prompt_ids_len = len(input_ids) while True: resp = await self.req_sender.async_recv(resp, wait_main=True) @@ -210,8 +225,6 @@ async def async_stream_infer(self, cache_block_ids = resp.data.get('cache_block_ids', None) if resp.data else None req_metrics = resp.data.get('req_metrics', None) if resp.data else None logprobs = resp.data.pop('logprobs', None) if resp.data else None - extra_outputs = self._get_extra_outputs(resp) - routed_experts = extra_outputs.get('routed_experts', None) if resp.type == ResponseType.SUCCESS: token_ids = resp.data['token_ids'] @@ -221,19 +234,24 @@ async def async_stream_infer(self, token_ids[output_offset:].tolist(), cache_block_ids=cache_block_ids, req_metrics=req_metrics, - routed_experts=routed_experts, logprobs=logprobs) output_offset = len(token_ids) elif resp.type in (ResponseType.FINISH, ResponseType.CANCEL): resp_data = resp.data - if resp_data is None: + token_ids = [] + logits = None + if resp_data is not None: # request might be cancelled before any output - token_ids = [] - logits = None - else: - token_ids = resp_data['token_ids'][output_offset:].tolist() logits = resp_data.get('logits', None) - num_ids = len(token_ids) - output_offset + gen_token_ids = resp_data.get('token_ids', None) + if gen_token_ids is not None: + token_ids = gen_token_ids[output_offset:].tolist() + + num_ids = len(token_ids) + num_all_ids = prompt_ids_len + output_offset + num_ids + extra_outputs = self._get_extra_outputs(resp, num_all_ids) + routed_experts = extra_outputs.get('routed_experts', None) + logger.debug(f'session[{session_id}] finish: num_out_ids={num_ids}.') yield EngineOutput(resp.type, token_ids, diff --git a/lmdeploy/pytorch/engine/model_agent/agent.py b/lmdeploy/pytorch/engine/model_agent/agent.py index 55fb6e807b..f4c710d85d 100644 --- a/lmdeploy/pytorch/engine/model_agent/agent.py +++ b/lmdeploy/pytorch/engine/model_agent/agent.py @@ -577,7 +577,7 @@ def _prepare_inputs_prefill( # for second round chat self.step_inputs.reindex(delta) - if inputs.is_first_chunk: + if inputs.is_first_chunk or not inputs.is_chunk: self._prev_chunk_output = None # check long context diff --git a/lmdeploy/pytorch/engine/mp_engine/base.py b/lmdeploy/pytorch/engine/mp_engine/base.py index 2dfe423ed7..2d52ce1497 100644 --- a/lmdeploy/pytorch/engine/mp_engine/base.py +++ b/lmdeploy/pytorch/engine/mp_engine/base.py @@ -4,7 +4,7 @@ from dataclasses import dataclass, field from typing import Any -from lmdeploy.messages import ResponseType +from lmdeploy.messages import EngineOutput, ResponseType from lmdeploy.pytorch.disagg.conn.protocol import ( DistServeConnectionRequest, DistServeDropConnectionRequest, @@ -20,6 +20,7 @@ @dataclass class SessionState: is_exists: asyncio.Event = field(default_factory=asyncio.Event) + cancelled: bool = False class MPEngine(EngineBase): @@ -28,6 +29,7 @@ def __init__(self) -> None: """Initialize mp engine.""" self.session_states = defaultdict(SessionState) self.engine_config = self._collective_rpc('get_engine_config') + self.pending_cancel_sessions = set() def _collective_rpc(self, func, *args, **kwargs): """Collective rpc call.""" @@ -37,7 +39,7 @@ async def _collective_rpc_async(self, func, *args, **kwargs): """Collective rpc call.""" raise NotImplementedError('This method has not been implemented yet.') - async def _collective_rpc_streaming_async(self, func, *args, **kwargs): + async def _collective_rpc_streaming_async(self, func: str, sess_event: asyncio.Event, *args, **kwargs): """Collective rpc call.""" raise NotImplementedError('This method has not been implemented yet.') @@ -100,29 +102,41 @@ def __init__(self, engine: MPEngine): async def async_end(self, session_id: int): """End the given session.""" if session_id not in self.session_states: + self.engine.pending_cancel_sessions.discard(session_id) logger.warning(f'Session {session_id} not found when end session.') return ResponseType.SESSION_NOT_EXIST await self.session_states[session_id].is_exists.wait() ret = await self.engine._collective_rpc_async('instance_async_end', session_id) self.session_states.pop(session_id) + self.engine.pending_cancel_sessions.discard(session_id) return ret async def async_cancel(self, session_id: int): """Stop current streaming inference.""" if session_id not in self.session_states: - logger.warning(f'Session {session_id} not found when cancel session.') + logger.debug(f'Session {session_id} not found when cancel session.') return ResponseType.SESSION_NOT_EXIST - await self.session_states[session_id].is_exists.wait() + state = self.session_states[session_id] + self.engine.pending_cancel_sessions.add(session_id) + if not state.is_exists.is_set(): + logger.debug(f'Session {session_id} not started yet, recording pending cancel.') + state.cancelled = True + return ResponseType.SUCCESS return await self.engine._collective_rpc_async('instance_async_cancel', session_id) async def async_stream_infer(self, session_id: int, *args, **kwargs): """Send stream inference request.""" state = self.session_states[session_id] + if state.cancelled or session_id in self.engine.pending_cancel_sessions: + state.is_exists.set() + logger.debug(f'Session {session_id} canceld, async_stream_infer') + yield EngineOutput(ResponseType.CANCEL, []) + return kwargs['session_id'] = session_id - kwargs['notify_add_msg'] = True - generator = self.engine._collective_rpc_streaming_async('instance_async_stream_infer', *args, **kwargs) - # session should have been added - state.is_exists.set() + generator = self.engine._collective_rpc_streaming_async('instance_async_stream_infer', + state.is_exists, + *args, + **kwargs) async for result in generator: yield result diff --git a/lmdeploy/pytorch/engine/mp_engine/base_worker.py b/lmdeploy/pytorch/engine/mp_engine/base_worker.py index bc2076863a..48157a9478 100644 --- a/lmdeploy/pytorch/engine/mp_engine/base_worker.py +++ b/lmdeploy/pytorch/engine/mp_engine/base_worker.py @@ -151,7 +151,9 @@ def add(self, stream_id, result): def pop(self, stream_id, result): if not isinstance(result, EngineOutput): return result - output = self._output.pop(stream_id) + output = self._output.pop(stream_id, None) + if output is None: + return result result.token_ids = output.token_ids or [] result.logprobs = output.logprobs or None return result diff --git a/lmdeploy/pytorch/engine/mp_engine/ray_engine.py b/lmdeploy/pytorch/engine/mp_engine/ray_engine.py index 3c14d3fd0c..cec2785027 100644 --- a/lmdeploy/pytorch/engine/mp_engine/ray_engine.py +++ b/lmdeploy/pytorch/engine/mp_engine/ray_engine.py @@ -4,7 +4,7 @@ import ray from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy -from lmdeploy.messages import PytorchEngineConfig +from lmdeploy.messages import EngineOutput, PytorchEngineConfig, ResponseType from lmdeploy.pytorch import envs as _envs from lmdeploy.pytorch.ray import RayContext, get_device_str, get_resource_kwargs from lmdeploy.utils import get_logger @@ -40,17 +40,29 @@ async def _stream_task_wrapper(self, stream_id: int, init_event: asyncio.Event, """Create a stream task.""" method = getattr(self, func) event = self._stream_aiter[stream_id][0] + + # notify after add msg + def _notify_add_msg(): + nonlocal init_event + init_event.set() + + if func == 'instance_async_stream_infer': + kwargs['notify_add_msg_func'] = _notify_add_msg + + result = EngineOutput(ResponseType.INTERNAL_ENGINE_ERROR, []) try: generator = method(*args, **kwargs) - init_event.set() async for result in generator: self._engine_output_gather.add(stream_id, result) self._stream_aiter[stream_id][1] = (result, False) event.set() + except Exception: + logger.exception(f'Stream task {stream_id} failed.') finally: + if not init_event.is_set(): + init_event.set() self._stream_aiter[stream_id][1] = (result, True) event.set() - init_event.set() async def create_stream_task(self, func, *args, **kwargs): """Create a stream task.""" @@ -147,11 +159,11 @@ async def _collective_rpc_async(self, func, *args, **kwargs): method = getattr(self.worker, func) return await method.remote(*args, **kwargs) - async def _collective_rpc_streaming_async(self, func, *args, **kwargs): + async def _collective_rpc_streaming_async(self, func: str, sess_event: asyncio.Event, *args, **kwargs): """Collective rpc call.""" # ray generator would try cache every result, which is too verbose. stream_id = await self._collective_rpc_async('create_stream_task', func, *args, **kwargs) - + sess_event.set() stopped = False while not stopped: result, stopped = await self._collective_rpc_async('get_stream_task_result', stream_id) diff --git a/lmdeploy/pytorch/engine/mp_engine/zmq_engine.py b/lmdeploy/pytorch/engine/mp_engine/zmq_engine.py index 760d17aeb0..712b69eca8 100644 --- a/lmdeploy/pytorch/engine/mp_engine/zmq_engine.py +++ b/lmdeploy/pytorch/engine/mp_engine/zmq_engine.py @@ -172,9 +172,9 @@ async def _collective_rpc_async(self, func, *args, **kwargs): """Collective rpc call.""" return await self.rpc_client.async_call(func, *args, **kwargs) - async def _collective_rpc_streaming_async(self, func, *args, **kwargs): + async def _collective_rpc_streaming_async(self, func: str, sess_event: asyncio.Event, *args, **kwargs): """Collective rpc call.""" - async for out in self.rpc_client.async_stream_call(func, *args, **kwargs): + async for out in self.rpc_client.async_stream_call(func, sess_event, *args, **kwargs): yield out def close(self) -> None: diff --git a/lmdeploy/pytorch/engine/mp_engine/zmq_rpc.py b/lmdeploy/pytorch/engine/mp_engine/zmq_rpc.py index 5d43f5ccf5..5a5870fad0 100644 --- a/lmdeploy/pytorch/engine/mp_engine/zmq_rpc.py +++ b/lmdeploy/pytorch/engine/mp_engine/zmq_rpc.py @@ -303,10 +303,10 @@ async def async_call(self, method, *args, **kwargs): """Async call.""" return await self._async_call_impl(method, False, *args, **kwargs) - async def async_stream_call(self, method, *args, **kwargs): + async def async_stream_call(self, method, sess_event: asyncio.Event, *args, **kwargs): """Streaming call.""" stream_id = await self._async_call_impl(method, True, *args, **kwargs) - + sess_event.set() stopped = False while not stopped: output, stopped = await self.async_call('_asyncrpcserver_get_stream_output', stream_id) diff --git a/lmdeploy/pytorch/engine/request.py b/lmdeploy/pytorch/engine/request.py index 86e37fa28c..004341c9d2 100644 --- a/lmdeploy/pytorch/engine/request.py +++ b/lmdeploy/pytorch/engine/request.py @@ -170,8 +170,8 @@ def __init__(self): self.senders: dict[int, RequestSender] = dict() self.callbacks: dict[RequestType, Callable] = dict() self.request_priority: list[RequestType] = [ - RequestType.STOP_ENGINE, RequestType.ADD_SESSION, RequestType.STOP_SESSION, RequestType.END_SESSION, - RequestType.ADD_MESSAGE + RequestType.STOP_ENGINE, RequestType.ADD_SESSION, RequestType.ADD_MESSAGE, + RequestType.STOP_SESSION, RequestType.END_SESSION, ] self.requests: asyncio.Queue = None self._loop_task: asyncio.Future = None diff --git a/lmdeploy/pytorch/ray.py b/lmdeploy/pytorch/ray.py index bb575df98d..330ed7077a 100644 --- a/lmdeploy/pytorch/ray.py +++ b/lmdeploy/pytorch/ray.py @@ -100,14 +100,16 @@ def init_ray_cluster(world_size: int, ray_address: str = None, dp: int = 1, devi try: num_cpus = world_size object_store_memory = _get_obj_store_memory(dp=dp) - ray.init(address=ray_address, + ctx = ray.init(address=ray_address, ignore_reinit_error=True, num_cpus=num_cpus, object_store_memory=object_store_memory) + logger.info(f'Ray initialized with address: {ctx.address_info["address"]}') except ValueError as e: if e.args is not None and len(e.args) >= 1 and e.args[ 0] == 'When connecting to an existing cluster, num_cpus and num_gpus must not be provided.': - ray.init(address=ray_address, ignore_reinit_error=True) + ctx = ray.init(address=ray_address, ignore_reinit_error=True) + logger.info(f'Ray initialized with address: {ctx.address_info["address"]}') else: raise diff --git a/lmdeploy/pytorch/spec_decode/spec_agent.py b/lmdeploy/pytorch/spec_decode/spec_agent.py index 8a0e2ba64a..cfdfbaca21 100644 --- a/lmdeploy/pytorch/spec_decode/spec_agent.py +++ b/lmdeploy/pytorch/spec_decode/spec_agent.py @@ -178,6 +178,8 @@ def _prepare_inputs_from_main(self, model_inputs: ModelInputs, extra_inputs: Ext history_lengths = model_inputs.history_lengths.clone() if not model_inputs.is_chunk: + # clear each time + self._prev_chunk_last.clear() # Case A: non-chunked — shift left by 1, place next_token at end input_ids = model_inputs.input_ids.clone() input_ids[:, :-1] = model_inputs.input_ids[:, 1:] @@ -192,6 +194,8 @@ def _prepare_inputs_from_main(self, model_inputs: ModelInputs, extra_inputs: Ext else: if model_inputs.is_first_chunk: + # clear each time + self._prev_chunk_last.clear() # Case B: first chunk — skip first token, save last for next chunk input_ids = model_inputs.input_ids[:, 1:] seq_length = model_inputs.seq_length - 1 diff --git a/lmdeploy/serve/managers/session_manager.py b/lmdeploy/serve/managers/session_manager.py index 685631091f..b9a264233c 100644 --- a/lmdeploy/serve/managers/session_manager.py +++ b/lmdeploy/serve/managers/session_manager.py @@ -221,8 +221,6 @@ async def async_abort_all(self): for session in list(self.sessions.values()): tasks.append(session.async_abort()) await asyncio.gather(*tasks, return_exceptions=True) - # "abort all" is designed for async RL. The aborted sessions will be no longer used, - # so we clear the sessions here. self.sessions.clear() def has(self, session_id): diff --git a/tests/pytorch/engine/test_zmq_rpc.py b/tests/pytorch/engine/test_zmq_rpc.py index 2be0d2fb11..a6e24eb4f4 100644 --- a/tests/pytorch/engine/test_zmq_rpc.py +++ b/tests/pytorch/engine/test_zmq_rpc.py @@ -48,7 +48,7 @@ async def async_main(self, port): result = await client.async_call('method', 'test1') assert result == 'test1: method' - async for result in client.async_stream_call('streaming_method', 'test3'): + async for result in client.async_stream_call('streaming_method', asyncio.Event(), 'test3'): pass assert result == 'test3: streaming method 2'