[None][feat]: Add test_moe_semantics.py to help agent understand the …#12797
[None][feat]: Add test_moe_semantics.py to help agent understand the …#12797WeiHaocheng wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
…semantic of Moe module
📝 WalkthroughWalkthroughAdded a comprehensive semantic reference test suite for MoE (Mixture of Experts) functionality. The test suite validates weight shapes, forward-pass semantics, routing methods, activation functions, and configuration modes against pure PyTorch reference implementations across multiple test classes. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~35 minutes 🚥 Pre-merge checks | ❌ 3❌ Failed checks (2 warnings, 1 inconclusive)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
tests/unittest/_torch/modules/moe/test_moe_semantics.py (1)
824-846: Exerciseweight_loading_modethrough an actual MoE instance.This section proves the algebra for
torch.cat([W3, W1]), but it never constructs aVanillaMoEorcreate_moeinstance withweight_loading_mode=FUSED_GATE_UP_PROJ. If the real loader flips the fused layout or splits it incorrectly, this test still passes. Please compare two actual modules initialized from the same weights instead of only comparing hand-written matmuls.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/modules/moe/test_moe_semantics.py` around lines 824 - 846, The test only verifies the algebra of concatenating W3 and W1 but doesn't exercise the actual loader or MoE modules; update test_fused_gate_up_equivalence to instantiate two MoE modules (e.g., VanillaMoE or via create_moe) with identical initial weights and different weight_loading_mode settings (normal vs FUSED_GATE_UP_PROJ), load the same W1/W3 into each (or set state_dicts accordingly), run the same input x through both modules, and assert their gate and up outputs match (use torch.allclose with the existing tolerances) so the real loader/layout logic is validated; reference VanillaMoE/create_moe, weight_loading_mode and FUSED_GATE_UP_PROJ to find where to change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/unittest/_torch/modules/moe/test_moe_semantics.py`:
- Around line 560-571: The test currently only compares selected expert ids and
that scales sum to 1; instead, for each token t canonicalize by expert id and
compare per-expert weights: for each t use indices[t] and scales[t] to build a
dict/array mapping expert_id -> weight and do the same for ref_topk_idx[t] and
ref_top_k_weights[t], then assert element-wise equality (e.g., torch.allclose)
between the aligned weight vectors; refer to symbols indices, scales,
ref_topk_idx (or ref_top_k_weights) and num_tokens to locate where to implement
this per-token alignment and comparison.
- Around line 705-719: This test calls the custom operator via
torch.ops.trtllm.silu_and_mul but never imports the module that registers it;
explicitly import the module that contains the `@torch.library.custom_op`
registration (e.g., import torch_custom_ops or the library's public registration
path) at the top of the test file or inside test_swiglu_kernel so the custom op
is registered before calling torch.ops.trtllm.silu_and_mul.
---
Nitpick comments:
In `@tests/unittest/_torch/modules/moe/test_moe_semantics.py`:
- Around line 824-846: The test only verifies the algebra of concatenating W3
and W1 but doesn't exercise the actual loader or MoE modules; update
test_fused_gate_up_equivalence to instantiate two MoE modules (e.g., VanillaMoE
or via create_moe) with identical initial weights and different
weight_loading_mode settings (normal vs FUSED_GATE_UP_PROJ), load the same W1/W3
into each (or set state_dicts accordingly), run the same input x through both
modules, and assert their gate and up outputs match (use torch.allclose with the
existing tolerances) so the real loader/layout logic is validated; reference
VanillaMoE/create_moe, weight_loading_mode and FUSED_GATE_UP_PROJ to find where
to change.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: a2762e1d-026c-4dab-ae3c-431847930ac6
📒 Files selected for processing (1)
tests/unittest/_torch/modules/moe/test_moe_semantics.py
| # Indices may be in different order (sorted=False), so compare sets | ||
| for t in range(num_tokens): | ||
| actual_set = set(indices[t].cpu().tolist()) | ||
| ref_set = set(ref_topk_idx[t].cpu().tolist()) | ||
| assert actual_set == ref_set, \ | ||
| f"Token {t}: experts mismatch {actual_set} vs {ref_set}" | ||
| # Scales should sum to 1.0 | ||
| assert torch.allclose(scales.sum(dim=-1), | ||
| torch.ones(num_tokens, | ||
| device="cuda", | ||
| dtype=torch.float32), | ||
| atol=1e-5) |
There was a problem hiding this comment.
Compare the per-expert weights, not just the selected expert set.
The loop at Line 561 only checks which experts were chosen, and Lines 567-571 only check that the returned weights sum to 1. A regression that keeps the same expert ids but attaches the wrong weight to each one will still pass. Canonicalize both rows by expert id and compare scales against ref_top_k_weights element-wise.
Suggested assertion update
- # Indices may be in different order (sorted=False), so compare sets
- for t in range(num_tokens):
- actual_set = set(indices[t].cpu().tolist())
- ref_set = set(ref_topk_idx[t].cpu().tolist())
- assert actual_set == ref_set, \
- f"Token {t}: experts mismatch {actual_set} vs {ref_set}"
+ # Indices may be in different order (sorted=False), so compare
+ # expert/weight pairs after canonicalizing by expert id.
+ for t in range(num_tokens):
+ actual_order = torch.argsort(indices[t].to(torch.int64))
+ ref_order = torch.argsort(ref_topk_idx[t])
+ assert torch.equal(
+ indices[t][actual_order],
+ ref_topk_idx[t][ref_order].to(torch.int32),
+ )
+ assert torch.allclose(
+ scales[t][actual_order],
+ ref_top_k_weights[t][ref_order],
+ atol=1e-6,
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/unittest/_torch/modules/moe/test_moe_semantics.py` around lines 560 -
571, The test currently only compares selected expert ids and that scales sum to
1; instead, for each token t canonicalize by expert id and compare per-expert
weights: for each t use indices[t] and scales[t] to build a dict/array mapping
expert_id -> weight and do the same for ref_topk_idx[t] and
ref_top_k_weights[t], then assert element-wise equality (e.g., torch.allclose)
between the aligned weight vectors; refer to symbols indices, scales,
ref_topk_idx (or ref_top_k_weights) and num_tokens to locate where to implement
this per-token alignment and comparison.
| def test_swiglu_kernel(self): | ||
| """§15a: silu_and_mul — silu(gate) * up.""" | ||
| torch.manual_seed(42) | ||
| intermediate_size = 64 | ||
| x = torch.randn(8, intermediate_size * 2, device="cuda") | ||
|
|
||
| result = torch.ops.trtllm.silu_and_mul(x) | ||
|
|
||
| # Reference: silu(first_half) * second_half | ||
| gate = x[..., :intermediate_size] | ||
| up = x[..., intermediate_size:] | ||
| ref = F.silu(gate) * up | ||
|
|
||
| assert result.shape == (8, intermediate_size) | ||
| assert torch.allclose(result, ref, atol=1e-5) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "Call site and local imports in the new test:"
rg -n -C2 'torch\.ops\.trtllm\.silu_and_mul|import tensorrt_llm|from tensorrt_llm' \
tests/unittest/_torch/modules/moe/test_moe_semantics.py
echo
echo "Possible registration paths in pytest setup or package init:"
fd -i 'conftest.py' tests tensorrt_llm | xargs -r rg -n -C2 'torch_custom_ops|custom_ops|silu_and_mul'
fd -i '__init__.py' tensorrt_llm | xargs -r rg -n -C2 'torch_custom_ops|custom_ops|silu_and_mul'Repository: NVIDIA/TensorRT-LLM
Length of output: 4096
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Search for silu_and_mul registration ==="
rg -n 'silu_and_mul' tensorrt_llm --type py
echo
echo "=== Check test file module-level imports (first 50 lines) ==="
head -50 tests/unittest/_torch/modules/moe/test_moe_semantics.py
echo
echo "=== Look for torch_custom_ops loading in package init ==="
rg -r 'torch_custom_ops|trtllm.*register' tensorrt_llm/__init__.py tensorrt_llm/_torch/__init__.py 2>/dev/null || echo "(files may not exist)"
echo
echo "=== Check if torch_custom_ops.py exists and what it contains ==="
find tensorrt_llm -name '*custom_ops*' -o -name '*torch_custom_ops*' | head -20Repository: NVIDIA/TensorRT-LLM
Length of output: 6441
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Check torch_custom_ops.py registration definition ==="
sed -n '1670,1710p' tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
echo
echo "=== Check custom_ops/__init__.py for registration imports ==="
cat tensorrt_llm/_torch/custom_ops/__init__.py
echo
echo "=== Check if test conftest loads custom_ops ==="
find tests -name 'conftest.py' -exec grep -l 'custom_ops\|torch_custom_ops\|trtllm' {} \;
echo
echo "=== Check test directory structure for conftest at test level ==="
ls -la tests/unittest/_torch/modules/moe/ | grep conftestRepository: NVIDIA/TensorRT-LLM
Length of output: 3064
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Check tests/unittest/conftest.py for custom_ops imports ==="
cat tests/unittest/conftest.py
echo
echo "=== Check if torch_custom_ops.py is imported anywhere at module level in test file ==="
head -100 tests/unittest/_torch/modules/moe/test_moe_semantics.py | grep -n 'import\|from'
echo
echo "=== Check how the test class is structured - look for setUp or fixtures ==="
sed -n '700,750p' tests/unittest/_torch/modules/moe/test_moe_semantics.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 14680
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Check tensorrt_llm/__init__.py ==="
if [ -f tensorrt_llm/__init__.py ]; then
rg -n 'torch_custom_ops|custom_ops' tensorrt_llm/__init__.py || echo "No matches"
fi
echo
echo "=== Check tensorrt_llm/_torch/__init__.py ==="
if [ -f tensorrt_llm/_torch/__init__.py ]; then
cat tensorrt_llm/_torch/__init__.py
else
echo "File does not exist"
fi
echo
echo "=== Search for imports of torch_custom_ops in _torch package ==="
rg -n 'from.*torch_custom_ops|import.*torch_custom_ops' tensorrt_llm/_torch/Repository: NVIDIA/TensorRT-LLM
Length of output: 484
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Check if fused_moe imports torch_custom_ops ==="
rg -n 'torch_custom_ops|silu_and_mul' tensorrt_llm/_torch/modules/fused_moe/__init__.py
echo
echo "=== Check modules/swiglu.py (which uses torch.ops.trtllm.silu_and_mul) ==="
head -30 tensorrt_llm/_torch/modules/swiglu.py
echo
echo "=== Check if swiglu imports torch_custom_ops ==="
rg -n 'import' tensorrt_llm/_torch/modules/swiglu.py | head -20
echo
echo "=== Trace dependency: Does any module imported by test transitively load torch_custom_ops? ==="
rg -n 'from.*torch_custom_ops|import.*torch_custom_ops' tensorrt_llm/_torch/modules/fused_moe/ tensorrt_llm/_torch/utils.py tensorrt_llm/_torch/model_config.py 2>/dev/null || echo "(checking...)"Repository: NVIDIA/TensorRT-LLM
Length of output: 114
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Check swiglu.py full content (it uses torch.ops.trtllm.silu_and_mul) ==="
cat tensorrt_llm/_torch/modules/swiglu.py
echo
echo "=== Check where swiglu.py is imported ==="
rg -l 'from.*swiglu|import.*swiglu' tensorrt_llm/ tests/
echo
echo "=== Try to find ANY import of torch_custom_ops in modules/ ==="
find tensorrt_llm/_torch/modules -name '*.py' -exec grep -l 'torch_custom_ops' {} \;Repository: NVIDIA/TensorRT-LLM
Length of output: 2892
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Check tensorrt_llm/_torch/utils.py for torch_custom_ops imports ==="
rg -n 'torch_custom_ops|custom_ops' tensorrt_llm/_torch/utils.py || echo "No matches"
echo
echo "=== Check tensorrt_llm/_torch/modules/fused_moe/__init__.py ==="
cat tensorrt_llm/_torch/modules/fused_moe/__init__.py 2>/dev/null || echo "File not found"
echo
echo "=== Check all imports in fused_moe/ that might load torch_custom_ops ==="
rg -n 'torch_custom_ops|from.*custom_ops' tensorrt_llm/_torch/modules/fused_moe/ 2>/dev/null | head -20
echo
echo "=== Specifically check fused_moe/create_moe.py (main entry point) ==="
rg -n 'torch_custom_ops|import.*custom_ops' tensorrt_llm/_torch/modules/fused_moe/create_moe.py 2>/dev/null || echo "No matches"Repository: NVIDIA/TensorRT-LLM
Length of output: 2550
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Full module-level imports of test file (lines 1-100) ==="
head -100 tests/unittest/_torch/modules/moe/test_moe_semantics.py | grep -n '^import\|^from'
echo
echo "=== Check if relu2 or any imported function transitively imports torch_custom_ops ==="
rg -n 'relu2' tensorrt_llm/_torch/utils.py -A 5 -B 5
echo
echo "=== Check if there's a conftest.py at the moe test level ==="
ls -la tests/unittest/_torch/modules/moe/conftest.py 2>/dev/null || echo "No conftest.py in moe test directory"
echo
echo "=== Final check: Does torch_custom_ops get imported somewhere that always loads before any test? ==="
rg -r 'import.*torch_custom_ops' tensorrt_llm --type py | grep -v test | head -5Repository: NVIDIA/TensorRT-LLM
Length of output: 1288
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Search for ANY import of torch_custom_ops (exact, not regex) ==="
rg 'from.*torch_custom_ops import|import.*torch_custom_ops' tensorrt_llm tensorrt_llm/_torch --type py
echo
echo "=== Check if pytest conftest or plugin setup loads it ==="
find tests -name 'conftest.py' -o -name 'pytest.ini' -o -name 'setup.cfg' | xargs grep -l 'torch_custom_ops' 2>/dev/null || echo "Not found in pytest config"
echo
echo "=== Check the actual test run - is there any fixture in the test class? ==="
sed -n '680,730p' tests/unittest/_torch/modules/moe/test_moe_semantics.pyRepository: NVIDIA/TensorRT-LLM
Length of output: 2574
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Confirm: torch_custom_ops is NOT imported anywhere that loads before test ==="
echo "Checking if test file has anything that imports from modules that import torch_custom_ops:"
grep -n 'from tensorrt_llm' tests/unittest/_torch/modules/moe/test_moe_semantics.py | head -20
echo
echo "=== Does any of those modules (on first import line 121+) import torch_custom_ops? ==="
python3 << 'EOF'
import sys
sys.path.insert(0, '/repo')
# Try importing what the test imports
try:
from tensorrt_llm._torch.model_config import ModelConfig
print("ModelConfig imported successfully")
except Exception as e:
print(f"Error: {e}")
# Check if torch.ops.trtllm.silu_and_mul is registered
import torch
try:
# This should fail if torch_custom_ops was never imported
torch.ops.trtllm.silu_and_mul
print("silu_and_mul IS registered!")
except AttributeError as e:
print(f"silu_and_mul NOT registered: {e}")
EOFRepository: NVIDIA/TensorRT-LLM
Length of output: 1905
🏁 Script executed:
#!/bin/bash
set -euoripefail
echo "=== Check if fused_moe modules import torch_custom_ops directly or transitively ==="
for file in tensorrt_llm/_torch/modules/fused_moe/*.py; do
if grep -q 'torch_custom_ops\|custom_ops' "$file" 2>/dev/null; then
echo "Found in: $file"
fi
done
echo
echo "=== Check fused_moe_triton.py (imported in test) ==="
head -40 tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py | grep -n 'import\|torch_custom_ops'
echo
echo "=== Check if swiglu.py is imported ANYWHERE in normal module execution ==="
rg -n 'import.*swiglu' tensorrt_llm/_torch/ | grep -v test | grep -v fused_moe
echo
echo "=== Definitive check: Is torch.ops.trtllm registration guaranteed before test_swiglu_kernel? ==="
echo "Searching for where silu_and_mul op gets registered..."
rg -n '@torch.library.custom_op.*silu_and_mul' tensorrt_llm/Repository: NVIDIA/TensorRT-LLM
Length of output: 737
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Check if fused_moe modules import torch_custom_ops directly or transitively ==="
for file in tensorrt_llm/_torch/modules/fused_moe/*.py; do
if grep -q 'torch_custom_ops\|custom_ops' "$file" 2>/dev/null; then
echo "Found in: $file"
fi
done
echo
echo "=== Check fused_moe_triton.py (imported in test) ==="
head -40 tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py | grep -n 'import\|torch_custom_ops'
echo
echo "=== Check if swiglu.py is imported ANYWHERE in normal module execution ==="
rg -n 'import.*swiglu' tensorrt_llm/_torch/ | grep -v test | grep -v fused_moe
echo
echo "=== Definitive check: Is torch.ops.trtllm registration guaranteed before test_swiglu_kernel? ==="
echo "Searching for where silu_and_mul op gets registered..."
rg -n '@torch.library.custom_op.*silu_and_mul' tensorrt_llm/Repository: NVIDIA/TensorRT-LLM
Length of output: 2399
Import torch_custom_ops (or a public registration path) to make custom-op registration explicit in this test.
Line 711 calls torch.ops.trtllm.silu_and_mul, but torch_custom_ops.py is never explicitly imported in this test. The op's @torch.library.custom_op decorator runs at module load time, so registration depends on whether torch_custom_ops is loaded before the test runs. This creates an implicit dependency on import order that breaks isolated test runs like -k test_swiglu_kernel. Add an explicit import at the test level or module level to guarantee registration.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/unittest/_torch/modules/moe/test_moe_semantics.py` around lines 705 -
719, This test calls the custom operator via torch.ops.trtllm.silu_and_mul but
never imports the module that registers it; explicitly import the module that
contains the `@torch.library.custom_op` registration (e.g., import
torch_custom_ops or the library's public registration path) at the top of the
test file or inside test_swiglu_kernel so the custom op is registered before
calling torch.ops.trtllm.silu_and_mul.
…semantic of Moe module
Summary by CodeRabbit
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.