diff --git a/tests/offload/tpu_offload_connector_worker_test.py b/tests/offload/tpu_offload_connector_worker_test.py index f6714ef485..011de2670e 100644 --- a/tests/offload/tpu_offload_connector_worker_test.py +++ b/tests/offload/tpu_offload_connector_worker_test.py @@ -838,3 +838,162 @@ def test_host_memory_kind_default(self): connector = self._create_connector() worker = connector.connector_worker self.assertEqual(worker.host_sharding.memory_kind, "pinned_host") + + @parameterized.named_parameters( + dict(testcase_name="_delay_0.5s", delay=0.5), + dict(testcase_name="_delay_1.0s", delay=1.0), + ) + def test_tpu_connector_raw_race_protection(self, delay): + """ + Verifies that start_load_kv blocks if a load is requested for a chunk + that is still being saved locally in the background. + """ + # 1. Setup + connector = self._create_connector() + worker = connector.connector_worker + + block_to_save = 0 + dst_chunk = 42 + + # 2. Prepare Save Metadata + save_spec = SaveSpec( + num_skip_leading_tokens=0, + num_total_tokens=self.block_size, + is_final_save=False, + skip_save=False, + src_blocks=[block_to_save], + dst_chunks=[dst_chunk], + ) + req_meta_save = TPUReqMeta( + req_id="raw_save_req", + token_ids=list(range(self.block_size)), + local_block_ids=[block_to_save], + save_spec=save_spec, + ) + + # 3. Prepare Load Metadata (same chunk) + load_spec = LoadSpec( + num_matched_tokens=self.block_size, + dst_blocks=[block_to_save], + src_chunks=[dst_chunk], + can_load=True, + num_skip_leading_tokens=0, + ) + req_meta_load = TPUReqMeta( + req_id="raw_load_req", + token_ids=list(range(self.block_size)), + local_block_ids=[block_to_save], + load_spec=load_spec, + ) + + # 4. Patch _transfer_and_register_cpu_chunks to inject delay + original_transfer = worker._transfer_and_register_cpu_chunks + + def delayed_transfer(*args, **kwargs): + time.sleep(delay) # Inject delay in background save + return original_transfer(*args, **kwargs) + + # 5. Execute Save (Async) + connector.bind_connector_metadata( + TPUOffloadConnectorMetadata(requests_meta=[req_meta_save])) + + with mock.patch.object(worker, '_transfer_and_register_cpu_chunks', wraps=delayed_transfer): + worker.start_save_kv() + + # Verify it was submitted (it is in flight) + self.assertIn(dst_chunk, worker._local_in_flight_saves) + + # 6. Execute Load (Blocking) immediately + connector.bind_connector_metadata( + TPUOffloadConnectorMetadata(requests_meta=[req_meta_load])) + + start_time = time.time() + worker.start_load_kv(fwd_ctx=None) + end_time = time.time() + + # 7. Verification + # Load should have taken at least around the delay time because it was blocked by the delayed save + load_duration = end_time - start_time + logger.info(f"RAW Load duration: {load_duration:.4f}s (expected >= {delay}s)") + self.assertGreaterEqual(load_duration, delay * 0.8, "Load did not appear to block on in-flight save") + + # Verify chunk is removed from in-flight tracking + self.assertNotIn(dst_chunk, worker._local_in_flight_saves) + + # Verify data was loaded (basic check) + self.assertEqual(len(worker.runner.kv_caches), self.num_layers) + + @parameterized.named_parameters( + dict(testcase_name="_2_saves", num_saves=2), + dict(testcase_name="_5_saves", num_saves=5), + ) + def test_tpu_connector_waw_race_protection(self, num_saves): + """ + Verifies that background saves targeting the same CPU slot are chained + and execute sequentially (FIFO) to prevent WAW races. + """ + # 1. Setup + connector = self._create_connector() + worker = connector.connector_worker + + block_to_save = 0 + dst_chunk = 42 + + # 2. Prepare Save Metadata for N saves + save_reqs = [] + for i in range(num_saves): + save_spec = SaveSpec( + num_skip_leading_tokens=0, + num_total_tokens=self.block_size, + is_final_save=False, + skip_save=False, + src_blocks=[block_to_save], + dst_chunks=[dst_chunk], + ) + req_meta = TPUReqMeta( + req_id=f"waw_save_{i}", + token_ids=list(range(self.block_size)), + local_block_ids=[block_to_save], + save_spec=save_spec, + ) + save_reqs.append(req_meta) + + # 3. Patch _transfer_and_register_cpu_chunks to track execution and inject delay + execution_order = [] + original_transfer = worker._transfer_and_register_cpu_chunks + + def tracked_transfer(chunks_on_cpu, num_blocks, manifest, is_batched=False): + req_id = manifest[0].req_id + execution_order.append(f"start_{req_id}") + if req_id == "waw_save_0": + time.sleep(0.5) # Inject delay in first save to hold up the chain + execution_order.append(f"end_{req_id}") + return original_transfer(chunks_on_cpu, num_blocks, manifest, is_batched) + + # 4. Execute Saves sequentially in dispatch, but they run concurrently/chained in background + with mock.patch.object(worker, '_transfer_and_register_cpu_chunks', wraps=tracked_transfer): + for i in range(num_saves): + connector.bind_connector_metadata( + TPUOffloadConnectorMetadata(requests_meta=[save_reqs[i]])) + worker.start_save_kv() + + # Reset flag to allow next save in same step (simulating fast recycling) + worker._processed_save_for_step = False + + # Wait for all to finish + while worker._pending_save_futures: + worker._process_completed_saves() + time.sleep(0.01) + + # 5. Verification + logger.info(f"WAW Execution order for {num_saves} saves: {execution_order}") + + # We expect saves to execute in strict FIFO order: start_0, end_0, start_1, end_1... + expected_order = [] + for i in range(num_saves): + expected_order.append(f"start_waw_save_{i}") + expected_order.append(f"end_waw_save_{i}") + + self.assertEqual(execution_order, expected_order, + "WAW Saves did not execute in strict FIFO order") + diff --git a/tests/offload/tpu_offload_cpu_backend_test.py b/tests/offload/tpu_offload_cpu_backend_test.py index f961847a4a..5895f51490 100644 --- a/tests/offload/tpu_offload_cpu_backend_test.py +++ b/tests/offload/tpu_offload_cpu_backend_test.py @@ -93,3 +93,40 @@ def test_reclaim_unoccupied_chunks(self): assert backend.current_size_bytes == 0 assert len(backend.cache) == 0 + + def test_concurrent_access(self): + """Simulates concurrent access to verify thread safety.""" + import threading + import concurrent.futures + + backend = LocalCPUBackend(num_cpu_chunks=100) + num_threads = 10 + ops_per_thread = 50 + + def worker_task(thread_idx): + for i in range(ops_per_thread): + chunk_id = CpuChunkId((thread_idx * ops_per_thread + i) % 100) + value = create_mock_jax_array(10) + + # Alternate between add and get + if i % 2 == 0: + backend.add(chunk_id, value) + else: + backend.get(chunk_id) + + # Periodically reclaim + if i % 10 == 0: + # Keep some random subset of chunks + backend.reclaim_unoccupied_chunks(occupied_chunk_ids=[ + CpuChunkId(0), CpuChunkId(50) + ]) + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(worker_task, i) for i in range(num_threads)] + for future in concurrent.futures.as_completed(futures): + future.result() # Raise exception if any occurred + + # If we reach here without deadlock or exception, basic thread safety is confirmed. + # We can verify internal state consistency if needed. + assert backend.current_size_bytes >= 0 + diff --git a/tpu_inference/offload/cpu_backend.py b/tpu_inference/offload/cpu_backend.py index d6fb277331..cdb6e32258 100644 --- a/tpu_inference/offload/cpu_backend.py +++ b/tpu_inference/offload/cpu_backend.py @@ -13,6 +13,7 @@ # limitations under the License. import sys +import threading from collections import OrderedDict from typing import Any, Optional @@ -42,6 +43,7 @@ def __init__(self, num_cpu_chunks: int): self.current_size_bytes = 0 self._num_saved_cpu_chunks = 0 self.metrics_collector = TPUKVCacheMetrics.get_or_create() + self.lock = threading.Lock() logger.info( "LocalCPUBackend initialized." f"CPU cache capacity: {self.max_num_cpu_chunks} chunks / pages.") @@ -70,54 +72,57 @@ def add(self, chunk_id: CpuChunkId, value: Any) -> bool: If the cache is full, it evicts the least recently used, unpinned entries until there is enough space. """ - if chunk_id < 0 or chunk_id >= self.max_num_cpu_chunks: - # TODO(jcgu): report failure when offload scheduler / worker - # can handle failed operations. - raise ValueError(f" get invalid chunk_id: {chunk_id}") - - # Add the new item. - if chunk_id in self.cache: - old_value = self.cache.pop(chunk_id) - self.current_size_bytes -= self._get_value_size(old_value) - del old_value - self._num_saved_cpu_chunks -= 1 - - self.cache[chunk_id] = value - self._num_saved_cpu_chunks += 1 - value_size = self._get_value_size(value) - self.current_size_bytes += value_size - logger.debug( - f"Added chunk_id: {chunk_id} (size:{value_size}) to CPU backend.") - logger.debug( - f"Cache: {self.current_size_bytes} bytes, {self._num_saved_cpu_chunks} occupied chunks." - ) - self.metrics_collector.record_host_memory_usage( - self.current_size_bytes) - return True + with self.lock: + if chunk_id < 0 or chunk_id >= self.max_num_cpu_chunks: + # TODO(jcgu): report failure when offload scheduler / worker + # can handle failed operations. + raise ValueError(f" get invalid chunk_id: {chunk_id}") + + # Add the new item. + if chunk_id in self.cache: + old_value = self.cache.pop(chunk_id) + self.current_size_bytes -= self._get_value_size(old_value) + del old_value + self._num_saved_cpu_chunks -= 1 + + self.cache[chunk_id] = value + self._num_saved_cpu_chunks += 1 + value_size = self._get_value_size(value) + self.current_size_bytes += value_size + logger.debug( + f"Added chunk_id: {chunk_id} (size:{value_size}) to CPU backend.") + logger.debug( + f"Cache: {self.current_size_bytes} bytes, {self._num_saved_cpu_chunks} occupied chunks." + ) + self.metrics_collector.record_host_memory_usage( + self.current_size_bytes) + return True def get(self, chunk_id: CpuChunkId) -> Optional[Any]: """ Gets the value for a given chunk_id and marks it as recently used. """ - if chunk_id in self.cache: - return self.cache[chunk_id] - return None + with self.lock: + if chunk_id in self.cache: + return self.cache[chunk_id] + return None def reclaim_unoccupied_chunks(self, occupied_chunk_ids: list[CpuChunkId]): - chunk_ids = list(self.cache.keys()) - unoccupied_chunk_ids = [ - chunk_id for chunk_id in chunk_ids - if chunk_id not in occupied_chunk_ids - ] - reclaimed_size_bytes = 0 - for chunk_id in unoccupied_chunk_ids: - dummy_value = self.cache.pop(chunk_id) - reclaimed_size_bytes += self._get_value_size(dummy_value) - del dummy_value - self.current_size_bytes -= reclaimed_size_bytes - - logger.debug( - f" Reclaimed {len(unoccupied_chunk_ids)} unoccupied chunks, " - f"with {reclaimed_size_bytes} bytes.") - self.metrics_collector.record_host_memory_usage( - self.current_size_bytes) + with self.lock: + chunk_ids = list(self.cache.keys()) + unoccupied_chunk_ids = [ + chunk_id for chunk_id in chunk_ids + if chunk_id not in occupied_chunk_ids + ] + reclaimed_size_bytes = 0 + for chunk_id in unoccupied_chunk_ids: + dummy_value = self.cache.pop(chunk_id) + reclaimed_size_bytes += self._get_value_size(dummy_value) + del dummy_value + self.current_size_bytes -= reclaimed_size_bytes + + logger.debug( + f" Reclaimed {len(unoccupied_chunk_ids)} unoccupied chunks, " + f"with {reclaimed_size_bytes} bytes.") + self.metrics_collector.record_host_memory_usage( + self.current_size_bytes) diff --git a/tpu_inference/offload/offload_manager.py b/tpu_inference/offload/offload_manager.py index 4fe1fd2994..09ea24dae7 100644 --- a/tpu_inference/offload/offload_manager.py +++ b/tpu_inference/offload/offload_manager.py @@ -164,12 +164,14 @@ def touch(self, chunk_hashes: list[ChunkHash]) -> int: def allocate_for_save( self, chunk_hashes: list[ChunkHash] ) -> Tuple[list[CPUChunk], list[int]] | None: - # filter out chunks that are already stored - num_chunks = len(chunk_hashes) - new_chunk_idxs = [ - i for i in range(num_chunks) - if chunk_hashes[i] not in self.cpu_cache - ] + # Deduplicate chunk_hashes while keeping track of original indices + seen_hashes = {} + new_chunk_idxs = [] + for i, chunk_hash in enumerate(chunk_hashes): + if chunk_hash not in self.cpu_cache: + if chunk_hash not in seen_hashes: + seen_hashes[chunk_hash] = i + new_chunk_idxs.append(i) num_new_chunks = len(new_chunk_idxs) if num_new_chunks == 0: @@ -224,7 +226,11 @@ def complete_save(self, chunk_hashes: list[ChunkHash]) -> None: """ After store completion, mark the chunk to be ready to load.""" for chunk_hash in chunk_hashes: chunk = self.cpu_cache[chunk_hash] - assert not chunk.is_ready_to_load + if chunk.is_ready_to_load: + logger.warning( + f"Chunk {chunk_hash} is already ready to load. Ignoring duplicate confirmation in OffloadManager." + ) + continue # mark ready to load chunk.touch() assert chunk.is_ready_to_load @@ -232,19 +238,15 @@ def complete_save(self, chunk_hashes: list[ChunkHash]) -> None: def complete_load(self, chunk_hashes: list[ChunkHash]) -> None: for chunk_hash in chunk_hashes: chunk = self.cpu_cache[chunk_hash] - assert chunk.is_in_use + if not chunk.is_in_use: + logger.warning( + f"Chunk {chunk_hash} is not in use (ref_cnt={chunk.ref_cnt}). Ignoring duplicate load confirmation in OffloadManager." + ) + continue chunk.untouch() def mark_completion(self, chunk_ids, operation: Literal['save', 'load']) -> None: - try: - chunk_hashes = [ - self.chunk_pool.allocated_id_to_hash_map[chunk_id] - for chunk_id in chunk_ids - ] - except Exception as e: - raise ValueError(f' failed to retrieve chunk hashes: {e}') - chunk_hashes = [] unknown_chunk_ids = [] for chunk_id in chunk_ids: @@ -401,8 +403,9 @@ def free(self, num_freed_blocks = num_finished_blocks if self._blocks_for_load[req_id] < num_freed_blocks: logger.warning( - f" Req({req_id}) has {num_finished_blocks} load staging buffer to free, but only has {self._blocks_for_load[req_id]} on record." + f" Req({req_id}) has {num_finished_blocks} load staging buffer to free, but only has {self._blocks_for_load[req_id]} on record. Capping to recorded value." ) + num_freed_blocks = max(0, self._blocks_for_load[req_id]) self._blocks_for_load[req_id] -= num_freed_blocks if self._blocks_for_load[req_id] <= 0: @@ -421,8 +424,9 @@ def free(self, num_freed_blocks = num_finished_blocks if self._blocks_for_save[req_id] < num_freed_blocks: logger.warning( - f" Req({req_id}) has {num_finished_blocks} save staging buffer to free, but only has {self._blocks_for_save[req_id]} on record." + f" Req({req_id}) has {num_finished_blocks} save staging buffer to free, but only has {self._blocks_for_save[req_id]} on record. Capping to recorded value." ) + num_freed_blocks = max(0, self._blocks_for_save[req_id]) self._blocks_for_save[req_id] -= num_freed_blocks if self._blocks_for_save[req_id] <= 0: diff --git a/tpu_inference/offload/tpu_offload_connector.py b/tpu_inference/offload/tpu_offload_connector.py index 7908c9017b..247f8d301e 100644 --- a/tpu_inference/offload/tpu_offload_connector.py +++ b/tpu_inference/offload/tpu_offload_connector.py @@ -102,7 +102,7 @@ import copy import random import time -from collections import defaultdict +from collections import Counter, defaultdict from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Literal, Optional @@ -364,18 +364,36 @@ def aggregate(self, other: KVConnectorStats) -> KVConnectorStats: Returns: KVConnectorStats: The updated instance (self) containing the aggregated stats. """ - if isinstance(other, KVOffloadConnectorStats): - other_saves = other.data.get("finished_save_chunks", {}) - for k, v in other_saves.items(): - if k not in self.data["finished_save_chunks"]: - self.data["finished_save_chunks"][k] = [] - self.data["finished_save_chunks"][k].extend(v) - - other_loads = other.data.get("finished_load_chunks", {}) - for k, v in other_loads.items(): - if k not in self.data["finished_load_chunks"]: - self.data["finished_load_chunks"][k] = [] - self.data["finished_load_chunks"][k].extend(v) + # NOTE: This method is functionally critical for control flow in multi-node setups. + # In addition to passive metrics, this stats object carries completion signals + # (finished chunk IDs) from Workers to the Scheduler. Aggregation ensures + # the Scheduler receives a unified view of all finished chunks across all nodes, + # allowing it to correctly free staging buffers and update internal tracking. + if not isinstance(other, KVOffloadConnectorStats): + logger.warning( + f"Cannot aggregate with non-KVOffloadConnectorStats: {type(other)}" + ) + return self + + for req, chunks in other.data["finished_save_chunks"].items(): + if req not in self.data["finished_save_chunks"]: + self.data["finished_save_chunks"][req] = chunks + else: + c1 = Counter(self.data["finished_save_chunks"][req]) + c2 = Counter(chunks) + # Keep max count for each chunk to avoid multiplication across workers + c_max = c1 | c2 + self.data["finished_save_chunks"][req] = list(c_max.elements()) + + for req, chunks in other.data["finished_load_chunks"].items(): + if req not in self.data["finished_load_chunks"]: + self.data["finished_load_chunks"][req] = chunks + else: + c1 = Counter(self.data["finished_load_chunks"][req]) + c2 = Counter(chunks) + # Keep max count for each chunk to avoid multiplication across workers + c_max = c1 | c2 + self.data["finished_load_chunks"][req] = list(c_max.elements()) return self @@ -1198,21 +1216,30 @@ def update_connector_output(self, connector_output: KVConnectorOutput): for req_id, saved_chunk_ids in connector_output.kv_connector_stats.data[ "finished_save_chunks"].items(): - if req_id not in self._reqs_being_saved: + # NOTE(jcgu): there might be in-flight savings even if the request has finished logically. + # This is handled by tracking in-flight operations independently in _reqs_being_saved + # and updating resource managers regardless of request lifecycle status. + valid_saved_chunks = [] + for saved_chunk_id in saved_chunk_ids: + if saved_chunk_id in self._reqs_being_saved[req_id]: + self._reqs_being_saved[req_id].remove(saved_chunk_id) + valid_saved_chunks.append(saved_chunk_id) + else: + logger.debug( + f"Ignoring duplicate or unknown saved chunk confirmation: req={req_id}, chunk={saved_chunk_id}" + ) + + if not valid_saved_chunks: continue - num_saved_chunks = len(saved_chunk_ids) + + num_valid_saved_chunks = len(valid_saved_chunks) logger.debug( - f" finished_save_chunks for {req_id}: {saved_chunk_ids}") + f" valid finished_save_chunks for {req_id}: {valid_saved_chunks}" + ) # free staging blocks self.staging_buffer_manager.free( - req_id, usage="save", num_finished_blocks=num_saved_chunks) + req_id, usage="save", num_finished_blocks=num_valid_saved_chunks) - # update in-flight save - # NOTE(jcgu): there might be in-flight savings, - # even if the requests has been finished. - for saved_chunk_id in saved_chunk_ids: - assert saved_chunk_id in self._reqs_being_saved[req_id] - self._reqs_being_saved[req_id].remove(saved_chunk_id) if len(self._reqs_being_saved[req_id]) == 0: self._reqs_being_saved.pop(req_id, None) else: @@ -1221,24 +1248,32 @@ def update_connector_output(self, connector_output: KVConnectorOutput): ) # update the status of occupied cpu chunks - self.offload_manager.mark_completion(saved_chunk_ids, "save") + self.offload_manager.mark_completion(valid_saved_chunks, "save") for req_id, loaded_chunk_ids in connector_output.kv_connector_stats.data[ "finished_load_chunks"].items(): - if req_id not in self._reqs_being_loaded: + valid_loaded_chunks = [] + for loaded_chunk_id in loaded_chunk_ids: + if loaded_chunk_id in self._reqs_being_loaded[req_id]: + self._reqs_being_loaded[req_id].remove(loaded_chunk_id) + valid_loaded_chunks.append(loaded_chunk_id) + else: + logger.debug( + f"Ignoring duplicate or unknown loaded chunk confirmation: req={req_id}, chunk={loaded_chunk_id}" + ) + + if not valid_loaded_chunks: continue - num_loaded_chunks = len(loaded_chunk_ids) + + num_valid_loaded_chunks = len(valid_loaded_chunks) logger.debug( - f" finished_load_chunks for {req_id}: {num_loaded_chunks}" + f" valid finished_load_chunks for {req_id}: {num_valid_loaded_chunks}" ) self.staging_buffer_manager.free( req_id, usage="load", - num_finished_blocks=num_loaded_chunks) - # update in-flight save - for loaded_chunk_id in loaded_chunk_ids: - assert loaded_chunk_id in self._reqs_being_loaded[req_id] - self._reqs_being_loaded[req_id].remove(loaded_chunk_id) + num_finished_blocks=num_valid_loaded_chunks) + if len(self._reqs_being_loaded[req_id]) == 0: self._reqs_being_loaded.pop(req_id, None) else: @@ -1246,7 +1281,7 @@ def update_connector_output(self, connector_output: KVConnectorOutput): f" remaining_loading_blocks:{req_id}, {self._reqs_being_loaded[req_id]}." ) # update the status of occupied cpu chunks - self.offload_manager.mark_completion(loaded_chunk_ids, "load") + self.offload_manager.mark_completion(valid_loaded_chunks, "load") def request_finished( self, @@ -1327,6 +1362,8 @@ def __init__(self, vllm_config: VllmConfig, self._processed_save_for_step = False # On-going asynchronous save operations tracking futures and their associated manifest. self._pending_save_futures: list[tuple[Future, list[SaveReqInfo]]] = [] + # Tracks local in-flight saves to prevent read-before-write race conditions during loads. + self._local_in_flight_saves: dict[CpuChunkId, Future] = {} # record finished save / load blocks (with req_ids) for each iteration self.offload_stats = KVOffloadConnectorStats() @@ -1793,7 +1830,7 @@ def _batched_gather_tpu_blocks( return gathered_kv_caches_tpu, manifest, total_num_blocks_to_save def _transfer_and_register_cpu_chunks(self, - flat_kv_caches_tpu: Any, + chunks_on_cpu: list[jax.Array], total_num_blocks_to_save: int, manifest: list[SaveReqInfo], is_batched: bool = False): @@ -1836,13 +1873,11 @@ def _transfer_and_register_cpu_chunks(self, start_time = time.time() # 1. Swap Out the buffer - chunks_on_cpu = None - # D2H - chunks_on_cpu = [] - for i in range(total_num_blocks_to_save): - chunks_on_cpu.append( - jax.device_put(flat_kv_caches_tpu[i], - self.expanded_host_sharding)) + # Note: The actual jax.device_put (D2H dispatch) has been moved to the + # main thread (start_save_kv / _start_batched_save_kv) to ensure + # deterministic, globally aligned dispatch order across nodes. + # Here in the background thread, we only need to wait for the transfer + # to complete. jax.block_until_ready(chunks_on_cpu) # no split @@ -1908,9 +1943,11 @@ def _start_batched_save_kv(self, metadata: TPUOffloadConnectorMetadata): flat_kv_caches_tpu, manifest, total_num_blocks_to_save = gather_result # 2. ASYNC NON-BLOCKING: Single Batch Transfer - def _async_batch_transfer_task(*args, **kwargs): + def _async_batch_transfer_task(dependencies, chunks_on_cpu, total_num_blocks, manifest, is_batched=True): try: - self._transfer_and_register_cpu_chunks(*args, **kwargs) + for dep in dependencies: + dep.result() + self._transfer_and_register_cpu_chunks(chunks_on_cpu, total_num_blocks, manifest, is_batched) except Exception as e: logger.error(f"Error in batched transfer: {e}", exc_info=True) @@ -1920,11 +1957,32 @@ def _async_batch_transfer_task(*args, **kwargs): # Note: We use manifest for the pending future tracking. # record_save will be handled in the main thread by _process_completed_saves. + # Dispatch to CPU (Main Thread) + chunks_on_cpu = [] + if flat_kv_caches_tpu is not None: + for i in range(total_num_blocks_to_save): + chunks_on_cpu.append( + jax.device_put(flat_kv_caches_tpu[i], + self.expanded_host_sharding)) + + # Collect dependencies to prevent WAW race + dependencies = set() + for info in manifest: + for chunk_id in info.dst_chunks: + if chunk_id in self._local_in_flight_saves: + dependencies.add(self._local_in_flight_saves[chunk_id]) + future = self.save_executor.submit(_async_batch_transfer_task, - flat_kv_caches_tpu, + dependencies, + chunks_on_cpu, total_num_blocks_to_save, manifest, is_batched=True) + + for info in manifest: + for chunk_id in info.dst_chunks: + self._local_in_flight_saves[chunk_id] = future + self._pending_save_futures.append((future, manifest)) def _get_blocks_for_req_from_metadata( @@ -2021,9 +2079,11 @@ def start_save_kv(self): is_final_save=meta.save_spec.is_final_save) # Define a safe wrapper for the async part to ensure logging - def _async_transfer_task(req_id, *args): + def _async_transfer_task(req_id, dependencies, chunks_on_cpu, num_blocks, manifest, is_batched): try: - self._transfer_and_register_cpu_chunks(*args) + for dep in dependencies: + dep.result() + self._transfer_and_register_cpu_chunks(chunks_on_cpu, num_blocks, manifest, is_batched) except Exception as e: raise ValueError( f"Error transferring blocks for request {req_id}: {e}" @@ -2033,12 +2093,29 @@ def _async_transfer_task(req_id, *args): # 2. ASYNC NON-BLOCKING: Transfer to CPU and Register logger.debug( f"Submitting transfer task for request {meta.req_id}") + # Dispatch to CPU (Main Thread) + chunks_on_cpu = [] + for i in range(num_blocks_to_save): + chunks_on_cpu.append( + jax.device_put(flat_kv_caches_tpu[i], + self.expanded_host_sharding)) + + # Collect dependencies to prevent WAW race + dependencies = set() + for chunk_id in dst_chunks: + if chunk_id in self._local_in_flight_saves: + dependencies.add(self._local_in_flight_saves[chunk_id]) + future = self.save_executor.submit(_async_transfer_task, meta.req_id, - flat_kv_caches_tpu, + dependencies, + chunks_on_cpu, num_blocks_to_save, [info], False) + for chunk_id in dst_chunks: + self._local_in_flight_saves[chunk_id] = future + self._pending_save_futures.append((future, [info])) self.metrics_collector.record_d2h_operation() @@ -2068,6 +2145,12 @@ def _process_completed_saves(self): self.offload_stats.record_save( req=info.req_id, saved_chunk_ids=info.dst_chunks) + + # Clean up local in-flight tracking safely + for chunk_id in info.dst_chunks: + if self._local_in_flight_saves.get(chunk_id) == future: + self._local_in_flight_saves.pop(chunk_id, None) + # TODO: Metrics data transfer complete completed_count += 1 @@ -2119,6 +2202,7 @@ def start_load_kv(self, fwd_ctx: "ForwardContext") -> None: # Process each request that needs its KV cache loaded load_times = [] + waited_futures = set() for meta in metadata.requests_meta: if not (meta.load_spec and meta.load_spec.can_load): continue @@ -2169,6 +2253,24 @@ def start_load_kv(self, fwd_ctx: "ForwardContext") -> None: assembled_kv_on_cpu = [] for i in range(num_blocks_to_load): src_chunk_id = src_chunks[i] + + # Check if this chunk is currently being saved locally + if src_chunk_id in self._local_in_flight_saves: + future = self._local_in_flight_saves[src_chunk_id] + if future not in waited_futures: + logger.warning( + f"Load of chunk {src_chunk_id} requested before save completed locally. " + "Blocking worker main thread until local save finishes..." + ) + # NOTE(amitmkumar): Condition C - Add a configurable timeout (e.g., 120s) + # to future.result() to handle hanging futures gracefully. + future.result() + waited_futures.add(future) + + # Clean up the completed future tracking safely + if self._local_in_flight_saves.get(src_chunk_id) == future: + del self._local_in_flight_saves[src_chunk_id] + cached_value = self.cpu_backend.get(src_chunk_id) if cached_value is not None: assembled_kv_on_cpu.append(cached_value)