Skip to content
28 changes: 27 additions & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,31 @@ def __init__(self,

self.engine_timers_cache = {}

if self.optimizer_name() or self.client_optimizer is not None:
if self.optimizer is None:
raise RuntimeError(
"DeepSpeedEngine: Optimizer initialization failed. Check for JIT compilation errors.")

optimizer_methods = ['step', 'load_state_dict']
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add backward to this list.


if self.zero_optimization_partition_gradients():
optimizer_methods.append('overlapping_partition_gradients_reduce_epilogue')

for method in optimizer_methods:
attr = getattr(self.optimizer, method, None)
if attr is None or not callable(attr):
raise RuntimeError(
f"DeepSpeedEngine: Optimizer missing callable `{method}`. "
"This indicates incomplete initialization (e.g., JIT/toolchain failure)."
)

# Validate engine separately
if not hasattr(self, "backward") or not callable(getattr(self, "backward")):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies if not been previously clear, but self.optimizer.backward needs validating not self.backward.

raise RuntimeError(
"DeepSpeedEngine initialization failed: missing callable `backward`. "
"Engine may be partially initialized."
)

if self.global_rank == 0:
self._config.print("DeepSpeedEngine configuration")
if self.dump_state():
Expand Down Expand Up @@ -2413,7 +2438,8 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary()
# ZeRO stage >= 2 communicates during non gradient accumulation boundaries as well
if self.zero_optimization_partition_gradients():
self.optimizer.overlapping_partition_gradients_reduce_epilogue()
if hasattr(self.optimizer, 'overlapping_partition_gradients_reduce_epilogue'):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this check is now redundant due to line 425.

self.optimizer.overlapping_partition_gradients_reduce_epilogue()

# Communicate only at gradient accumulation boundaries
elif self.is_gradient_accumulation_boundary():
Expand Down
Loading