Skip to content

Fix multi-node KV offloading state desynchronization and JAX dispatch…#2983

Draft
amitkumar307d wants to merge 1 commit into
vllm-project:mainfrom
amitkumar307d:multinode-kv-offloading
Draft

Fix multi-node KV offloading state desynchronization and JAX dispatch…#2983
amitkumar307d wants to merge 1 commit into
vllm-project:mainfrom
amitkumar307d:multinode-kv-offloading

Conversation

@amitkumar307d

@amitkumar307d amitkumar307d commented Jun 24, 2026

Copy link
Copy Markdown
Contributor

Description

This PR fixes state desynchronization, JAX dispatch deadlocks, and resource leaks in Multi-Node KV Offloading, ensuring stability during high-concurrency benchmarks.

Context & Problem:
Previously, running high-concurrency benchmarks (like prefix_repetition) on large models (e.g., Qwen3-Coder 480B) with multi-node KV offloading enabled led to fatal crashes (AssertionError) and runtime hangs.

  1. State Desynchronization: In multi-node setups, completion signals (finished chunk IDs) from worker nodes were being discarded in the stats aggregation phase, causing the Scheduler to miss confirmations and desync buffer accounting.
  2. JAX Dispatch Deadlocks: Running jax.device_put (D2H) in background threads caused non-deterministic dispatch orders across nodes, leading to low-level TPU runtime halts.
  3. Resource Leaks: Duplicate hashes in a batch resulted in multiple physical CPU chunks being allocated for the same content, wasting CPU RAM.

Solution

  • Robust Distributed State Synchronization:
    • Implemented KVOffloadConnectorStats.aggregate using collections.Counter and the union (|) operator to correctly merge completion signals across all worker nodes without artificially multiplying chunk counts.
    • Fixed KVOffloadConnectorStats.num_finished_blocks to accurately count total finished chunks instead of just request entries.
  • Deterministic JAX Dispatch Order: Moved jax.device_put (D2H dispatch) from background threads to the main thread in TPUOffloadConnectorWorker.start_save_kv (and batched version) to ensure globally aligned dispatch order across all TPU nodes.
  • Resource Leak Prevention: Added deduplication of chunk_hashes in OffloadManager.allocate_for_save to ensure only one CPUChunk is allocated per unique content hash in a batch.
  • Resilient Bookkeeping Cleanup: Relaxed strict assertions and replaced remove() with discard() in update_connector_output to handle redundant or late completion signals gracefully without crashing the engine.

Tests

  • Verified with Qwen3-Coder 480B on a multi-node TPU cluster running prefix_repetition benchmarks under high concurrency.
  • Confirmed that the AssertionError state desync is resolved and the server runs to completion without JAX runtime hangs.

Qwen3-Coder 480B - Server command:

python3 -m vllm.entrypoints.openai.api_server \
--host=0.0.0.0 \
--port=8000 \
--tensor-parallel-size=16 \
--max-model-len=102400 \
--load-format=runai_streamer \
--kv-cache-dtype=fp8 \
--gpu-memory-utilization=0.8 \
--data-parallel-size=1 \
--max-num-batched-tokens=16384 \
--max-num-seqs=512 \
--model=Qwen/Qwen3-Coder-480B-A35B-Instruct \
--served-model-name=Qwen/Qwen3-Coder-480B-A35B-Instruct \
--enable-prefix-caching \
--async-scheduling \
--enable-expert-parallel \
--kv-transfer-config='{"kv_connector": "TPUOffloadConnector", "kv_connector_module_path": "tpu_inference.offload.tpu_offload_connector","kv_role": "kv_both", "kv_connector_extra_config": {"cpu_bytes_to_use": 107374182400, "lazy_offload": false}}'

Qwen3-Coder-480B - Client Command:

vllm bench serve   --backend=openai   --model=Qwen/Qwen3-Coder-480B-A35B-Instruct   --dataset-name=prefix_repetition   --host=localhost   --port=8000   --seed=123   --num-prompts=32   --max-concurrency=32   
--prefix-repetition-prefix-len=19424   --prefix-repetition-suffix-len=32   --prefix-repetition-output-len=1024   --prefix-repetition-num-prefixes=4   --percentile-metrics='ttft,tpot,itl,e2el'   --ignore-eos

Logs:

(APIServer pid=246100) INFO 06-24 08:44:22 [loggers.py:273] Engine 000: Avg prompt throughput: 3276.5 tokens/s, Avg generation throughput: 1095.4 tokens/s, Running: 32 reqs, Waiting: 0 reqs, GPU KV cache usage: 7.0%, Prefix cache hit rate: 85.2%, External prefix cache hit rate: 0.0%
(APIServer pid=246100) INFO 06-24 08:44:22 [metrics.py:103] KV Transfer metrics: Num finished save chunks =72, Num finished load chunks=0
(APIServer pid=246100) INFO 06-24 08:44:32 [loggers.py:273] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 1302.3 tokens/s, Running: 32 reqs, Waiting: 0 reqs, GPU KV cache usage: 8.1%, Prefix cache hit rate: 85.2%, External prefix cache hit rate: 0.0%
(APIServer pid=246100) INFO:     127.0.0.1:60682 - "GET /metrics HTTP/1.1" 200 OK
(APIServer pid=246100) INFO:     127.0.0.1:60688 - "GET /metrics HTTP/1.1" 200 OK

Checklist

  • I have performed a self-review of my code.
  • [] I have necessary comments in my code, particularly in hard-to-understand areas.
  • [] I have made or will make corresponding changes to any relevant documentation.

@amitkumar307d amitkumar307d force-pushed the multinode-kv-offloading branch 2 times, most recently from 129d7fb to 792884a Compare June 24, 2026 14:55
@amitkumar307d amitkumar307d force-pushed the multinode-kv-offloading branch 3 times, most recently from b5fc33c to 10e7c53 Compare June 24, 2026 16:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant