-
Notifications
You must be signed in to change notification settings - Fork 4.8k
Fix ZeRO-3 optimizer initialization validation (#7844) #7929
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 4 commits
a39180e
d931be0
ddca910
b36d39a
b2e17ab
0711f9b
d749679
b1a3bc1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -414,6 +414,31 @@ def __init__(self, | |
|
|
||
| self.engine_timers_cache = {} | ||
|
|
||
| if self.optimizer_name() or self.client_optimizer is not None: | ||
| if self.optimizer is None: | ||
| raise RuntimeError( | ||
| "DeepSpeedEngine: Optimizer initialization failed. Check for JIT compilation errors.") | ||
|
|
||
| optimizer_methods = ['step', 'load_state_dict'] | ||
|
|
||
| if self.zero_optimization_partition_gradients(): | ||
| optimizer_methods.append('overlapping_partition_gradients_reduce_epilogue') | ||
|
|
||
| for method in optimizer_methods: | ||
| attr = getattr(self.optimizer, method, None) | ||
| if attr is None or not callable(attr): | ||
| raise RuntimeError( | ||
| f"DeepSpeedEngine: Optimizer missing callable `{method}`. " | ||
| "This indicates incomplete initialization (e.g., JIT/toolchain failure)." | ||
| ) | ||
|
|
||
| # Validate engine separately | ||
| if not hasattr(self, "backward") or not callable(getattr(self, "backward")): | ||
|
||
| raise RuntimeError( | ||
| "DeepSpeedEngine initialization failed: missing callable `backward`. " | ||
| "Engine may be partially initialized." | ||
| ) | ||
|
|
||
| if self.global_rank == 0: | ||
| self._config.print("DeepSpeedEngine configuration") | ||
| if self.dump_state(): | ||
|
|
@@ -2413,7 +2438,8 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): | |
| self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary() | ||
| # ZeRO stage >= 2 communicates during non gradient accumulation boundaries as well | ||
| if self.zero_optimization_partition_gradients(): | ||
| self.optimizer.overlapping_partition_gradients_reduce_epilogue() | ||
| if hasattr(self.optimizer, 'overlapping_partition_gradients_reduce_epilogue'): | ||
|
||
| self.optimizer.overlapping_partition_gradients_reduce_epilogue() | ||
|
|
||
| # Communicate only at gradient accumulation boundaries | ||
| elif self.is_gradient_accumulation_boundary(): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add
backwardto this list.