Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 35 additions & 11 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,33 @@ std::vector<BlockPtr> getAllSequenceBlocks(BlockPtr lastBlock)
return sequenceBlocks;
}

// Compute maximum number of tokens that have been computed by prefill and generation.
// Accounts for chunked prefill to avoid storing state that hasn't been written to KV cache yet.
// We call LlmRequest::getContextRemainingLength to see how many tokens are still waiting to be computed in prefill.
// If this value is > 0 prefill is not finished yet, and number of computed tokens must be capped at the current context
// position. If it is == 0, we are in generation mode, and number of computed tokens equals number of unique tokens
// stored in request.
SizeType32 getMaterializedUniqueTokenCountForReuse(
VecUniqueTokens const& uniqueTokens, tensorrt_llm::batch_manager::LlmRequest const& llmRequest)
{
auto const totalUniqueTokenCount = static_cast<SizeType32>(uniqueTokens.size());
if (llmRequest.getContextRemainingLength() > 0)
{
return std::min(totalUniqueTokenCount, llmRequest.getContextCurrentPosition());
}
return totalUniqueTokenCount;
}

// Compute number of tokens that can be stored for reuse. The last computed token is never stored in KV cache, hence
// cannot be stored for reuse. Number of tokens that can be stored for reuse is thus the greater of 0 or
// getMaterializedUniqueTokenCountForReuse() - 1.
SizeType32 getUsableUniqueTokenCountForReuse(
VecUniqueTokens const& uniqueTokens, tensorrt_llm::batch_manager::LlmRequest const& llmRequest)
{
auto const materializedUniqueTokenCount = getMaterializedUniqueTokenCountForReuse(uniqueTokens, llmRequest);
return materializedUniqueTokenCount > 0 ? materializedUniqueTokenCount - 1 : 0;
}

} // namespace

namespace tensorrt_llm::batch_manager::kv_cache_manager
Expand Down Expand Up @@ -838,8 +865,9 @@ void BlockManager::storeContextBlocks(GenerationRequest& sequence, LlmRequest co
auto cacheBlockIds = sequence.getCacheBlockIds(windowSize);
auto const& uniqueTokens = llmRequest.getUniqueTokens(beamIdx);

auto const usableUniqueTokenCount = getUsableUniqueTokenCountForReuse(uniqueTokens, llmRequest);
auto blockedUniqueTokens
= chopVectorIntoBlocks<UniqueToken>(uniqueTokens, uniqueTokens.size() - 1, getTokensPerBlock(), false);
= chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableUniqueTokenCount, getTokensPerBlock(), false);
auto blockKeys = buildBlockKeys(blockedUniqueTokens, llmRequest);
(void) mWindowBlockManagers.at(windowSize).storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx]);
}
Expand Down Expand Up @@ -1998,11 +2026,9 @@ std::vector<KVCacheBlock::IdType> WindowBlockManager::storeBlocksForReuse(
auto const& uniqueTokens = llmRequest->getUniqueTokens(beamIdx);
auto const& cacheBlockIds = sequence.getCacheBlockIds(mWindowSize);

// TODO: get the caller to mark tokens as filled / not filled, so that the kv-cache manager doesn't
// have to guess. Only (length - 1) tokens of the sequence have their kv-state recorded in kv-cache. We assume
// the last token's state is not filled yet.
auto const usableSize = static_cast<runtime::SizeType32>(uniqueTokens.size()) - 1;
auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableSize, mTokensPerBlock, true);
auto const usableUniqueTokenCount = getUsableUniqueTokenCountForReuse(uniqueTokens, *llmRequest);
auto blockedUniqueTokens
= chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableUniqueTokenCount, mTokensPerBlock, true);
auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest);

auto [numStored, pinnedBlockIds] = storeBlocks(std::move(blockKeys), cacheBlockIds[beamIdx], pinBlocks);
Expand Down Expand Up @@ -2035,11 +2061,9 @@ std::optional<KVCacheBlock::IdType> WindowBlockManager::releaseBlocks(
sequence.getRequestId());
}
auto const& uniqueTokens = llmRequest->getUniqueTokens(/*beamIdx=*/0);
// Only (length - 1) tokens of the sequence have their kv-state
// recorded in kv-cache. We assume the last token's state is not filled yet.
auto const usableSize = static_cast<runtime::SizeType32>(uniqueTokens.size()) - 1;
auto blockedUniqueTokens
= chopVectorIntoBlocks<UniqueToken>(uniqueTokens, usableSize, mTokensPerBlock, /*allowPartial=*/true);
auto const usableUniqueTokenCount = getUsableUniqueTokenCountForReuse(uniqueTokens, *llmRequest);
auto blockedUniqueTokens = chopVectorIntoBlocks<UniqueToken>(
uniqueTokens, usableUniqueTokenCount, mTokensPerBlock, /*allowPartial=*/true);
auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest);

std::vector<KVCacheBlock::IdType> cacheBlockIds(allocatedBlocks.size());
Expand Down
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/nanobind/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ set(SRCS
runtime/bindings.cpp
runtime/hostfunc.cpp
runtime/moeBindings.cpp
testing/kvCacheManagerTestUtilBinding.cpp
testing/modelSpecBinding.cpp
userbuffers/bindings.cpp
thop/bindings.cpp
Expand Down
2 changes: 2 additions & 0 deletions cpp/tensorrt_llm/nanobind/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
#include "tensorrt_llm/nanobind/executor/bindings.h"
#include "tensorrt_llm/nanobind/process_group/bindings.h"
#include "tensorrt_llm/nanobind/runtime/bindings.h"
#include "tensorrt_llm/nanobind/testing/kvCacheManagerTestUtilBinding.h"
#include "tensorrt_llm/nanobind/testing/modelSpecBinding.h"
#include "tensorrt_llm/nanobind/thop/bindings.h"
#include "tensorrt_llm/nanobind/userbuffers/bindings.h"
Expand Down Expand Up @@ -498,6 +499,7 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
tpb::Buffers::initBindings(mInternalBatchManager);
tensorrt_llm::nanobind::runtime::initBindings(mInternalRuntime);
tensorrt_llm::nanobind::testing::initBindings(mInternalTesting);
tensorrt_llm::nanobind::testing::initKvCacheTestUtilBindings(mInternalTesting);
tpb::initBindings(mInternalBatchManager);

tb::kv_cache_manager::KVCacheManagerConnectorBindings::initBindings(mInternalBatchManager);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "kvCacheManagerTestUtilBinding.h"
#include "tensorrt_llm/nanobind/common/customCasters.h"
#include "tensorrt_llm/testing/kvCacheManagerTestUtil.h"

#include <nanobind/nanobind.h>

namespace nb = nanobind;

namespace tensorrt_llm::nanobind::testing
{

void initKvCacheTestUtilBindings(nb::module_& m)
{
m.def("simulate_prefill_completion_only_use_for_testing",
&tensorrt_llm::testing::KvCacheManagerTestUtil::simulatePrefillCompletion, nb::arg("llm_request"),
nb::call_guard<nb::gil_scoped_release>(),
"NEVER USE IN PRODUCTION. Simulates prefill completion on an LlmRequest for test purposes.");
}

} // namespace tensorrt_llm::nanobind::testing
29 changes: 29 additions & 0 deletions cpp/tensorrt_llm/nanobind/testing/kvCacheManagerTestUtilBinding.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <nanobind/nanobind.h>

namespace nb = nanobind;

namespace tensorrt_llm::nanobind::testing
{

void initKvCacheTestUtilBindings(nb::module_& m);

} // namespace tensorrt_llm::nanobind::testing
43 changes: 43 additions & 0 deletions cpp/tensorrt_llm/testing/kvCacheManagerTestUtil.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "tensorrt_llm/batch_manager/llmRequest.h"

namespace tensorrt_llm::testing
{

/// @brief Test utilities for KV cache manager unit tests. NEVER use in production code.
class KvCacheManagerTestUtil
{
public:
/// @brief Simulate completion of the prefill stage on an LlmRequest.
///
/// NEVER CALL FROM PRODUCTION CODE. This is solely for use in unit tests.
///
/// Most BlockManager/KVCacheManager functions (storeContextBlocks, releaseBlocks,
/// removeSequence, releaseSequence) require prefill to be complete before they are
/// called. This method updates llmRequest state as if prefill has just finished,
/// allowing unit tests to invoke those functions correctly.
static void simulatePrefillCompletion(batch_manager::LlmRequest& llmRequest)
{
llmRequest.setContextCurrentPosition(llmRequest.getPromptLen());
}
};

} // namespace tensorrt_llm::testing
2 changes: 2 additions & 0 deletions cpp/tests/unit_tests/batch_manager/capacitySchedulerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/executor/requestUtils.h"
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/testing/kvCacheManagerTestUtil.h"

#include <NvInferPlugin.h>

Expand Down Expand Up @@ -401,6 +402,7 @@ int runTest(CapacityScheduler& capacityScheduler,

if (llmReq->getContextRemainingLength() == 0)
{
tensorrt_llm::testing::KvCacheManagerTestUtil::simulatePrefillCompletion(*llmReq);
kvCacheManager->storeContextBlocks(*llmReq);
if (crossKvCacheManager)
{
Expand Down
Loading
Loading