diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp index 2e4bf1f0666..47b84326257 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); @@ -599,6 +599,13 @@ RequestStatuses CacheTransceiver::checkContextTransferStatus( void CacheTransceiver::checkGenTransferStatus(std::optional const& atLeastRequestNum) { bool blockAll = !atLeastRequestNum.has_value(); + std::optional receiverFutureTimeoutMs = std::nullopt; + // If blockAll is true, we want to block and not use a timeout + if (!blockAll && mCacheTransceiverConfig.has_value()) + { + receiverFutureTimeoutMs = mCacheTransceiverConfig->getKvTransferSenderFutureTimeoutMs(); + } + std::vector genTransferReadyRequestIds; for (auto&& [request, future] : mRequesterFutures) { @@ -709,20 +716,59 @@ void CacheTransceiver::checkGenTransferStatus(std::optional const& atLeastR " checkGenTransferStatus toCompleteIdSet size: %zu, atLeastRequestNum: %d ", toCompleteIdSet.size(), atLeastRequestNum.value_or(0)); } + auto const syncSize = (syncComm != nullptr) ? syncComm->getSize() : 1; for (auto it = mRequesterFutures.begin(); it != mRequesterFutures.end();) { if (blockAll || toCompleteIdSet.find(it->first->mRequestId) != toCompleteIdSet.end()) { try { - it->second.get(); - it->first->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE); - - // Gather the kv cache transfer time from all workers and update to leader rank - if (!common::getEnvKVCacheTimeOutputPath().empty()) + // Wait for up to a specified timeout + auto status = it->second.wait_for( + std::chrono::milliseconds(receiverFutureTimeoutMs.value_or(0))); + if (status == std::future_status::ready || blockAll) { - auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupDataComm : mGroupComm; - updateKVCacheTransferBW(syncComm, it->first); + it->second.get(); + it->first->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE); + + // Gather the kv cache transfer time from all workers and update to leader rank. + // Only call the timing collective when either all ranks block together (blockAll) + // or the request was confirmed ready on every rank in the initial poll, to avoid + // hanging in allgather when a peer timed out and skipped this request. + if (!common::getEnvKVCacheTimeOutputPath().empty()) + { + auto const freqIt = frequencyMap.find(it->first->mRequestId); + if (blockAll || (freqIt != frequencyMap.end() && freqIt->second == syncSize)) + { + updateKVCacheTransferBW(syncComm, it->first); + } + } + if (useMPI()) + { + TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), + "**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***", + it->first->mRequestId, it->first->getContextPhaseParams().value().getReqId()); + } + else + { + TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(), + "**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***", + it->first->mRequestId, it->first->getContextPhaseParams().value().getReqId()); + } + it = mRequesterFutures.erase(it); + } + else if (status == std::future_status::timeout) + { + TLLM_LOG_WARNING("Timed out waiting for generation KV cache transfer after %d milliseconds.", + receiverFutureTimeoutMs.value()); + ++it; + } + else + { + TLLM_LOG_ERROR( + "Future returned unexpected status for request %ld. Marking as error", it->first->mRequestId); + it->first->setState(LlmRequestState::kDISAGG_TRANS_ERROR); + it = mRequesterFutures.erase(it); } } catch (std::exception const& e) @@ -730,20 +776,8 @@ void CacheTransceiver::checkGenTransferStatus(std::optional const& atLeastR TLLM_LOG_ERROR( "Error occurred during generation transfer for request %ld: %s", it->first->mRequestId, e.what()); it->first->setState(LlmRequestState::kDISAGG_TRANS_ERROR); + it = mRequesterFutures.erase(it); } - if (useMPI()) - { - TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), - "**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***", - it->first->mRequestId, it->first->getContextPhaseParams().value().getReqId()); - } - else - { - TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(), - "**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***", - it->first->mRequestId, it->first->getContextPhaseParams().value().getReqId()); - } - it = mRequesterFutures.erase(it); } else {