From e3571353b2fe624fd28963eeeaaa09ebcf4b8e4d Mon Sep 17 00:00:00 2001 From: BlueCrescent <7198877+BlueCrescent@users.noreply.github.com> Date: Mon, 13 Apr 2026 00:31:11 +0200 Subject: [PATCH] [ckpt] fix: Prevent int32 overflow for high train sample counts. Signed-off-by: BlueCrescent <7198877+BlueCrescent@users.noreply.github.com> --- src/megatron/bridge/training/state.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/megatron/bridge/training/state.py b/src/megatron/bridge/training/state.py index 14113e4d7b..867e54c2b7 100644 --- a/src/megatron/bridge/training/state.py +++ b/src/megatron/bridge/training/state.py @@ -70,10 +70,10 @@ def state_dict(self) -> dict[str, torch.Tensor]: their corresponding tensor representations. """ return { - "step": torch.tensor(self.step, dtype=torch.int32), - "consumed_train_samples": torch.tensor(self.consumed_train_samples, dtype=torch.int32), - "skipped_train_samples": torch.tensor(self.skipped_train_samples, dtype=torch.int32), - "consumed_valid_samples": torch.tensor(self.consumed_valid_samples, dtype=torch.int32), + "step": torch.tensor(self.step, dtype=torch.int64), + "consumed_train_samples": torch.tensor(self.consumed_train_samples, dtype=torch.int64), + "skipped_train_samples": torch.tensor(self.skipped_train_samples, dtype=torch.int64), + "consumed_valid_samples": torch.tensor(self.consumed_valid_samples, dtype=torch.int64), "floating_point_operations_so_far": torch.tensor( self.floating_point_operations_so_far, dtype=torch.float64 ),