diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 14f42a0ae8a..6efcf28e404 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -231,7 +231,7 @@ def end_transfer(self, request: LlmRequest) -> bool: logger.warning( f"Request {request.py_request_id} not found in transfer manager" ) - return + return False if transfer_metadata.end_transfer(): self._requests_in_transfer.pop(request.py_request_id) @@ -610,7 +610,12 @@ def _end_transfer_and_maybe_terminate(self, request: LlmRequest): self._terminate_request(request) return if self.async_transfer_manager.end_transfer(request): - self._terminate_request(request) + # When should_store_blocks is True, _handle_responses already + # terminated this request via the early-termination path + # (enable_partial_reuse_for_disagg branch). Skip the redundant + # termination to avoid double free_resources calls. + if not self.async_transfer_manager.should_store_blocks: + self._terminate_request(request) # Performance metrics methods are in PerfMetricsManager (self.perf_manager) diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index d2ac243a027..6da01e56adc 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -453,14 +453,6 @@ def _enqueue_request(self, context_phase_params = request.disaggregated_params.get_context_phase_params( ) - if self._is_pytorch_backend and not self.llm_args.disable_overlap_scheduler \ - and self.llm_args.kv_cache_config.enable_block_reuse \ - and self.engine.kv_cache_transceiver is not None \ - and request_type == tllm.RequestType.REQUEST_TYPE_CONTEXT_ONLY: - raise ValueError( - "Context only requests are not supported in pytorch backend when overlap is enabled with block reuse." - ) - assert request.id is not None def _deduce_max_tokens(request: GenerationRequest, diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index bae33892d7c..c89961d74ce 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -598,10 +598,6 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness): def test_auto_dtype(self, ctx_disable_overlap_scheduler, gen_disable_overlap_scheduler, ctx_enable_block_reuse, gen_enable_block_reuse): - if ctx_enable_block_reuse and not ctx_disable_overlap_scheduler: - pytest.skip( - "Skip this test because overlap scheduler is not supported with block reuse for context server" - ) ctx_server_config = { "disable_overlap_scheduler": ctx_disable_overlap_scheduler, "kv_cache_config": { @@ -1514,7 +1510,7 @@ def _test_chunked_prefill_helper(self, *, ctx_pp: int): max_batch_size = 32 kv_cache_config = { - "enable_block_reuse": True if ctx_pp == 1 else False, + "enable_block_reuse": True, } ctx_server_config = { diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_overlap.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_overlap.yaml index 3a872fbbc95..391f95605b2 100644 --- a/tests/integration/defs/disaggregated/test_configs/disagg_config_overlap.yaml +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_overlap.yaml @@ -11,9 +11,9 @@ context_servers: tensor_parallel_size: 1 pipeline_parallel_size: 1 kv_cache_config: - enable_block_reuse: false + enable_block_reuse: true + enable_partial_reuse: true free_gpu_memory_fraction: 0.2 - enable_partial_reuse: false cache_transceiver_config: backend: DEFAULT generation_servers: @@ -24,8 +24,8 @@ generation_servers: max_num_tokens: 4096 max_seq_len: 4096 kv_cache_config: - enable_block_reuse: false + enable_block_reuse: true + enable_partial_reuse: true free_gpu_memory_fraction: 0.2 - enable_partial_reuse: false cache_transceiver_config: backend: DEFAULT diff --git a/tests/unittest/_torch/executor/test_overlap_scheduler.py b/tests/unittest/_torch/executor/test_overlap_scheduler.py index 306c5e45101..08bbb3a087a 100644 --- a/tests/unittest/_torch/executor/test_overlap_scheduler.py +++ b/tests/unittest/_torch/executor/test_overlap_scheduler.py @@ -25,14 +25,16 @@ def model_path(): def create_llm(model_dir, disable_overlap_scheduler, sampler_type, - scheduler_config=None): + scheduler_config=None, + enable_block_reuse=False): """Create LLM with specific overlap scheduler setting""" if scheduler_config is None: scheduler_config = SchedulerConfig() pytorch_config = dict(disable_overlap_scheduler=disable_overlap_scheduler, sampler_type=sampler_type) - trt_kv_cache_config = TRT_KvCacheConfig(enable_block_reuse=False) + trt_kv_cache_config = TRT_KvCacheConfig( + enable_block_reuse=enable_block_reuse) return LLM( model=str(model_dir), @@ -51,10 +53,13 @@ def create_llm(model_dir, @pytest.mark.parametrize("sampler_type", ["TorchSampler", "TRTLLMSampler"]) @pytest.mark.parametrize("use_python_scheduler", [False, True], ids=["cpp_scheduler", "python_scheduler"]) +@pytest.mark.parametrize("enable_block_reuse", [False, True], + ids=["no_reuse", "block_reuse"]) @pytest.mark.high_cuda_memory @pytest.mark.mpi_ray_parity def test_overlap_scheduler_consistency(model_path, test_case, sampler_type, - use_python_scheduler): + use_python_scheduler, + enable_block_reuse): scheduler_config = SchedulerConfig( use_python_scheduler=use_python_scheduler) @@ -76,7 +81,8 @@ def test_overlap_scheduler_consistency(model_path, test_case, sampler_type, with create_llm(model_path, disable_overlap_scheduler=False, sampler_type=sampler_type, - scheduler_config=scheduler_config) as llm: + scheduler_config=scheduler_config, + enable_block_reuse=enable_block_reuse) as llm: outputs_with_overlap = llm.generate(prompts, sampling_params=sampling_config, use_tqdm=True) @@ -88,7 +94,8 @@ def test_overlap_scheduler_consistency(model_path, test_case, sampler_type, with create_llm(model_path, disable_overlap_scheduler=True, sampler_type=sampler_type, - scheduler_config=scheduler_config) as llm: + scheduler_config=scheduler_config, + enable_block_reuse=enable_block_reuse) as llm: outputs_without_overlap = llm.generate(prompts, sampling_params=sampling_config, use_tqdm=True) @@ -98,9 +105,48 @@ def test_overlap_scheduler_consistency(model_path, test_case, sampler_type, # Verify outputs are consistent for with_overlap, without_overlap in zip(texts_with_overlap, - texts_without_overlap): + texts_without_overlap, + strict=True): assert with_overlap == without_overlap +@pytest.mark.parametrize("sampler_type", ["TorchSampler", "TRTLLMSampler"]) +@pytest.mark.high_cuda_memory +@pytest.mark.mpi_ray_parity +def test_overlap_scheduler_block_reuse_cache_hit(model_path, test_case, + sampler_type): + """Verify that blocks are actually reused when sending the same prompt + twice with the overlap scheduler enabled. Uses a single prompt to avoid + batch-internal cache hits that could make the cold-cache check flaky.""" + prompt = test_case["prompts"][0] + max_new_tokens = test_case["max_new_tokens"] + temperature = test_case["temperature"] + top_p = test_case["top_p"] + stop_words = test_case["stop_words"] + + sampling_config = SamplingParams(max_tokens=max_new_tokens, + stop=stop_words, + temperature=temperature, + top_p=top_p, + n=1, + use_beam_search=True) + + with create_llm(model_path, + disable_overlap_scheduler=False, + sampler_type=sampler_type, + enable_block_reuse=True) as llm: + output_first = llm.generate([prompt], + sampling_params=sampling_config, + use_tqdm=True)[0] + assert output_first.cached_tokens == 0, ( + "First pass should have no cached tokens (cold cache)") + + output_second = llm.generate([prompt], + sampling_params=sampling_config, + use_tqdm=True)[0] + assert output_second.cached_tokens > 0, ( + "Second pass should reuse cached blocks") + + if __name__ == "__main__": test_overlap_scheduler_consistency()