Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions nemo_rl/models/megatron/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import hashlib
from datetime import timedelta
import json
import os
import time
Expand Down Expand Up @@ -190,15 +191,18 @@ def destroy_parallel_state():
pass


def setup_distributed() -> None:
def setup_distributed(timeout_minutes: Optional[int] = None) -> None:
"""Handle NCCL settings, dtype mapping, and basic config setup."""
# Disable dynamo autotune_local_cache to avoid crash when there's already a cache
# with different order of node_bundles
configure_dynamo_cache()
# Ensure clean slate before import
destroy_parallel_state()
# Need to initialize the process group before calling into Megatron-Bridge, otherwise Megatron-Bridge will try to set an incorrect device
torch.distributed.init_process_group("nccl")
kwargs = {}
if timeout_minutes is not None:
kwargs["timeout"] = timedelta(minutes=timeout_minutes)
torch.distributed.init_process_group("nccl", **kwargs)


def validate_and_set_config(
Expand Down Expand Up @@ -399,7 +403,7 @@ def setup_model_config(

# Create checkpoint configs
checkpoint_config = _create_checkpoint_config(
pretrained_path, weights_path, optimizer_path
pretrained_path, weights_path, optimizer_path, config
)

# Validate training configuration
Expand Down Expand Up @@ -590,7 +594,7 @@ def _validate_chunking_config(config: PolicyConfig) -> None:


def _create_checkpoint_config(
pretrained_path: str, weights_path: Optional[str], optimizer_path: Optional[str]
pretrained_path: str, weights_path: Optional[str], optimizer_path: Optional[str], config=None
) -> CheckpointConfig:
"""Create checkpoint configurations."""
return CheckpointConfig(
Expand All @@ -599,10 +603,10 @@ def _create_checkpoint_config(
load=weights_path,
load_optim=optimizer_path is not None,
pretrained_checkpoint=pretrained_path,
async_save=False,
fully_parallel_save=True,
fully_parallel_load=True,
load_rng=False,
async_save=(config or {}).get("megatron_cfg", {}).get("async_save", False),
fully_parallel_save=(config or {}).get("megatron_cfg", {}).get("fully_parallel_save", True),
fully_parallel_load=(config or {}).get("megatron_cfg", {}).get("fully_parallel_load", True),
load_rng=(config or {}).get("megatron_cfg", {}).get("load_rng", False),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these args are access from megatron as cfg.checkpoint.async_save but I don't see the "checkpoint" part here, could you double check this part

)


Expand Down Expand Up @@ -675,15 +679,17 @@ def _create_megatron_config(
return ConfigContainer(
model=model_cfg,
checkpoint=checkpoint_config,
logger=LoggerConfig(logging_level=0),
logger=LoggerConfig(logging_level=config["megatron_cfg"].get("logging_level", 0)),
train=TrainingConfig(
micro_batch_size=1, # ignored
global_batch_size=config["train_global_batch_size"], # ignored
train_iters=config["megatron_cfg"]["train_iters"],
),
optimizer=OptimizerConfig(**config["megatron_cfg"]["optimizer"]),
ddp=DistributedDataParallelConfig(
check_for_nan_in_grad=True,
check_for_nan_in_grad=config["megatron_cfg"][
"distributed_data_parallel_config"
].get("check_for_nan_in_grad", True),
grad_reduce_in_fp32=config["megatron_cfg"][
"distributed_data_parallel_config"
]["grad_reduce_in_fp32"],
Expand Down
4 changes: 3 additions & 1 deletion nemo_rl/models/policy/workers/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,9 @@ def __init__(
self.rank = get_rank_safe()

# Step 1: Setup distributed
setup_distributed()
setup_distributed(
timeout_minutes=config.get("megatron_cfg", {}).get("distributed_timeout_minutes"),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We encourage to set default value in config.yaml instead of in code, so that people can know what feature we have and their default behavior w/o looking into the code.

Can you help to:

  1. Update to the below, also other configs
  2. Add the param (set to the default value) to several base configs? (other configs will inherit from the base one so don't need to change)
    1. examples/configs/distillation_math.yaml
    2. examples/configs/dpo.yaml
    3. examples/configs/grpo_math_1B.yaml
    4. examples/configs/rm.yaml
    5. examples/configs/sft.yaml
    6. examples/nemo_gym/grpo_nanov3.yaml
    7. examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml
    8. research/template_project/configs/grpo_math_1B.yaml
Suggested change
timeout_minutes=config.get("megatron_cfg", {}).get("distributed_timeout_minutes"),
timeout_minutes=config["megatron_cfg"]["distributed_timeout_minutes"],

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similarly this is cfg.dist.distributed_timeout_minutes in megatron bridge. just chekcing this is fine.

)

# Step 2: Validate and setup model paths
hf_model_name, pretrained_path, pt_checkpoint_exists = validate_model_paths(
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/models/megatron/test_megatron_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,28 @@ def test_basic_checkpoint_config(self, tmp_path):
assert checkpoint_config.fully_parallel_load is True
assert checkpoint_config.load_rng is False

def test_checkpoint_config_overrides(self, tmp_path):
"""Test that checkpoint config fields can be overridden via megatron_cfg."""
from nemo_rl.models.megatron.setup import _create_checkpoint_config

config = {
"megatron_cfg": {
"async_save": True,
"fully_parallel_save": False,
"fully_parallel_load": False,
"load_rng": True,
}
}

checkpoint_config = _create_checkpoint_config(
str(tmp_path / "pretrained"), None, None, config
)

assert checkpoint_config.async_save is True
assert checkpoint_config.fully_parallel_save is False
assert checkpoint_config.fully_parallel_load is False
assert checkpoint_config.load_rng is True


@pytest.mark.mcore
class TestValidateTrainingConfig:
Expand Down
Loading