[ckpt] fix: Prevent int32 overflow for high train sample counts.#3290
[ckpt] fix: Prevent int32 overflow for high train sample counts.#3290BlueCrescent wants to merge 1 commit intoNVIDIA-NeMo:mainfrom
Conversation
Signed-off-by: BlueCrescent <7198877+BlueCrescent@users.noreply.github.com>
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughUpdated the Changes
Estimated code review effort🎯 1 (Trivial) | ⏱️ ~3 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
Hmm, with this change I was able to write out the iteration 700,000 checkpoint. But when warmstarting from it I saw cuda OOM twice and |
What does this PR do ?
Previously, when creating the TrainState for checkpointing int32 was used as type for encoding the number of consumed samples and related values. For longer trainings, torch will detect this as potential overflow and crash with
RuntimeError: value cannot be converted to type int32 without overflow. We saw this crash while checkpointing step 700,000 while training with sequence length 8192 and batch size of 3072, so at 2,150,400,000 (> 2^31 - 1 = 2,147,483,648) samples.Changelog
Summary by CodeRabbit
Bug Fixes