-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Fix unsafe pickle deserialization in gRPC PolicyServer (CVE-2026-26210) #1944
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,17 +1,59 @@ | ||||||||||
| from datetime import datetime | ||||||||||
| import hashlib | ||||||||||
| import hmac | ||||||||||
| import io | ||||||||||
| import json | ||||||||||
| import os | ||||||||||
| import secrets | ||||||||||
| from typing import Optional | ||||||||||
| import zmq | ||||||||||
| import pickle | ||||||||||
| import threading | ||||||||||
| import torch | ||||||||||
| import torch.multiprocessing as mp | ||||||||||
| import sys | ||||||||||
| current_file_path = os.path.abspath(__file__) | ||||||||||
| # sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..")) | ||||||||||
| import pickle | ||||||||||
| import argparse | ||||||||||
| from safetensors.torch import save, load as st_load | ||||||||||
| from ktransformers.server.balance_serve.settings import sched_ext, create_sched_settings, create_sched_settings_qwen2moe, create_sched_settings_qwen3moe | ||||||||||
|
|
||||||||||
| # HMAC key for message authentication between server and client. | ||||||||||
| # Set KTRANSFORMERS_RPC_SECRET in the environment, or a random key is | ||||||||||
| # generated at import time (single-process / inherited-by-fork use). | ||||||||||
| _RPC_SECRET = os.environ.get( | ||||||||||
| "KTRANSFORMERS_RPC_SECRET", "" | ||||||||||
| ).encode() or secrets.token_bytes(32) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def _sign(data: bytes) -> bytes: | ||||||||||
| return hmac.new(_RPC_SECRET, data, hashlib.sha256).digest() | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def _verify(data: bytes, sig: bytes) -> bool: | ||||||||||
| return hmac.compare_digest(_sign(data), sig) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def _serialize_msg(obj: dict) -> bytes: | ||||||||||
| """Serialize an RPC message to JSON bytes. Non-JSON-serializable values | ||||||||||
| are dropped with a placeholder so the frame always round-trips.""" | ||||||||||
| def _default(o): | ||||||||||
| return f"<non-serializable:{type(o).__name__}>" | ||||||||||
| return json.dumps(obj, default=_default).encode() | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The switch from |
||||||||||
|
|
||||||||||
|
|
||||||||||
| def _deserialize_msg(data: bytes) -> dict: | ||||||||||
| return json.loads(data) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def _serialize_tensors(tensor_dict: dict) -> bytes: | ||||||||||
| """Serialize a flat {name: tensor} dict with safetensors.""" | ||||||||||
| buf = io.BytesIO() | ||||||||||
| save(tensor_dict, buf) | ||||||||||
| return buf.getvalue() | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def _deserialize_tensors(data: bytes) -> dict: | ||||||||||
| """Deserialize safetensors bytes back to {name: tensor}.""" | ||||||||||
| return st_load(data) | ||||||||||
|
|
||||||||||
|
|
||||||||||
| if mp.get_start_method(allow_none=True) is None: | ||||||||||
|
|
@@ -24,11 +66,13 @@ | |||||||||
| class SchedulerServer: | ||||||||||
| def __init__(self, settings, main_args): | ||||||||||
| self.sched = sched_ext.create_scheduler(settings) | ||||||||||
|
|
||||||||||
| self.context = zmq.Context() | ||||||||||
| self.frontend = self.context.socket(zmq.ROUTER) | ||||||||||
| print(f"sched zmq rpc server on port {main_args.sched_port}") | ||||||||||
| self.frontend.bind(f"tcp://*:{main_args.sched_port}") | ||||||||||
|
|
||||||||||
| bind_addr = getattr(main_args, 'sched_bind', '127.0.0.1') | ||||||||||
| print(f"sched zmq rpc server on {bind_addr}:{main_args.sched_port}") | ||||||||||
| self.frontend.bind(f"tcp://{bind_addr}:{main_args.sched_port}") | ||||||||||
|
|
||||||||||
| self.backend = self.context.socket(zmq.DEALER) | ||||||||||
| self.backend.bind("inproc://backend") | ||||||||||
|
|
@@ -42,62 +86,78 @@ def stop_scheduler(self): | |||||||||
| def start_proxy(self): | ||||||||||
| zmq.proxy(self.frontend, self.backend) | ||||||||||
|
|
||||||||||
| def _send(self, worker, response: dict, tensor_data: bytes = b""): | ||||||||||
| payload = _serialize_msg(response) | ||||||||||
| sig = _sign(payload + tensor_data) | ||||||||||
| worker.send_multipart([sig, payload, tensor_data]) | ||||||||||
|
|
||||||||||
| def _recv(self, worker) -> dict: | ||||||||||
| parts = worker.recv_multipart() | ||||||||||
| if len(parts) != 3: | ||||||||||
| raise ValueError("Invalid message frame") | ||||||||||
| sig, payload, _ = parts | ||||||||||
| if not _verify(payload, sig): | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The HMAC verification in
Suggested change
|
||||||||||
| raise ValueError("HMAC verification failed") | ||||||||||
| return _deserialize_msg(payload) | ||||||||||
|
|
||||||||||
| def worker_routine(self): | ||||||||||
| worker = self.context.socket(zmq.REP) | ||||||||||
| worker.connect("inproc://backend") | ||||||||||
| while True: | ||||||||||
| try: | ||||||||||
| message = worker.recv() | ||||||||||
| data = pickle.loads(message) | ||||||||||
| data = self._recv(worker) | ||||||||||
|
|
||||||||||
| method = data.get('method') | ||||||||||
| params = data.get('params', {}) | ||||||||||
| # print(f"Received request: {method}") | ||||||||||
|
|
||||||||||
| if method == 'add_query': | ||||||||||
| query_add = params.get('query') | ||||||||||
| query_id = self.sched.add_query(query_add) | ||||||||||
| response = {'status': 'ok', 'query_id': query_id} | ||||||||||
| worker.send(pickle.dumps(response)) | ||||||||||
| self._send(worker, response) | ||||||||||
|
|
||||||||||
| elif method == 'cancel_query': | ||||||||||
| query_id = params.get('query_id') | ||||||||||
| self.sched.cancel(query_id) | ||||||||||
| response = {'status': 'ok'} | ||||||||||
| worker.send(pickle.dumps(response)) | ||||||||||
| self._send(worker, response) | ||||||||||
|
|
||||||||||
| elif method == 'update_last_batch': | ||||||||||
| updates = params.get('updates') | ||||||||||
|
|
||||||||||
| batch_todo = self.sched.update_last_batch(updates) | ||||||||||
|
|
||||||||||
| response = {'status': 'ok', 'batch_todo': batch_todo} | ||||||||||
| # print (batch_todo.query_lengths, batch_todo.query_ids) | ||||||||||
| worker.send(pickle.dumps(response)) | ||||||||||
| self._send(worker, response) | ||||||||||
|
|
||||||||||
| elif method == 'get_inference_context': | ||||||||||
| inference_context = self.sched.get_inference_context() | ||||||||||
| data = { | ||||||||||
| "k_cache":inference_context.k_cache, | ||||||||||
| "v_cache":inference_context.v_cache | ||||||||||
| print("Serializing KVCache with safetensors") | ||||||||||
|
|
||||||||||
| tensors = {} | ||||||||||
| for i, t in enumerate(inference_context.k_cache): | ||||||||||
| tensors[f"k_cache_{i}"] = t | ||||||||||
| for i, t in enumerate(inference_context.v_cache): | ||||||||||
| tensors[f"v_cache_{i}"] = t | ||||||||||
|
|
||||||||||
| tensor_bytes = _serialize_tensors(tensors) | ||||||||||
| k_count = len(inference_context.k_cache) | ||||||||||
| v_count = len(inference_context.v_cache) | ||||||||||
| response = { | ||||||||||
| 'status': 'ok', | ||||||||||
| 'k_cache_count': k_count, | ||||||||||
| 'v_cache_count': v_count, | ||||||||||
| } | ||||||||||
| print(f"Serializing KVCache") | ||||||||||
| data["k_cache"] = [mp.reductions.reduce_tensor(t) for t in data['k_cache']] | ||||||||||
| data["v_cache"] = [mp.reductions.reduce_tensor(t) for t in data['v_cache']] | ||||||||||
| # print(data) | ||||||||||
| response = {'status': 'ok', 'inference_context': data} | ||||||||||
|
|
||||||||||
| worker.send(pickle.dumps(response)) | ||||||||||
| # response['inference_context'].k_cache[0][0, 0, 0, 0, 0] = 1 | ||||||||||
| # print("k_cache update") | ||||||||||
| self._send(worker, response, tensor_bytes) | ||||||||||
|
|
||||||||||
| else: | ||||||||||
| response = {'status': 'error', 'message': 'Unknown method'} | ||||||||||
| worker.send(pickle.dumps(response)) | ||||||||||
| self._send(worker, response) | ||||||||||
|
|
||||||||||
| except Exception as e: | ||||||||||
| response = {'status': 'error', 'message': str(e)} | ||||||||||
| worker.send(pickle.dumps(response)) | ||||||||||
| try: | ||||||||||
| response = {'status': 'error', 'message': str(e)} | ||||||||||
| self._send(worker, response) | ||||||||||
| except Exception: | ||||||||||
| pass | ||||||||||
|
|
||||||||||
| def start_rpc_service(self): | ||||||||||
| try: | ||||||||||
|
|
@@ -125,70 +185,93 @@ def start_server(settings, main_args): | |||||||||
| server.start_rpc_service() | ||||||||||
|
|
||||||||||
|
|
||||||||||
| # Add async client for webserver | ||||||||||
| class SchedulerClient: | ||||||||||
| def __init__(self, sched_port): | ||||||||||
| address=f'tcp://localhost:{sched_port}' | ||||||||||
| address = f'tcp://localhost:{sched_port}' | ||||||||||
| self.address = address | ||||||||||
| self.context = zmq.Context() | ||||||||||
| self.socket = self.context.socket(zmq.REQ) | ||||||||||
| self.socket.connect(self.address) | ||||||||||
| print(f"Connected to server at {self.address}") | ||||||||||
|
|
||||||||||
| def __del__(self): | ||||||||||
| self.socket.close() | ||||||||||
| self.context.term() | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def _send(self, request: dict): | ||||||||||
| payload = _serialize_msg(request) | ||||||||||
| sig = _sign(payload) | ||||||||||
| self.socket.send_multipart([sig, payload, b""]) | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For protocol consistency, the client should sign the message including the empty tensor data frame, matching the server's expectation that the signature covers all data parts.
Suggested change
|
||||||||||
|
|
||||||||||
| def _recv(self) -> tuple: | ||||||||||
| parts = self.socket.recv_multipart() | ||||||||||
| if len(parts) != 3: | ||||||||||
| raise ValueError("Invalid message frame") | ||||||||||
| sig, payload, tensor_data = parts | ||||||||||
| if not _verify(payload + tensor_data, sig): | ||||||||||
| raise ValueError("HMAC verification failed") | ||||||||||
| return _deserialize_msg(payload), tensor_data | ||||||||||
|
|
||||||||||
| def send_request(self, method, params=None): | ||||||||||
| if params is None: | ||||||||||
| params = {} | ||||||||||
| request = { | ||||||||||
| 'method': method, | ||||||||||
| 'params': params | ||||||||||
| } | ||||||||||
| # print(f'send request {request}') | ||||||||||
| self.socket.send(pickle.dumps(request)) | ||||||||||
| response = self.socket.recv() | ||||||||||
| # print(response) | ||||||||||
| response = pickle.loads(response) | ||||||||||
| self._send(request) | ||||||||||
| response, _ = self._recv() | ||||||||||
| if response.get('status') == 'ok': | ||||||||||
| return response | ||||||||||
| else: | ||||||||||
| raise Exception(f"Error from server: {response.get('message')}") | ||||||||||
|
|
||||||||||
| def add_query(self, query): | ||||||||||
| response = self.send_request('add_query', {'query': query}) | ||||||||||
| return response.get('query_id') | ||||||||||
|
|
||||||||||
| def cancel_query(self, query_id): | ||||||||||
| self.send_request('cancel_query', {'query_id': query_id}) | ||||||||||
|
|
||||||||||
| def update_last_batch(self, updates): | ||||||||||
| response = self.send_request('update_last_batch', {'updates': updates}) | ||||||||||
| # print(f"update_last_batch response {response}") | ||||||||||
| return response.get('batch_todo') | ||||||||||
|
|
||||||||||
| def rebuild_inferece_context(self,response): | ||||||||||
| data = response.get('inference_context') | ||||||||||
|
|
||||||||||
| def rebuild_inferece_context(self, response=None, tensor_data=None): | ||||||||||
| if tensor_data is None: | ||||||||||
| raise ValueError("No tensor data received") | ||||||||||
| tensors = _deserialize_tensors(tensor_data) | ||||||||||
|
|
||||||||||
| inference_context = sched_ext.InferenceContext() | ||||||||||
| print('Rebuilding kvcache') | ||||||||||
| inference_context.k_cache = [fn(*args) for fn,args in data['k_cache']] | ||||||||||
| inference_context.v_cache = [fn(*args) for fn,args in data['v_cache']] | ||||||||||
| print('Rebuilding kvcache from safetensors') | ||||||||||
|
|
||||||||||
| k_count = response.get('k_cache_count', 0) | ||||||||||
| v_count = response.get('v_cache_count', 0) | ||||||||||
| inference_context.k_cache = [tensors[f"k_cache_{i}"] for i in range(k_count)] | ||||||||||
| inference_context.v_cache = [tensors[f"v_cache_{i}"] for i in range(v_count)] | ||||||||||
| return inference_context | ||||||||||
|
|
||||||||||
| def get_inference_context_raw(self): | ||||||||||
| response = self.send_request('get_inference_context') | ||||||||||
| return response | ||||||||||
|
|
||||||||||
| request = { | ||||||||||
| 'method': 'get_inference_context', | ||||||||||
| 'params': {} | ||||||||||
| } | ||||||||||
| self._send(request) | ||||||||||
| response, tensor_data = self._recv() | ||||||||||
| if response.get('status') == 'ok': | ||||||||||
| return response, tensor_data | ||||||||||
| else: | ||||||||||
| raise Exception(f"Error from server: {response.get('message')}") | ||||||||||
|
|
||||||||||
|
|
||||||||||
| if __name__ == '__main__': | ||||||||||
| parser = argparse.ArgumentParser() | ||||||||||
| parser.add_argument("--config", type=str, required=True) | ||||||||||
| args = parser.parse_args() | ||||||||||
| with open(args.config, "rb") as f: | ||||||||||
| main_args = pickle.load(f) | ||||||||||
| if main_args.architectures == "Qwen2MoeForCausalLM": | ||||||||||
| with open(args.config, "r") as f: | ||||||||||
| main_args = json.load(f) | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a mismatch in configuration file formats. While this file has been updated to use |
||||||||||
| main_args = argparse.Namespace(**main_args) | ||||||||||
| if main_args.architectures == "Qwen2MoeForCausalLM": | ||||||||||
| settings = create_sched_settings_qwen2moe(main_args) | ||||||||||
| elif main_args.architectures == "Qwen3MoeForCausalLM": | ||||||||||
| settings = create_sched_settings_qwen3moe(main_args) | ||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generating a random
_RPC_SECRETat import time usingsecrets.token_bytes(32)will cause HMAC verification to fail if the client and server run in separate processes (which is the case here, asbalance_serve.pyspawnssched_rpc.pyviasubprocess.Popen). Each process will generate its own unique secret. To fix this, ensure the secret is either provided via theKTRANSFORMERS_RPC_SECRETenvironment variable or that the parent process generates it and explicitly sets it in the environment before spawning the child process.