Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 55 additions & 21 deletions cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -599,6 +599,13 @@ RequestStatuses CacheTransceiver::checkContextTransferStatus(
void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastRequestNum)
{
bool blockAll = !atLeastRequestNum.has_value();
std::optional<int> receiverFutureTimeoutMs = std::nullopt;
// If blockAll is true, we want to block and not use a timeout
if (!blockAll && mCacheTransceiverConfig.has_value())
{
receiverFutureTimeoutMs = mCacheTransceiverConfig->getKvTransferSenderFutureTimeoutMs();
}

std::vector<LlmRequest::RequestIdType> genTransferReadyRequestIds;
for (auto&& [request, future] : mRequesterFutures)
{
Expand Down Expand Up @@ -709,41 +716,68 @@ void CacheTransceiver::checkGenTransferStatus(std::optional<int> const& atLeastR
" checkGenTransferStatus toCompleteIdSet size: %zu, atLeastRequestNum: %d ", toCompleteIdSet.size(),
atLeastRequestNum.value_or(0));
}
auto const syncSize = (syncComm != nullptr) ? syncComm->getSize() : 1;
for (auto it = mRequesterFutures.begin(); it != mRequesterFutures.end();)
{
if (blockAll || toCompleteIdSet.find(it->first->mRequestId) != toCompleteIdSet.end())
{
try
{
it->second.get();
it->first->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE);

// Gather the kv cache transfer time from all workers and update to leader rank
if (!common::getEnvKVCacheTimeOutputPath().empty())
// Wait for up to a specified timeout
auto status = it->second.wait_for(
std::chrono::milliseconds(receiverFutureTimeoutMs.value_or(0)));
if (status == std::future_status::ready || blockAll)
{
auto syncComm = mCacheState->getParallelConfig().mEnableAttentionDP ? mGroupDataComm : mGroupComm;
updateKVCacheTransferBW(syncComm, it->first);
it->second.get();
it->first->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE);

// Gather the kv cache transfer time from all workers and update to leader rank.
// Only call the timing collective when either all ranks block together (blockAll)
// or the request was confirmed ready on every rank in the initial poll, to avoid
// hanging in allgather when a peer timed out and skipped this request.
if (!common::getEnvKVCacheTimeOutputPath().empty())
{
auto const freqIt = frequencyMap.find(it->first->mRequestId);
if (blockAll || (freqIt != frequencyMap.end() && freqIt->second == syncSize))
{
updateKVCacheTransferBW(syncComm, it->first);
}
}
if (useMPI())
{
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
"**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***",
it->first->mRequestId, it->first->getContextPhaseParams().value().getReqId());
}
else
{
TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(),
"**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***",
it->first->mRequestId, it->first->getContextPhaseParams().value().getReqId());
}
it = mRequesterFutures.erase(it);
}
else if (status == std::future_status::timeout)
{
TLLM_LOG_WARNING("Timed out waiting for generation KV cache transfer after %d milliseconds.",
receiverFutureTimeoutMs.value());
++it;
}
else
{
TLLM_LOG_ERROR(
"Future returned unexpected status for request %ld. Marking as error", it->first->mRequestId);
it->first->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
it = mRequesterFutures.erase(it);
}
}
catch (std::exception const& e)
{
TLLM_LOG_ERROR(
"Error occurred during generation transfer for request %ld: %s", it->first->mRequestId, e.what());
it->first->setState(LlmRequestState::kDISAGG_TRANS_ERROR);
it = mRequesterFutures.erase(it);
}
if (useMPI())
{
TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(),
"**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***",
it->first->mRequestId, it->first->getContextPhaseParams().value().getReqId());
}
else
{
TLLM_LOG_DEBUG(tensorrt_llm::pg_utils::get_world_pg()->getRank(),
"**** it->first->mRequestId: %ld, context request ID: %ld ******** get feature ***",
it->first->mRequestId, it->first->getContextPhaseParams().value().getReqId());
}
it = mRequesterFutures.erase(it);
}
else
{
Expand Down
Loading