diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index 0bbe5bf49b..1ff0557e29 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -525,13 +525,17 @@ def train( config.grad_scale_func = optimizer.scale_loss config.timers = None if isinstance(model[0], DDP) and args.overlap_grad_reduce: - assert config.no_sync_func is None, ( - "When overlap_grad_reduce is True, config.no_sync_func must be None; " - "a custom no_sync_func is not supported when overlapping grad-reduce" - ) - config.no_sync_func = [model_chunk.no_sync for model_chunk in model] - if len(model) == 1: - config.no_sync_func = config.no_sync_func[0] + # Install model.no_sync as the DDP grad-sync gate. Required by + # --overlap-grad-reduce so the pipeline scheduler wraps intermediate + # micro-batches in no_sync_func(). Megatron's upstream asserts + # `config.no_sync_func is None` here (written for a one-shot + # pretraining loop), but slime calls train() once per rollout — on + # the 2nd entry the value we installed last time is still there. + # Skip the install if already set with our value. + if config.no_sync_func is None: + config.no_sync_func = [model_chunk.no_sync for model_chunk in model] + if len(model) == 1: + config.no_sync_func = config.no_sync_func[0] if args.align_grad_reduce: config.grad_sync_func = [model_chunk.start_grad_sync for model_chunk in model] if len(model) == 1: @@ -577,7 +581,10 @@ def train( # or random initialization don't propagate to all ranks in first all-gather (which is a # no-op if things work correctly). if should_disable_forward_pre_hook(args): - disable_forward_pre_hook(model, param_sync=False) + # On re-entry (2nd+ rollout), hooks were already disabled at the end of + # the previous call, so only disable if any chunk still has active hooks. + if any(len(m.remove_forward_pre_hook_handles) > 0 for m in model if isinstance(m, DDP)): + disable_forward_pre_hook(model, param_sync=False) # Also remove param_sync_func temporarily so that sync calls made in # `forward_backward_func` are no-ops. param_sync_func = config.param_sync_func