[TRTLLM-10939][feat] Enable block reuse with overlap scheduler#12816
[TRTLLM-10939][feat] Enable block reuse with overlap scheduler#12816chienchunhung wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
2b7c7a5 to
27c188b
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #42188 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughAdjusts PyExecutor termination flow to make end_transfer boolean-consistent and avoid redundant terminations when blocks are stored; removes a backend-only validation blocking block-reuse with overlap scheduling; and enables/parametrizes block-reuse in related integration and unit tests and configs. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
tests/integration/defs/disaggregated/test_configs/disagg_config_overlap.yaml (1)
13-15: Makeenable_partial_reuseexplicit in this overlap config.This config only exercises the disaggregated reuse path because
enable_partial_reusecurrently defaults totrue. Making that explicit keeps the coverage stable if the default changes later.📝 Suggested change
context_servers: num_instances: 1 max_batch_size: 1 max_num_tokens: 3000 max_seq_len: 4096 tensor_parallel_size: 1 pipeline_parallel_size: 1 kv_cache_config: enable_block_reuse: true + enable_partial_reuse: true free_gpu_memory_fraction: 0.2 @@ generation_servers: num_instances: 1 tensor_parallel_size: 1 pipeline_parallel_size: 1 max_batch_size: 256 max_num_tokens: 4096 max_seq_len: 4096 kv_cache_config: enable_block_reuse: true + enable_partial_reuse: true free_gpu_memory_fraction: 0.2Also applies to: 25-27
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/integration/defs/disaggregated/test_configs/disagg_config_overlap.yaml` around lines 13 - 15, The test config omits the kv_cache_config flag enable_partial_reuse which currently defaults to true; make it explicit to ensure this disaggregated reuse path remains covered. Update the YAML under kv_cache_config in disagg_config_overlap.yaml (and the similar block at lines 25-27) to add enable_partial_reuse: true so the test explicitly enables partial reuse (refer to kv_cache_config and enable_partial_reuse).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/pyexecutor/py_executor.py`:
- Around line 609-616: The code currently clears the termination guard
unconditionally which can allow double termination when
AsyncTransferManager.end_transfer(request) returns False for the first of
multiple transfers; change the second branch to mirror the first: call
end_transfer(request) and if it returns True then perform termination work (call
self._terminate_request(...) and remove from self.active_requests if present)
and only then discard self._terminated_request_ids for request.py_request_id,
otherwise return immediately without discarding; update the logic around
AsyncTransferManager.end_transfer, self.active_requests.remove,
self._terminate_request, and self._terminated_request_ids.discard to ensure the
guard is only cleared when end_transfer() returned True.
- Around line 3406-3413: The _terminated_request_ids set is only updated in the
transfer callback path, causing IDs (recorded at line where req_id is taken from
request.py_request_id and added to self._terminated_request_ids) to never be
removed for ordinary completions or for PP requests terminated later by
DisaggPPTerminationHandler; update the logic so that any completion path that
finalizes a request also removes the id from self._terminated_request_ids (and
pops result_wait_queues) — specifically ensure _do_terminate_request(),
_end_transfer_and_maybe_terminate(), and the normal completion code path call a
common cleanup helper (e.g., _cleanup_request_termination(req_id)) that frees
resources via resource_manager.free_resources(request) if needed and removes
req_id from self._terminated_request_ids and result_wait_queues to prevent
permanent growth and reuse conflicts.
In `@tests/unittest/_torch/executor/test_overlap_scheduler.py`:
- Around line 120-149: The current cold-cache check is flaky because calling
llm.generate on a batch can produce cache hits for later prompts; change the
first warmup/generation to invoke llm.generate per-prompt to guarantee a cold
cache for each request: inside the with create_llm(...) block, replace the
batched call that produces outputs_first with a loop over prompts that calls
llm.generate([prompt], sampling_params=sampling_config, use_tqdm=True) and
assert output.cached_tokens == 0 for each single-prompt result (keep the later
batched generate for the cache-reuse check). Reference functions/objects:
create_llm, llm.generate, SamplingParams, outputs_first/outputs_second.
---
Nitpick comments:
In
`@tests/integration/defs/disaggregated/test_configs/disagg_config_overlap.yaml`:
- Around line 13-15: The test config omits the kv_cache_config flag
enable_partial_reuse which currently defaults to true; make it explicit to
ensure this disaggregated reuse path remains covered. Update the YAML under
kv_cache_config in disagg_config_overlap.yaml (and the similar block at lines
25-27) to add enable_partial_reuse: true so the test explicitly enables partial
reuse (refer to kv_cache_config and enable_partial_reuse).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: a1c7f708-6bee-4c8e-82be-00cfa11a8b12
📒 Files selected for processing (5)
tensorrt_llm/_torch/pyexecutor/py_executor.pytensorrt_llm/executor/base_worker.pytests/integration/defs/accuracy/test_disaggregated_serving.pytests/integration/defs/disaggregated/test_configs/disagg_config_overlap.yamltests/unittest/_torch/executor/test_overlap_scheduler.py
💤 Files with no reviewable changes (1)
- tensorrt_llm/executor/base_worker.py
| if self.async_transfer_manager.end_transfer(request): | ||
| self.active_requests.remove(request) | ||
| self._terminate_request(request) | ||
| self._terminated_request_ids.discard(request.py_request_id) | ||
| return | ||
| if self.async_transfer_manager.end_transfer(request): | ||
| self._terminate_request(request) | ||
| self._terminated_request_ids.discard(request.py_request_id) |
There was a problem hiding this comment.
Don't drop the termination guard before the last transfer.
AsyncTransferManager can track more than one transfer for the same request. Clearing _terminated_request_ids even when end_transfer() returns False drops the guard after the first completion, so a request with both connector and transceiver transfers can hit _do_terminate_request() again on the final completion and double-free resources.
🐛 Proposed fix
if response:
response.result.cached_tokens = request.cached_tokens
self._enqueue_responses([(request.py_request_id, response)])
if self.async_transfer_manager.end_transfer(request):
self.active_requests.remove(request)
self._terminate_request(request)
- self._terminated_request_ids.discard(request.py_request_id)
+ self._terminated_request_ids.discard(request.py_request_id)
return
if self.async_transfer_manager.end_transfer(request):
self._terminate_request(request)
- self._terminated_request_ids.discard(request.py_request_id)
+ self._terminated_request_ids.discard(request.py_request_id)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/pyexecutor/py_executor.py` around lines 609 - 616, The
code currently clears the termination guard unconditionally which can allow
double termination when AsyncTransferManager.end_transfer(request) returns False
for the first of multiple transfers; change the second branch to mirror the
first: call end_transfer(request) and if it returns True then perform
termination work (call self._terminate_request(...) and remove from
self.active_requests if present) and only then discard
self._terminated_request_ids for request.py_request_id, otherwise return
immediately without discarding; update the logic around
AsyncTransferManager.end_transfer, self.active_requests.remove,
self._terminate_request, and self._terminated_request_ids.discard to ensure the
guard is only cleared when end_transfer() returned True.
| req_id = request.py_request_id | ||
| if req_id in self._terminated_request_ids: | ||
| return | ||
| self._terminated_request_ids.add(req_id) | ||
| self.resource_manager.free_resources(request) | ||
|
|
||
| if self.gather_all_responses or self.dist.rank == 0: | ||
| self.result_wait_queues.pop(request.py_request_id, None) | ||
| self.result_wait_queues.pop(req_id, None) |
There was a problem hiding this comment.
Bound _terminated_request_ids outside the transfer callback too.
Line 3409 records every terminated request, but the only visible cleanup is in _end_transfer_and_maybe_terminate(). Ordinary completions—and PP requests whose real _do_terminate_request() runs later via DisaggPPTerminationHandler—never remove their IDs, so this set grows for the lifetime of the executor and will turn termination into a no-op if request IDs are ever reused.
🧹 Proposed fix
def _do_terminate_request(self, request: LlmRequest):
req_id = request.py_request_id
if req_id in self._terminated_request_ids:
return
- self._terminated_request_ids.add(req_id)
- self.resource_manager.free_resources(request)
+ keep_guard = req_id in self.async_transfer_manager.requests_in_transfer()
+ self._terminated_request_ids.add(req_id)
+ try:
+ self.resource_manager.free_resources(request)
+ finally:
+ if not keep_guard:
+ self._terminated_request_ids.discard(req_id)
if self.gather_all_responses or self.dist.rank == 0:
self.result_wait_queues.pop(req_id, None)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/pyexecutor/py_executor.py` around lines 3406 - 3413, The
_terminated_request_ids set is only updated in the transfer callback path,
causing IDs (recorded at line where req_id is taken from request.py_request_id
and added to self._terminated_request_ids) to never be removed for ordinary
completions or for PP requests terminated later by DisaggPPTerminationHandler;
update the logic so that any completion path that finalizes a request also
removes the id from self._terminated_request_ids (and pops result_wait_queues) —
specifically ensure _do_terminate_request(),
_end_transfer_and_maybe_terminate(), and the normal completion code path call a
common cleanup helper (e.g., _cleanup_request_termination(req_id)) that frees
resources via resource_manager.free_resources(request) if needed and removes
req_id from self._terminated_request_ids and result_wait_queues to prevent
permanent growth and reuse conflicts.
| prompts = test_case["prompts"] | ||
| max_new_tokens = test_case["max_new_tokens"] | ||
| temperature = test_case["temperature"] | ||
| top_p = test_case["top_p"] | ||
| stop_words = test_case["stop_words"] | ||
|
|
||
| sampling_config = SamplingParams(max_tokens=max_new_tokens, | ||
| stop=stop_words, | ||
| temperature=temperature, | ||
| top_p=top_p, | ||
| n=1, | ||
| use_beam_search=True) | ||
|
|
||
| with create_llm(model_path, | ||
| disable_overlap_scheduler=False, | ||
| sampler_type=sampler_type, | ||
| enable_block_reuse=True) as llm: | ||
| outputs_first = llm.generate(prompts, | ||
| sampling_params=sampling_config, | ||
| use_tqdm=True) | ||
| for output in outputs_first: | ||
| assert output.cached_tokens == 0, ( | ||
| "First pass should have no cached tokens (cold cache)") | ||
|
|
||
| outputs_second = llm.generate(prompts, | ||
| sampling_params=sampling_config, | ||
| use_tqdm=True) | ||
| for output in outputs_second: | ||
| assert output.cached_tokens > 0, ( | ||
| "Second pass should reuse cached blocks") |
There was a problem hiding this comment.
The cold-cache assertion is batch-order dependent.
With max_num_tokens=128, the first generate(prompts, ...) call can legitimately produce cache hits for later prompts in the same batch once earlier requests finish chunked prefill. That makes cached_tokens == 0 on every first-pass output flaky instead of a pure cold-cache check.
💡 Proposed fix
- prompts = test_case["prompts"]
+ prompt = test_case["prompts"][0]
@@
- outputs_first = llm.generate(prompts,
- sampling_params=sampling_config,
- use_tqdm=True)
- for output in outputs_first:
- assert output.cached_tokens == 0, (
- "First pass should have no cached tokens (cold cache)")
+ output_first = llm.generate([prompt],
+ sampling_params=sampling_config,
+ use_tqdm=True)[0]
+ assert output_first.cached_tokens == 0, (
+ "First pass should have no cached tokens (cold cache)")
- outputs_second = llm.generate(prompts,
- sampling_params=sampling_config,
- use_tqdm=True)
- for output in outputs_second:
- assert output.cached_tokens > 0, (
- "Second pass should reuse cached blocks")
+ output_second = llm.generate([prompt],
+ sampling_params=sampling_config,
+ use_tqdm=True)[0]
+ assert output_second.cached_tokens > 0, (
+ "Second pass should reuse cached blocks")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/unittest/_torch/executor/test_overlap_scheduler.py` around lines 120 -
149, The current cold-cache check is flaky because calling llm.generate on a
batch can produce cache hits for later prompts; change the first
warmup/generation to invoke llm.generate per-prompt to guarantee a cold cache
for each request: inside the with create_llm(...) block, replace the batched
call that produces outputs_first with a loop over prompts that calls
llm.generate([prompt], sampling_params=sampling_config, use_tqdm=True) and
assert output.cached_tokens == 0 for each single-prompt result (keep the later
batched generate for the cache-reuse check). Reference functions/objects:
create_llm, llm.generate, SamplingParams, outputs_first/outputs_second.
Make _do_terminate_request idempotent to prevent double-termination when both _handle_responses (early termination) and _end_transfer_and_maybe_terminate fire on the same request under the overlap scheduler. - Add _terminated_request_ids tracking set to skip redundant free_resources calls - Remove ValueError guard in base_worker.py that blocked context-only + overlap + block_reuse + disagg - Remove pytest.skip for overlap + block_reuse in disagg test - Add enable_block_reuse parameter to overlap scheduler tests - Add cache-hit verification test - Fix end_transfer bare return -> return False Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Made-with: Cursor Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Made-with: Cursor Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Made-with: Cursor Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Made-with: Cursor
27c188b to
1b3b9d8
Compare
|
/bot run --disable-fail-fast |
|
@CodeRabbit review |
✅ Actions performedReview triggered.
|
|
PR_Github #42374 [ run ] triggered by Bot. Commit: |
Summary by CodeRabbit
New Features
Bug Fixes
Tests
Description
Re-enable KV cache block reuse (prefix caching) when the overlap scheduler is active. Block reuse and the overlap scheduler were previously mutually exclusive due to an explicit guard in
base_worker.pythat rejected context-only requests in disaggregated serving when both features were enabled.Problem
Block reuse (
enable_block_reuse, defaultTrue) and the overlap scheduler (disable_overlap_scheduler=False, default) are both enabled by default, but their combination was explicitly blocked for disaggregated context-only requests.Root cause
Removing the guard exposed a latent issue: both
_handle_responses(early termination with pinned blocks) and_end_transfer_and_maybe_terminate(after KV transfer completes) call_terminate_request→free_resourceson the same request. Under the non-overlap scheduler this was benign because the transfer typically completed before_handle_responsesran, so only one path fired. Under the overlap scheduler, the deferred processing creates a window where the transfer is still in-flight when_handle_responsesterminates the request, causingend_transferto terminate it again.Fix
In
_end_transfer_and_maybe_terminate, skip the redundant_terminate_requestcall whenshould_store_blocksis True, since the early-termination path in_handle_responsesalready handled it. This preserves the existing early-termination + pin/unpin mechanism while preventing the double-termination crash.Changes
tensorrt_llm/_torch/pyexecutor/py_executor.py_end_transfer_and_maybe_terminate, guard the non-fast-transfer_terminate_requestcall withif not should_store_blocks. Whenshould_store_blocksis True,_handle_responsesalready terminated the request via theenable_partial_reuse_for_disaggearly-termination branch.end_transferto returnFalse(instead of barereturn) onKeyError, preventing unintended termination by the caller.tensorrt_llm/executor/base_worker.pyValueErrorguard that rejected context-only requests when overlap scheduler, block reuse, and KV cache transceiver were all active.tests/unittest/_torch/executor/test_overlap_scheduler.pyenable_block_reuseparameter tocreate_llmhelper and totest_overlap_scheduler_consistencyas a parametrized axis ([False, True]), so the existing consistency test now covers both block reuse configurations.test_overlap_scheduler_block_reuse_cache_hit— sends the same prompt twice and verifies blocks are actually reused (cached_tokens > 0on second pass).strict=Truetozip()call for length-mismatch safety.tests/integration/defs/accuracy/test_disaggregated_serving.pypytest.skipfor the overlap + block reuse combination on context servers.enable_block_reuseunconditionally in_test_chunked_prefill_helper(was gated onctx_pp == 1; regular block reuse has no PP restriction).tests/integration/defs/disaggregated/test_configs/disagg_config_overlap.yamlenable_block_reuse: trueandenable_partial_reuse: trueon both context and generation servers to exercise the full block reuse path in the disaggregated overlap test.Test Coverage
test_overlap_scheduler_consistency[no_reuse-*]test_overlap_scheduler_consistency[block_reuse-*]test_overlap_scheduler_block_reuse_cache_hitcached_tokens > 0)test_auto_dtype(disaggregated)disagg_config_overlap.yamltestPR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.