[None][feat] AutoDeploy: Moved to #12861 - Gemma4 vision#12810
[None][feat] AutoDeploy: Moved to #12861 - Gemma4 vision#12810bmarimuthu-nv wants to merge 8 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
…mask) Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
📝 WalkthroughWalkthroughThis pull request introduces support for Gemma 3n and Gemma 4 models with comprehensive attention mechanism enhancements including shared KV-cache semantics, custom attention masks, and sliding-window support across multiple attention backends (FlashInfer, TensorRT-LLM, Triton, PyTorch). New model implementations, configuration files, and infrastructure for attention mask providers are included alongside extensive test coverage. Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant Graph as FX Graph
participant Transform as InjectCustomAttentionMask
participant Provider as AttentionMaskProvider
participant AttentionOp as torch_attention Node
Client->>Graph: Export model to FX
Graph->>Transform: Apply transform
Transform->>Transform: Scan for torch_attention nodes
Transform->>Provider: Lookup provider (model_type, backend)
Provider->>Provider: Check registry
Provider-->>Transform: Return provider function
Transform->>Provider: Call provider with context
Provider->>Graph: Create mask nodes in graph
Graph-->>Provider: Return mask node
Provider-->>Transform: Return mask node
Transform->>AttentionOp: Insert mask node before attention
Transform->>AttentionOp: Set attn_mask argument
AttentionOp-->>Graph: Updated graph with mask inputs
Graph-->>Client: Return transformed graph
sequenceDiagram
participant Prefill as Prefill Phase
participant Decode as Decode Phase
participant TritonOp as TritonPagedAttention
participant KVCache as KV Cache
participant SlideWindow as Sliding Window
Prefill->>TritonOp: triton_paged_context(Q, KV_cache, ...)
TritonOp->>SlideWindow: Check sliding_window param
SlideWindow->>TritonOp: Compute first_valid_pos
TritonOp->>TritonOp: Load KV pages from first_window_page
TritonOp->>TritonOp: Apply sliding window mask
TritonOp->>KVCache: Update KV_cache (write phase)
KVCache-->>TritonOp: Cache updated
Decode->>TritonOp: triton_paged_decode(Q, KV_cache, sliding_window)
TritonOp->>TritonOp: Filter KV pages by sliding window
TritonOp->>TritonOp: Apply per-page window constraint
TritonOp->>KVCache: Read from filtered pages (read-only)
KVCache-->>TritonOp: Return cached KV
TritonOp-->>Decode: Return attention output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ❌ 3❌ Failed checks (2 warnings, 1 inconclusive)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 15
Note
Due to the large number of review comments, Critical, Major severity comments were prioritized as inline comments.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py (1)
166-190:⚠️ Potential issue | 🟠 MajorKeep GELU gated-only in the torch reference ops.
This helper now resolves
ActivationType.Geluunconditionally, and Line 315, Line 485, Line 634, and Line 822 use it without checkingis_gated_mlp. That makesis_gated_mlp=False, act_fn=ActivationType.Gelusucceed in the torch reference MoE ops even though the fused TRTLLM path still rejects nongated GELU, so the two implementations no longer share the same supported surface.🛠️ Suggested fix
+def _validate_mlp_style_and_act_fn(is_gated_mlp: bool, act_fn: int) -> None: + if is_gated_mlp and act_fn in ( + ActivationType.Silu, + ActivationType.Swiglu, + ActivationType.Gelu, + ActivationType.Geglu, + ): + return + if not is_gated_mlp and act_fn in (ActivationType.Silu, ActivationType.Relu2): + return + raise ValueError( + f"Unsupported combination: is_gated_mlp='{is_gated_mlp}', act_fn='{act_fn}'." + ) + def torch_moe( x: torch.Tensor, selected_experts: torch.Tensor, @@ ) -> torch.Tensor: + _validate_mlp_style_and_act_fn(is_gated_mlp, act_fn) torch_act_fn = _resolve_torch_fn(act_fn)Apply the same validation to the other
torch_*_moeentry points as well.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py` around lines 166 - 190, The helper _resolve_torch_fn currently exposes GELU unconditionally; change it to enforce "GELU is allowed only for gated MLPs" by adding an is_gated_mlp: bool parameter (or otherwise checking gating) and assert/raise if act_fn == ActivationType.Gelu and not is_gated_mlp; then update every torch_*_moe entry point that calls _resolve_torch_fn (the torch reference MoE functions referenced in the review) to pass the is_gated_mlp flag from their own parameters so the same validation occurs in the torch path as in the fused TRTLLM path.tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py (1)
1269-1296:⚠️ Potential issue | 🟠 MajorAdd
"out"to themutates_argstuple.
triton_paged_mha_with_cache()writes tooutwhen provided (lines 1347, 1360, 1373, 1380), butmutates_argsonly listskv_cache. PyTorch requires all mutated Tensor inputs—including optionaloutbuffers—to be declared inmutates_args; omitting them leaves export/functionalization with an incomplete alias contract and undefined behavior.Change line 1269:
`@torch.library.custom_op`("auto_deploy::triton_paged_mha_with_cache", mutates_args=("kv_cache", "out"))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py` around lines 1269 - 1296, The custom op triton_paged_mha_with_cache currently declares only "kv_cache" in mutates_args but also writes to the optional out buffer; update the decorator on triton_paged_mha_with_cache to include "out" in the mutates_args tuple so that both kv_cache and out are declared as mutated (i.e., change the mutates_args to include "out" alongside "kv_cache" in the `@torch.library.custom_op` decorator for triton_paged_mha_with_cache).tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py (2)
451-479:⚠️ Potential issue | 🔴 CriticalDeclare the cache/output side effects in the custom-op schema.
This op writes
k_cache,v_cache, and optionallyout. Registering it withmutates_args=()violates PyTorch'storch.library.custom_opcontract and allows the framework to treat it as pure during export/functionalization, silently eliding or reordering cache updates.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py` around lines 451 - 479, The custom op registration for "auto_deploy::torch_cached_attention_with_cache" incorrectly declares mutates_args=(), but the implementation mutates k_cache, v_cache and sometimes out; update the torch.library.custom_op decorator on torch_backend_mha_with_cache to list the arg indices (or names if supported) corresponding to k_cache, v_cache, and out so PyTorch knows these tensors are mutated (e.g., replace mutates_args=() with the appropriate tuple including the positions of k_cache, v_cache, and out), ensuring the schema matches the actual side effects.
571-597:⚠️ Potential issue | 🟠 MajorMove
custom_attn_maskafterread_cache_onlyto match the real operator schema.The fake implementation has
custom_attn_maskpositioned beforescale, but the real operator (decorated with@torch.library.custom_op) places it afterread_cache_only. PyTorch's.register_fake()requires the signature to match exactly, including argument order. This mismatch will cause positional arguments to be misaligned when the fake implementation is invoked.Suggested fix
def torch_backend_mha_with_cache_fake( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, batch_info_host: torch.Tensor, seq_len: torch.Tensor, input_pos: torch.Tensor, slot_idx: torch.Tensor, cu_seqlen: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, - custom_attn_mask: Optional[torch.Tensor] = None, - # BUFFERS - # <none> - # CONSTANTS scale: Optional[float] = None, sinks: Optional[torch.Tensor] = None, sliding_window_size: Optional[int] = None, logit_cap: Optional[float] = None, read_cache_only: bool = False, + custom_attn_mask: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, ) -> torch.Tensor:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py` around lines 571 - 597, The fake implementation torch_backend_mha_with_cache_fake registered via torch_backend_mha_with_cache.register_fake has custom_attn_mask incorrectly placed before scale, causing a mismatch with the real operator schema; update the function signature of torch_backend_mha_with_cache_fake so that custom_attn_mask is moved to after the read_cache_only parameter (preserving its Optional[torch.Tensor] default) to exactly match the real `@torch.library.custom_op` schema and avoid positional argument misalignment when the fake is invoked.
🟡 Minor comments (9)
tensorrt_llm/_torch/auto_deploy/mlir/codegen/triton_emitter.py-559-564 (1)
559-564:⚠️ Potential issue | 🟡 MinorMake kernel dumping best-effort.
If
AD_DUMP_KERNELS_DIRis invalid or unwritable, Lines 560-564 raise and abort kernel generation even though this path is only diagnostic. Please catchOSErrorand warn instead.♻️ Suggested hardening
_kernel_dump_dir = _os.environ.get("AD_DUMP_KERNELS_DIR") if _kernel_dump_dir: - _dump_path = _os.path.join(_kernel_dump_dir, f"triton_gen_{sg_hash}.py") - _os.makedirs(_kernel_dump_dir, exist_ok=True) - with open(_dump_path, "w") as _f: - _f.write(full_src) + try: + _dump_path = _os.path.join(_kernel_dump_dir, f"triton_gen_{sg_hash}.py") + _os.makedirs(_kernel_dump_dir, exist_ok=True) + with open(_dump_path, "w", encoding="utf-8") as _f: + _f.write(full_src) + except OSError: + _logging.getLogger("mlir_codegen").warning( + "Failed to dump generated kernel %s into %s", + sg_hash, + _kernel_dump_dir, + exc_info=True, + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/mlir/codegen/triton_emitter.py` around lines 559 - 564, The kernel-dump block using _kernel_dump_dir/_dump_path and writing full_src should be made best-effort: wrap os.makedirs and the file open/write in a try/except that catches OSError and does not re-raise; on error emit a warning (e.g., via logging.warning or warnings.warn) that includes the directory/path and exception details so kernel generation continues even if AD_DUMP_KERNELS_DIR is invalid or unwritable.tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py-162-165 (1)
162-165:⚠️ Potential issue | 🟡 MinorPrefer explicit dictionary lookup to avoid unnecessary method calls.
The
meta.get("cached_attn_op", ...)pattern on lines 162–163 evaluatesattn_descriptor.get_cached_attention_op()eagerly on every call, even when the key exists inmeta. While current implementations are lightweight (simple torch op references), the fallback pattern is unnecessary sincecached_attn_opis guaranteed to be set by_insert_cached_attn_node()before this code path executes. Use an explicit check to clarify the expected flow:- cached_attn_op = module._node_ref.meta.get( - "cached_attn_op", attn_descriptor.get_cached_attention_op() - ) + if "cached_attn_op" in module._node_ref.meta: + cached_attn_op = module._node_ref.meta["cached_attn_op"] + else: + cached_attn_op = attn_descriptor.get_cached_attention_op()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py` around lines 162 - 165, Replace the dict.get fallback with an explicit lookup since cached_attn_op is always set by _insert_cached_attn_node(); specifically, check module._node_ref.meta for the "cached_attn_op" key and assign cached_attn_op from that mapping (instead of calling attn_descriptor.get_cached_attention_op() as the default), so the eager call to attn_descriptor.get_cached_attention_op() is avoided and the code clearly reflects the invariant established by _insert_cached_attn_node().tests/unittest/auto_deploy/singlegpu/transformations/library/test_gather_logits_before_lm_head.py-29-29 (1)
29-29:⚠️ Potential issue | 🟡 MinorUpdate the header year to include 2026.
Line 1 still has a 2025-only copyright notice even though this file is modified in this 2026 PR.
As per coding guidelines, "Include NVIDIA copyright header on ALL new files (update year on modified files)."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/auto_deploy/singlegpu/transformations/library/test_gather_logits_before_lm_head.py` at line 29, Update the file header copyright year to include 2026 (e.g., change the existing "2025" to "2025-2026" or add 2026) so the modified file's header is current; locate the top-of-file copyright comment in tests/unittest/auto_deploy/singlegpu/transformations/library/test_gather_logits_before_lm_head.py and edit the header line accordingly.tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py-2-3 (1)
2-3:⚠️ Potential issue | 🟡 MinorAdd the standard NVIDIA SPDX header to this module.
This file is being modified here, but it still starts directly with imports. Please add the standard NVIDIA copyright/license block above Line 1.
As per coding guidelines, "All TensorRT-LLM Open Source Software code files should contain an NVIDIA copyright header with the year of latest meaningful modification."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py` around lines 2 - 3, Add the standard NVIDIA SPDX copyright/license header at the top of the module (above the existing imports) in tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py; ensure the header includes the NVIDIA copyright line, the SPDX license identifier and the year of latest meaningful modification, and keep the existing imports for Gemma3nForCausalLM, Gemma3nForConditionalGeneration, Gemma4ForCausalLM, and Gemma4ForConditionalGeneration unchanged below the header.tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_torch_attention_op.py-13-14 (1)
13-14:⚠️ Potential issue | 🟡 MinorAdd the standard NVIDIA SPDX header to this test file.
This file is modified in this PR but still has no header block at the top.
As per coding guidelines, "All TensorRT-LLM Open Source Software code files should contain an NVIDIA copyright header with the year of latest meaningful modification."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_torch_attention_op.py` around lines 13 - 14, Add the standard NVIDIA SPDX copyright header block at the top of this test file; open tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_torch_attention_op.py and insert the NVIDIA SPDX header (including copyright line with the year of latest meaningful modification and SPDX-License-Identifier) before any imports or code—ensure the header is applied to this file and similarly formatted as other repository files so the test function test_torch_backend_attention_custom_bool_mask_context() remains unchanged.tests/unittest/auto_deploy/_utils_test/torch_attention_reference.py-70-71 (1)
70-71:⚠️ Potential issue | 🟡 MinorAdd the standard NVIDIA SPDX header to this helper.
This file is modified in this PR but still has no header block at the top.
As per coding guidelines, "All TensorRT-LLM Open Source Software code files should contain an NVIDIA copyright header with the year of latest meaningful modification."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/auto_deploy/_utils_test/torch_attention_reference.py` around lines 70 - 71, This helper module torch_attention_reference.py is missing the required NVIDIA SPDX copyright header; add the standard NVIDIA copyright/SPDX header block at the top of the file (with the year of latest meaningful modification) so the file contains the same header format used across the repo; ensure the header is a comment block placed before any imports or code (e.g., before definitions that reference symbols like scale) and matches the project's canonical NVIDIA header text and SPDX identifier.tests/integration/defs/accuracy/test_llm_api_autodeploy.py-1011-1017 (1)
1011-1017:⚠️ Potential issue | 🟡 MinorRemove the commented-out tasks before merging.
This block is dead code and currently trips the file's lint (
E265). If those evaluations are intentionally deferred, track them outside the test instead of checking in commented calls.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/integration/defs/accuracy/test_llm_api_autodeploy.py` around lines 1011 - 1017, Remove the dead/commented evaluation code lines that reference MMLU and GSM8K (the commented block calling MMLU(self.MODEL_NAME).evaluate(...) and GSM8K(self.MODEL_NAME).evaluate(...)) to eliminate the lint error E265; if those evaluations must be preserved for later, move them out of this test into a separate helper or test file and reference MODEL_NAME and EXTRA_EVALUATOR_KWARGS there instead of leaving commented calls in tests/integration/defs/accuracy/test_llm_api_autodeploy.py.tests/unittest/auto_deploy/singlegpu/transformations/library/test_inject_custom_attention_mask.py-1-13 (1)
1-13:⚠️ Potential issue | 🟡 MinorUpdate this new file's copyright year.
This file was added in 2026, but the header still says 2025.
Suggested change
-# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.As per coding guidelines, "All TensorRT-LLM Open Source Software code files should contain an NVIDIA copyright header with the year of latest meaningful modification."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/auto_deploy/singlegpu/transformations/library/test_inject_custom_attention_mask.py` around lines 1 - 13, Update the copyright header year from 2025 to 2026 in the new test file test_inject_custom_attention_mask.py so the NVIDIA copyright line reflects the latest modification year; locate the top-of-file header block (the Apache License comment) and change the year in the line that currently reads "Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved." to 2026.tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_shared_kv_attention.py-357-362 (1)
357-362:⚠️ Potential issue | 🟡 MinorMake this
out=test describe the same batch it executes.
batch_info_host.update([1, 3, 0, 0, 1, 1])says there are 2 sequences and 4 total tokens, butq,k,v,seq_len,slot_idx, andcu_seqlenonly encode a single 3-token prefill sequence. The assertion passes because the kernel truncates metadata-derived lengths, so this test never actually validates the bookkeeping aroundout.Suggested test fix
batch_info_host = BatchInfo() - batch_info_host.update([1, 3, 0, 0, 1, 1]) + batch_info_host.update([1, 3, 0, 0, 0, 0]) seq_len = torch.tensor([3], dtype=torch.int32) - input_pos = torch.tensor([0, 1, 2], dtype=torch.int32) - slot_idx = torch.tensor([0, 1, 2], dtype=torch.int32) + input_pos = torch.tensor([0], dtype=torch.int32) + slot_idx = torch.tensor([0], dtype=torch.int32) cu_seqlen = torch.tensor([0], dtype=torch.int32)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_shared_kv_attention.py` around lines 357 - 362, The BatchInfo metadata doesn't match the single 3-token prefill used by q/k/v/etc., so change the BatchInfo.update call to describe the same batch (one sequence of length 3 and appropriate zeroed/padded slots) so out-based assertions validate bookkeeping; specifically update the array passed to batch_info_host.update to reflect a single 3-token sequence that matches seq_len, input_pos, slot_idx and cu_seqlen, then re-run the assertions on out to verify the kernel's bookkeeping.
🧹 Nitpick comments (5)
tensorrt_llm/_torch/auto_deploy/llm.py (1)
155-158: Add a None check before wrapping factory.model in Path().The
factory.modelproperty returnsOptional[str]due to its type signature, even though in practice_modelis required. To fail fast with a clear message rather than relying onPath(None)raisingTypeError, consider validating that the resolved model path is not None before assignment. Additionally, since downstream utilities like lm-eval expectconfig.jsonto exist at this path, a simple assertion is reasonable defensive programming.💡 Suggested defensive check
# AutoDeploy resolves HF checkpoints through the factory prefetch path rather than the # shared model loader. Preserve that resolved snapshot path for downstream utilities such # as multimodal lm-eval, which read config.json from ``_hf_model_dir``. + assert self.factory.model is not None, "Model path not resolved after prefetch" self._hf_model_dir = Path(self.factory.model)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/llm.py` around lines 155 - 158, factory.model is typed Optional[str], so wrapping it directly with Path(self.factory.model) can raise an unclear TypeError; before assigning to self._hf_model_dir, validate that self.factory.model is not None and raise an explicit error (or use an assertion) with a clear message that the factory did not resolve a model path, then set self._hf_model_dir = Path(self.factory.model). This ensures downstream consumers expecting config.json at self._hf_model_dir fail fast with a readable error instead of a TypeError from Path().tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py (1)
134-138: Please add one assertion for the new non-default code paths.These updated calls still exercise only
sliding_window=Noneand the defaultread_cache_only=False. The new behavior that translates the window and skipsappend_paged_kv_cache()for shared-KV is where regressions are most likely, so this suite should pin at least oneread_cache_only=Truecase and verify thatkv_cacheis left untouched.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py` around lines 134 - 138, Add a new assertion and test call that exercises the non-default read_cache_only=True path so the shared-KV early-return is validated: invoke the same test helper (the call surrounding sliding_window and read_cache_only parameters in test_flashinfer_attention_op.py) with read_cache_only=True (and a non-default sliding_window if applicable), capture kv_cache before the call, and assert after the call that kv_cache is unchanged (identity or equality). Also ensure the test verifies that append_paged_kv_cache() is not invoked for the shared-KV case (e.g., by using an existing spy/mocker or by asserting no changes to paged cache structures). Use the symbols read_cache_only, sliding_window, kv_cache, and append_paged_kv_cache to locate and implement the assertions.tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py (1)
244-244: Add a type hint for the newcached_attn_opparameter.Please type the new argument explicitly (and return
Nonefor this mutator) to keep this interface consistent and safer to maintain.✍️ Proposed typing update
-from ...custom_ops.attention_interface import AttentionDescriptor, Constant, PrepareMetadataCallable +from ...custom_ops.attention_interface import ( + AttentionDescriptor, + Constant, + MHACallable, + PrepareMetadataCallable, +) @@ def _insert_cached_attn_node( self, gm: GraphModule, attn_node: Node, - cached_attn_op, + cached_attn_op: MHACallable, qkv_nodes: List[Node], meta_nodes_std: List[Node], meta_nodes_extra: List[Node], cache_nodes: List[Node], constants: List[Constant], - ): + ) -> None:As per coding guidelines, "Always annotate functions with type hints; make the return type
Noneif the function does not return anything."🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py` at line 244, The new parameter cached_attn_op added to the mutator function needs an explicit type annotation and the function's return type must be annotated as None; update the function signature that declares cached_attn_op to annotate cached_attn_op with an appropriate type (e.g., typing.Any or a more specific Callable/Module type used in this module) and set the function return type to -> None, importing the required typing symbol if necessary.tests/unittest/auto_deploy/singlegpu/transformations/library/test_gather_logits_before_lm_head.py (1)
370-440: Add a packed softcapping case for the real memory-reduction path.This test only exercises the softcap pattern on generate-format input, where
gather_required=Falseand the tensor stays[batch, 1, hidden]. The OOM/regression described in the docstring is the packed path that would otherwise materialize[num_tokens, vocab_size]before gather, so mirroringtest_transform_packed_formatwithSoftcapLMHeadModelwould make this guard much stronger.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/auto_deploy/singlegpu/transformations/library/test_gather_logits_before_lm_head.py` around lines 370 - 440, Add a packed-format variant of test_transform_with_softcapping that mirrors the logic in test_transform_packed_format: instantiate SoftcapLMHeadModel and export to gm with hidden_states shaped for packed input (e.g., [batch, max_seq_len, hidden_size]) and appropriate logit_gather_ids/seq_len to force gather_required=True, then apply the same gather_logits_before_lm_head transform via InferenceOptimizer and perform the same assertions (gather exists, gather index < lm_head linear index, and forward output shape). Ensure you set BatchInfo.update_tokens_gather_info(batch_size, True) and pass token_gather_indices and batch_info_host when invoking gm_transformed so the packed path (memory-reduction path) is exercised.tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma4_modeling.py (1)
1099-1105: Usezip(..., strict=True)in the encoder weight-copy helper.Silent truncation here hides layer-count mismatches between
Gemma4VisionEncoderand_RefVisionEncoder, which is exactly the kind of structural drift this helper should fail fast on.Suggested change
- for ad_layer, ref_layer in zip(ad_encoder.layers, ref_encoder.layers): + for ad_layer, ref_layer in zip(ad_encoder.layers, ref_encoder.layers, strict=True): _transfer_vision_encoder_layer_weights(ad_layer, ref_layer)Based on learnings, "In TensorRT-LLM (Python requires >=3.10 and <4 as per setup.py), you can use Python 3.10+ features (e.g., PEP 585 generics like dict[str, int], list[str], etc.) throughout the codebase, and you do not need to add from future import annotations."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma4_modeling.py` around lines 1099 - 1105, The helper _transfer_vision_encoder_weights currently uses zip(ad_encoder.layers, ref_encoder.layers) which silently truncates if layer counts differ; change it to use zip(ad_encoder.layers, ref_encoder.layers, strict=True) so mismatched layer lengths between Gemma4VisionEncoder and _RefVisionEncoder raise immediately (update the call in _transfer_vision_encoder_weights to use strict=True).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/auto_deploy/cookbooks/gemma_4_trtllm_cookbook.ipynb`:
- Around line 56-57: The notebook cell that runs "%pip install torch openai"
should not reinstall torch inside the TRT-LLM container because the release
image already provides a CUDA/TensorRT-compatible Torch; remove "torch" from the
install command or replace the cell with a guarded install that only
pip-installs openai (or first tries import torch and only installs if
missing/mismatched). Update the cell content that currently contains "%pip
install torch openai" accordingly so only "openai" is installed (or add an
import-check guard) to avoid replacing the bundled Torch wheel.
- Around line 160-164: The BASE_URL used to construct the OpenAI client is
currently set to the bind address "http://0.0.0.0:8000/v1"; update the BASE_URL
constant to a routable client address (e.g., "http://127.0.0.1:8000/v1" or
"http://localhost:8000/v1") so the OpenAI(...) client connects correctly; change
the string assigned to BASE_URL (referenced where client =
OpenAI(base_url=BASE_URL, api_key=API_KEY)) and keep the rest of the
instantiation unchanged.
In `@examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml`:
- Around line 7-22: This Gemma4 base config is missing the
gather_logits_before_lm_head transform and can materialize [num_tokens,
vocab_size] before the LM head; update the YAML for the
Gemma4ForConditionalGeneration export to enable the same transform as in
gemma4_moe.yaml by adding gather_logits_before_lm_head under transforms
(alongside compile_model.piecewise_enabled) so the model uses the
gather-before-lm-head transform during export and avoids the piecewise
CUDA-graph memory regression.
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py`:
- Around line 363-364: The call to _GlobalFlashInferPlanner.reset(q.device)
inside the attention invocation should be removed so it doesn't clear
plan_params_prefill and plan_params_decode per layer;
prepare_flashinfer_metadata() already resets the planner once per forward. Edit
the attention implementation to stop calling
_GlobalFlashInferPlanner.reset(q.device) (or add a guard so reset only runs when
metadata hasn't been prepared), ensuring plan_prefill() and plan_decode() can
reuse existing plan parameters across same-shape layers.
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py`:
- Around line 671-677: The code currently only rejects attention when dropout_p
!= 0.0 but no longer rejects non-causal attention; update the checks that pull
attn_mask, dropout_p, is_causal via extract_op_args for source_attn_node to also
reject (log and skip) when is_causal is False so non-causal torch_attention
nodes are not rewritten to causal ops—apply the same fix to the equivalent block
that handles the other cached torch path (the block that also extracts
attn_mask, dropout_p, is_causal around the other source node). Ensure
ad_logger.debug includes source_attn_node and is_causal in the message and that
the function/logic returns/continues when is_causal is False.
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py`:
- Around line 1485-1498: The current matching allows non-causal torch_attention
nodes to be rewritten to the paged (causal-only) kernel because is_causal is
extracted but not enforced; update the decision logic where you extract
is_causal (in extract_op_args usage around source_attn_node) to check that
is_causal is True before proceeding to the paged kernel path (same place that
currently checks layout and dropout_p), and also ensure get_constants() or the
higher-level matching rejects nodes with is_causal=False so torch_attention(...,
is_causal=False) is not rewritten to the causal paged kernel; reference the
symbols extract_op_args, is_causal, get_constants(), torch_attention, and
source_attn_node when making the change.
- Around line 889-890: The masked sliding-window path uses query_positions =
q_offsets (local positions) instead of absolute query indices, causing incorrect
masking for extend requests; fix by computing the cached-prefix offset as
cached_prefix = total_kv_len - q_len and use query_positions = q_offsets +
cached_prefix (i.e., add that prefix to q_offsets) in the masked kernel where
total_kv_len is loaded (variable seq_len_with_cache_ptr / total_kv_len and
q_offsets/q_len); apply the same change to the other masked block referenced
(the code around the second occurrence noted in the comment).
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.py`:
- Around line 482-487: The activation normalization unconditionally rewrites
Silu→Swiglu and Gelu→Geglu which breaks non‑gated MLPs; change
_normalize_trtllm_act_fn to accept an is_gated_mlp: bool and only rewrite to
gated variants when is_gated_mlp is True, leaving act_fn unchanged otherwise,
then update all call sites that pass act_fn to also pass the is_gated_mlp flag
(the three quantized torch.ops.trtllm.* call sites that currently call
_normalize_trtllm_act_fn), and replace the assert in
_validate_mlp_style_and_act_fn with a proper exception raise (e.g., ValueError)
for invalid combinations so validation works in optimized Python.
- Around line 471-479: The helper _validate_mlp_style_and_act_fn currently uses
assert which is removed under python -O; change it to an explicit runtime check
and raise a ValueError with the existing descriptive message when the
(is_gated_mlp, act_fn) combination is unsupported (e.g., compute the boolean
condition currently in the assert and if False raise ValueError(f"...") so
public custom ops registered via `@torch.library.custom_op` fail immediately with
the same explanatory text).
In `@tensorrt_llm/_torch/auto_deploy/mlir/codegen/triton_emitter.py`:
- Line 90: The softplus emitter entry "ad.softplus" uses unstable log(1+exp(x));
change it to the numerically stable form max(x, 0) + log1p(exp(-abs(x))) to
avoid overflow for large x: update the lambda for "ad.softplus" in
triton_emitter.py to compute tl.maximum({a}, 0) + tl.log1p(tl.exp(-tl.abs({a})))
(use the same tl.* helpers as other emitters and mirror the implementation
pattern from tensorrt_llm/_torch/modules/mamba/softplus.py).
In `@tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma3n.py`:
- Around line 562-567: The get_per_layer_inputs implementation feeds raw
input_ids into embed_tokens_per_layer which is sized by
config.vocab_size_per_layer_input and will crash when vocab_size_per_layer_input
< full vocab; before calling embed_tokens_per_layer in get_per_layer_inputs,
mask/map input_ids into the per-layer vocab (e.g., replace input_ids with
input_ids % config.vocab_size_per_layer_input or otherwise clamp/map them to [0,
vocab_size_per_layer_input) ) so indices are valid, and apply the identical fix
to the other occurrence that also calls embed_tokens_per_layer around lines
608-610.
In `@tensorrt_llm/_torch/auto_deploy/transform/attention_mask_provider.py`:
- Around line 48-53: The current broad except swallows all errors from
factory._get_model_config(); change it to only catch the expected
lookup/missing-provider errors (e.g., LookupError/KeyError) and let any other
exceptions propagate so real bugs surface. Concretely, in the block that calls
get_model_config() (the local get_model_config = getattr(...) and subsequent try
around get_model_config()), replace except Exception with a narrow except
(LookupError, KeyError): return None (or similar expected-failure handling) and
re-raise any other exceptions (i.e., do not catch them).
In
`@tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py`:
- Around line 65-79: The current backward walk from lm_head_node blindly follows
current.all_input_nodes[0] until is_linear_op(current), which can traverse
unsafe multi-input or sequence-mixing ops; instead, limit walking to a small
whitelist of known-safe post-LM-head ops (e.g., elementwise/unary ops like Div,
Mul, Tanh, Clamp, Add) and only descend through an input that is actually the
linear path: for each current node check current.op_type (or kind) against the
whitelist and then pick the input node that either (a) is_linear_op(input) or
(b) has matching tensor shape/dtype or a single producer that looks like the
logits path; stop and fallback to lm_head_node if none match or max depth
exceeded. Update the loop that uses current, is_linear_op, and node_to_gather to
perform these guarded checks and log which branch was chosen.
In `@tensorrt_llm/_torch/auto_deploy/utils/node_utils.py`:
- Around line 1052-1065: get_op_schema currently accepts an ambiguous op and
picks an arbitrary overload from op._schemas, causing
extract_op_args/set_op_args to use wrong slots; update get_op_schema to accept
typed parameters (e.g., op: Union[OpOverload, OpOverloadPacket]) or restrict to
OpOverload, and when receiving an OpOverloadPacket either explicitly select the
.default overload or raise a clear RuntimeError if multiple overloads exist,
then return a torch.FunctionSchema; also add the missing return type annotation
to _get_op_schema (-> torch.FunctionSchema) and annotate its parameter as Node,
update get_op_schema signature with the proper typing, and add the required
NVIDIA copyright header to the top of the file.
In `@tests/integration/defs/accuracy/test_llm_api_autodeploy.py`:
- Around line 1001-1010: The test currently passes the raw HF model id
(self.MODEL_NAME) into AutoDeployLLM causing network/auth dependency; change the
argument to point at the local cached artifact by resolving the model id to the
local cache path before constructing AutoDeployLLM (use the repo's test artifact
resolver or helper used elsewhere in the file—e.g., the existing local model
cache helper or method on the test class) so AutoDeployLLM(model=...) receives
the local filesystem path rather than the HF id; keep other args (tokenizer,
world_size, yaml_extra) unchanged and ensure the resolver uses the same cache
mechanism as the other tests that rely on local artifacts.
---
Outside diff comments:
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py`:
- Around line 451-479: The custom op registration for
"auto_deploy::torch_cached_attention_with_cache" incorrectly declares
mutates_args=(), but the implementation mutates k_cache, v_cache and sometimes
out; update the torch.library.custom_op decorator on
torch_backend_mha_with_cache to list the arg indices (or names if supported)
corresponding to k_cache, v_cache, and out so PyTorch knows these tensors are
mutated (e.g., replace mutates_args=() with the appropriate tuple including the
positions of k_cache, v_cache, and out), ensuring the schema matches the actual
side effects.
- Around line 571-597: The fake implementation torch_backend_mha_with_cache_fake
registered via torch_backend_mha_with_cache.register_fake has custom_attn_mask
incorrectly placed before scale, causing a mismatch with the real operator
schema; update the function signature of torch_backend_mha_with_cache_fake so
that custom_attn_mask is moved to after the read_cache_only parameter
(preserving its Optional[torch.Tensor] default) to exactly match the real
`@torch.library.custom_op` schema and avoid positional argument misalignment when
the fake is invoked.
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.py`:
- Around line 1269-1296: The custom op triton_paged_mha_with_cache currently
declares only "kv_cache" in mutates_args but also writes to the optional out
buffer; update the decorator on triton_paged_mha_with_cache to include "out" in
the mutates_args tuple so that both kv_cache and out are declared as mutated
(i.e., change the mutates_args to include "out" alongside "kv_cache" in the
`@torch.library.custom_op` decorator for triton_paged_mha_with_cache).
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.py`:
- Around line 166-190: The helper _resolve_torch_fn currently exposes GELU
unconditionally; change it to enforce "GELU is allowed only for gated MLPs" by
adding an is_gated_mlp: bool parameter (or otherwise checking gating) and
assert/raise if act_fn == ActivationType.Gelu and not is_gated_mlp; then update
every torch_*_moe entry point that calls _resolve_torch_fn (the torch reference
MoE functions referenced in the review) to pass the is_gated_mlp flag from their
own parameters so the same validation occurs in the torch path as in the fused
TRTLLM path.
---
Minor comments:
In `@tensorrt_llm/_torch/auto_deploy/mlir/codegen/triton_emitter.py`:
- Around line 559-564: The kernel-dump block using _kernel_dump_dir/_dump_path
and writing full_src should be made best-effort: wrap os.makedirs and the file
open/write in a try/except that catches OSError and does not re-raise; on error
emit a warning (e.g., via logging.warning or warnings.warn) that includes the
directory/path and exception details so kernel generation continues even if
AD_DUMP_KERNELS_DIR is invalid or unwritable.
In `@tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py`:
- Around line 2-3: Add the standard NVIDIA SPDX copyright/license header at the
top of the module (above the existing imports) in
tensorrt_llm/_torch/auto_deploy/models/custom/__init__.py; ensure the header
includes the NVIDIA copyright line, the SPDX license identifier and the year of
latest meaningful modification, and keep the existing imports for
Gemma3nForCausalLM, Gemma3nForConditionalGeneration, Gemma4ForCausalLM, and
Gemma4ForConditionalGeneration unchanged below the header.
In `@tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py`:
- Around line 162-165: Replace the dict.get fallback with an explicit lookup
since cached_attn_op is always set by _insert_cached_attn_node(); specifically,
check module._node_ref.meta for the "cached_attn_op" key and assign
cached_attn_op from that mapping (instead of calling
attn_descriptor.get_cached_attention_op() as the default), so the eager call to
attn_descriptor.get_cached_attention_op() is avoided and the code clearly
reflects the invariant established by _insert_cached_attn_node().
In `@tests/integration/defs/accuracy/test_llm_api_autodeploy.py`:
- Around line 1011-1017: Remove the dead/commented evaluation code lines that
reference MMLU and GSM8K (the commented block calling
MMLU(self.MODEL_NAME).evaluate(...) and GSM8K(self.MODEL_NAME).evaluate(...)) to
eliminate the lint error E265; if those evaluations must be preserved for later,
move them out of this test into a separate helper or test file and reference
MODEL_NAME and EXTRA_EVALUATOR_KWARGS there instead of leaving commented calls
in tests/integration/defs/accuracy/test_llm_api_autodeploy.py.
In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_shared_kv_attention.py`:
- Around line 357-362: The BatchInfo metadata doesn't match the single 3-token
prefill used by q/k/v/etc., so change the BatchInfo.update call to describe the
same batch (one sequence of length 3 and appropriate zeroed/padded slots) so
out-based assertions validate bookkeeping; specifically update the array passed
to batch_info_host.update to reflect a single 3-token sequence that matches
seq_len, input_pos, slot_idx and cu_seqlen, then re-run the assertions on out to
verify the kernel's bookkeeping.
In `@tests/unittest/auto_deploy/_utils_test/torch_attention_reference.py`:
- Around line 70-71: This helper module torch_attention_reference.py is missing
the required NVIDIA SPDX copyright header; add the standard NVIDIA
copyright/SPDX header block at the top of the file (with the year of latest
meaningful modification) so the file contains the same header format used across
the repo; ensure the header is a comment block placed before any imports or code
(e.g., before definitions that reference symbols like scale) and matches the
project's canonical NVIDIA header text and SPDX identifier.
In
`@tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_torch_attention_op.py`:
- Around line 13-14: Add the standard NVIDIA SPDX copyright header block at the
top of this test file; open
tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_torch_attention_op.py
and insert the NVIDIA SPDX header (including copyright line with the year of
latest meaningful modification and SPDX-License-Identifier) before any imports
or code—ensure the header is applied to this file and similarly formatted as
other repository files so the test function
test_torch_backend_attention_custom_bool_mask_context() remains unchanged.
In
`@tests/unittest/auto_deploy/singlegpu/transformations/library/test_gather_logits_before_lm_head.py`:
- Line 29: Update the file header copyright year to include 2026 (e.g., change
the existing "2025" to "2025-2026" or add 2026) so the modified file's header is
current; locate the top-of-file copyright comment in
tests/unittest/auto_deploy/singlegpu/transformations/library/test_gather_logits_before_lm_head.py
and edit the header line accordingly.
In
`@tests/unittest/auto_deploy/singlegpu/transformations/library/test_inject_custom_attention_mask.py`:
- Around line 1-13: Update the copyright header year from 2025 to 2026 in the
new test file test_inject_custom_attention_mask.py so the NVIDIA copyright line
reflects the latest modification year; locate the top-of-file header block (the
Apache License comment) and change the year in the line that currently reads
"Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved." to 2026.
---
Nitpick comments:
In `@tensorrt_llm/_torch/auto_deploy/llm.py`:
- Around line 155-158: factory.model is typed Optional[str], so wrapping it
directly with Path(self.factory.model) can raise an unclear TypeError; before
assigning to self._hf_model_dir, validate that self.factory.model is not None
and raise an explicit error (or use an assertion) with a clear message that the
factory did not resolve a model path, then set self._hf_model_dir =
Path(self.factory.model). This ensures downstream consumers expecting
config.json at self._hf_model_dir fail fast with a readable error instead of a
TypeError from Path().
In `@tensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.py`:
- Line 244: The new parameter cached_attn_op added to the mutator function needs
an explicit type annotation and the function's return type must be annotated as
None; update the function signature that declares cached_attn_op to annotate
cached_attn_op with an appropriate type (e.g., typing.Any or a more specific
Callable/Module type used in this module) and set the function return type to ->
None, importing the required typing symbol if necessary.
In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma4_modeling.py`:
- Around line 1099-1105: The helper _transfer_vision_encoder_weights currently
uses zip(ad_encoder.layers, ref_encoder.layers) which silently truncates if
layer counts differ; change it to use zip(ad_encoder.layers, ref_encoder.layers,
strict=True) so mismatched layer lengths between Gemma4VisionEncoder and
_RefVisionEncoder raise immediately (update the call in
_transfer_vision_encoder_weights to use strict=True).
In
`@tests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.py`:
- Around line 134-138: Add a new assertion and test call that exercises the
non-default read_cache_only=True path so the shared-KV early-return is
validated: invoke the same test helper (the call surrounding sliding_window and
read_cache_only parameters in test_flashinfer_attention_op.py) with
read_cache_only=True (and a non-default sliding_window if applicable), capture
kv_cache before the call, and assert after the call that kv_cache is unchanged
(identity or equality). Also ensure the test verifies that
append_paged_kv_cache() is not invoked for the shared-KV case (e.g., by using an
existing spy/mocker or by asserting no changes to paged cache structures). Use
the symbols read_cache_only, sliding_window, kv_cache, and append_paged_kv_cache
to locate and implement the assertions.
In
`@tests/unittest/auto_deploy/singlegpu/transformations/library/test_gather_logits_before_lm_head.py`:
- Around line 370-440: Add a packed-format variant of
test_transform_with_softcapping that mirrors the logic in
test_transform_packed_format: instantiate SoftcapLMHeadModel and export to gm
with hidden_states shaped for packed input (e.g., [batch, max_seq_len,
hidden_size]) and appropriate logit_gather_ids/seq_len to force
gather_required=True, then apply the same gather_logits_before_lm_head transform
via InferenceOptimizer and perform the same assertions (gather exists, gather
index < lm_head linear index, and forward output shape). Ensure you set
BatchInfo.update_tokens_gather_info(batch_size, True) and pass
token_gather_indices and batch_info_host when invoking gm_transformed so the
packed path (memory-reduction path) is exercised.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: f43e7035-fa98-4d5c-81f1-e6ee4c2464c1
📒 Files selected for processing (44)
docs/source/models/supported-models.mdexamples/auto_deploy/cookbooks/gemma_4_trtllm_cookbook.ipynbexamples/auto_deploy/model_registry/configs/gemma3n_e2b_it.yamlexamples/auto_deploy/model_registry/configs/gemma4_moe.yamlexamples/auto_deploy/model_registry/configs/gemma4_moe_base.yamlexamples/auto_deploy/model_registry/models.yamltensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.pytensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.pytensorrt_llm/_torch/auto_deploy/config/default.yamltensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.pytensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_attention.pytensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.pytensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.pytensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_paged_attention.pytensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.pytensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.pytensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/torch_moe.pytensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/trtllm_moe.pytensorrt_llm/_torch/auto_deploy/export/export.pytensorrt_llm/_torch/auto_deploy/llm.pytensorrt_llm/_torch/auto_deploy/mlir/codegen/triton_emitter.pytensorrt_llm/_torch/auto_deploy/models/custom/__init__.pytensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma3n.pytensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma4.pytensorrt_llm/_torch/auto_deploy/transform/__init__.pytensorrt_llm/_torch/auto_deploy/transform/attention_mask_provider.pytensorrt_llm/_torch/auto_deploy/transform/attention_mask_providers.pytensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.pytensorrt_llm/_torch/auto_deploy/transform/library/inject_custom_attention_mask.pytensorrt_llm/_torch/auto_deploy/transform/library/kvcache.pytensorrt_llm/_torch/auto_deploy/transform/library/kvcache_transformers.pytensorrt_llm/_torch/auto_deploy/utils/_graph.pytensorrt_llm/_torch/auto_deploy/utils/node_utils.pytests/integration/defs/accuracy/test_llm_api_autodeploy.pytests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma3n_modeling.pytests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_gemma4_modeling.pytests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_shared_kv_attention.pytests/unittest/auto_deploy/_utils_test/torch_attention_reference.pytests/unittest/auto_deploy/singlegpu/compile/test_captured_graph.pytests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_flashinfer_attention_op.pytests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_torch_attention_op.pytests/unittest/auto_deploy/singlegpu/custom_ops/attention/test_triton_paged_attention.pytests/unittest/auto_deploy/singlegpu/transformations/library/test_gather_logits_before_lm_head.pytests/unittest/auto_deploy/singlegpu/transformations/library/test_inject_custom_attention_mask.py
| "%pip install torch openai" | ||
| ] |
There was a problem hiding this comment.
Don't reinstall torch inside the TRT-LLM container.
The release image already ships a Torch build that matches its CUDA/TensorRT stack. %pip install torch here can replace it with an incompatible wheel and break the rest of the notebook; only openai should be installed on top, or this should be guarded by an import check.
Suggested change
-%pip install torch openai
+%pip install openai📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| "%pip install torch openai" | |
| ] | |
| "%pip install openai" | |
| ] |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/auto_deploy/cookbooks/gemma_4_trtllm_cookbook.ipynb` around lines 56
- 57, The notebook cell that runs "%pip install torch openai" should not
reinstall torch inside the TRT-LLM container because the release image already
provides a CUDA/TensorRT-compatible Torch; remove "torch" from the install
command or replace the cell with a guarded install that only pip-installs openai
(or first tries import torch and only installs if missing/mismatched). Update
the cell content that currently contains "%pip install torch openai" accordingly
so only "openai" is installed (or add an import-check guard) to avoid replacing
the bundled Torch wheel.
| "BASE_URL = \"http://0.0.0.0:8000/v1\"\n", | ||
| "API_KEY = \"null\"\n", | ||
| "MODEL_ID = \"google/gemma-4-26B-A4B-it\"\n", | ||
| "\n", | ||
| "client = OpenAI(base_url=BASE_URL, api_key=API_KEY)" |
There was a problem hiding this comment.
Use a routable client URL here.
0.0.0.0 is a bind address, not a client destination. The server can listen on 0.0.0.0, but the OpenAI client should connect to 127.0.0.1 or localhost.
Suggested change
-BASE_URL = "http://0.0.0.0:8000/v1"
+BASE_URL = "http://127.0.0.1:8000/v1"📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| "BASE_URL = \"http://0.0.0.0:8000/v1\"\n", | |
| "API_KEY = \"null\"\n", | |
| "MODEL_ID = \"google/gemma-4-26B-A4B-it\"\n", | |
| "\n", | |
| "client = OpenAI(base_url=BASE_URL, api_key=API_KEY)" | |
| "BASE_URL = \"http://127.0.0.1:8000/v1\"\n", | |
| "API_KEY = \"null\"\n", | |
| "MODEL_ID = \"google/gemma-4-26B-A4B-it\"\n", | |
| "\n", | |
| "client = OpenAI(base_url=BASE_URL, api_key=API_KEY)" |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/auto_deploy/cookbooks/gemma_4_trtllm_cookbook.ipynb` around lines
160 - 164, The BASE_URL used to construct the OpenAI client is currently set to
the bind address "http://0.0.0.0:8000/v1"; update the BASE_URL constant to a
routable client address (e.g., "http://127.0.0.1:8000/v1" or
"http://localhost:8000/v1") so the OpenAI(...) client connects correctly; change
the string assigned to BASE_URL (referenced where client =
OpenAI(base_url=BASE_URL, api_key=API_KEY)) and keep the rest of the
instantiation unchanged.
| model_factory: Gemma4ForConditionalGeneration | ||
| tokenizer: google/gemma-4-26B-A4B | ||
| attn_backend: triton_paged | ||
| compile_backend: torch-cudagraph | ||
| cuda_graph_config: | ||
| batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] | ||
| max_num_tokens: 8192 | ||
| max_batch_size: 512 | ||
| max_seq_len: 8192 | ||
| enable_chunked_prefill: true | ||
| kv_cache_config: | ||
| enable_block_reuse: false | ||
| free_gpu_memory_fraction: 0.8 | ||
| transforms: | ||
| compile_model: | ||
| piecewise_enabled: true |
There was a problem hiding this comment.
Enable gather_logits_before_lm_head in the base Gemma4 config too.
This config uses the same Gemma4ForConditionalGeneration export path and piecewise_enabled: true setup as examples/auto_deploy/model_registry/configs/gemma4_moe.yaml, but it never opts into the transform that the new Gemma4 softcapping test is guarding. Without it, the base model can still materialize [num_tokens, vocab_size] before gather and hit the same piecewise CUDA-graph memory regression.
🔧 Suggested config change
transforms:
compile_model:
piecewise_enabled: true
+ gather_logits_before_lm_head:
+ enabled: true📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| model_factory: Gemma4ForConditionalGeneration | |
| tokenizer: google/gemma-4-26B-A4B | |
| attn_backend: triton_paged | |
| compile_backend: torch-cudagraph | |
| cuda_graph_config: | |
| batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] | |
| max_num_tokens: 8192 | |
| max_batch_size: 512 | |
| max_seq_len: 8192 | |
| enable_chunked_prefill: true | |
| kv_cache_config: | |
| enable_block_reuse: false | |
| free_gpu_memory_fraction: 0.8 | |
| transforms: | |
| compile_model: | |
| piecewise_enabled: true | |
| model_factory: Gemma4ForConditionalGeneration | |
| tokenizer: google/gemma-4-26B-A4B | |
| attn_backend: triton_paged | |
| compile_backend: torch-cudagraph | |
| cuda_graph_config: | |
| batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] | |
| max_num_tokens: 8192 | |
| max_batch_size: 512 | |
| max_seq_len: 8192 | |
| enable_chunked_prefill: true | |
| kv_cache_config: | |
| enable_block_reuse: false | |
| free_gpu_memory_fraction: 0.8 | |
| transforms: | |
| compile_model: | |
| piecewise_enabled: true | |
| gather_logits_before_lm_head: | |
| enabled: true |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/auto_deploy/model_registry/configs/gemma4_moe_base.yaml` around
lines 7 - 22, This Gemma4 base config is missing the
gather_logits_before_lm_head transform and can materialize [num_tokens,
vocab_size] before the LM head; update the YAML for the
Gemma4ForConditionalGeneration export to enable the same transform as in
gemma4_moe.yaml by adding gather_logits_before_lm_head under transforms
(alongside compile_model.piecewise_enabled) so the model uses the
gather-before-lm-head transform during export and avoids the piecewise
CUDA-graph memory regression.
| _GlobalFlashInferPlanner.reset(q.device) | ||
|
|
There was a problem hiding this comment.
Avoid resetting the planner inside every attention invocation.
prepare_flashinfer_metadata() already resets the global planner once per forward. Doing it again here clears plan_params_prefill and plan_params_decode before every layer, so same-shape layers lose the reuse that plan_prefill() / plan_decode() are built around and end up re-planning on each call.
💡 Minimal fix
- _GlobalFlashInferPlanner.reset(q.device)
+ if _GlobalFlashInferPlanner.workspace_buffer is None:
+ _GlobalFlashInferPlanner.reset(q.device)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py`
around lines 363 - 364, The call to _GlobalFlashInferPlanner.reset(q.device)
inside the attention invocation should be removed so it doesn't clear
plan_params_prefill and plan_params_decode per layer;
prepare_flashinfer_metadata() already resets the planner once per forward. Edit
the attention implementation to stop calling
_GlobalFlashInferPlanner.reset(q.device) (or add a guard so reset only runs when
metadata hasn't been prepared), ensuring plan_prefill() and plan_decode() can
reuse existing plan parameters across same-shape layers.
| attn_mask, dropout_p, is_causal = extract_op_args( | ||
| source_attn_node, "attn_mask", "dropout_p", "is_causal" | ||
| ) | ||
| if attn_mask is not None or dropout_p != 0.0 or not is_causal: | ||
| if dropout_p != 0.0: | ||
| ad_logger.debug( | ||
| "Unsupported attention arguments for " | ||
| f"{source_attn_node=}: {attn_mask=}, {dropout_p=}, {is_causal=}" | ||
| f"Unsupported attention arguments for {source_attn_node=}: {dropout_p=}" | ||
| ) |
There was a problem hiding this comment.
Keep rejecting non-causal source attention here.
Both cached torch paths always apply causal masking. After this change is_causal=False is no longer screened out, so a non-causal torch_attention node can now be rewritten to a stricter causal op.
Also applies to: 695-700
🧰 Tools
🪛 Ruff (0.15.9)
[warning] 671-671: Unpacked variable attn_mask is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
[warning] 671-671: Unpacked variable is_causal is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py`
around lines 671 - 677, The code currently only rejects attention when dropout_p
!= 0.0 but no longer rejects non-causal attention; update the checks that pull
attn_mask, dropout_p, is_causal via extract_op_args for source_attn_node to also
reject (log and skip) when is_causal is False so non-causal torch_attention
nodes are not rewritten to causal ops—apply the same fix to the equivalent block
that handles the other cached torch path (the block that also extracts
attn_mask, dropout_p, is_causal around the other source node). Ensure
ad_logger.debug includes source_attn_node and is_causal in the message and that
the function/logic returns/continues when is_causal is False.
| def get_per_layer_inputs(self, input_ids: torch.LongTensor) -> torch.Tensor: | ||
| return self.embed_tokens_per_layer(input_ids).reshape( | ||
| *input_ids.shape, | ||
| self.config.num_hidden_layers, | ||
| self.hidden_size_per_layer_input, | ||
| ) |
There was a problem hiding this comment.
Mask per-layer input IDs in the text-only path.
embed_tokens_per_layer is sized by vocab_size_per_layer_input, but this branch feeds it raw input_ids from the full text vocab. Any config with vocab_size_per_layer_input < vocab_size will crash in Gemma3nForCausalLM for otherwise valid token IDs. The conditional wrapper already masks these IDs before calling the helper.
Suggested fix
if input_ids is not None:
inputs_embeds = self.embed_tokens(input_ids)
- per_layer_inputs = self.get_per_layer_inputs(input_ids)
+ per_layer_inputs_mask = torch.logical_and(
+ input_ids >= 0, input_ids < self.config.vocab_size_per_layer_input
+ )
+ per_layer_inputs_tokens = torch.where(
+ per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)
+ )
+ per_layer_inputs = self.get_per_layer_inputs(per_layer_inputs_tokens)Also applies to: 608-610
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/auto_deploy/models/custom/modeling_gemma3n.py` around
lines 562 - 567, The get_per_layer_inputs implementation feeds raw input_ids
into embed_tokens_per_layer which is sized by config.vocab_size_per_layer_input
and will crash when vocab_size_per_layer_input < full vocab; before calling
embed_tokens_per_layer in get_per_layer_inputs, mask/map input_ids into the
per-layer vocab (e.g., replace input_ids with input_ids %
config.vocab_size_per_layer_input or otherwise clamp/map them to [0,
vocab_size_per_layer_input) ) so indices are valid, and apply the identical fix
to the other occurrence that also calls embed_tokens_per_layer around lines
608-610.
| get_model_config = getattr(factory, "_get_model_config", None) | ||
| if callable(get_model_config): | ||
| try: | ||
| model_config, _unused_kwargs = get_model_config() | ||
| except Exception: | ||
| return None |
There was a problem hiding this comment.
Don't hide _get_model_config() failures behind a missing-provider fallback.
If _get_model_config() starts throwing because of a real bug, this returns None, InjectCustomAttentionMask treats that as "no provider", and the model runs without the backend mask. Catch only the expected lookup failures here and let unexpected exceptions surface.
As per coding guidelines, "When using try-except blocks, limit the except clause to the smallest set of errors possible."
🧰 Tools
🪛 Ruff (0.15.9)
[warning] 52-52: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/auto_deploy/transform/attention_mask_provider.py` around
lines 48 - 53, The current broad except swallows all errors from
factory._get_model_config(); change it to only catch the expected
lookup/missing-provider errors (e.g., LookupError/KeyError) and let any other
exceptions propagate so real bugs surface. Concretely, in the block that calls
get_model_config() (the local get_model_config = getattr(...) and subsequent try
around get_model_config()), replace except Exception with a narrow except
(LookupError, KeyError): return None (or similar expected-failure handling) and
re-raise any other exceptions (i.e., do not catch them).
| # Walk backward through elementwise/unary ops (e.g. softcapping: div, tanh, mul) | ||
| # to find the actual lm_head linear node. | ||
| current = lm_head_node | ||
| while current is not None and not is_linear_op(current): | ||
| inputs = current.all_input_nodes | ||
| current = inputs[0] if len(inputs) >= 1 else None | ||
|
|
||
| if current is not None and is_linear_op(current): | ||
| node_to_gather = current.all_input_nodes[0] | ||
| self._log_info( | ||
| f"Found LM head linear through post-processing chain: {current.name}" | ||
| ) | ||
| else: | ||
| node_to_gather = lm_head_node | ||
| self._log_info("lm_head node is not linear, using it as the node to gather") |
There was a problem hiding this comment.
Don't walk an arbitrary first-input chain to find the LM head.
This loop follows all_input_nodes[0] through any op until it hits a linear. That assumes input 0 is always the logits data path and that every intervening op is safe to move gather_tokens across, which is not true for generic multi-input or sequence-mixing post-processing. Please restrict the walk to a small whitelist of known-safe post-LM-head ops and pick the input that actually comes from the linear path.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py`
around lines 65 - 79, The current backward walk from lm_head_node blindly
follows current.all_input_nodes[0] until is_linear_op(current), which can
traverse unsafe multi-input or sequence-mixing ops; instead, limit walking to a
small whitelist of known-safe post-LM-head ops (e.g., elementwise/unary ops like
Div, Mul, Tanh, Clamp, Add) and only descend through an input that is actually
the linear path: for each current node check current.op_type (or kind) against
the whitelist and then pick the input node that either (a) is_linear_op(input)
or (b) has matching tensor shape/dtype or a single producer that looks like the
logits path; stop and fallback to lm_head_node if none match or max depth
exceeded. Update the loop that uses current, is_linear_op, and node_to_gather to
perform these guarded checks and log which branch was chosen.
| def get_op_schema(op) -> torch.FunctionSchema: | ||
| """Return the schema for an op or op overload packet.""" | ||
| if hasattr(op, "_schemas"): | ||
| return next(iter(op._schemas.values())) | ||
| elif hasattr(op, "_schema"): | ||
| if hasattr(op, "_schema"): | ||
| return op._schema | ||
| raise RuntimeError(f"No schema found on op {op}") | ||
|
|
||
|
|
||
| def _get_op_schema(node: Node): | ||
| """Return the op schema for a call_function node.""" | ||
| if node.op != "call_function": | ||
| raise ValueError(f"_get_op_schema only supports call_function nodes, got {node.op}") | ||
| return get_op_schema(node.target) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# Check file size and read the relevant section
wc -l tensorrt_llm/_torch/auto_deploy/utils/node_utils.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 119
🏁 Script executed:
# Read lines around 1052-1065 with context
sed -n '1040,1080p' tensorrt_llm/_torch/auto_deploy/utils/node_utils.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 1362
🏁 Script executed:
# Check for NVIDIA copyright header in the file
head -20 tensorrt_llm/_torch/auto_deploy/utils/node_utils.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 615
🏁 Script executed:
# Search for all usages of get_op_schema and _get_op_schema in the codebase
rg "get_op_schema|_get_op_schema" -A 3 -B 1Repository: NVIDIA/TensorRT-LLM
Length of output: 6684
🏁 Script executed:
# Check for NVIDIA copyright/license header anywhere in the file
head -100 tensorrt_llm/_torch/auto_deploy/utils/node_utils.py | grep -i "nvidia\|copyright\|license" || echo "No header found in first 100 lines"Repository: NVIDIA/TensorRT-LLM
Length of output: 97
🏁 Script executed:
# Check the git history to see if this file was recently modified
git log --oneline -n 5 tensorrt_llm/_torch/auto_deploy/utils/node_utils.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 99
🏁 Script executed:
# Verify if OpOverloadPacket can have multiple schemas
web_search "PyTorch OpOverloadPacket _schemas multiple overloads"Repository: NVIDIA/TensorRT-LLM
Length of output: 113
🌐 Web query:
PyTorch OpOverloadPacket _schemas multiple overloads torch._ops
💡 Result:
In PyTorch, torch.ops.<ns>.<op> (e.g. torch.ops.aten.add) is an OpOverloadPacket: a packet that can contain multiple overloads of the same operator (e.g. add.Tensor, add.Scalar, etc.). (android.googlesource.com)
What OpOverloadPacket._schemas is
OpOverloadPacket._schemas is a property that builds a dict of all overload schemas for that op by iterating self._overload_names and calling torch._C._get_schema(qualified_name, overload_name) for each overload. (android.googlesource.com)
So you’ll get something like:
- keys: overload names (the C++ “overload_name”;
""means the default overload) - values:
FunctionSchemaobjects (what you see as._schemaon anOpOverload)
How to list overloads and get each schema
import torch
p = torch.ops.aten.add # OpOverloadPacket
p.overloads() # e.g. ["Tensor", "Scalar", "default", ...]
schemas = p._schemas # dict: overload_name -> schema ("" key can appear)overloads() returns names but normalizes "" to "default". (android.googlesource.com)
How to select a specific overload (recommended)
Use attribute access on the packet to get an OpOverload, then read ._schema:
torch.ops.aten.add.Tensor._schema
torch.ops.aten.add.Scalar._schema
torch.ops.aten.add.default._schema # "default" maps to overload_name ""Internally, "default" is translated to "" when resolving the overload. (android.googlesource.com)
Citations:
- 1: https://android.googlesource.com/platform/external/pytorch/%2B/HEAD/torch/_ops.py
- 2: https://android.googlesource.com/platform/external/pytorch/%2B/HEAD/torch/_ops.py
- 3: https://android.googlesource.com/platform/external/pytorch/%2B/HEAD/torch/_ops.py
- 4: https://android.googlesource.com/platform/external/pytorch/%2B/HEAD/torch/_ops.py
🏁 Script executed:
# Check git log more detail to see when this file was created/modified
git log --format="%H %ai %s" tensorrt_llm/_torch/auto_deploy/utils/node_utils.py | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 157
🏁 Script executed:
# Check git diff to see if lines 1052-1065 were recently changed
git log -p -S "get_op_schema" -- tensorrt_llm/_torch/auto_deploy/utils/node_utils.py | head -150Repository: NVIDIA/TensorRT-LLM
Length of output: 45
Add type hints and resolve OpOverloadPacket ambiguity in schema selection.
get_op_schema(op) silently picks an arbitrary overload from op._schemas when passed an OpOverloadPacket with multiple overloads. This causes extract_op_args() and set_op_args() to read or write incorrect argument slots if called with an unresolved packet. Per PyTorch conventions, explicitly select a specific overload (e.g., .default) or fail fast on ambiguous packets.
Add missing type hints:
- Annotate
opparameter asOpOverload | OpOverloadPacket(or accept onlyOpOverloadif packets should not be passed) - Add return type
torch.FunctionSchemato_get_op_schema()
Also add NVIDIA copyright header per coding guidelines for modified Python files.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/auto_deploy/utils/node_utils.py` around lines 1052 -
1065, get_op_schema currently accepts an ambiguous op and picks an arbitrary
overload from op._schemas, causing extract_op_args/set_op_args to use wrong
slots; update get_op_schema to accept typed parameters (e.g., op:
Union[OpOverload, OpOverloadPacket]) or restrict to OpOverload, and when
receiving an OpOverloadPacket either explicitly select the .default overload or
raise a clear RuntimeError if multiple overloads exist, then return a
torch.FunctionSchema; also add the missing return type annotation to
_get_op_schema (-> torch.FunctionSchema) and annotate its parameter as Node,
update get_op_schema signature with the proper typing, and add the required
NVIDIA copyright header to the top of the file.
| yaml_paths, registry_world_size = _get_registry_yaml_extra( | ||
| self.MODEL_NAME) | ||
| if get_device_count() < registry_world_size: | ||
| pytest.skip("Not enough devices for world size, skipping test") | ||
|
|
||
| self.get_default_sampling_params() | ||
| with AutoDeployLLM(model=self.MODEL_NAME, | ||
| tokenizer=self.MODEL_NAME, | ||
| world_size=registry_world_size, | ||
| yaml_extra=yaml_paths) as llm: |
There was a problem hiding this comment.
Use the local model cache for this accuracy test.
This is the only accuracy case here that passes the raw HF id into AutoDeployLLM. That makes the run depend on network/auth during CI instead of the local artifact cache used by the rest of the file.
Suggested change
yaml_paths, registry_world_size = _get_registry_yaml_extra(
self.MODEL_NAME)
if get_device_count() < registry_world_size:
pytest.skip("Not enough devices for world size, skipping test")
+ model_path = hf_id_to_local_model_dir(self.MODEL_NAME)
self.get_default_sampling_params()
- with AutoDeployLLM(model=self.MODEL_NAME,
- tokenizer=self.MODEL_NAME,
+ with AutoDeployLLM(model=model_path,
+ tokenizer=model_path,
world_size=registry_world_size,
yaml_extra=yaml_paths) as llm:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/integration/defs/accuracy/test_llm_api_autodeploy.py` around lines 1001
- 1010, The test currently passes the raw HF model id (self.MODEL_NAME) into
AutoDeployLLM causing network/auth dependency; change the argument to point at
the local cached artifact by resolving the model id to the local cache path
before constructing AutoDeployLLM (use the repo's test artifact resolver or
helper used elsewhere in the file—e.g., the existing local model cache helper or
method on the test class) so AutoDeployLLM(model=...) receives the local
filesystem path rather than the HF id; keep other args (tokenizer, world_size,
yaml_extra) unchanged and ensure the resolver uses the same cache mechanism as
the other tests that rely on local artifacts.
There was a problem hiding this comment.
why do we need a sync here?
|
could we rebase and trigger CI? @bmarimuthu-nv |
|
needed a lot of structural changes and Gemma4 base had also merged. Created a cleaner PR: #12861. |
Moved to #12861
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.