Skip to content
13 changes: 12 additions & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,15 @@ def __init__(self,

self.engine_timers_cache = {}

if self.optimizer_name() or self.client_optimizer is not None:
if self.optimizer is None:
raise RuntimeError("DeepSpeedEngine: Optimizer initialization failed. Check for JIT compilation errors.")
# ZeRO-3 specific check to prevent step 0 deadlocks
if self.zero_optimization_stage() == 3:
if not hasattr(self.optimizer, 'step'):
raise AttributeError("DeepSpeedEngine: ZeRO-3 optimizer is missing core functional attributes (.step). "
"This usually indicates a toolchain mismatch or failed JIT kernels.")

if self.global_rank == 0:
self._config.print("DeepSpeedEngine configuration")
if self.dump_state():
Expand Down Expand Up @@ -2413,7 +2422,9 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary()
# ZeRO stage >= 2 communicates during non gradient accumulation boundaries as well
if self.zero_optimization_partition_gradients():
self.optimizer.overlapping_partition_gradients_reduce_epilogue()
if hasattr(self.optimizer, 'overlapping_partition_gradients_reduce_epilogue'):
self.optimizer.overlapping_partition_gradients_reduce_epilogue()


# Communicate only at gradient accumulation boundaries
elif self.is_gradient_accumulation_boundary():
Expand Down
Loading