diff --git a/src/megatron/bridge/training/state.py b/src/megatron/bridge/training/state.py index 14113e4d7b..867e54c2b7 100644 --- a/src/megatron/bridge/training/state.py +++ b/src/megatron/bridge/training/state.py @@ -70,10 +70,10 @@ def state_dict(self) -> dict[str, torch.Tensor]: their corresponding tensor representations. """ return { - "step": torch.tensor(self.step, dtype=torch.int32), - "consumed_train_samples": torch.tensor(self.consumed_train_samples, dtype=torch.int32), - "skipped_train_samples": torch.tensor(self.skipped_train_samples, dtype=torch.int32), - "consumed_valid_samples": torch.tensor(self.consumed_valid_samples, dtype=torch.int32), + "step": torch.tensor(self.step, dtype=torch.int64), + "consumed_train_samples": torch.tensor(self.consumed_train_samples, dtype=torch.int64), + "skipped_train_samples": torch.tensor(self.skipped_train_samples, dtype=torch.int64), + "consumed_valid_samples": torch.tensor(self.consumed_valid_samples, dtype=torch.int64), "floating_point_operations_so_far": torch.tensor( self.floating_point_operations_so_far, dtype=torch.float64 ),