Expose hardcoded Megatron infrastructure parameters to user config#2230
Expose hardcoded Megatron infrastructure parameters to user config#2230nic-nvidia wants to merge 2 commits intoNVIDIA-NeMo:mainfrom
Conversation
Read checkpoint, timeout, and diagnostic settings from megatron_cfg with backward-compatible defaults instead of hardcoding them. New megatron_cfg fields (all optional, existing defaults preserved): - async_save, fully_parallel_save, fully_parallel_load, load_rng - distributed_timeout_minutes - logging_level New distributed_data_parallel_config field: - check_for_nan_in_grad Closes NVIDIA-NeMo#2229
There was a problem hiding this comment.
Hi @nic-nvidia , thanks for the enhancement! LGTM except the default config place. Could you help to update?
Also @yaoyu-33 @cuichenx , could you help to check whether the params in this PR are well supported in MBrdige?
| # Step 1: Setup distributed | ||
| setup_distributed() | ||
| setup_distributed( | ||
| timeout_minutes=config.get("megatron_cfg", {}).get("distributed_timeout_minutes"), |
There was a problem hiding this comment.
We encourage to set default value in config.yaml instead of in code, so that people can know what feature we have and their default behavior w/o looking into the code.
Can you help to:
- Update to the below, also other configs
- Add the param (set to the default value) to several base configs? (other configs will inherit from the base one so don't need to change)
examples/configs/distillation_math.yamlexamples/configs/dpo.yamlexamples/configs/grpo_math_1B.yamlexamples/configs/rm.yamlexamples/configs/sft.yamlexamples/nemo_gym/grpo_nanov3.yamlexamples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yamlresearch/template_project/configs/grpo_math_1B.yaml
| timeout_minutes=config.get("megatron_cfg", {}).get("distributed_timeout_minutes"), | |
| timeout_minutes=config["megatron_cfg"]["distributed_timeout_minutes"], |
There was a problem hiding this comment.
similarly this is cfg.dist.distributed_timeout_minutes in megatron bridge. just chekcing this is fine.
| async_save=(config or {}).get("megatron_cfg", {}).get("async_save", False), | ||
| fully_parallel_save=(config or {}).get("megatron_cfg", {}).get("fully_parallel_save", True), | ||
| fully_parallel_load=(config or {}).get("megatron_cfg", {}).get("fully_parallel_load", True), | ||
| load_rng=(config or {}).get("megatron_cfg", {}).get("load_rng", False), |
There was a problem hiding this comment.
these args are access from megatron as cfg.checkpoint.async_save but I don't see the "checkpoint" part here, could you double check this part
Summary
megatron_cfgwith backward-compatible defaults instead of hardcoding them insetup.pymegatron_cfgfields:async_save,fully_parallel_save,fully_parallel_load,load_rng,distributed_timeout_minutes,logging_leveldistributed_data_parallel_configfield:check_for_nan_in_gradCloses #2229
Test plan
test_basic_checkpoint_configpasses (backward compat, no config arg)test_checkpoint_config_overridesvalidates all 4 checkpoint fields