diff --git a/nemo_rl/models/megatron/setup.py b/nemo_rl/models/megatron/setup.py index 93778aafa3..9357734393 100644 --- a/nemo_rl/models/megatron/setup.py +++ b/nemo_rl/models/megatron/setup.py @@ -13,6 +13,7 @@ # limitations under the License. import hashlib +from datetime import timedelta import json import os import time @@ -190,7 +191,7 @@ def destroy_parallel_state(): pass -def setup_distributed() -> None: +def setup_distributed(timeout_minutes: Optional[int] = None) -> None: """Handle NCCL settings, dtype mapping, and basic config setup.""" # Disable dynamo autotune_local_cache to avoid crash when there's already a cache # with different order of node_bundles @@ -198,7 +199,10 @@ def setup_distributed() -> None: # Ensure clean slate before import destroy_parallel_state() # Need to initialize the process group before calling into Megatron-Bridge, otherwise Megatron-Bridge will try to set an incorrect device - torch.distributed.init_process_group("nccl") + kwargs = {} + if timeout_minutes is not None: + kwargs["timeout"] = timedelta(minutes=timeout_minutes) + torch.distributed.init_process_group("nccl", **kwargs) def validate_and_set_config( @@ -399,7 +403,7 @@ def setup_model_config( # Create checkpoint configs checkpoint_config = _create_checkpoint_config( - pretrained_path, weights_path, optimizer_path + pretrained_path, weights_path, optimizer_path, config ) # Validate training configuration @@ -590,7 +594,7 @@ def _validate_chunking_config(config: PolicyConfig) -> None: def _create_checkpoint_config( - pretrained_path: str, weights_path: Optional[str], optimizer_path: Optional[str] + pretrained_path: str, weights_path: Optional[str], optimizer_path: Optional[str], config=None ) -> CheckpointConfig: """Create checkpoint configurations.""" return CheckpointConfig( @@ -599,10 +603,10 @@ def _create_checkpoint_config( load=weights_path, load_optim=optimizer_path is not None, pretrained_checkpoint=pretrained_path, - async_save=False, - fully_parallel_save=True, - fully_parallel_load=True, - load_rng=False, + async_save=(config or {}).get("megatron_cfg", {}).get("async_save", False), + fully_parallel_save=(config or {}).get("megatron_cfg", {}).get("fully_parallel_save", True), + fully_parallel_load=(config or {}).get("megatron_cfg", {}).get("fully_parallel_load", True), + load_rng=(config or {}).get("megatron_cfg", {}).get("load_rng", False), ) @@ -675,7 +679,7 @@ def _create_megatron_config( return ConfigContainer( model=model_cfg, checkpoint=checkpoint_config, - logger=LoggerConfig(logging_level=0), + logger=LoggerConfig(logging_level=config["megatron_cfg"].get("logging_level", 0)), train=TrainingConfig( micro_batch_size=1, # ignored global_batch_size=config["train_global_batch_size"], # ignored @@ -683,7 +687,9 @@ def _create_megatron_config( ), optimizer=OptimizerConfig(**config["megatron_cfg"]["optimizer"]), ddp=DistributedDataParallelConfig( - check_for_nan_in_grad=True, + check_for_nan_in_grad=config["megatron_cfg"][ + "distributed_data_parallel_config" + ].get("check_for_nan_in_grad", True), grad_reduce_in_fp32=config["megatron_cfg"][ "distributed_data_parallel_config" ]["grad_reduce_in_fp32"], diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 7fdf397b4b..221308a830 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -134,7 +134,9 @@ def __init__( self.rank = get_rank_safe() # Step 1: Setup distributed - setup_distributed() + setup_distributed( + timeout_minutes=config.get("megatron_cfg", {}).get("distributed_timeout_minutes"), + ) # Step 2: Validate and setup model paths hf_model_name, pretrained_path, pt_checkpoint_exists = validate_model_paths( diff --git a/tests/unit/models/megatron/test_megatron_setup.py b/tests/unit/models/megatron/test_megatron_setup.py index 8466b0d90a..d160e8935f 100644 --- a/tests/unit/models/megatron/test_megatron_setup.py +++ b/tests/unit/models/megatron/test_megatron_setup.py @@ -559,6 +559,28 @@ def test_basic_checkpoint_config(self, tmp_path): assert checkpoint_config.fully_parallel_load is True assert checkpoint_config.load_rng is False + def test_checkpoint_config_overrides(self, tmp_path): + """Test that checkpoint config fields can be overridden via megatron_cfg.""" + from nemo_rl.models.megatron.setup import _create_checkpoint_config + + config = { + "megatron_cfg": { + "async_save": True, + "fully_parallel_save": False, + "fully_parallel_load": False, + "load_rng": True, + } + } + + checkpoint_config = _create_checkpoint_config( + str(tmp_path / "pretrained"), None, None, config + ) + + assert checkpoint_config.async_save is True + assert checkpoint_config.fully_parallel_save is False + assert checkpoint_config.fully_parallel_load is False + assert checkpoint_config.load_rng is True + @pytest.mark.mcore class TestValidateTrainingConfig: