Skip to content

[ckpt] fix: Prevent int32 overflow for high train sample counts.#3290

Draft
BlueCrescent wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
BlueCrescent:fix_sample_count_type_in_checkpoint_state
Draft

[ckpt] fix: Prevent int32 overflow for high train sample counts.#3290
BlueCrescent wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
BlueCrescent:fix_sample_count_type_in_checkpoint_state

Conversation

@BlueCrescent
Copy link
Copy Markdown

@BlueCrescent BlueCrescent commented Apr 12, 2026

What does this PR do ?

Previously, when creating the TrainState for checkpointing int32 was used as type for encoding the number of consumed samples and related values. For longer trainings, torch will detect this as potential overflow and crash with RuntimeError: value cannot be converted to type int32 without overflow. We saw this crash while checkpointing step 700,000 while training with sequence length 8192 and batch size of 3072, so at 2,150,400,000 (> 2^31 - 1 = 2,147,483,648) samples.

Changelog

  • Switched from int32 to int64 for all values using this type in TrainState.

Summary by CodeRabbit

Bug Fixes

  • Improved integer precision in training state serialization. Counter values are now stored as 64-bit integers instead of 32-bit, preventing potential overflow issues during extended training runs. This ensures reliable checkpoint persistence and stability when handling large training step and sample count values.

Signed-off-by: BlueCrescent <7198877+BlueCrescent@users.noreply.github.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 12, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 12, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: e8c35d23-d746-4271-b7ee-f73d9a8e5323

📥 Commits

Reviewing files that changed from the base of the PR and between 7ea8a45 and e357135.

📒 Files selected for processing (1)
  • src/megatron/bridge/training/state.py

📝 Walkthrough

Walkthrough

Updated the TrainState.state_dict() method in the training state module to serialize integer training counters (step, consumed_train_samples, skipped_train_samples, consumed_valid_samples) as torch.int64 tensors instead of torch.int32 tensors.

Changes

Cohort / File(s) Summary
Training State Serialization
src/megatron/bridge/training/state.py
Changed dtype of integer training counter tensors from torch.int32 to torch.int64 in the state_dict() method (4 lines modified).

Estimated code review effort

🎯 1 (Trivial) | ⏱️ ~3 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main change: preventing int32 overflow for high train sample counts by switching to int64 serialization.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Test Results For Major Changes ✅ Passed PR contains minor changes (8 lines in one method) and comprehensive unit tests verifying int64 serialization with values exceeding int32 limits.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@BlueCrescent BlueCrescent marked this pull request as draft April 12, 2026 23:23
@BlueCrescent
Copy link
Copy Markdown
Author

Hmm, with this change I was able to write out the iteration 700,000 checkpoint. But when warmstarting from it I saw cuda OOM twice and iteration 700004: Unexpected result nan (message='found NaN in local forward loss calculation') once.
Need to investigate further.

@yaoyu-33 yaoyu-33 added area:ckpt Checkpoint conversion, loading, export, and save paths bug Something isn't working needs-author Author action is required before review or merge can continue labels Apr 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:ckpt Checkpoint conversion, loading, export, and save paths bug Something isn't working community-request needs-author Author action is required before review or merge can continue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants