-
Notifications
You must be signed in to change notification settings - Fork 690
Map user-input session_id to internal session_id to maintain session identity #4523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -103,25 +103,43 @@ class VariableInterface: | |
| tool_parser: ToolParser | None = None | ||
| allow_terminate_by_client: bool = False | ||
| enable_abort_handling: bool = False | ||
|
|
||
| @staticmethod | ||
| def get_session(session_id: int) -> Session: | ||
| session_mgr = VariableInterface.get_session_manager() | ||
| if session_id == -1: | ||
| # map user input session_id to inside session_id | ||
| user_session_id_map: dict[int, int] = {} | ||
|
|
||
|
lvhan028 marked this conversation as resolved.
Outdated
|
||
| @classmethod | ||
| def create_session(cls, user_session_id: int | None = None) -> Session: | ||
| session_mgr = cls.get_session_manager() | ||
| if user_session_id is None or user_session_id == -1: | ||
| # user doesn't input session_id, so we need to generate a new one | ||
| session = session_mgr.get() | ||
| else: | ||
| session = session_mgr.get(session_id) | ||
| # find the inside session_id by user_session_id, create a new one | ||
| # if it doesn't exist and update the user_session_id_map | ||
| session_id = cls.user_session_id_map.get(user_session_id, None) | ||
| session = session_mgr.get(session_id, create_if_not_exists=True) | ||
| cls.user_session_id_map[user_session_id] = session.session_id | ||
|
lvhan028 marked this conversation as resolved.
Outdated
|
||
| # Stamp epoch for ``stop_all_session`` / ``abort_all`` coordination in ``AsyncEngine.generate``. | ||
| session.epoch = VariableInterface.async_engine.epoch | ||
| session.epoch = cls.async_engine.epoch | ||
| return session | ||
|
|
||
| @staticmethod | ||
| def get_session_manager(): | ||
| return VariableInterface.async_engine.session_mgr | ||
| @classmethod | ||
| def find_session(cls, user_session_id: int) -> Session | None: | ||
| """Find the session by user_session_id. | ||
|
|
||
| Users cannot access inner session_id directly. | ||
| """ | ||
| if user_session_id not in cls.user_session_id_map: | ||
| return None | ||
| session_id = cls.user_session_id_map.get(user_session_id, None) | ||
| return cls.get_session_manager().get(session_id, create_if_not_exists=False) | ||
|
|
||
| @staticmethod | ||
| def get_engine_config(): | ||
| return VariableInterface.async_engine.backend_config | ||
| @classmethod | ||
| def get_session_manager(cls): | ||
| return cls.async_engine.session_mgr | ||
|
|
||
| @classmethod | ||
| def get_engine_config(cls): | ||
| return cls.async_engine.backend_config | ||
|
|
||
|
|
||
| router = APIRouter() | ||
|
|
@@ -419,7 +437,7 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque | |
| return error_check_ret | ||
| if VariableInterface.tool_parser is not None: | ||
| request = VariableInterface.tool_parser.adjust_request(request) | ||
| session = VariableInterface.get_session(request.session_id) | ||
| session = VariableInterface.create_session(request.session_id) | ||
|
||
|
|
||
| json_request = await raw_request.json() | ||
| migration_request = json_request.pop('migration_request', None) | ||
|
|
@@ -793,10 +811,10 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None | |
| sessions = [] | ||
| if isinstance(request.prompt, str): | ||
| request.prompt = [request.prompt] | ||
| sessions.append(VariableInterface.get_session(request.session_id)) | ||
| sessions.append(VariableInterface.create_session(request.session_id)) | ||
| elif isinstance(request.prompt, list): | ||
| for i in range(len(request.prompt)): | ||
| sessions.append(VariableInterface.get_session(i + 1)) | ||
| sessions.append(VariableInterface.create_session(i + 1)) | ||
| if isinstance(request.stop, str): | ||
| request.stop = [request.stop] | ||
| random_seed = request.seed if request.seed else None | ||
|
|
@@ -971,7 +989,7 @@ async def generate(request: GenerateReqInput, raw_request: Request = None): | |
| if error_check_ret is not None: | ||
| return error_check_ret | ||
|
|
||
| session = VariableInterface.get_session(request.session_id) | ||
| session = VariableInterface.create_session(request.session_id) | ||
|
|
||
| prompt = request.prompt | ||
| input_ids = request.input_ids | ||
|
|
@@ -1265,8 +1283,12 @@ async def abort_request(request: AbortRequest, raw_request: Request = None): | |
| if request.abort_all: | ||
| await VariableInterface.async_engine.stop_all_session() | ||
| else: | ||
| session = VariableInterface.get_session(request.session_id) | ||
| session = VariableInterface.find_session(request.session_id) | ||
| if session is None: | ||
| return create_error_response(HTTPStatus.BAD_REQUEST, f'Session {request.session_id} not found.') | ||
| await session.async_abort() | ||
| session_mgr = VariableInterface.get_session_manager() | ||
| session_mgr.remove(session) | ||
|
lvhan028 marked this conversation as resolved.
|
||
| return Response(status_code=200) | ||
|
|
||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.