diff --git a/src/megatron/bridge/models/gpt_provider.py b/src/megatron/bridge/models/gpt_provider.py index b302c0dec6..36067e1066 100644 --- a/src/megatron/bridge/models/gpt_provider.py +++ b/src/megatron/bridge/models/gpt_provider.py @@ -358,7 +358,17 @@ def mtp_block_spec(config: "GPTModelProvider", vp_stage: Optional[int] = None) - if hasattr(spec, "layer_specs") and len(spec.layer_specs) == 0: # Get the decoder layer spec explicitly if no decoder layer in the last stage, # Only happens with block spec (TransformerBlockSubmodules) when using MoE. - spec = default_layer_spec(config) + # Re-derive all decoder layer specs and use the last one to get the correct + # layer type (dense vs MoE) for the MTP transformer layer. + from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_layer_specs + + decoder_layer_specs = get_gpt_decoder_layer_specs( + config, + use_transformer_engine=True, + normalization=config.normalization, + qk_l2_norm=config.qk_l2_norm, + ) + spec = decoder_layer_specs[-1] return get_gpt_mtp_block_spec(config, spec, use_transformer_engine=True, vp_stage=vp_stage) else: return None