Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions lmdeploy/pytorch/engine/engine_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,14 @@ def _lazy_create_ray_store():
try:
_SHARED_STORE = ray.get_actor(name, namespace='lmdeploy')
except ValueError:
_SHARED_STORE = ray.remote(num_cpus=0,)(SharedStore).options(
name=name,
namespace='lmdeploy',
lifetime='detached',
).remote()
try:
_SHARED_STORE = ray.remote(num_cpus=0,)(SharedStore).options(
name=name,
namespace='lmdeploy',
lifetime='detached',
).remote()
except ray.exceptions.ActorAlreadyExistsError:
_SHARED_STORE = ray.get_actor(name, namespace='lmdeploy')


class EngineInstance(EngineInstanceBase):
Expand Down Expand Up @@ -169,6 +172,7 @@ async def async_stream_infer(self,
gen_config: GenerationConfig = None,
multimodal: InputMultiModalType = None,
adapter_name: str = None,
notify_add_msg_func = None,
**kwargs):
"""Send stream inference request.

Expand Down Expand Up @@ -202,6 +206,10 @@ async def async_stream_infer(self,
)
logger.debug(f'session[{session_id}] add message: num_input_ids={len(input_ids)}.')
resp = self.req_sender.send_async(RequestType.ADD_MESSAGE, msg)
# notify add msg
if notify_add_msg_func is not None:
notify_add_msg_func()

output_offset = 0

while True:
Expand Down
10 changes: 5 additions & 5 deletions lmdeploy/pytorch/engine/mp_engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ async def _collective_rpc_async(self, func, *args, **kwargs):
"""Collective rpc call."""
raise NotImplementedError('This method has not been implemented yet.')

async def _collective_rpc_streaming_async(self, func, *args, **kwargs):
async def _collective_rpc_streaming_async(self, func: str, sess_event: asyncio.Event, *args, **kwargs):
"""Collective rpc call."""
raise NotImplementedError('This method has not been implemented yet.')

Expand Down Expand Up @@ -119,10 +119,10 @@ async def async_stream_infer(self, session_id: int, *args, **kwargs):
"""Send stream inference request."""
state = self.session_states[session_id]
kwargs['session_id'] = session_id
kwargs['notify_add_msg'] = True
generator = self.engine._collective_rpc_streaming_async('instance_async_stream_infer', *args, **kwargs)
# session should have been added
state.is_exists.set()
generator = self.engine._collective_rpc_streaming_async('instance_async_stream_infer',
state.is_exists,
*args,
**kwargs)

async for result in generator:
yield result
4 changes: 3 additions & 1 deletion lmdeploy/pytorch/engine/mp_engine/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def add(self, stream_id, result):
def pop(self, stream_id, result):
if not isinstance(result, EngineOutput):
return result
output = self._output.pop(stream_id)
output = self._output.pop(stream_id, None)
if output is None:
return result
result.token_ids = output.token_ids or []
result.logprobs = output.logprobs or None
return result
22 changes: 17 additions & 5 deletions lmdeploy/pytorch/engine/mp_engine/ray_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import ray
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

from lmdeploy.messages import PytorchEngineConfig
from lmdeploy.messages import EngineOutput, PytorchEngineConfig, ResponseType
from lmdeploy.pytorch import envs as _envs
from lmdeploy.pytorch.ray import RayContext, get_device_str, get_resource_kwargs
from lmdeploy.utils import get_logger
Expand Down Expand Up @@ -40,17 +40,29 @@ async def _stream_task_wrapper(self, stream_id: int, init_event: asyncio.Event,
"""Create a stream task."""
method = getattr(self, func)
event = self._stream_aiter[stream_id][0]

# notify after add msg
def _notify_add_msg():
nonlocal init_event
init_event.set()

if func == 'async_stream_infer':
Comment thread
RunningLeon marked this conversation as resolved.
Outdated
kwargs['notify_add_msg_func'] = _notify_add_msg

result = EngineOutput(ResponseType.INTERNAL_ENGINE_ERROR, [])
try:
generator = method(*args, **kwargs)
init_event.set()
async for result in generator:
self._engine_output_gather.add(stream_id, result)
self._stream_aiter[stream_id][1] = (result, False)
event.set()
except Exception:
logger.exception(f'Stream task {stream_id} failed.')
finally:
if not init_event.is_set():
init_event.set()
self._stream_aiter[stream_id][1] = (result, True)
event.set()
init_event.set()

async def create_stream_task(self, func, *args, **kwargs):
"""Create a stream task."""
Expand Down Expand Up @@ -147,11 +159,11 @@ async def _collective_rpc_async(self, func, *args, **kwargs):
method = getattr(self.worker, func)
return await method.remote(*args, **kwargs)

async def _collective_rpc_streaming_async(self, func, *args, **kwargs):
async def _collective_rpc_streaming_async(self, func: str, sess_event: asyncio.Event, *args, **kwargs):
"""Collective rpc call."""
# ray generator would try cache every result, which is too verbose.
stream_id = await self._collective_rpc_async('create_stream_task', func, *args, **kwargs)

sess_event.set()
stopped = False
while not stopped:
result, stopped = await self._collective_rpc_async('get_stream_task_result', stream_id)
Expand Down
4 changes: 2 additions & 2 deletions lmdeploy/pytorch/engine/mp_engine/zmq_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ async def _collective_rpc_async(self, func, *args, **kwargs):
"""Collective rpc call."""
return await self.rpc_client.async_call(func, *args, **kwargs)

async def _collective_rpc_streaming_async(self, func, *args, **kwargs):
async def _collective_rpc_streaming_async(self, func: str, sess_event: asyncio.Event, *args, **kwargs):
"""Collective rpc call."""
async for out in self.rpc_client.async_stream_call(func, *args, **kwargs):
async for out in self.rpc_client.async_stream_call(func, sess_event, *args, **kwargs):
yield out

def close(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions lmdeploy/pytorch/engine/mp_engine/zmq_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,10 @@ async def async_call(self, method, *args, **kwargs):
"""Async call."""
return await self._async_call_impl(method, False, *args, **kwargs)

async def async_stream_call(self, method, *args, **kwargs):
async def async_stream_call(self, method, sess_event: asyncio.Event, *args, **kwargs):
"""Streaming call."""
stream_id = await self._async_call_impl(method, True, *args, **kwargs)

sess_event.set()
stopped = False
Comment thread
RunningLeon marked this conversation as resolved.
while not stopped:
output, stopped = await self.async_call('_asyncrpcserver_get_stream_output', stream_id)
Expand Down
6 changes: 4 additions & 2 deletions lmdeploy/pytorch/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,16 @@ def init_ray_cluster(world_size: int, ray_address: str = None, dp: int = 1, devi
try:
num_cpus = world_size
object_store_memory = _get_obj_store_memory(dp=dp)
ray.init(address=ray_address,
ctx = ray.init(address=ray_address,
ignore_reinit_error=True,
num_cpus=num_cpus,
object_store_memory=object_store_memory)
logger.info(f'Ray initialized with address: {ctx.address_info["address"]}')
except ValueError as e:
if e.args is not None and len(e.args) >= 1 and e.args[
0] == 'When connecting to an existing cluster, num_cpus and num_gpus must not be provided.':
ray.init(address=ray_address, ignore_reinit_error=True)
ctx = ray.init(address=ray_address, ignore_reinit_error=True)
logger.info(f'Ray initialized with address: {ctx.address_info["address"]}')
else:
raise

Expand Down
Loading