Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def end_transfer(self, request: LlmRequest) -> bool:
logger.warning(
f"Request {request.py_request_id} not found in transfer manager"
)
return
return False

if transfer_metadata.end_transfer():
self._requests_in_transfer.pop(request.py_request_id)
Expand Down Expand Up @@ -610,7 +610,12 @@ def _end_transfer_and_maybe_terminate(self, request: LlmRequest):
self._terminate_request(request)
return
if self.async_transfer_manager.end_transfer(request):
self._terminate_request(request)
# When should_store_blocks is True, _handle_responses already
# terminated this request via the early-termination path
# (enable_partial_reuse_for_disagg branch). Skip the redundant
# termination to avoid double free_resources calls.
if not self.async_transfer_manager.should_store_blocks:
self._terminate_request(request)

# Performance metrics methods are in PerfMetricsManager (self.perf_manager)

Expand Down
8 changes: 0 additions & 8 deletions tensorrt_llm/executor/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,14 +453,6 @@ def _enqueue_request(self,
context_phase_params = request.disaggregated_params.get_context_phase_params(
)

if self._is_pytorch_backend and not self.llm_args.disable_overlap_scheduler \
and self.llm_args.kv_cache_config.enable_block_reuse \
and self.engine.kv_cache_transceiver is not None \
and request_type == tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY:
raise ValueError(
"Context only requests are not supported in pytorch backend when overlap is enabled with block reuse."
)

assert request.id is not None

def _deduce_max_tokens(request: GenerationRequest,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -598,10 +598,6 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
def test_auto_dtype(self, ctx_disable_overlap_scheduler,
gen_disable_overlap_scheduler, ctx_enable_block_reuse,
gen_enable_block_reuse):
if ctx_enable_block_reuse and not ctx_disable_overlap_scheduler:
pytest.skip(
"Skip this test because overlap scheduler is not supported with block reuse for context server"
)
ctx_server_config = {
"disable_overlap_scheduler": ctx_disable_overlap_scheduler,
"kv_cache_config": {
Expand Down Expand Up @@ -1514,7 +1510,7 @@ def _test_chunked_prefill_helper(self, *, ctx_pp: int):
max_batch_size = 32

kv_cache_config = {
"enable_block_reuse": True if ctx_pp == 1 else False,
"enable_block_reuse": True,
}

ctx_server_config = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ context_servers:
tensor_parallel_size: 1
pipeline_parallel_size: 1
kv_cache_config:
enable_block_reuse: false
enable_block_reuse: true
enable_partial_reuse: true
free_gpu_memory_fraction: 0.2
enable_partial_reuse: false
cache_transceiver_config:
backend: DEFAULT
generation_servers:
Expand All @@ -24,8 +24,8 @@ generation_servers:
max_num_tokens: 4096
max_seq_len: 4096
kv_cache_config:
enable_block_reuse: false
enable_block_reuse: true
enable_partial_reuse: true
free_gpu_memory_fraction: 0.2
enable_partial_reuse: false
cache_transceiver_config:
backend: DEFAULT
58 changes: 52 additions & 6 deletions tests/unittest/_torch/executor/test_overlap_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@ def model_path():
def create_llm(model_dir,
disable_overlap_scheduler,
sampler_type,
scheduler_config=None):
scheduler_config=None,
enable_block_reuse=False):
"""Create LLM with specific overlap scheduler setting"""
if scheduler_config is None:
scheduler_config = SchedulerConfig()
pytorch_config = dict(disable_overlap_scheduler=disable_overlap_scheduler,
sampler_type=sampler_type)

trt_kv_cache_config = TRT_KvCacheConfig(enable_block_reuse=False)
trt_kv_cache_config = TRT_KvCacheConfig(
enable_block_reuse=enable_block_reuse)

return LLM(
model=str(model_dir),
Expand All @@ -51,10 +53,13 @@ def create_llm(model_dir,
@pytest.mark.parametrize("sampler_type", ["TorchSampler", "TRTLLMSampler"])
@pytest.mark.parametrize("use_python_scheduler", [False, True],
ids=["cpp_scheduler", "python_scheduler"])
@pytest.mark.parametrize("enable_block_reuse", [False, True],
ids=["no_reuse", "block_reuse"])
@pytest.mark.high_cuda_memory
@pytest.mark.mpi_ray_parity
def test_overlap_scheduler_consistency(model_path, test_case, sampler_type,
use_python_scheduler):
use_python_scheduler,
enable_block_reuse):
scheduler_config = SchedulerConfig(
use_python_scheduler=use_python_scheduler)

Expand All @@ -76,7 +81,8 @@ def test_overlap_scheduler_consistency(model_path, test_case, sampler_type,
with create_llm(model_path,
disable_overlap_scheduler=False,
sampler_type=sampler_type,
scheduler_config=scheduler_config) as llm:
scheduler_config=scheduler_config,
enable_block_reuse=enable_block_reuse) as llm:
outputs_with_overlap = llm.generate(prompts,
sampling_params=sampling_config,
use_tqdm=True)
Expand All @@ -88,7 +94,8 @@ def test_overlap_scheduler_consistency(model_path, test_case, sampler_type,
with create_llm(model_path,
disable_overlap_scheduler=True,
sampler_type=sampler_type,
scheduler_config=scheduler_config) as llm:
scheduler_config=scheduler_config,
enable_block_reuse=enable_block_reuse) as llm:
outputs_without_overlap = llm.generate(prompts,
sampling_params=sampling_config,
use_tqdm=True)
Expand All @@ -98,9 +105,48 @@ def test_overlap_scheduler_consistency(model_path, test_case, sampler_type,

# Verify outputs are consistent
for with_overlap, without_overlap in zip(texts_with_overlap,
texts_without_overlap):
texts_without_overlap,
strict=True):
assert with_overlap == without_overlap


@pytest.mark.parametrize("sampler_type", ["TorchSampler", "TRTLLMSampler"])
@pytest.mark.high_cuda_memory
@pytest.mark.mpi_ray_parity
def test_overlap_scheduler_block_reuse_cache_hit(model_path, test_case,
sampler_type):
"""Verify that blocks are actually reused when sending the same prompt
twice with the overlap scheduler enabled. Uses a single prompt to avoid
batch-internal cache hits that could make the cold-cache check flaky."""
prompt = test_case["prompts"][0]
max_new_tokens = test_case["max_new_tokens"]
temperature = test_case["temperature"]
top_p = test_case["top_p"]
stop_words = test_case["stop_words"]

sampling_config = SamplingParams(max_tokens=max_new_tokens,
stop=stop_words,
temperature=temperature,
top_p=top_p,
n=1,
use_beam_search=True)

with create_llm(model_path,
disable_overlap_scheduler=False,
sampler_type=sampler_type,
enable_block_reuse=True) as llm:
output_first = llm.generate([prompt],
sampling_params=sampling_config,
use_tqdm=True)[0]
assert output_first.cached_tokens == 0, (
"First pass should have no cached tokens (cold cache)")

output_second = llm.generate([prompt],
sampling_params=sampling_config,
use_tqdm=True)[0]
assert output_second.cached_tokens > 0, (
"Second pass should reuse cached blocks")


if __name__ == "__main__":
test_overlap_scheduler_consistency()
Loading