From a9d8a9b78bafc064d828a86f147549a53ac33fd0 Mon Sep 17 00:00:00 2001 From: Jonathan Thomm Date: Fri, 17 Apr 2026 17:24:37 +0000 Subject: [PATCH] Fix double prepare_grads / loss-scaler-double-update in train_one_step When args.check_for_nan_in_loss_and_grad=False, train_one_step called optimizer.prepare_grads() and then optimizer.step(). Megatron's MixedPrecisionOptimizer.step() calls prepare_grads() internally, so prepare_grads ran twice per step. With fp16 + a grad scaler, that: - advanced grad_scaler.update() twice per step, breaking dynamic loss scaling cadence; - on configurations where model_param.main_grad persists (typical DDP), re-copied scaled grads into main grads and unscaled them again. Additionally, when optimizer.step() returned (False, None, None) on a grad-scaler overflow, the subsequent assert update_successful fired even though that's a legitimate skipped-step signal, not a programming error. Call optimizer.step() exactly once and derive the overflow signal from its return values, skipping the LR scheduler advance and emitting a warning when an overflow occurs. Move the MTP CI gradient check before optimizer.step() since step() modifies gradients. --- slime/backends/megatron_utils/model.py | 32 +++++++++++++++----------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/slime/backends/megatron_utils/model.py b/slime/backends/megatron_utils/model.py index 0bbe5bf49b..6a4f207c41 100644 --- a/slime/backends/megatron_utils/model.py +++ b/slime/backends/megatron_utils/model.py @@ -431,33 +431,39 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p forward_only=False, ) + # CI check: verify only MTP parameters have non-zero gradients when truncation happens + # This check must happen before optimizer.step() as gradients may be modified during step + if args.ci_test and args.enable_mtp_training: + from slime.backends.megatron_utils.ci_utils import check_mtp_only_grad + + check_mtp_only_grad(model, step_id) + + # Update parameters. valid_step = True - grad_norm = float("nan") + update_successful, grad_norm, num_zeros_in_grad = optimizer.step() + if not getattr(args, "check_for_nan_in_loss_and_grad", True): - found_inf_flag = optimizer.prepare_grads() + found_inf_flag = not update_successful and grad_norm is None and num_zeros_in_grad is None if found_inf_flag: valid_step = False + current_scale = optimizer.get_loss_scale().item() + logger.warning( + "Inf found in gradients (step_id=%d, loss_scale=%s), skipping parameter update (dynamic loss scaling will reduce scale)", + step_id, + current_scale, + ) else: - grad_norm = optimizer.get_grad_norm() if isinstance(grad_norm, torch.Tensor): valid_step = not (torch.isnan(grad_norm) or torch.isinf(grad_norm)) else: valid_step = not (math.isnan(grad_norm) or math.isinf(grad_norm)) - # CI check: verify only MTP parameters have non-zero gradients when truncation happens - # This check must happen before optimizer.step() as gradients may be modified during step - if args.ci_test and args.enable_mtp_training: - from slime.backends.megatron_utils.ci_utils import check_mtp_only_grad - - check_mtp_only_grad(model, step_id) - if valid_step: - # Update parameters. - update_successful, grad_norm, num_zeros_in_grad = optimizer.step() - # Update learning rate. assert update_successful opt_param_scheduler.step(increment=args.global_batch_size) + else: + grad_norm = float("nan") # release grad for model_chunk in model: