[DSV3] Fix the ckpt loading issue when no MoE layer on the mtp rank#3315
[DSV3] Fix the ckpt loading issue when no MoE layer on the mtp rank#3315gdengk wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
Conversation
3759e55 to
3fa93de
Compare
📝 WalkthroughWalkthroughThe pull request modifies layer specification derivation logic in Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/megatron/bridge/models/gpt_provider.py`:
- Around line 365-370: The GPTModelProvider dataclass is missing the qk_l2_norm
attribute but the code in get_gpt_decoder_layer_specs accesses
config.qk_l2_norm; add a new field to the GPTModelProvider dataclass named
qk_l2_norm: bool = False (matching the style/pattern of qk_layernorm and fp8) so
instances have this attribute and runtime AttributeError is avoided; update any
dataclass default initialization or constructor usage as needed to include the
new field.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: bb23f6f8-8f8b-442c-b0b5-75eda2b121db
📒 Files selected for processing (1)
src/megatron/bridge/models/gpt_provider.py
| decoder_layer_specs = get_gpt_decoder_layer_specs( | ||
| config, | ||
| use_transformer_engine=True, | ||
| normalization=config.normalization, | ||
| qk_l2_norm=config.qk_l2_norm, | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "1) Check qk-related fields/usage in gpt_provider.py"
rg -n -C2 '\bqk_(layernorm|l2_norm)\b' src/megatron/bridge/models/gpt_provider.py
echo
echo "2) Find TransformerConfig definitions and qk fields"
fd -i 'transformer_config.py' | while read -r f; do
echo "=== $f ==="
rg -n -C2 '\bqk_(layernorm|l2_norm)\b' "$f" || true
done
echo
echo "3) Inspect get_gpt_decoder_layer_specs definition/signature"
rg -n -C4 'def\s+get_gpt_decoder_layer_specs\s*\('
echo
echo "4) Compare argument naming at all call sites"
rg -n -C2 'get_gpt_decoder_layer_specs\s*\('Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1626
🏁 Script executed:
# Search for get_gpt_decoder_layer_specs definition more broadly
rg -n 'def\s+get_gpt_decoder_layer_specs' --type py
# Search for qk_l2_norm references anywhere in the codebase
rg -n 'qk_l2_norm' --type py
# Check the TransformerConfig class in gpt_provider.py more completely
sed -n '165,190p' src/megatron/bridge/models/gpt_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1749
🏁 Script executed:
# Search for get_gpt_decoder_layer_specs definition across the repo
find . -type f -name "*.py" -exec grep -l "def get_gpt_decoder_layer_specs" {} \;
# Also check imports in gpt_provider.py to find where this function comes from
grep -n "get_gpt_decoder_layer_specs" src/megatron/bridge/models/gpt_provider.py | head -20
# Check the imports at the top of gpt_provider.py
head -50 src/megatron/bridge/models/gpt_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 2201
🏁 Script executed:
# Find all usages of get_gpt_decoder_layer_specs in the repo
rg -n -B2 -A5 'get_gpt_decoder_layer_specs\s*\(' src/megatron/bridge --type py
# Check if there's a parent/base TransformerConfig that defines qk_l2_norm
rg -n 'qk_l2_norm' src/megatron/bridge/models/transformer_config.py
# Check GPTModelProvider config class more comprehensively
sed -n '150,210p' src/megatron/bridge/models/gpt_provider.py
# Look at the llama4_utils.py usage to understand the pattern
rg -n -B5 -A5 'qk_l2_norm' src/megatron/bridge/models/llama/llama4_utils.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 4172
🏁 Script executed:
# Check the base class of GPTModelProvider
grep -n "class GPTModelProvider" src/megatron/bridge/models/gpt_provider.py
# Check if there's a parent config that might define qk_l2_norm
sed -n '140,180p' src/megatron/bridge/models/gpt_provider.py
# Check the full dataclass definition
rg -n "@dataclass|class GPTModelProvider" src/megatron/bridge/models/gpt_provider.py -A1
# Look for any initialization or property definition of qk_l2_norm
rg -n "qk_l2_norm" src/megatron/bridge/models/gpt_provider.py
# Check if ModelProviderMixin or parent classes define this field
grep -rn "class ModelProviderMixin" src/megatron/bridge/models/Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 2339
🏁 Script executed:
# Check the base TransformerConfig class from megatron.bridge.models.transformer_config
cat src/megatron/bridge/models/transformer_config.py
# Also check if there's a relationship between qk_l2_norm and qk_layernorm
rg -n "qk_" src/megatron/bridge/models/transformer_config.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 11837
🏁 Script executed:
# Check if qk_l2_norm is defined in megatron.core by looking at the function signature
# that calls get_gpt_decoder_layer_specs in other files to understand the expected parameter
cat src/megatron/bridge/training/mlm_compat/model.py | head -120
# Also check if there's any fallback or default logic for qk_l2_norm
rg -n "getattr.*qk_l2_norm|hasattr.*qk_l2_norm" --type pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 4966
🏁 Script executed:
# Look at the actual function signature or docstring for get_gpt_decoder_layer_specs
# by checking what parameters it accepts in different contexts
rg -n "def get_gpt_decoder_layer_specs|def get_gpt_decoder_block_spec" --type py -A 10
# Check what fields args has in mlm_compat
rg -n "add_argument.*qk_l2_norm|add_argument.*qk_layernorm" --type py -A 2 -B 2Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 53
🏁 Script executed:
# Check the arguments defined in mlm_compat to see what fields it expects
rg -n "qk_l2_norm|qk_layernorm" src/megatron/bridge/training/mlm_compat/arguments.py -B 2 -A 2
# Check if there's any documentation or patterns for how config fields are managed
rg -n "class.*TransformerConfig|@dataclass" src/megatron/bridge/models/gpt_provider.py | head -5
# Look for any property or __getattr__ definitions that might create qk_l2_norm dynamically
rg -n "property|__getattr__|def __init__" src/megatron/bridge/models/gpt_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 260
🏁 Script executed:
# Check if qk_l2_norm might be inherited from megatron.core.transformer.transformer_config.TransformerConfig
# by searching for it in megatron imports or documentation
rg -n "qk_l2_norm" src/megatron/bridge/models/ -l
# Also check the exact inheritance chain and what TransformerConfig from megatron.core provides
# by looking at how it's used in mlm_compat where both qk_layernorm and qk_l2_norm are passed
sed -n '40,80p' src/megatron/bridge/training/mlm_compat/model.py
# Check if the args object in mlm_compat defines qk_l2_norm
sed -n '1,200p' src/megatron/bridge/training/mlm_compat/arguments.py | grep -A 5 -B 5 "qk_l2_norm"Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1602
🏁 Script executed:
# Check the full arguments.py file to see if qk_l2_norm is defined as an argument
cat src/megatron/bridge/training/mlm_compat/arguments.py | grep -n "qk_" -A 2 -B 2
# Also check what _transformer_config_from_args does to understand the config construction
rg -n "_transformer_config_from_args" src/megatron/bridge/training/mlm_compat/arguments.py -A 30Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1811
🏁 Script executed:
# Check if qk_l2_norm is defined in megatron.core.transformer.transformer_config
# by searching for references in the entire repo
rg -n "class TransformerConfig" src/megatron/bridge/models/transformer_config.py -A 50 | head -100
# Look at line 75-77 of arguments.py more carefully - it copies fields that exist in args
# Let's see what args defines
rg -n "add_argument.*qk" src/megatron/bridge/training/mlm_compat/arguments.py
# Check if qk_l2_norm is added anywhere in the argument parser
rg -rn "add_argument.*qk_l2_norm" src/megatron/bridge/training/Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 2411
🏁 Script executed:
# Let's check what MCoreTransformerConfig defines for qk fields
# by examining what fields are expected in the inherited class
rg -n "qk_l2_norm|qk_layernorm" tests/unit_tests/training/mlm_compat/test_model.py -B 3 -A 3
# Also look at what the actual error would be if config.qk_l2_norm doesn't exist
# Check if there's a fallback or default mechanism
sed -n '360,375p' src/megatron/bridge/models/gpt_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1836
🏁 Script executed:
# Check the exact line 369 context and what config is being used there
sed -n '355,375p' src/megatron/bridge/models/gpt_provider.py
# Check if GPTModelProvider or its base TransformerConfig class actually has qk_l2_norm field
# by examining what fields it exposes
rg -n "qk_l2_norm.*:" src/megatron/bridge/models/gpt_provider.py
# Check if there's a default or if it's only conditionally added
rg -n "qk_l2_norm\s*=" src/megatron/bridge/models/gpt_provider.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1166
🏁 Script executed:
# Check if qk_l2_norm is used with getattr elsewhere for safe access
rg -n "getattr.*qk_l2_norm|getattr.*qk_layernorm" --type py
# Look at how other config accesses are done in the same MoE fallback section
sed -n '350,380p' src/megatron/bridge/models/gpt_provider.py
# Verify what fields are in GPTModelProvider vs what might be inherited
rg -n "^\s+[a-z_]+:\s+(Optional\[)?[a-zA-Z]" src/megatron/bridge/models/gpt_provider.py | grep -E "qk|norm|fp8"Repository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 1603
🏁 Script executed:
# Check if qk_l2_norm should be added as a field to GPTModelProvider
# by looking at all fields defined in the class
sed -n '124,230p' src/megatron/bridge/models/gpt_provider.py | grep -n ":" | head -30
# Also check if there's documentation or comments about this field being intentionally missing
rg -n "qk_l2_norm|qk_layernorm" src/megatron/bridge/models/gpt_provider.py
# Look at llama4_utils to see if it defines qk_l2_norm or just uses it
sed -n '140,160p' src/megatron/bridge/models/llama/llama4_utils.pyRepository: NVIDIA-NeMo/Megatron-Bridge
Length of output: 2896
Add missing qk_l2_norm field to GPTModelProvider config.
The code at line 369 accesses config.qk_l2_norm, but this field is not defined in the GPTModelProvider dataclass. Only qk_layernorm is defined (line 180). Direct attribute access will raise AttributeError at runtime if the config instance lacks this field. Add qk_l2_norm: bool = False to the dataclass to match the pattern of related fields like qk_layernorm and fp8.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/megatron/bridge/models/gpt_provider.py` around lines 365 - 370, The
GPTModelProvider dataclass is missing the qk_l2_norm attribute but the code in
get_gpt_decoder_layer_specs accesses config.qk_l2_norm; add a new field to the
GPTModelProvider dataclass named qk_l2_norm: bool = False (matching the
style/pattern of qk_layernorm and fp8) so instances have this attribute and
runtime AttributeError is avoided; update any dataclass default initialization
or constructor usage as needed to include the new field.
What does this PR do ?
Add a one line overview of what this PR aims to accomplish.
Changelog
GitHub Actions CI
See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.
Before your PR is "Ready for review"
Pre checks:
If you haven't finished some of the above items you can still open "Draft" PR.
Additional Information
Summary by CodeRabbit