diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index 6ac05aa68ee..aa5e01ef0b6 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -730,8 +730,11 @@ def __init__( self._lora_uid_counter = 0 self._lora_uid_to_low_ranks: Dict[str, Dict[int, Dict[str, int]]] = {} - # hold the torch tensors and prevent them from being freed - # TODO(enweiz): free device tensors if it's used for c++ runtime only + # When cpp_peft_cache_manager is provided (PyTorch backend), the C++ + # PeftCacheManager manages its own GPU cache with proper eviction. + # The Python-side GPU tensors are only needed by the legacy TRT backend + # which reads raw data_ptr() values via input_buffers(). + self._retain_device_tensors = cpp_peft_cache_manager is None self._lora_weights: List[torch.Tensor] = [] self._lora_weights_pointers_list: Dict[str, Dict[int, Dict[str, List[int]]]] = {} self._cpp_lora_weights: Dict[str, torch.Tensor] = {} # on cpu @@ -864,15 +867,14 @@ def load_from_model_file(uid, model_file): t_out = t_out.cuda().to(str_dtype_to_torch(model_config.dtype)).contiguous() rank = t_in.shape[0] self._lora_uid_to_low_ranks[uid][layer_idx][lora_module] = int(rank) - self._lora_weights_pointers_list[uid][layer_idx][lora_module] = [ - t_in.data_ptr(), - t_out.data_ptr(), - 0, - ] - - # prevent torch free this buffer - self._lora_weights.append(t_in) - self._lora_weights.append(t_out) + if self._retain_device_tensors: + self._lora_weights_pointers_list[uid][layer_idx][lora_module] = [ + t_in.data_ptr(), + t_out.data_ptr(), + 0, + ] + self._lora_weights.append(t_in) + self._lora_weights.append(t_out) self._cpp_lora_weights[uid].append( torch.concatenate([t_in.flatten().cpu(), t_out.flatten().cpu()]) ) @@ -1161,17 +1163,16 @@ def load_from_model_dir(uid, model_dir, hf_config): t_mag = t_mag.to(str_dtype_to_torch(model_config.dtype)) self._lora_uid_to_low_ranks[uid][layer_idx][lora_module] = effective_rank - self._lora_weights_pointers_list[uid][layer_idx][lora_module] = [ - t_in.data_ptr(), - t_out.data_ptr(), - t_mag.data_ptr() if (is_dora and t_mag is not None) else 0, - ] - - # prevent torch free this buffer - self._lora_weights.append(t_in) - self._lora_weights.append(t_out) - if is_dora and t_mag is not None: - self._lora_weights.append(t_mag) + if self._retain_device_tensors: + self._lora_weights_pointers_list[uid][layer_idx][lora_module] = [ + t_in.data_ptr(), + t_out.data_ptr(), + t_mag.data_ptr() if (is_dora and t_mag is not None) else 0, + ] + self._lora_weights.append(t_in) + self._lora_weights.append(t_out) + if is_dora and t_mag is not None: + self._lora_weights.append(t_mag) t_in_cpu = t_in.flatten().cpu() t_out_cpu = t_out.flatten().cpu() diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index 229a7f1e180..997017531c3 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -480,6 +480,7 @@ test_e2e.py::test_eagle3_output_repetition_4gpus[Qwen3/saved_models_Qwen3-235B-A test_e2e.py::test_eagle3_output_repetition_4gpus[llama4-models/nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8-Llama-4-Maverick-17B-128E-Eagle3] test_e2e.py::test_eagle3_output_repetition_4gpus[Qwen3/saved_models_Qwen3-235B-A22B_nvfp4_hf-Qwen3/qwen3-235B-eagle3] unittest/llmapi/test_llm_pytorch.py::test_gemma3_1b_instruct_multi_lora +unittest/llmapi/test_llm_pytorch.py::test_lora_many_adapters_no_memory_leak llmapi/test_llm_examples.py::test_llmapi_server_example # e2e serve test diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 80bca38171a..6c51374b67e 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -34,6 +34,7 @@ l0_a10: - unittest/inputs/test_chat_template_dispatch.py - unittest/inputs/test_content_format.py - unittest/others/test_convert_utils.py + - unittest/others/test_lora_manager.py - unittest/others/test_time_breakdown.py - unittest/others/test_tracing.py - unittest/disaggregated/test_disagg_openai_client.py diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 44b64d78080..83e10ece062 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -176,7 +176,6 @@ triton_server/test_triton_rcca.py::test_rcca_bug_4934893[Temperature:0.5-TOP_P:0 examples/test_gpt.py::test_llm_minitron_fp8_with_pseudo_loras[4b] SKIP (https://nvbugs/5606233) test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-hf-nvfp4-False-False] SKIP (https://nvbugs/5629791) accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[False] SKIP (https://nvbugs/5629792) -llmapi/test_llm_examples.py::test_llmapi_example_multilora SKIP (https://nvbugs/5636857) accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=TRTLLM-torch_compile=False] SKIP (https://nvbugs/5616182) full:H100_PCIe/unittest/llmapi/test_llm_pytorch.py::test_llama_7b_multi_lora_evict_and_reload_lora_gpu_cache SKIP (https://nvbugs/5682551) test_e2e.py::test_openai_completions_example[trt] SKIP (https://nvbugs/5701450) diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index fe801c9e482..f80a2558c70 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -800,6 +800,84 @@ def test_gemma3_1b_instruct_multi_lora(cuda_graph_config) -> None: assert len(outputs) == 2 +@skip_gpu_memory_less_than_40gb +@pytest.mark.part3 +def test_lora_many_adapters_no_memory_leak() -> None: + """Verify GPU memory stays bounded when loading many unique LoRA adapters. + + Creates 20 dummy adapters but sets max_loras=2 and max_cpu_loras=4 to + force eviction. Without proper cleanup, _lora_weights can accumulate + GPU tensors for every loaded adapter, causing unbounded memory growth. + """ + model_dir = f"{llm_models_root()}/gemma/gemma-3-1b-it" + num_adapters = 20 + target_modules = ['attn_q', 'attn_k', 'attn_v'] + + with tempfile.TemporaryDirectory() as lora_dir: + model = AutoModelForCausalLM.from_pretrained(model_dir, + dtype=torch.bfloat16, + device_map="auto") + hf_modules = ["q_proj", "k_proj", "v_proj"] + peft_lora_config = PeftLoraConfig(r=8, + target_modules=hf_modules, + bias="none", + task_type="CAUSAL_LM") + lora_paths = [] + for i in range(num_adapters): + lora_model = get_peft_model(model, peft_lora_config) + for param in lora_model.parameters(): + param.data.zero_() + lora_path = f"{lora_dir}/lora_{i}" + lora_model.save_pretrained(lora_path) + lora_paths.append(lora_path) + + del model + torch.cuda.empty_cache() + + trtllm_lora_config = LoraConfig(lora_dir=lora_paths[:1], + lora_target_modules=target_modules, + max_lora_rank=8, + max_loras=2, + max_cpu_loras=4) + kv_cache_config = KvCacheConfig(enable_block_reuse=False, + enable_partial_reuse=False) + llm = LLM(model_dir, + lora_config=trtllm_lora_config, + kv_cache_config=kv_cache_config) + + sampling_params = SamplingParams(max_tokens=20) + warmup_count = 5 + + mem_samples = [] + for i in range(num_adapters): + lora_req = LoRARequest(f"lora-{i}", i, lora_paths[i]) + output = llm.generate("Hello, tell me a story.", + sampling_params, + lora_request=lora_req) + assert output.outputs[0].text != "" + + if i >= warmup_count: + mem_samples.append(torch.cuda.memory_allocated()) + + num_measured = len(mem_samples) + assert num_measured >= 2, "Not enough samples to measure growth" + + total_growth = mem_samples[-1] - mem_samples[0] + per_adapter_mb = (total_growth / (num_measured - 1)) / (1024 * 1024) + + # Each adapter is ~3 MB on GPU (r=8, 3 modules, 26 layers, bf16). + # The C++ PeftCacheManager handles eviction and _lora_weights + # stays empty, so per-adapter growth should be ~0. If GPU tensors + # leak, we would see ~3 MB/adapter of linear growth. Threshold + # of 1 MB/adapter catches leaks while tolerating noise from + # allocator fragmentation averaged over many samples. + max_per_adapter_mb = 1.0 + assert per_adapter_mb < max_per_adapter_mb, ( + f"GPU memory growing at {per_adapter_mb:.2f} MB/adapter over " + f"{num_measured} adapters (total {total_growth / (1024**2):.1f} MB). " + f"Possible _lora_weights leak.") + + @pytest.mark.parametrize( "lora_rank,max_lora_rank,description", [ diff --git a/tests/unittest/others/test_lora_manager.py b/tests/unittest/others/test_lora_manager.py new file mode 100644 index 00000000000..ff36aeebda8 --- /dev/null +++ b/tests/unittest/others/test_lora_manager.py @@ -0,0 +1,166 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for LoraManager._retain_device_tensors behavior. + +Verifies that GPU tensors are not accumulated in _lora_weights when the +PyTorch backend's C++ PeftCacheManager is provided, preventing OOM with +many unique LoRA adapters. +""" + +import json +import tempfile +import unittest +from dataclasses import dataclass, field +from pathlib import Path +from unittest.mock import MagicMock + +import torch +from safetensors.torch import save_file + +from tensorrt_llm.lora_manager import LoraManager +from tensorrt_llm.mapping import Mapping + + +@dataclass +class MockModelConfig: + """Minimal model config for LoraManager tests.""" + + lora_target_modules: list = field(default_factory=lambda: ["attn_q", "attn_k", "attn_v"]) + trtllm_modules_to_hf_modules: dict = field( + default_factory=lambda: { + "attn_q": "q_proj", + "attn_k": "k_proj", + "attn_v": "v_proj", + } + ) + hidden_size: int = 64 + dtype: str = "float16" + swap_gate_up_proj_lora_b_weight: bool = True + + +def _create_dummy_hf_lora_adapter( + adapter_dir: Path, hidden_size: int = 64, rank: int = 8, num_layers: int = 2 +): + """Create a minimal HF-format LoRA adapter on disk.""" + config = { + "r": rank, + "lora_alpha": rank, + "target_modules": ["q_proj", "k_proj", "v_proj"], + "bias": "none", + "peft_type": "LORA", + "task_type": "CAUSAL_LM", + } + with open(adapter_dir / "adapter_config.json", "w") as f: + json.dump(config, f) + + weights = {} + for layer_idx in range(num_layers): + for module in ["q_proj", "k_proj", "v_proj"]: + prefix = f"base_model.model.model.layers.{layer_idx}.self_attn.{module}" + weights[f"{prefix}.lora_A.weight"] = torch.randn(rank, hidden_size, dtype=torch.float16) + weights[f"{prefix}.lora_B.weight"] = torch.randn(hidden_size, rank, dtype=torch.float16) + + save_file(weights, str(adapter_dir / "adapter_model.safetensors")) + + +@unittest.skipUnless(torch.cuda.is_available(), "CUDA required") +class TestLoraManagerRetainDeviceTensors(unittest.TestCase): + """Tests for the _retain_device_tensors flag that prevents GPU memory leaks.""" + + def _create_manager(self, cpp_peft_cache_manager=None): + mapping = Mapping(world_size=1, rank=0, tp_size=1) + model_config = MockModelConfig() + return LoraManager( + mapping=mapping, + model_config=model_config, + cpp_peft_cache_manager=cpp_peft_cache_manager, + ) + + def test_retain_device_tensors_true_when_no_cpp_cache(self): + """Legacy TRT path: cpp_peft_cache_manager=None retains GPU tensors.""" + manager = self._create_manager(cpp_peft_cache_manager=None) + self.assertTrue(manager._retain_device_tensors) + + def test_retain_device_tensors_false_when_cpp_cache_provided(self): + """PyTorch path: cpp_peft_cache_manager provided skips GPU tensor retention.""" + mock_cache = MagicMock() + manager = self._create_manager(cpp_peft_cache_manager=mock_cache) + self.assertFalse(manager._retain_device_tensors) + + def test_lora_weights_empty_with_cpp_cache(self): + """With cpp_peft_cache_manager, _lora_weights stays empty after loading.""" + mock_cache = MagicMock() + manager = self._create_manager(cpp_peft_cache_manager=mock_cache) + + with tempfile.TemporaryDirectory() as tmpdir: + adapter_dir = Path(tmpdir) / "adapter_0" + adapter_dir.mkdir() + _create_dummy_hf_lora_adapter(adapter_dir) + + model_config = MockModelConfig() + manager.load_from_hf( + model_dirs=[str(adapter_dir)], + model_config=model_config, + uids=["test-uid-0"], + ) + + self.assertEqual(len(manager._lora_weights), 0) + self.assertIn("test-uid-0", manager._cpp_lora_weights) + + def test_lora_weights_populated_without_cpp_cache(self): + """Without cpp_peft_cache_manager (TRT), _lora_weights has GPU tensors.""" + manager = self._create_manager(cpp_peft_cache_manager=None) + + with tempfile.TemporaryDirectory() as tmpdir: + adapter_dir = Path(tmpdir) / "adapter_0" + adapter_dir.mkdir() + _create_dummy_hf_lora_adapter(adapter_dir) + + model_config = MockModelConfig() + manager.load_from_hf( + model_dirs=[str(adapter_dir)], + model_config=model_config, + uids=["test-uid-0"], + ) + + self.assertGreater(len(manager._lora_weights), 0) + self.assertTrue(all(t.is_cuda for t in manager._lora_weights)) + self.assertIn("test-uid-0", manager._lora_weights_pointers_list) + + def test_many_adapters_no_gpu_accumulation(self): + """Loading many adapters with cpp_cache does not accumulate GPU tensors.""" + mock_cache = MagicMock() + manager = self._create_manager(cpp_peft_cache_manager=mock_cache) + model_config = MockModelConfig() + + num_adapters = 20 + with tempfile.TemporaryDirectory() as tmpdir: + for i in range(num_adapters): + adapter_dir = Path(tmpdir) / f"adapter_{i}" + adapter_dir.mkdir() + _create_dummy_hf_lora_adapter(adapter_dir) + + manager.load_from_hf( + model_dirs=[str(adapter_dir)], + model_config=model_config, + uids=[f"uid-{i}"], + ) + + self.assertEqual(len(manager._lora_weights), 0) + self.assertEqual(len(manager._cpp_lora_weights), num_adapters) + + +if __name__ == "__main__": + unittest.main()