-
Notifications
You must be signed in to change notification settings - Fork 690
Map user-input session_id to internal session_id to maintain session identity #4523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -196,13 +196,34 @@ def __init__(self): | |
| """Initialize the session manager.""" | ||
|
|
||
| self.sessions = {} | ||
| self.session_id_generator = itertools.count(1) | ||
| self.session_id_generator = itertools.count(0) | ||
| self.request_handle_pool = None | ||
| self.loop = None | ||
| # user_session_id->session_id. If user specifies | ||
| # a session_id when visiting the api_server's endpoint, | ||
| # we map the user_session_id to the session_id to keep | ||
| # session's id globally identical across different requests. | ||
| self.user_session_id_map = {} | ||
| # session_id->user_session_id map. | ||
| self.session_id_map = {} | ||
|
|
||
| def map_user_session_id(self, user_session_id: int) -> int: | ||
| """Map a user_session_id to a session_id.""" | ||
| if user_session_id in self.user_session_id_map: | ||
| raise ValueError(f'User session id {user_session_id} already exists') | ||
| session_id = next(self.session_id_generator) | ||
| self.user_session_id_map[user_session_id] = session_id | ||
| self.session_id_map[session_id] = user_session_id | ||
| return session_id | ||
|
Comment on lines
+210
to
+217
|
||
|
|
||
| def get(self, session_id: int | None = None, create_if_not_exists: bool = True, **kwargs) -> Session | None: | ||
| """Get or create a session.""" | ||
| if not create_if_not_exists: | ||
| return self.sessions.get(session_id, None) | ||
|
lvhan028 marked this conversation as resolved.
|
||
|
|
||
| if session_id is None: | ||
| session_id = next(self.session_id_generator) | ||
|
|
||
| def get(self, session_id: int | None = None, **kwargs) -> Session: | ||
| """Create a new session.""" | ||
| session_id = session_id or next(self.session_id_generator) | ||
| if session_id in self.sessions: | ||
| logger.debug(f'[SessionManager] session {session_id} already exists. Updating...') | ||
| session = self.sessions[session_id] | ||
|
|
@@ -224,17 +245,24 @@ async def async_abort_all(self): | |
| # "abort all" is designed for async RL. The aborted sessions will be no longer used, | ||
| # so we clear the sessions here. | ||
| self.sessions.clear() | ||
| self.user_session_id_map.clear() | ||
| self.session_id_map.clear() | ||
|
|
||
| def has(self, session_id): | ||
| return session_id in self.sessions | ||
|
|
||
| def remove(self, session: Session): | ||
| self.sessions.pop(session.session_id, None) | ||
| user_session_id = self.session_id_map.pop(session.session_id, None) | ||
| if user_session_id is not None: | ||
| self.user_session_id_map.pop(user_session_id, None) | ||
|
|
||
| def clear(self): | ||
| self.sessions.clear() | ||
| self.user_session_id_map.clear() | ||
| self.session_id_map.clear() | ||
| # reset the session id generator | ||
| self.session_id_generator = itertools.count(1) | ||
| self.session_id_generator = itertools.count(0) | ||
|
|
||
| def attach_event_loop(self, loop): | ||
| self.loop = loop | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -101,24 +101,40 @@ class VariableInterface: | |
| enable_abort_handling: bool = False | ||
| response_parser_cls: type[ResponseParser] | None = None | ||
|
|
||
| @staticmethod | ||
| def get_session(session_id: int) -> Session: | ||
| session_mgr = VariableInterface.get_session_manager() | ||
| if session_id == -1: | ||
| @classmethod | ||
| def create_session(cls, user_session_id: int | None = None) -> Session: | ||
| session_mgr = cls.get_session_manager() | ||
| if user_session_id is None or user_session_id == -1: | ||
| # user doesn't input session_id, so we need to generate a new one | ||
| session = session_mgr.get() | ||
| else: | ||
| # find the inside session_id by user_session_id, create a new one | ||
| # if it doesn't exist and update the user_session_id_map | ||
| session_id = session_mgr.map_user_session_id(user_session_id) | ||
| session = session_mgr.get(session_id) | ||
| # Stamp epoch for ``stop_all_session`` / ``abort_all`` coordination in ``AsyncEngine.generate``. | ||
| session.epoch = VariableInterface.async_engine.epoch | ||
| session.epoch = cls.async_engine.epoch | ||
| return session | ||
|
|
||
| @staticmethod | ||
| def get_session_manager(): | ||
| return VariableInterface.async_engine.session_mgr | ||
| @classmethod | ||
| def find_session(cls, user_session_id: int) -> Session | None: | ||
| """Find the session by user_session_id. | ||
|
|
||
| Users cannot access inner session_id directly. | ||
| """ | ||
| session_mgr = cls.get_session_manager() | ||
| session_id = session_mgr.user_session_id_map.get(user_session_id, None) | ||
| if session_id is None: | ||
| return None | ||
| return session_mgr.get(session_id, create_if_not_exists=False) | ||
|
|
||
| @staticmethod | ||
| def get_engine_config(): | ||
| return VariableInterface.async_engine.backend_config | ||
| @classmethod | ||
| def get_session_manager(cls): | ||
| return cls.async_engine.session_mgr | ||
|
|
||
| @classmethod | ||
| def get_engine_config(cls): | ||
| return cls.async_engine.backend_config | ||
|
|
||
|
|
||
| router = APIRouter() | ||
|
|
@@ -360,7 +376,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque | |
| error_check_ret = check_request(request) | ||
| if error_check_ret is not None: | ||
| return error_check_ret | ||
| session = VariableInterface.get_session(request.session_id) | ||
| session = VariableInterface.create_session(request.session_id) | ||
|
||
|
|
||
| json_request = await raw_request.json() | ||
| migration_request = json_request.pop('migration_request', None) | ||
|
|
@@ -688,10 +704,10 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None | |
| sessions = [] | ||
| if isinstance(request.prompt, str): | ||
| request.prompt = [request.prompt] | ||
| sessions.append(VariableInterface.get_session(request.session_id)) | ||
| sessions.append(VariableInterface.create_session(request.session_id)) | ||
| elif isinstance(request.prompt, list): | ||
| for i in range(len(request.prompt)): | ||
| sessions.append(VariableInterface.get_session(i + 1)) | ||
| sessions.append(VariableInterface.create_session(i + 1)) | ||
| if isinstance(request.stop, str): | ||
| request.stop = [request.stop] | ||
| random_seed = request.seed if request.seed is not None else None | ||
|
|
@@ -854,7 +870,7 @@ async def generate(request: GenerateReqInput, raw_request: Request = None): | |
| if error_check_ret is not None: | ||
| return error_check_ret | ||
|
|
||
| session = VariableInterface.get_session(request.session_id) | ||
| session = VariableInterface.create_session(request.session_id) | ||
|
|
||
| prompt = request.prompt | ||
| input_ids = request.input_ids | ||
|
|
@@ -1077,8 +1093,6 @@ async def sleep(raw_request: Request = None): | |
| if level not in (1, 2): | ||
| return create_error_response(HTTPStatus.BAD_REQUEST, 'The "level" query parameter must be 1 or 2.') | ||
| async_engine = VariableInterface.async_engine | ||
| async_engine.prepare_sleep() | ||
| await async_engine.stop_all_session() | ||
| await async_engine.sleep(level) | ||
| return Response(status_code=200) | ||
|
|
||
|
|
@@ -1152,8 +1166,12 @@ async def abort_request(request: AbortRequest, raw_request: Request = None): | |
| if request.abort_all: | ||
| await VariableInterface.async_engine.stop_all_session() | ||
| else: | ||
| session = VariableInterface.get_session(request.session_id) | ||
| session = VariableInterface.find_session(request.session_id) | ||
| if session is None: | ||
| return create_error_response(HTTPStatus.BAD_REQUEST, f'Session {request.session_id} not found.') | ||
| await session.async_abort() | ||
| session_mgr = VariableInterface.get_session_manager() | ||
| session_mgr.remove(session) | ||
|
lvhan028 marked this conversation as resolved.
|
||
| return Response(status_code=200) | ||
|
|
||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.