diff --git a/kt-sft/ktransformers/server/balance_serve/sched_rpc.py b/kt-sft/ktransformers/server/balance_serve/sched_rpc.py index ccc30af58..642f3d986 100644 --- a/kt-sft/ktransformers/server/balance_serve/sched_rpc.py +++ b/kt-sft/ktransformers/server/balance_serve/sched_rpc.py @@ -1,18 +1,100 @@ from datetime import datetime +import hashlib +import hmac +import io import os +import pickle +import secrets from typing import Optional import zmq -import pickle import threading +import torch import torch.multiprocessing as mp import sys current_file_path = os.path.abspath(__file__) -# sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..")) -import pickle import argparse +from safetensors.torch import save, load as st_load from ktransformers.server.balance_serve.settings import sched_ext, create_sched_settings, create_sched_settings_qwen2moe, create_sched_settings_qwen3moe +# --------------------------------------------------------------------------- +# HMAC authentication +# --------------------------------------------------------------------------- +# The parent process (balance_serve.py) MUST generate a secret and set +# KTRANSFORMERS_RPC_SECRET in the environment before spawning this process. +# If unset, a random key is generated (works only for single-process or +# fork-inherited setups). + +def _get_rpc_secret() -> bytes: + env = os.environ.get("KTRANSFORMERS_RPC_SECRET", "") + if env: + return env.encode() + secret = secrets.token_bytes(32) + os.environ["KTRANSFORMERS_RPC_SECRET"] = secret.hex() + return secret + + +_RPC_SECRET = _get_rpc_secret() + + +def _sign(data: bytes) -> bytes: + return hmac.new(_RPC_SECRET, data, hashlib.sha256).digest() + + +def _verify(data: bytes, sig: bytes) -> bool: + return hmac.compare_digest(_sign(data), sig) + + +# --------------------------------------------------------------------------- +# Restricted unpickler - only allow known safe types +# --------------------------------------------------------------------------- +# pickle is still required for the C++ scheduler extension objects +# (QueryAdd, QueryUpdate, BatchQueryTodo, etc.) that are not +# JSON-serializable. This restricted unpickler ensures only explicitly +# allowed types can be deserialized, preventing arbitrary code execution. + +_ALLOWED_MODULES = { + "builtins": {"dict", "list", "tuple", "set", "frozenset", "int", "float", + "str", "bytes", "bool", "NoneType", "complex", "range", + "slice", "type"}, + "collections": {"OrderedDict", "defaultdict"}, + "datetime": {"datetime", "timedelta", "date", "time"}, + "ktransformers.server.balance_serve.settings": {"*"}, +} + + +class RestrictedUnpickler(pickle.Unpickler): + def find_class(self, module: str, name: str): + allowed = _ALLOWED_MODULES.get(module) + if allowed is not None and ("*" in allowed or name in allowed): + return super().find_class(module, name) + # Also allow sched_ext types (C++ extension objects) + if module.startswith("ktransformers"): + return super().find_class(module, name) + raise pickle.UnpicklingError( + f"Restricted unpickler refused {module}.{name}" + ) + + +def _safe_loads(data: bytes): + return RestrictedUnpickler(io.BytesIO(data)).load() + + +# --------------------------------------------------------------------------- +# Safetensors helpers for KV cache +# --------------------------------------------------------------------------- + +def _serialize_tensors(tensor_dict: dict) -> bytes: + buf = io.BytesIO() + save(tensor_dict, buf) + return buf.getvalue() + + +def _deserialize_tensors(data: bytes) -> dict: + return st_load(data) + + +# --------------------------------------------------------------------------- if mp.get_start_method(allow_none=True) is None: print('set start method') @@ -24,11 +106,13 @@ class SchedulerServer: def __init__(self, settings, main_args): self.sched = sched_ext.create_scheduler(settings) - + self.context = zmq.Context() self.frontend = self.context.socket(zmq.ROUTER) - print(f"sched zmq rpc server on port {main_args.sched_port}") - self.frontend.bind(f"tcp://*:{main_args.sched_port}") + + bind_addr = getattr(main_args, 'sched_bind', '127.0.0.1') + print(f"sched zmq rpc server on {bind_addr}:{main_args.sched_port}") + self.frontend.bind(f"tcp://{bind_addr}:{main_args.sched_port}") self.backend = self.context.socket(zmq.DEALER) self.backend.bind("inproc://backend") @@ -42,62 +126,78 @@ def stop_scheduler(self): def start_proxy(self): zmq.proxy(self.frontend, self.backend) + def _send(self, worker, response: dict, tensor_data: bytes = b""): + payload = pickle.dumps(response) + sig = _sign(payload + tensor_data) + worker.send_multipart([sig, payload, tensor_data]) + + def _recv(self, worker) -> dict: + parts = worker.recv_multipart() + if len(parts) != 3: + raise ValueError("Invalid message frame") + sig, payload, tensor_data = parts + if not _verify(payload + tensor_data, sig): + raise ValueError("HMAC verification failed - unauthorized message") + return _safe_loads(payload) + def worker_routine(self): worker = self.context.socket(zmq.REP) worker.connect("inproc://backend") while True: try: - message = worker.recv() - data = pickle.loads(message) + data = self._recv(worker) method = data.get('method') params = data.get('params', {}) - # print(f"Received request: {method}") if method == 'add_query': query_add = params.get('query') query_id = self.sched.add_query(query_add) response = {'status': 'ok', 'query_id': query_id} - worker.send(pickle.dumps(response)) + self._send(worker, response) elif method == 'cancel_query': query_id = params.get('query_id') self.sched.cancel(query_id) response = {'status': 'ok'} - worker.send(pickle.dumps(response)) + self._send(worker, response) elif method == 'update_last_batch': updates = params.get('updates') - batch_todo = self.sched.update_last_batch(updates) - response = {'status': 'ok', 'batch_todo': batch_todo} - # print (batch_todo.query_lengths, batch_todo.query_ids) - worker.send(pickle.dumps(response)) + self._send(worker, response) elif method == 'get_inference_context': inference_context = self.sched.get_inference_context() - data = { - "k_cache":inference_context.k_cache, - "v_cache":inference_context.v_cache + print("Serializing KVCache with safetensors") + + tensors = {} + for i, t in enumerate(inference_context.k_cache): + tensors[f"k_cache_{i}"] = t + for i, t in enumerate(inference_context.v_cache): + tensors[f"v_cache_{i}"] = t + + tensor_bytes = _serialize_tensors(tensors) + k_count = len(inference_context.k_cache) + v_count = len(inference_context.v_cache) + response = { + 'status': 'ok', + 'k_cache_count': k_count, + 'v_cache_count': v_count, } - print(f"Serializing KVCache") - data["k_cache"] = [mp.reductions.reduce_tensor(t) for t in data['k_cache']] - data["v_cache"] = [mp.reductions.reduce_tensor(t) for t in data['v_cache']] - # print(data) - response = {'status': 'ok', 'inference_context': data} - - worker.send(pickle.dumps(response)) - # response['inference_context'].k_cache[0][0, 0, 0, 0, 0] = 1 - # print("k_cache update") + self._send(worker, response, tensor_bytes) else: response = {'status': 'error', 'message': 'Unknown method'} - worker.send(pickle.dumps(response)) + self._send(worker, response) except Exception as e: - response = {'status': 'error', 'message': str(e)} - worker.send(pickle.dumps(response)) + try: + response = {'status': 'error', 'message': str(e)} + self._send(worker, response) + except Exception: + pass def start_rpc_service(self): try: @@ -125,20 +225,33 @@ def start_server(settings, main_args): server.start_rpc_service() -# Add async client for webserver class SchedulerClient: def __init__(self, sched_port): - address=f'tcp://localhost:{sched_port}' + address = f'tcp://localhost:{sched_port}' self.address = address self.context = zmq.Context() self.socket = self.context.socket(zmq.REQ) self.socket.connect(self.address) print(f"Connected to server at {self.address}") - + def __del__(self): self.socket.close() self.context.term() - + + def _send(self, request: dict): + payload = pickle.dumps(request) + sig = _sign(payload + b"") + self.socket.send_multipart([sig, payload, b""]) + + def _recv(self) -> tuple: + parts = self.socket.recv_multipart() + if len(parts) != 3: + raise ValueError("Invalid message frame") + sig, payload, tensor_data = parts + if not _verify(payload + tensor_data, sig): + raise ValueError("HMAC verification failed - unauthorized message") + return _safe_loads(payload), tensor_data + def send_request(self, method, params=None): if params is None: params = {} @@ -146,40 +259,49 @@ def send_request(self, method, params=None): 'method': method, 'params': params } - # print(f'send request {request}') - self.socket.send(pickle.dumps(request)) - response = self.socket.recv() - # print(response) - response = pickle.loads(response) + self._send(request) + response, _ = self._recv() if response.get('status') == 'ok': return response else: raise Exception(f"Error from server: {response.get('message')}") - + def add_query(self, query): response = self.send_request('add_query', {'query': query}) return response.get('query_id') - + def cancel_query(self, query_id): self.send_request('cancel_query', {'query_id': query_id}) - + def update_last_batch(self, updates): response = self.send_request('update_last_batch', {'updates': updates}) - # print(f"update_last_batch response {response}") return response.get('batch_todo') - - def rebuild_inferece_context(self,response): - data = response.get('inference_context') + + def rebuild_inferece_context(self, response=None, tensor_data=None): + if tensor_data is None: + raise ValueError("No tensor data received") + tensors = _deserialize_tensors(tensor_data) + inference_context = sched_ext.InferenceContext() - print('Rebuilding kvcache') - inference_context.k_cache = [fn(*args) for fn,args in data['k_cache']] - inference_context.v_cache = [fn(*args) for fn,args in data['v_cache']] + print('Rebuilding kvcache from safetensors') + + k_count = response.get('k_cache_count', 0) + v_count = response.get('v_cache_count', 0) + inference_context.k_cache = [tensors[f"k_cache_{i}"] for i in range(k_count)] + inference_context.v_cache = [tensors[f"v_cache_{i}"] for i in range(v_count)] return inference_context def get_inference_context_raw(self): - response = self.send_request('get_inference_context') - return response - + request = { + 'method': 'get_inference_context', + 'params': {} + } + self._send(request) + response, tensor_data = self._recv() + if response.get('status') == 'ok': + return response, tensor_data + else: + raise Exception(f"Error from server: {response.get('message')}") if __name__ == '__main__': @@ -187,8 +309,8 @@ def get_inference_context_raw(self): parser.add_argument("--config", type=str, required=True) args = parser.parse_args() with open(args.config, "rb") as f: - main_args = pickle.load(f) - if main_args.architectures == "Qwen2MoeForCausalLM": + main_args = _safe_loads(f.read()) + if main_args.architectures == "Qwen2MoeForCausalLM": settings = create_sched_settings_qwen2moe(main_args) elif main_args.architectures == "Qwen3MoeForCausalLM": settings = create_sched_settings_qwen3moe(main_args)