Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 23 additions & 16 deletions tpu_inference/offload/offload_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,15 @@ def allocate_for_save(
) -> Tuple[list[CPUChunk], list[int]] | None:
# filter out chunks that are already stored
num_chunks = len(chunk_hashes)
new_chunk_idxs = [
i for i in range(num_chunks)
if chunk_hashes[i] not in self.cpu_cache
]

# Deduplicate chunk_hashes while keeping track of original indices
seen_hashes = {}
new_chunk_idxs = []
for i, chunk_hash in enumerate(chunk_hashes):
if chunk_hash not in self.cpu_cache:
if chunk_hash not in seen_hashes:
seen_hashes[chunk_hash] = i
new_chunk_idxs.append(i)

num_new_chunks = len(new_chunk_idxs)
if num_new_chunks == 0:
Expand Down Expand Up @@ -224,27 +229,27 @@ def complete_save(self, chunk_hashes: list[ChunkHash]) -> None:
""" After store completion, mark the chunk to be ready to load."""
for chunk_hash in chunk_hashes:
chunk = self.cpu_cache[chunk_hash]
assert not chunk.is_ready_to_load
if chunk.is_ready_to_load:
logger.warning(
f"Chunk {chunk_hash} is already ready to load. Ignoring duplicate confirmation in OffloadManager."
)
continue
# mark ready to load
chunk.touch()
assert chunk.is_ready_to_load

def complete_load(self, chunk_hashes: list[ChunkHash]) -> None:
for chunk_hash in chunk_hashes:
chunk = self.cpu_cache[chunk_hash]
assert chunk.is_in_use
if not chunk.is_in_use:
logger.warning(
f"Chunk {chunk_hash} is not in use (ref_cnt={chunk.ref_cnt}). Ignoring duplicate load confirmation in OffloadManager."
)
continue
chunk.untouch()

def mark_completion(self, chunk_ids, operation: Literal['save',
'load']) -> None:
try:
chunk_hashes = [
self.chunk_pool.allocated_id_to_hash_map[chunk_id]
for chunk_id in chunk_ids
]
except Exception as e:
raise ValueError(f' failed to retrieve chunk hashes: {e}')

chunk_hashes = []
unknown_chunk_ids = []
for chunk_id in chunk_ids:
Expand Down Expand Up @@ -401,8 +406,9 @@ def free(self,
num_freed_blocks = num_finished_blocks
if self._blocks_for_load[req_id] < num_freed_blocks:
logger.warning(
f" Req({req_id}) has {num_finished_blocks} load staging buffer to free, but only has {self._blocks_for_load[req_id]} on record."
f" Req({req_id}) has {num_finished_blocks} load staging buffer to free, but only has {self._blocks_for_load[req_id]} on record. Capping to recorded value."
)
num_freed_blocks = max(0, self._blocks_for_load[req_id])

self._blocks_for_load[req_id] -= num_freed_blocks
if self._blocks_for_load[req_id] <= 0:
Expand All @@ -421,8 +427,9 @@ def free(self,
num_freed_blocks = num_finished_blocks
if self._blocks_for_save[req_id] < num_freed_blocks:
logger.warning(
f" Req({req_id}) has {num_finished_blocks} save staging buffer to free, but only has {self._blocks_for_save[req_id]} on record."
f" Req({req_id}) has {num_finished_blocks} save staging buffer to free, but only has {self._blocks_for_save[req_id]} on record. Capping to recorded value."
)
num_freed_blocks = max(0, self._blocks_for_save[req_id])

self._blocks_for_save[req_id] -= num_freed_blocks
if self._blocks_for_save[req_id] <= 0:
Expand Down
132 changes: 101 additions & 31 deletions tpu_inference/offload/tpu_offload_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
import copy
import random
import time
from collections import defaultdict
from collections import Counter, defaultdict
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal, Optional
Expand Down Expand Up @@ -349,6 +349,37 @@ def is_empty(self) -> bool:
return self.num_finished_blocks == 0

def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
# NOTE: This method is functionally critical for control flow in multi-node setups.
# In addition to passive metrics, this stats object carries completion signals
# (finished chunk IDs) from Workers to the Scheduler. Aggregation ensures
# the Scheduler receives a unified view of all finished chunks across all nodes,
# allowing it to correctly free staging buffers and update internal tracking.
if not isinstance(other, KVOffloadConnectorStats):
logger.warning(
f"Cannot aggregate with non-KVOffloadConnectorStats: {type(other)}"
)
return self

for req, chunks in other.data["finished_save_chunks"].items():
if req not in self.data["finished_save_chunks"]:
self.data["finished_save_chunks"][req] = chunks
else:
c1 = Counter(self.data["finished_save_chunks"][req])
c2 = Counter(chunks)
# Keep max count for each chunk to avoid multiplication across workers
c_max = c1 | c2
self.data["finished_save_chunks"][req] = list(c_max.elements())

for req, chunks in other.data["finished_load_chunks"].items():
if req not in self.data["finished_load_chunks"]:
self.data["finished_load_chunks"][req] = chunks
else:
c1 = Counter(self.data["finished_load_chunks"][req])
c2 = Counter(chunks)
# Keep max count for each chunk to avoid multiplication across workers
c_max = c1 | c2
self.data["finished_load_chunks"][req] = list(c_max.elements())

return self

def reduce(self) -> dict[str, int | float]:
Expand All @@ -373,8 +404,13 @@ def reduce(self) -> dict[str, int | float]:

@property
def num_finished_blocks(self) -> int:
return len(self.data["finished_save_chunks"]) + len(
self.data["finished_load_chunks"])
total_save = sum(
len(chunks)
for chunks in self.data["finished_save_chunks"].values())
total_load = sum(
len(chunks)
for chunks in self.data["finished_load_chunks"].values())
return total_save + total_load


# The metadata used for communicating between scheduler and worker connectors.
Expand Down Expand Up @@ -1141,19 +1177,30 @@ def update_connector_output(self, connector_output: KVConnectorOutput):

for req_id, saved_chunk_ids in connector_output.kv_connector_stats.data[
"finished_save_chunks"].items():
num_saved_chunks = len(saved_chunk_ids)
# NOTE(jcgu): there might be in-flight savings even if the request has finished logically.
# This is handled by tracking in-flight operations independently in _reqs_being_saved
# and updating resource managers regardless of request lifecycle status.
valid_saved_chunks = []
for saved_chunk_id in saved_chunk_ids:
if saved_chunk_id in self._reqs_being_saved[req_id]:
self._reqs_being_saved[req_id].remove(saved_chunk_id)
valid_saved_chunks.append(saved_chunk_id)
else:
logger.debug(
f"Ignoring duplicate or unknown saved chunk confirmation: req={req_id}, chunk={saved_chunk_id}"
)

if not valid_saved_chunks:
continue

num_valid_saved_chunks = len(valid_saved_chunks)
logger.debug(
f" finished_save_chunks for {req_id}: {saved_chunk_ids}")
f" valid finished_save_chunks for {req_id}: {valid_saved_chunks}"
)
# free staging blocks
self.staging_buffer_manager.free(
req_id, usage="save", num_finished_blocks=num_saved_chunks)
req_id, usage="save", num_finished_blocks=num_valid_saved_chunks)

# update in-flight save
# NOTE(jcgu): there might be in-flight savings,
# even if the requests has been finished.
for saved_chunk_id in saved_chunk_ids:
assert saved_chunk_id in self._reqs_being_saved[req_id]
self._reqs_being_saved[req_id].remove(saved_chunk_id)
if len(self._reqs_being_saved[req_id]) == 0:
self._reqs_being_saved.pop(req_id, None)
else:
Expand All @@ -1162,30 +1209,40 @@ def update_connector_output(self, connector_output: KVConnectorOutput):
)

# update the status of occupied cpu chunks
self.offload_manager.mark_completion(saved_chunk_ids, "save")
self.offload_manager.mark_completion(valid_saved_chunks, "save")

for req_id, loaded_chunk_ids in connector_output.kv_connector_stats.data[
"finished_load_chunks"].items():
num_loaded_chunks = len(loaded_chunk_ids)
valid_loaded_chunks = []
for loaded_chunk_id in loaded_chunk_ids:
if loaded_chunk_id in self._reqs_being_loaded[req_id]:
self._reqs_being_loaded[req_id].remove(loaded_chunk_id)
valid_loaded_chunks.append(loaded_chunk_id)
else:
logger.debug(
f"Ignoring duplicate or unknown loaded chunk confirmation: req={req_id}, chunk={loaded_chunk_id}"
)

if not valid_loaded_chunks:
continue

num_valid_loaded_chunks = len(valid_loaded_chunks)
logger.debug(
f" finished_load_chunks for {req_id}: {num_loaded_chunks}"
f" valid finished_load_chunks for {req_id}: {num_valid_loaded_chunks}"
)
self.staging_buffer_manager.free(
req_id,
usage="load",
num_finished_blocks=num_loaded_chunks)
# update in-flight save
for loaded_chunk_id in loaded_chunk_ids:
assert loaded_chunk_id in self._reqs_being_loaded[req_id]
self._reqs_being_loaded[req_id].remove(loaded_chunk_id)
num_finished_blocks=num_valid_loaded_chunks)

if len(self._reqs_being_loaded[req_id]) == 0:
self._reqs_being_loaded.pop(req_id, None)
else:
logger.debug(
f" remaining_loading_blocks:{req_id}, {self._reqs_being_loaded[req_id]}."
)
# update the status of occupied cpu chunks
self.offload_manager.mark_completion(loaded_chunk_ids, "load")
self.offload_manager.mark_completion(valid_loaded_chunks, "load")

def request_finished(
self,
Expand Down Expand Up @@ -1714,7 +1771,7 @@ def _batched_gather_tpu_blocks(
return gathered_kv_caches_tpu, manifest, total_num_blocks_to_save

def _transfer_and_register_cpu_chunks(self,
flat_kv_caches_tpu: Any,
chunks_on_cpu: list[jax.Array],
total_num_blocks_to_save: int,
manifest: list[SaveReqInfo],
is_batched: bool = False):
Expand Down Expand Up @@ -1757,13 +1814,11 @@ def _transfer_and_register_cpu_chunks(self,
start_time = time.time()

# 1. Swap Out the buffer
chunks_on_cpu = None
# D2H
chunks_on_cpu = []
for i in range(total_num_blocks_to_save):
chunks_on_cpu.append(
jax.device_put(flat_kv_caches_tpu[i],
self.expanded_host_sharding))
# Note: The actual jax.device_put (D2H dispatch) has been moved to the
# main thread (start_save_kv / _start_batched_save_kv) to ensure
# deterministic, globally aligned dispatch order across nodes.
# Here in the background thread, we only need to wait for the transfer
# to complete.
jax.block_until_ready(chunks_on_cpu)
# no split

Expand Down Expand Up @@ -1841,8 +1896,16 @@ def _async_batch_transfer_task(*args, **kwargs):
# Note: We use manifest for the pending future tracking.
# record_save will be handled in the main thread by _process_completed_saves.

# Dispatch to CPU (Main Thread)
chunks_on_cpu = []
if flat_kv_caches_tpu is not None:
for i in range(total_num_blocks_to_save):
chunks_on_cpu.append(
jax.device_put(flat_kv_caches_tpu[i],
self.expanded_host_sharding))

future = self.save_executor.submit(_async_batch_transfer_task,
flat_kv_caches_tpu,
chunks_on_cpu,
total_num_blocks_to_save,
manifest,
is_batched=True)
Expand Down Expand Up @@ -1954,9 +2017,16 @@ def _async_transfer_task(req_id, *args):
# 2. ASYNC NON-BLOCKING: Transfer to CPU and Register
logger.debug(
f"Submitting transfer task for request {meta.req_id}")
# Dispatch to CPU (Main Thread)
chunks_on_cpu = []
for i in range(num_blocks_to_save):
chunks_on_cpu.append(
jax.device_put(flat_kv_caches_tpu[i],
self.expanded_host_sharding))

future = self.save_executor.submit(_async_transfer_task,
meta.req_id,
flat_kv_caches_tpu,
chunks_on_cpu,
num_blocks_to_save, [info],
False)

Expand Down