Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 138 additions & 55 deletions kt-sft/ktransformers/server/balance_serve/sched_rpc.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,59 @@
from datetime import datetime
import hashlib
import hmac
import io
import json
import os
import secrets
from typing import Optional
import zmq
import pickle
import threading
import torch
import torch.multiprocessing as mp
import sys
current_file_path = os.path.abspath(__file__)
# sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", ".."))
import pickle
import argparse
from safetensors.torch import save, load as st_load
from ktransformers.server.balance_serve.settings import sched_ext, create_sched_settings, create_sched_settings_qwen2moe, create_sched_settings_qwen3moe

# HMAC key for message authentication between server and client.
# Set KTRANSFORMERS_RPC_SECRET in the environment, or a random key is
# generated at import time (single-process / inherited-by-fork use).
_RPC_SECRET = os.environ.get(
"KTRANSFORMERS_RPC_SECRET", ""
).encode() or secrets.token_bytes(32)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Generating a random _RPC_SECRET at import time using secrets.token_bytes(32) will cause HMAC verification to fail if the client and server run in separate processes (which is the case here, as balance_serve.py spawns sched_rpc.py via subprocess.Popen). Each process will generate its own unique secret. To fix this, ensure the secret is either provided via the KTRANSFORMERS_RPC_SECRET environment variable or that the parent process generates it and explicitly sets it in the environment before spawning the child process.



def _sign(data: bytes) -> bytes:
return hmac.new(_RPC_SECRET, data, hashlib.sha256).digest()


def _verify(data: bytes, sig: bytes) -> bool:
return hmac.compare_digest(_sign(data), sig)


def _serialize_msg(obj: dict) -> bytes:
"""Serialize an RPC message to JSON bytes. Non-JSON-serializable values
are dropped with a placeholder so the frame always round-trips."""
def _default(o):
return f"<non-serializable:{type(o).__name__}>"
return json.dumps(obj, default=_default).encode()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The switch from pickle to json for RPC message serialization is a critical breaking change. The _default function drops non-serializable objects, but the RPC protocol relies on passing complex objects like QueryAdd, QueryUpdate, and BatchQueryTodo (from the sched_ext extension). These objects are not JSON-serializable and will be replaced by placeholder strings (e.g., "<non-serializable:QueryAdd>"), causing the scheduler to fail when it receives these strings instead of the expected objects. You must implement a proper to_dict/from_dict mechanism for these types or use a serialization format that supports them.



def _deserialize_msg(data: bytes) -> dict:
return json.loads(data)


def _serialize_tensors(tensor_dict: dict) -> bytes:
"""Serialize a flat {name: tensor} dict with safetensors."""
buf = io.BytesIO()
save(tensor_dict, buf)
return buf.getvalue()


def _deserialize_tensors(data: bytes) -> dict:
"""Deserialize safetensors bytes back to {name: tensor}."""
return st_load(data)


if mp.get_start_method(allow_none=True) is None:
Expand All @@ -24,11 +66,13 @@
class SchedulerServer:
def __init__(self, settings, main_args):
self.sched = sched_ext.create_scheduler(settings)

self.context = zmq.Context()
self.frontend = self.context.socket(zmq.ROUTER)
print(f"sched zmq rpc server on port {main_args.sched_port}")
self.frontend.bind(f"tcp://*:{main_args.sched_port}")

bind_addr = getattr(main_args, 'sched_bind', '127.0.0.1')
print(f"sched zmq rpc server on {bind_addr}:{main_args.sched_port}")
self.frontend.bind(f"tcp://{bind_addr}:{main_args.sched_port}")

self.backend = self.context.socket(zmq.DEALER)
self.backend.bind("inproc://backend")
Expand All @@ -42,62 +86,78 @@ def stop_scheduler(self):
def start_proxy(self):
zmq.proxy(self.frontend, self.backend)

def _send(self, worker, response: dict, tensor_data: bytes = b""):
payload = _serialize_msg(response)
sig = _sign(payload + tensor_data)
worker.send_multipart([sig, payload, tensor_data])

def _recv(self, worker) -> dict:
parts = worker.recv_multipart()
if len(parts) != 3:
raise ValueError("Invalid message frame")
sig, payload, _ = parts
if not _verify(payload, sig):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The HMAC verification in _recv is inconsistent with the signing logic in _send. While _send signs payload + tensor_data, _recv only verifies the payload, ignoring the third part of the multipart message. This could lead to verification failures if a client sends non-empty tensor data or security issues if the tensor data is tampered with.

Suggested change
sig, payload, _ = parts
if not _verify(payload, sig):
sig, payload, tensor_data = parts
if not _verify(payload + tensor_data, sig):

raise ValueError("HMAC verification failed")
return _deserialize_msg(payload)

def worker_routine(self):
worker = self.context.socket(zmq.REP)
worker.connect("inproc://backend")
while True:
try:
message = worker.recv()
data = pickle.loads(message)
data = self._recv(worker)

method = data.get('method')
params = data.get('params', {})
# print(f"Received request: {method}")

if method == 'add_query':
query_add = params.get('query')
query_id = self.sched.add_query(query_add)
response = {'status': 'ok', 'query_id': query_id}
worker.send(pickle.dumps(response))
self._send(worker, response)

elif method == 'cancel_query':
query_id = params.get('query_id')
self.sched.cancel(query_id)
response = {'status': 'ok'}
worker.send(pickle.dumps(response))
self._send(worker, response)

elif method == 'update_last_batch':
updates = params.get('updates')

batch_todo = self.sched.update_last_batch(updates)

response = {'status': 'ok', 'batch_todo': batch_todo}
# print (batch_todo.query_lengths, batch_todo.query_ids)
worker.send(pickle.dumps(response))
self._send(worker, response)

elif method == 'get_inference_context':
inference_context = self.sched.get_inference_context()
data = {
"k_cache":inference_context.k_cache,
"v_cache":inference_context.v_cache
print("Serializing KVCache with safetensors")

tensors = {}
for i, t in enumerate(inference_context.k_cache):
tensors[f"k_cache_{i}"] = t
for i, t in enumerate(inference_context.v_cache):
tensors[f"v_cache_{i}"] = t

tensor_bytes = _serialize_tensors(tensors)
k_count = len(inference_context.k_cache)
v_count = len(inference_context.v_cache)
response = {
'status': 'ok',
'k_cache_count': k_count,
'v_cache_count': v_count,
}
print(f"Serializing KVCache")
data["k_cache"] = [mp.reductions.reduce_tensor(t) for t in data['k_cache']]
data["v_cache"] = [mp.reductions.reduce_tensor(t) for t in data['v_cache']]
# print(data)
response = {'status': 'ok', 'inference_context': data}

worker.send(pickle.dumps(response))
# response['inference_context'].k_cache[0][0, 0, 0, 0, 0] = 1
# print("k_cache update")
self._send(worker, response, tensor_bytes)

else:
response = {'status': 'error', 'message': 'Unknown method'}
worker.send(pickle.dumps(response))
self._send(worker, response)

except Exception as e:
response = {'status': 'error', 'message': str(e)}
worker.send(pickle.dumps(response))
try:
response = {'status': 'error', 'message': str(e)}
self._send(worker, response)
except Exception:
pass

def start_rpc_service(self):
try:
Expand Down Expand Up @@ -125,70 +185,93 @@ def start_server(settings, main_args):
server.start_rpc_service()


# Add async client for webserver
class SchedulerClient:
def __init__(self, sched_port):
address=f'tcp://localhost:{sched_port}'
address = f'tcp://localhost:{sched_port}'
self.address = address
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REQ)
self.socket.connect(self.address)
print(f"Connected to server at {self.address}")

def __del__(self):
self.socket.close()
self.context.term()


def _send(self, request: dict):
payload = _serialize_msg(request)
sig = _sign(payload)
self.socket.send_multipart([sig, payload, b""])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For protocol consistency, the client should sign the message including the empty tensor data frame, matching the server's expectation that the signature covers all data parts.

Suggested change
sig = _sign(payload)
self.socket.send_multipart([sig, payload, b""])
sig = _sign(payload + b"")
self.socket.send_multipart([sig, payload, b""])


def _recv(self) -> tuple:
parts = self.socket.recv_multipart()
if len(parts) != 3:
raise ValueError("Invalid message frame")
sig, payload, tensor_data = parts
if not _verify(payload + tensor_data, sig):
raise ValueError("HMAC verification failed")
return _deserialize_msg(payload), tensor_data

def send_request(self, method, params=None):
if params is None:
params = {}
request = {
'method': method,
'params': params
}
# print(f'send request {request}')
self.socket.send(pickle.dumps(request))
response = self.socket.recv()
# print(response)
response = pickle.loads(response)
self._send(request)
response, _ = self._recv()
if response.get('status') == 'ok':
return response
else:
raise Exception(f"Error from server: {response.get('message')}")

def add_query(self, query):
response = self.send_request('add_query', {'query': query})
return response.get('query_id')

def cancel_query(self, query_id):
self.send_request('cancel_query', {'query_id': query_id})

def update_last_batch(self, updates):
response = self.send_request('update_last_batch', {'updates': updates})
# print(f"update_last_batch response {response}")
return response.get('batch_todo')

def rebuild_inferece_context(self,response):
data = response.get('inference_context')

def rebuild_inferece_context(self, response=None, tensor_data=None):
if tensor_data is None:
raise ValueError("No tensor data received")
tensors = _deserialize_tensors(tensor_data)

inference_context = sched_ext.InferenceContext()
print('Rebuilding kvcache')
inference_context.k_cache = [fn(*args) for fn,args in data['k_cache']]
inference_context.v_cache = [fn(*args) for fn,args in data['v_cache']]
print('Rebuilding kvcache from safetensors')

k_count = response.get('k_cache_count', 0)
v_count = response.get('v_cache_count', 0)
inference_context.k_cache = [tensors[f"k_cache_{i}"] for i in range(k_count)]
inference_context.v_cache = [tensors[f"v_cache_{i}"] for i in range(v_count)]
return inference_context

def get_inference_context_raw(self):
response = self.send_request('get_inference_context')
return response

request = {
'method': 'get_inference_context',
'params': {}
}
self._send(request)
response, tensor_data = self._recv()
if response.get('status') == 'ok':
return response, tensor_data
else:
raise Exception(f"Error from server: {response.get('message')}")


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
args = parser.parse_args()
with open(args.config, "rb") as f:
main_args = pickle.load(f)
if main_args.architectures == "Qwen2MoeForCausalLM":
with open(args.config, "r") as f:
main_args = json.load(f)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a mismatch in configuration file formats. While this file has been updated to use json.load, the calling code in kt-sft/ktransformers/server/backend/interfaces/balance_serve.py (line 322) still uses pickle.dump(args, temp_file). This will result in a json.decoder.JSONDecodeError when the scheduler attempts to start. Both sides must be updated to use JSON.

main_args = argparse.Namespace(**main_args)
if main_args.architectures == "Qwen2MoeForCausalLM":
settings = create_sched_settings_qwen2moe(main_args)
elif main_args.architectures == "Qwen3MoeForCausalLM":
settings = create_sched_settings_qwen3moe(main_args)
Expand Down