diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 5fc07b3a7238..2314ca3f6ab6 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -24,17 +24,29 @@ import deepspeed from deepspeed import comm as dist -from deepspeed.runtime.utils import see_memory_usage, DummyOptim, register_output_backward_hooks, check_internal_apis_for_count_used_parameters +from deepspeed.runtime.utils import ( + see_memory_usage, + DummyOptim, + register_output_backward_hooks, + check_internal_apis_for_count_used_parameters, +) from .zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum from deepspeed.runtime.base_optimizer import ZeROOptimizer from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer from deepspeed.runtime.zenflow.zenflow_stage_1_and_2 import ZenFlowZeroOptimizer from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus -from deepspeed.runtime.zero.utils import is_zero_supported_optimizer, ZeRORuntimeException +from deepspeed.runtime.zero.utils import ( + is_zero_supported_optimizer, + ZeRORuntimeException, +) from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION -from deepspeed.runtime.zenflow.engine import (configure_zenflow, zenflow_step, is_zenflow_update_boundary, - sync_zenflow_optimizer_lr) +from deepspeed.runtime.zenflow.engine import ( + configure_zenflow, + zenflow_step, + is_zenflow_update_boundary, + sync_zenflow_optimizer_lr, +) from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer from deepspeed.runtime.fp16.loss_scaler import LossScaleConfig, LossScaleProfile @@ -42,35 +54,73 @@ from deepspeed.runtime.bf16_optimizer import BF16_Optimizer from deepspeed.linear.optimized_linear import LoRAOptimizedLinear -from deepspeed.module_inject.layers import GatherReplacedLayerParams, configure_tensor_parallel_runtime, collect_autotp_universal_checkpoint_info -from deepspeed.runtime.config import DEEPSPEED_OPTIMIZERS, \ - ADAGRAD_OPTIMIZER, ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \ - TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT, ZERO_ONE_ADAM_OPTIMIZER, MUADAM_OPTIMIZER, MUADAMW_OPTIMIZER, \ - MUSGD_OPTIMIZER, LION_OPTIMIZER, MUON_OPTIMIZER - -from deepspeed.runtime.model_checkpointing.constants import ValidationMode, \ - CHECKPOINT_TAG_VALIDATION, CHECKPOINT_WRITER, CHECKPOINT_SERIALIZATION +from deepspeed.module_inject.layers import ( + GatherReplacedLayerParams, + configure_tensor_parallel_runtime, + collect_autotp_universal_checkpoint_info, +) +from deepspeed.runtime.config import ( + DEEPSPEED_OPTIMIZERS, + ADAGRAD_OPTIMIZER, + ADAM_OPTIMIZER, + ADAMW_OPTIMIZER, + LAMB_OPTIMIZER, + ONEBIT_ADAM_OPTIMIZER, + ONEBIT_LAMB_OPTIMIZER, + TORCH_ADAM_PARAM, + ADAM_W_MODE, + ADAM_W_MODE_DEFAULT, + ZERO_ONE_ADAM_OPTIMIZER, + MUADAM_OPTIMIZER, + MUADAMW_OPTIMIZER, + MUSGD_OPTIMIZER, + LION_OPTIMIZER, + MUON_OPTIMIZER, +) + +from deepspeed.runtime.model_checkpointing.constants import ( + ValidationMode, + CHECKPOINT_TAG_VALIDATION, + CHECKPOINT_WRITER, + CHECKPOINT_SERIALIZATION, +) from deepspeed.runtime.dataloader import DeepSpeedDataLoader from deepspeed.runtime.zero.muon.muon_optimizer import MuonWithAuxAdam -from deepspeed.runtime.constants import \ - ROUTE_TRAIN, ROUTE_PREDICT, ROUTE_EVAL, \ - PLD_THETA, PLD_GAMMA, BFLOAT16, FP16, AMP, GRADIENT_ACCUMULATION_STEPS, \ - DATA_PARALLEL_GROUP, GLOBAL_RANK, DDP_BFLOAT16 +from deepspeed.runtime.constants import ( + ROUTE_TRAIN, + ROUTE_PREDICT, + ROUTE_EVAL, + PLD_THETA, + PLD_GAMMA, + BFLOAT16, + FP16, + AMP, + GRADIENT_ACCUMULATION_STEPS, + DATA_PARALLEL_GROUP, + GLOBAL_RANK, + DDP_BFLOAT16, +) from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.compression import compression_scheduler -from deepspeed.compression.constants import \ - WEIGHT_QUANTIZE_IN_FORWARD_ENABLED, \ - WEIGHT_QUANTIZATION, SHARED_PARAMETERS, \ - WEIGHT_QUANTIZE_ENABLED, \ - WEIGHT_QUANTIZE_GROUPS, \ - WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE, \ - WEIGHT_QUANTIZE_CHANGE_RATIO, \ - WEIGHT_QUANTIZE_TYPE, \ - WEIGHT_QUANTIZE_ROUNDING, \ - WEIGHT_QUANTIZE_VERBOSE, \ - WEIGHT_QUANTIZE_KERNEL -from deepspeed.checkpoint.constants import OPTIMIZER_STATE_DICT, FROZEN_PARAM_FRAGMENTS, UNIVERSAL_CHECKPOINT_INFO +from deepspeed.compression.constants import ( + WEIGHT_QUANTIZE_IN_FORWARD_ENABLED, + WEIGHT_QUANTIZATION, + SHARED_PARAMETERS, + WEIGHT_QUANTIZE_ENABLED, + WEIGHT_QUANTIZE_GROUPS, + WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE, + WEIGHT_QUANTIZE_CHANGE_RATIO, + WEIGHT_QUANTIZE_TYPE, + WEIGHT_QUANTIZE_ROUNDING, + WEIGHT_QUANTIZE_VERBOSE, + WEIGHT_QUANTIZE_KERNEL, +) +from deepspeed.checkpoint.constants import ( + OPTIMIZER_STATE_DICT, + FROZEN_PARAM_FRAGMENTS, + UNIVERSAL_CHECKPOINT_INFO, +) from deepspeed.checkpoint.utils import clone_tensors_for_torch_save from deepspeed.checkpoint.ds_to_universal import dp_index_to_str from deepspeed.runtime.sparse_tensor import SparseTensor @@ -80,31 +130,71 @@ from deepspeed.utils import logger, log_dist, log_dist_once, instrument_w_nvtx from deepspeed.utils.torch import required_torch_version from deepspeed.utils.z3_leaf_module import apply_zero_leaf_module_config -from deepspeed.utils.timer import NoopTimer, ThroughputTimer, SynchronizedWallClockTimer, \ - FORWARD_MICRO_TIMER, BACKWARD_MICRO_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_REDUCE_MICRO_TIMER, \ - STEP_MICRO_TIMER, \ - FORWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER, BACKWARD_INNER_GLOBAL_TIMER, BACKWARD_REDUCE_GLOBAL_TIMER, \ - STEP_GLOBAL_TIMER -from deepspeed.utils.debug import debug_extract_module_and_param_names, debug_clear_module_and_param_names +from deepspeed.utils.timer import ( + NoopTimer, + ThroughputTimer, + SynchronizedWallClockTimer, + FORWARD_MICRO_TIMER, + BACKWARD_MICRO_TIMER, + BACKWARD_INNER_MICRO_TIMER, + BACKWARD_REDUCE_MICRO_TIMER, + STEP_MICRO_TIMER, + FORWARD_GLOBAL_TIMER, + BACKWARD_GLOBAL_TIMER, + BACKWARD_INNER_GLOBAL_TIMER, + BACKWARD_REDUCE_GLOBAL_TIMER, + STEP_GLOBAL_TIMER, +) +from deepspeed.utils.debug import ( + debug_extract_module_and_param_names, + debug_clear_module_and_param_names, +) from deepspeed.monitor.monitor import MonitorMaster from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop -from deepspeed.runtime.utils import clip_grad_norm_, compare_tensors_in_structures, maybe_loss_for_backward +from deepspeed.runtime.utils import ( + clip_grad_norm_, + compare_tensors_in_structures, + maybe_loss_for_backward, +) from deepspeed.runtime.eigenvalue import Eigenvalue -from deepspeed.runtime.data_pipeline.constants import DATA_SAMPLING, \ - DATA_ROUTING, DATA_SAMPLING_ENABLED, CURRICULUM_LEARNING, \ - CURRICULUM_LEARNING_ENABLED, DATA_SAMPLING_NUM_WORKERS, RANDOM_LTD, \ - RANDOM_LTD_ENABLED, RANDOM_LTD_LAYER_ID, RANDOM_LTD_LAYER_NUM, \ - RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE, RANDOM_LTD_LAYER_TOKEN_LR_ENABLED, \ - RANDOM_LTD_GLOBAL_BATCH_SIZE, RANDOM_LTD_MICRO_BATCH_SIZE, DATA_EFFICIENCY +from deepspeed.runtime.data_pipeline.constants import ( + DATA_SAMPLING, + DATA_ROUTING, + DATA_SAMPLING_ENABLED, + CURRICULUM_LEARNING, + CURRICULUM_LEARNING_ENABLED, + DATA_SAMPLING_NUM_WORKERS, + RANDOM_LTD, + RANDOM_LTD_ENABLED, + RANDOM_LTD_LAYER_ID, + RANDOM_LTD_LAYER_NUM, + RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE, + RANDOM_LTD_LAYER_TOKEN_LR_ENABLED, + RANDOM_LTD_GLOBAL_BATCH_SIZE, + RANDOM_LTD_MICRO_BATCH_SIZE, + DATA_EFFICIENCY, +) from deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler -from deepspeed.runtime.checkpoint_engine import (create_checkpoint_engine, TorchCheckpointEngine, CheckpointCommitInfo) +from deepspeed.runtime.checkpoint_engine import ( + create_checkpoint_engine, + TorchCheckpointEngine, + CheckpointCommitInfo, +) from deepspeed.runtime.data_pipeline.data_routing.scheduler import RandomLTDScheduler -from deepspeed.runtime.data_pipeline.data_routing.helper import remove_random_ltd_state_dict -from deepspeed.runtime.data_pipeline.data_routing.basic_layer import RandomLayerTokenDrop +from deepspeed.runtime.data_pipeline.data_routing.helper import ( + remove_random_ltd_state_dict, +) +from deepspeed.runtime.data_pipeline.data_routing.basic_layer import ( + RandomLayerTokenDrop, +) from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint -from deepspeed.runtime.torch_autocast import init_autocast_params, get_default_autocast_lower_precision_modules, autocast_if_enabled +from deepspeed.runtime.torch_autocast import ( + init_autocast_params, + get_default_autocast_lower_precision_modules, + autocast_if_enabled, +) from .pipe.module import PipelineModule from .utils import get_ma_status @@ -122,22 +212,33 @@ from deepspeed.runtime.config import DtypeEnum -from deepspeed.compile.util import is_deepcompile_supported, get_deepcompile_handle, deepcompile_backward_prologue +from deepspeed.compile.util import ( + is_deepcompile_supported, + get_deepcompile_handle, + deepcompile_backward_prologue, +) from deepspeed.compile.backend import register_compile_pass, opt_passes -from deepspeed.compile.passes import zero3_compile, prefetch, selective_gather, offload_adam_states +from deepspeed.compile.passes import ( + zero3_compile, + prefetch, + selective_gather, + offload_adam_states, +) from deepspeed.compile.init_z1 import init_z1 from deepspeed.compile.init_z3 import init_z3 from deepspeed.compile.init_sp import init_autosp MEMORY_OPT_ALLREDUCE_SIZE = 500000000 -DeepSpeedOptimizerCallable = \ - Callable[[Union[Iterable[Parameter], Dict[str, Iterable]]], Optimizer] +DeepSpeedOptimizerCallable = Callable[ + [Union[Iterable[Parameter], Dict[str, Iterable]]], Optimizer +] DeepSpeedSchedulerCallable = Callable[[Optimizer], _LRScheduler] try: import apex from apex import amp + APEX_INSTALLED = True except ImportError: # Fail silently so we don't spam logs unnecessarily if user isn't using amp @@ -149,7 +250,9 @@ def split_half_float_double_sparse(tensors): supported_types = get_accelerator().supported_dtypes() for t in tensors: - assert t.dtype in supported_types, f"attempting to reduce an unsupported grad type: {t.dtype}" + assert ( + t.dtype in supported_types + ), f"attempting to reduce an unsupported grad type: {t.dtype}" sparse_tensor_buckets, dense_tensor_buckets = [], [] for i, dtype in enumerate(supported_types): @@ -186,8 +289,11 @@ def __init__(self, enable_micro_timers, enable_global_timers): self.backward_reduce_timers += [BACKWARD_REDUCE_MICRO_TIMER] self.step_timers += [STEP_MICRO_TIMER] self.micro_timers += [ - FORWARD_MICRO_TIMER, BACKWARD_MICRO_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_REDUCE_MICRO_TIMER, - STEP_MICRO_TIMER + FORWARD_MICRO_TIMER, + BACKWARD_MICRO_TIMER, + BACKWARD_INNER_MICRO_TIMER, + BACKWARD_REDUCE_MICRO_TIMER, + STEP_MICRO_TIMER, ] if enable_global_timers: @@ -197,8 +303,11 @@ def __init__(self, enable_micro_timers, enable_global_timers): self.backward_reduce_timers += [BACKWARD_REDUCE_GLOBAL_TIMER] self.step_timers += [STEP_GLOBAL_TIMER] self.global_timers += [ - FORWARD_GLOBAL_TIMER, BACKWARD_GLOBAL_TIMER, BACKWARD_INNER_GLOBAL_TIMER, BACKWARD_REDUCE_GLOBAL_TIMER, - STEP_GLOBAL_TIMER + FORWARD_GLOBAL_TIMER, + BACKWARD_GLOBAL_TIMER, + BACKWARD_INNER_GLOBAL_TIMER, + BACKWARD_REDUCE_GLOBAL_TIMER, + STEP_GLOBAL_TIMER, ] def active_timers(self): @@ -208,20 +317,22 @@ def active_timers(self): class DeepSpeedEngine(Module): r"""DeepSpeed engine for training.""" - def __init__(self, - args, - model, - optimizer=None, - model_parameters=None, - training_data=None, - lr_scheduler=None, - mpu=None, - dist_init_required=None, - collate_fn=None, - config=None, - config_class=None, - mesh_device=None, - dont_change_device=False): + def __init__( + self, + args, + model, + optimizer=None, + model_parameters=None, + training_data=None, + lr_scheduler=None, + mpu=None, + dist_init_required=None, + collate_fn=None, + config=None, + config_class=None, + mesh_device=None, + dont_change_device=False, + ): super(DeepSpeedEngine, self).__init__() self.dont_change_device = dont_change_device self.client_optimizer = optimizer @@ -254,7 +365,9 @@ def __init__(self, self.moe_layers = [] self._step_applied = False self._global_grad_norm = None - self.use_ds_comm = False # False --> Use torch.dist, True --> Use ds.comm backend. + self.use_ds_comm = ( + False # False --> Use torch.dist, True --> Use ds.comm backend. + ) self.checkpoint_engine = None self.optimizer = None self.basic_optimizer = None @@ -279,12 +392,16 @@ def __init__(self, self._do_sanity_check() if self.autotp_size() > 1: self._configure_tensor_parallel(model, self.tensor_parallel_config()) - see_memory_usage("DeepSpeed Engine: After args sanity test", force=self.memory_breakdown()) + see_memory_usage( + "DeepSpeed Engine: After args sanity test", force=self.memory_breakdown() + ) if mpu is not None: if self.elasticity_enabled(): if not self.is_elastic_model_parallel_supported(): - assert not self.elasticity_enabled(), ("Elasticity is not currently supported" - " with model parallelism.") + assert not self.elasticity_enabled(), ( + "Elasticity is not currently supported" + " with model parallelism." + ) self._set_distributed_vars(args) @@ -318,15 +435,22 @@ def __init__(self, # Configure wall clock timers self.timers = SynchronizedWallClockTimer() # Throughput timer - self.tput_timer = ThroughputTimer(self._config.timers_config, - batch_size=self.train_batch_size(), - steps_per_output=self.steps_per_print(), - monitor_memory=False) + self.tput_timer = ThroughputTimer( + self._config.timers_config, + batch_size=self.train_batch_size(), + steps_per_output=self.steps_per_print(), + monitor_memory=False, + ) - log_dist(f"DeepSpeed Flops Profiler Enabled: {self.flops_profiler_enabled()}", ranks=[0]) + log_dist( + f"DeepSpeed Flops Profiler Enabled: {self.flops_profiler_enabled()}", + ranks=[0], + ) if self.flops_profiler_enabled(): - self.flops_profiler = FlopsProfiler(self.module, self, self.flops_profiler_recompute_fwd_factor()) + self.flops_profiler = FlopsProfiler( + self.module, self, self.flops_profiler_recompute_fwd_factor() + ) if training_data: self.training_dataloader = self.deepspeed_io(training_data) @@ -350,9 +474,18 @@ def __init__(self, # ZeRO1/2/3 optimizers have their own grad scaler logic self.torch_autocast_z0_gradscaler = None if self.torch_autocast_enabled(): - init_autocast_params(self, self.torch_autocast_dtype(), self.torch_autocast_lower_precision_safe_modules()) - if (not self.zero_optimization() and self.torch_autocast_dtype() == torch.float16): - self.torch_autocast_z0_gradscaler = torch.amp.GradScaler(device=get_accelerator().device_name()) + init_autocast_params( + self, + self.torch_autocast_dtype(), + self.torch_autocast_lower_precision_safe_modules(), + ) + if ( + not self.zero_optimization() + and self.torch_autocast_dtype() == torch.float16 + ): + self.torch_autocast_z0_gradscaler = torch.amp.GradScaler( + device=get_accelerator().device_name() + ) self._configure_zenflow = lambda: configure_zenflow(self) self._is_zenflow_update_boundary = lambda: is_zenflow_update_boundary(self) @@ -372,8 +505,9 @@ def __init__(self, self.optimizer = self._configure_bf16_optimizer(optimizer=None) # Hook optimizer for snip_momentum pruning - if hasattr(model, 'pruners'): + if hasattr(model, "pruners"): from ..compression.helper import rewrite_optimizer_step + self.optimizer.pruners = model.pruners rewrite_optimizer_step(self.optimizer) @@ -381,9 +515,14 @@ def __init__(self, self.sparse_tensor_module_names = set() # if self.sparse_gradients_enabled(): for name, module in self.module.named_modules(): - if isinstance(module, (torch.nn.Embedding, torch.nn.EmbeddingBag)) and self.sparse_gradients_enabled(): + if ( + isinstance(module, (torch.nn.Embedding, torch.nn.EmbeddingBag)) + and self.sparse_gradients_enabled() + ): self.sparse_tensor_module_names.add(name + ".weight") - logger.info("Will convert {} to sparse tensor during training".format(name)) + logger.info( + "Will convert {} to sparse tensor during training".format(name) + ) self._optimized_linear_offload_setup() @@ -399,22 +538,51 @@ def __init__(self, self.progressive_layer_drop = self._configure_progressive_layer_drop() if self.curriculum_enabled_legacy(): - self.curriculum_scheduler_legacy = self._configure_curriculum_scheduler_legacy() + self.curriculum_scheduler_legacy = ( + self._configure_curriculum_scheduler_legacy() + ) if self.random_ltd_enabled(): random_ltd_config = self.random_ltd_config() random_ltd_config[RANDOM_LTD_GLOBAL_BATCH_SIZE] = self.train_batch_size() - random_ltd_config[RANDOM_LTD_MICRO_BATCH_SIZE] = self.train_micro_batch_size_per_gpu() - self.random_ltd_scheduler = self._configure_random_ltd_scheduler(random_ltd_config) + random_ltd_config[RANDOM_LTD_MICRO_BATCH_SIZE] = ( + self.train_micro_batch_size_per_gpu() + ) + self.random_ltd_scheduler = self._configure_random_ltd_scheduler( + random_ltd_config + ) # Engine timers - self.engine_timers = EngineTimers(enable_micro_timers=self.wall_clock_breakdown(), - enable_global_timers=self.wall_clock_breakdown() - or self.flops_profiler_enabled()) + self.engine_timers = EngineTimers( + enable_micro_timers=self.wall_clock_breakdown(), + enable_global_timers=self.wall_clock_breakdown() + or self.flops_profiler_enabled(), + ) self.engine_timers_cache = {} + if self.optimizer_name() or self.client_optimizer is not None: + if self.optimizer is None: + raise RuntimeError( + "DeepSpeedEngine: Optimizer initialization failed. Check for JIT compilation errors." + ) + + optimizer_methods = ["step", "load_state_dict", "backward"] + + if self.zero_optimization_partition_gradients(): + optimizer_methods.append( + "overlapping_partition_gradients_reduce_epilogue" + ) + + for method in optimizer_methods: + attr = getattr(self.optimizer, method, None) + if attr is None or not callable(attr): + raise RuntimeError( + f"DeepSpeedEngine: Optimizer missing callable `{method}`. " + "This indicates incomplete initialization (e.g., JIT/toolchain failure)." + ) + if self.global_rank == 0: self._config.print("DeepSpeedEngine configuration") if self.dump_state(): @@ -427,10 +595,16 @@ def __init__(self, self._is_compiled = False if is_deepcompile_supported(): # Predefined compile passes - self.register_compile_pass(zero3_compile.NAME, zero3_compile.add_z3_gather_release) + self.register_compile_pass( + zero3_compile.NAME, zero3_compile.add_z3_gather_release + ) self.register_compile_pass(prefetch.NAME, prefetch.schedule_prefetch) - self.register_compile_pass(selective_gather.NAME, selective_gather.selective_gather) - self.register_compile_pass(offload_adam_states.NAME, offload_adam_states.move_opt_states) + self.register_compile_pass( + selective_gather.NAME, selective_gather.selective_gather + ) + self.register_compile_pass( + offload_adam_states.NAME, offload_adam_states.move_opt_states + ) # We now support PyTorch style backward, but it relies on the counter in ZeRO optimizers. # However, we need some internal APIs to count the number of only used parameters. @@ -441,7 +615,10 @@ def __init__(self, self._support_torch_style_backward = False # Flag to control whether gradients should be scaled by gradient accumulation steps self._scale_wrt_gas = True - if isinstance(self.optimizer, ZeROOptimizer) and check_internal_apis_for_count_used_parameters(): + if ( + isinstance(self.optimizer, ZeROOptimizer) + and check_internal_apis_for_count_used_parameters() + ): self._support_torch_style_backward = True # These hooks are used for non-scalar backward support, such as `out.backward(out_grad)`, # not for `engine.backward(loss)`. In this case, we need to ensure that the preprocessing @@ -470,8 +647,9 @@ def _optimized_linear_offload_setup(self): if isinstance(module, LoRAOptimizedLinear): self.optimized_linear_lora_enabled = True if offload_ratio is not None: - assert offload_ratio == module.lora_config.offload_ratio, \ - "all lora_config offload ratios should be the same across the model" + assert ( + offload_ratio == module.lora_config.offload_ratio + ), "all lora_config offload ratios should be the same across the model" offload_ratio = module.lora_config.offload_ratio if module.zero_shards > 1: # set attr so checkpoint saving can handle BWS properly @@ -483,14 +661,16 @@ def _optimized_linear_offload_setup(self): total_params = 0 for _, p in self.module.named_parameters(): - if hasattr(p, 'ds_optim_param'): + if hasattr(p, "ds_optim_param"): total_params += p.numel() offload_limit = total_params * offload_ratio - logger.info(f'offloading {offload_ratio*100}% of eligible params, specifically {offload_limit} params') + logger.info( + f"offloading {offload_ratio*100}% of eligible params, specifically {offload_limit} params" + ) total_offloaded = 0 for _, p in self.module.named_parameters(): - if hasattr(p, 'ds_optim_param'): + if hasattr(p, "ds_optim_param"): if total_offloaded < offload_limit: total_offloaded += p.numel() p.ds_offload = True @@ -512,8 +692,9 @@ def _configure_tensor_parallel_states(self, model): self._set_client_model(model) # sanity check # currently, the compatibility between 'autotp' and 'zero > 1' has not been validated - assert self.zero_optimization_stage( - ) <= 2, "Currently, the compatibility between 'autotp' and 'zero_stage = 3' has not been validated" + assert ( + self.zero_optimization_stage() <= 2 + ), "Currently, the compatibility between 'autotp' and 'zero_stage = 3' has not been validated" self.mpu = groups self.mpu._init_tp_mesh_device(tensor_model_parallel_size=self.autotp_size()) @@ -528,30 +709,44 @@ def broadcast_and_check(args, bcast_rank, bcast_group): if len(args) > 0: if self.mpu.get_tensor_model_parallel_rank() == 0: _src_args = [args] - dist.broadcast_object_list(object_list=_src_args, - src=bcast_rank, - group=bcast_group, - device=torch.device(get_accelerator().current_device_name())) + dist.broadcast_object_list( + object_list=_src_args, + src=bcast_rank, + group=bcast_group, + device=torch.device( + get_accelerator().current_device_name() + ), + ) # Rank 0 does not need to compare with itself is_equal = True else: _src_args = [None] - dist.broadcast_object_list(object_list=_src_args, - src=bcast_rank, - group=bcast_group, - device=torch.device(get_accelerator().current_device_name())) + dist.broadcast_object_list( + object_list=_src_args, + src=bcast_rank, + group=bcast_group, + device=torch.device( + get_accelerator().current_device_name() + ), + ) is_equal = compare_tensors_in_structures(args, _src_args[0]) - equal_tensor = torch.tensor(is_equal, - dtype=self.communication_data_type, - device=torch.device(get_accelerator().current_device_name())) + equal_tensor = torch.tensor( + is_equal, + dtype=self.communication_data_type, + device=torch.device(get_accelerator().current_device_name()), + ) dist.all_reduce(equal_tensor, group=bcast_group) assert torch.equal( equal_tensor, - torch.tensor(groups.get_tensor_model_parallel_world_size(), - dtype=self.communication_data_type, - device=torch.device(get_accelerator().current_device_name())) + torch.tensor( + groups.get_tensor_model_parallel_world_size(), + dtype=self.communication_data_type, + device=torch.device( + get_accelerator().current_device_name() + ), + ), ), "Data inconsistency within the TP group. Please check the Dataloader implementation to ensure consistency." bcast_rank = self.mpu.get_tensor_model_parallel_src_rank() @@ -563,9 +758,9 @@ def broadcast_and_check(args, bcast_rank, bcast_group): logger.info(":The Dataloader has passed the TP group consistency check.") self.first_dataloader_check.remove() - self.first_dataloader_check = self.module.register_forward_pre_hook(check_dataloader_inputs_same_across_ranks, - prepend=True, - with_kwargs=True) + self.first_dataloader_check = self.module.register_forward_pre_hook( + check_dataloader_inputs_same_across_ranks, prepend=True, with_kwargs=True + ) def _apply_autotp_partitioning(self, model, tp_config): if getattr(model, "ds_autotp_parsed", False): @@ -575,11 +770,15 @@ def _apply_autotp_partitioning(self, model, tp_config): tp_size = self.autotp_size() if tp_config.tensor_parallel.tp_size not in (1, tp_size): - raise ValueError(f"tensor_parallel.tp.tp_size ({tp_config.tensor_parallel.tp_size}) " - f"does not match tensor_parallel.autotp_size ({tp_size}).") + raise ValueError( + f"tensor_parallel.tp.tp_size ({tp_config.tensor_parallel.tp_size}) " + f"does not match tensor_parallel.autotp_size ({tp_size})." + ) tp_config.tensor_parallel.tp_size = tp_size if tp_config.tensor_parallel.tp_group is None: - tp_config.tensor_parallel.tp_group = groups.get_tensor_model_parallel_group() + tp_config.tensor_parallel.tp_group = ( + groups.get_tensor_model_parallel_group() + ) from deepspeed.module_inject.auto_tp import AutoTP @@ -589,18 +788,26 @@ def _apply_autotp_partitioning(self, model, tp_config): partition_config = tp_config.get_partition_config_object() if partition_config is not None: - autotp = AutoTP(module=model, - all_reduce_linears=(), - prefix="", - state_dict=None, - linear_layer_setting=(torch.nn.Linear, torch.nn.Embedding), - orig_layer_impl=None, - keep_module_on_host=tp_config.keep_module_on_host, - partition_config=partition_config) - autotp.set_tensor_parallel_config(tp_size, tp_config.tensor_parallel.tp_group) + autotp = AutoTP( + module=model, + all_reduce_linears=(), + prefix="", + state_dict=None, + linear_layer_setting=(torch.nn.Linear, torch.nn.Embedding), + orig_layer_impl=None, + keep_module_on_host=tp_config.keep_module_on_host, + partition_config=partition_config, + ) + autotp.set_tensor_parallel_config( + tp_size, tp_config.tensor_parallel.tp_group + ) autotp.update_linear_policies() autotp._replace_module(model) - setattr(model, UNIVERSAL_CHECKPOINT_INFO, collect_autotp_universal_checkpoint_info(model)) + setattr( + model, + UNIVERSAL_CHECKPOINT_INFO, + collect_autotp_universal_checkpoint_info(model), + ) setattr(model, "ds_autotp_parsed", True) return @@ -620,7 +827,9 @@ def _apply_autotp_partitioning(self, model, tp_config): layer_specs = TPPlanConverter.convert(hf_tp_plan) if layer_specs is not None: - logger.info(f"Using HuggingFace tp_plan with {len(layer_specs)} layer specifications") + logger.info( + f"Using HuggingFace tp_plan with {len(layer_specs)} layer specifications" + ) tp_plan_config = AutoTPConfig(tp_size=tp_size, layer_specs=layer_specs) autotp = AutoTP( module=model, @@ -632,7 +841,9 @@ def _apply_autotp_partitioning(self, model, tp_config): keep_module_on_host=tp_config.keep_module_on_host, partition_config=tp_plan_config, ) - autotp.set_tensor_parallel_config(tp_size, tp_config.tensor_parallel.tp_group) + autotp.set_tensor_parallel_config( + tp_size, tp_config.tensor_parallel.tp_group + ) autotp.update_linear_policies() autotp._replace_module(model) setattr(model, "ds_autotp_parsed", True) @@ -641,9 +852,15 @@ def _apply_autotp_partitioning(self, model, tp_config): parser_dict = AutoTP.tp_parser(model) for client_module, injection_policy in parser_dict: tp_config.injection_policy_tuple = injection_policy - replace_transformer_layer(client_module, model, None, tp_config, model_config) + replace_transformer_layer( + client_module, model, None, tp_config, model_config + ) - setattr(model, UNIVERSAL_CHECKPOINT_INFO, collect_autotp_universal_checkpoint_info(model)) + setattr( + model, + UNIVERSAL_CHECKPOINT_INFO, + collect_autotp_universal_checkpoint_info(model), + ) setattr(model, "ds_autotp_parsed", True) def __del__(self): @@ -651,11 +868,13 @@ def __del__(self): self.destroy() except Exception as exc: # Avoid destructor-time exceptions for partially initialized engines. - logger.debug("DeepSpeedEngine.__del__ cleanup skipped: %s", exc, exc_info=True) + logger.debug( + "DeepSpeedEngine.__del__ cleanup skipped: %s", exc, exc_info=True + ) def destroy(self): optimizer = getattr(self, "optimizer", None) - if optimizer is not None and hasattr(optimizer, 'destroy'): + if optimizer is not None and hasattr(optimizer, "destroy"): optimizer.destroy() if self.is_deepcompile_active(): get_deepcompile_handle().cleanup() @@ -682,8 +901,12 @@ def _get_model_parameters(self): if p.requires_grad: trainable_num_params += n if self.global_rank == 0: - self.autotuning_model_info["num_params"] = num_params * self.mp_world_size - self.autotuning_model_info["trainable_num_params"] = trainable_num_params * self.mp_world_size + self.autotuning_model_info["num_params"] = ( + num_params * self.mp_world_size + ) + self.autotuning_model_info["trainable_num_params"] = ( + trainable_num_params * self.mp_world_size + ) logger.info(f"model parameter = {num_params}") @@ -713,10 +936,18 @@ def set_train_batch_size(self, train_batch_size): ValueError: if ``train_batch_size`` is not divisible by the configured micro-batch size and data parallelism. """ - if train_batch_size % (self.train_micro_batch_size_per_gpu() * self.dp_world_size) != 0: - #print(f'{train_batch_size=} {self.train_micro_batch_size_per_gpu()=} {self.dp_world_size=}') - raise ValueError('Train batch size must be divisible by micro-batch data parallelism') - new_gas = train_batch_size // (self.train_micro_batch_size_per_gpu() * self.dp_world_size) + if ( + train_batch_size + % (self.train_micro_batch_size_per_gpu() * self.dp_world_size) + != 0 + ): + # print(f'{train_batch_size=} {self.train_micro_batch_size_per_gpu()=} {self.dp_world_size=}') + raise ValueError( + "Train batch size must be divisible by micro-batch data parallelism" + ) + new_gas = train_batch_size // ( + self.train_micro_batch_size_per_gpu() * self.dp_world_size + ) # overwrite config self._config.train_batch_size = train_batch_size self._config.gradient_accumulation_steps = new_gas @@ -728,7 +959,11 @@ def set_train_micro_batch_size(self, micro_batch_size): micro_batch_size (int): The new micro batch size for training. """ # overwrite config - new_global_batch_size = micro_batch_size * self._config.gradient_accumulation_steps * self.dp_world_size + new_global_batch_size = ( + micro_batch_size + * self._config.gradient_accumulation_steps + * self.dp_world_size + ) self._config.train_batch_size = new_global_batch_size self._config.train_micro_batch_size_per_gpu = micro_batch_size @@ -738,7 +973,9 @@ def set_data_post_process_func(self, post_process_func): def set_custom_curriculum_learning_schedule(self, schedule_func_dict): if self.training_dataloader is not None and self.curriculum_learning_enabled(): - self.training_dataloader.data_sampler.set_custom_curriculum_learning_schedule(schedule_func_dict) + self.training_dataloader.data_sampler.set_custom_curriculum_learning_schedule( + schedule_func_dict + ) def get_global_grad_norm(self) -> float: """Return the 2-norm of all gradients. If there is model parallelism, @@ -759,13 +996,15 @@ def __getattr__(self, name): _module = {} if "module" in self.__dict__: - _module = self.__dict__['module'] + _module = self.__dict__["module"] if name in dir(self): return getattr(self, name) elif name in dir(_module): return getattr(_module, name) else: - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) def checkpoint_serialization_enabled(self): return self._config.checkpoint_config[CHECKPOINT_SERIALIZATION] @@ -774,10 +1013,16 @@ def checkpoint_writer_enabled(self): return self._config.checkpoint_config[CHECKPOINT_WRITER] is not None def checkpoint_tag_validation_enabled(self): - return self._config.checkpoint_config[CHECKPOINT_TAG_VALIDATION] != ValidationMode.IGNORE + return ( + self._config.checkpoint_config[CHECKPOINT_TAG_VALIDATION] + != ValidationMode.IGNORE + ) def checkpoint_tag_validation_fail(self): - return self._config.checkpoint_config[CHECKPOINT_TAG_VALIDATION] == ValidationMode.FAIL + return ( + self._config.checkpoint_config[CHECKPOINT_TAG_VALIDATION] + == ValidationMode.FAIL + ) def elasticity_enabled(self): return self._config.elasticity_enabled @@ -785,7 +1030,11 @@ def elasticity_enabled(self): def is_elastic_model_parallel_supported(self): if self.elasticity_enabled(): # Add code for finding number of GPUs per node automatically - if self._config.num_gpus_per_node % self._config.elastic_model_parallel_size == 0: + if ( + self._config.num_gpus_per_node + % self._config.elastic_model_parallel_size + == 0 + ): return True else: return False @@ -845,13 +1094,17 @@ def data_sampling_config(self): return self._config.data_efficiency_config[DATA_SAMPLING] def curriculum_learning_enabled(self): - return self._config.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][CURRICULUM_LEARNING_ENABLED] + return self._config.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING][ + CURRICULUM_LEARNING_ENABLED + ] def curriculum_learning_config(self): return self._config.data_efficiency_config[DATA_SAMPLING][CURRICULUM_LEARNING] def random_ltd_enabled(self): - return self._config.data_efficiency_config[DATA_ROUTING][RANDOM_LTD][RANDOM_LTD_ENABLED] + return self._config.data_efficiency_config[DATA_ROUTING][RANDOM_LTD][ + RANDOM_LTD_ENABLED + ] def random_ltd_config(self): return self._config.data_efficiency_config[DATA_ROUTING][RANDOM_LTD] @@ -859,23 +1112,33 @@ def random_ltd_config(self): def random_ltd_initialize(self): assert self.random_ltd_enabled() random_ltd_config = self.random_ltd_config() - random_ltd_queue = deque([x for x in sorted(random_ltd_config[RANDOM_LTD_LAYER_ID])]) + random_ltd_queue = deque( + [x for x in sorted(random_ltd_config[RANDOM_LTD_LAYER_ID])] + ) count = 0 for name, layer in self.module.named_modules(): if isinstance(layer, RandomLayerTokenDrop): - if len(random_ltd_queue) != 0 and str(random_ltd_queue[0]) in name: ###[1,2,3] - layer.init_config(random_ltd_config, self.random_ltd_scheduler, count) + if ( + len(random_ltd_queue) != 0 and str(random_ltd_queue[0]) in name + ): ###[1,2,3] + layer.init_config( + random_ltd_config, self.random_ltd_scheduler, count + ) random_ltd_queue.popleft() count += 1 if random_ltd_config[RANDOM_LTD_LAYER_NUM] != count: - raise ValueError(f'random_ltd_layer_num {random_ltd_config[RANDOM_LTD_LAYER_NUM]} must be \ - equivalent to the len of random_ltd_layer_id {count}') + raise ValueError( + f"random_ltd_layer_num {random_ltd_config[RANDOM_LTD_LAYER_NUM]} must be \ + equivalent to the len of random_ltd_layer_id {count}" + ) - if random_ltd_config[RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE][RANDOM_LTD_LAYER_TOKEN_LR_ENABLED]: + if random_ltd_config[RANDOM_LTD_LAYER_TOKEN_LR_SCHEDULE][ + RANDOM_LTD_LAYER_TOKEN_LR_ENABLED + ]: assert self.client_lr_scheduler is None - raise ValueError('not yet support') - #self.lr_scheduler = lr_schedules.WarmupLayerTokenDecayLR(self.optimizer, self.random_ltd_scheduler) + raise ValueError("not yet support") + # self.lr_scheduler = lr_schedules.WarmupLayerTokenDecayLR(self.optimizer, self.random_ltd_scheduler) def get_data_parallel_rank(self): return groups.get_data_parallel_rank() @@ -946,9 +1209,11 @@ def autotuning_metric(self): return self._config.autotuning_config.metric def autotuning_profile_model_info(self): - return self.autotuning_enabled( - ) and self._config.autotuning_config.model_info and self._config.autotuning_config.model_info.get( - "profile", False) + return ( + self.autotuning_enabled() + and self._config.autotuning_config.model_info + and self._config.autotuning_config.model_info.get("profile", False) + ) def sparse_gradients_enabled(self): return self._config.sparse_gradients_enabled @@ -960,7 +1225,11 @@ def train_micro_batch_size_per_gpu(self): return self._config.train_micro_batch_size_per_gpu def optimizer_name(self): - return (self.client_optimizer.__class__.__name__ if self.client_optimizer else self._config.optimizer_name) + return ( + self.client_optimizer.__class__.__name__ + if self.client_optimizer + else self._config.optimizer_name + ) def optimizer_params(self): return self._config.optimizer_params @@ -976,17 +1245,33 @@ def scheduler_params(self): def quantize_training(self): return ( - self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] - [WEIGHT_QUANTIZE_IN_FORWARD_ENABLED], - self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_ENABLED], - self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_GROUPS], - self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS] - [WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE], - self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_CHANGE_RATIO], - self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_TYPE], - self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_ROUNDING], - self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_VERBOSE], - self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][WEIGHT_QUANTIZE_KERNEL], + self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][ + WEIGHT_QUANTIZE_IN_FORWARD_ENABLED + ], + self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][ + WEIGHT_QUANTIZE_ENABLED + ], + self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][ + WEIGHT_QUANTIZE_GROUPS + ], + self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][ + WEIGHT_QUANTIZE_FP16_MIXED_QUANTIZE + ], + self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][ + WEIGHT_QUANTIZE_CHANGE_RATIO + ], + self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][ + WEIGHT_QUANTIZE_TYPE + ], + self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][ + WEIGHT_QUANTIZE_ROUNDING + ], + self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][ + WEIGHT_QUANTIZE_VERBOSE + ], + self._config.compression_config[WEIGHT_QUANTIZATION][SHARED_PARAMETERS][ + WEIGHT_QUANTIZE_KERNEL + ], ) def zero_optimization(self): @@ -1012,22 +1297,32 @@ def zero_offload_param(self): def zero_use_cpu_optimizer(self): if self._config.zero_config.offload_optimizer is not None: - return self._config.zero_config.offload_optimizer.device in [OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme] + return self._config.zero_config.offload_optimizer.device in [ + OffloadDeviceEnum.cpu, + OffloadDeviceEnum.nvme, + ] return False def zero_cpu_offload(self): if self._config.zero_config.offload_optimizer is not None: - return self._config.zero_config.offload_optimizer.device == OffloadDeviceEnum.cpu + return ( + self._config.zero_config.offload_optimizer.device + == OffloadDeviceEnum.cpu + ) return False def zero_partial_offload(self): return getattr(self._config.zero_config.offload_optimizer, "ratio", 1.0) def super_offload(self): - return getattr(self._config.zero_config.offload_optimizer, "super_offload", False) + return getattr( + self._config.zero_config.offload_optimizer, "super_offload", False + ) def cpuadam_cores_perc(self): - return getattr(self._config.zero_config.offload_optimizer, "cpuadam_cores_perc", 0.9) + return getattr( + self._config.zero_config.offload_optimizer, "cpuadam_cores_perc", 0.9 + ) def zero_sub_group_size(self): return self._config.zero_config.sub_group_size @@ -1062,8 +1357,11 @@ def zero_optimization_partition_weights(self): return self.zero_optimization_stage() >= ZeroStageEnum.weights def is_first_weights_partition_group(self): - ret = True if self.mics_shard_size() < 0 \ - and self.zero_optimization_partition_weights() else False + ret = ( + True + if self.mics_shard_size() < 0 and self.zero_optimization_partition_weights() + else False + ) if self.mics_shard_size() > 0 and self.global_rank < self.mics_shard_size(): ret = True return ret @@ -1151,7 +1449,11 @@ def torch_autocast_dtype(self) -> torch.dtype: def torch_autocast_lower_precision_safe_modules(self) -> List[str]: module_names = self._config.torch_autocast_lower_precision_safe_modules - return get_default_autocast_lower_precision_modules() if module_names is None else module_names + return ( + get_default_autocast_lower_precision_modules() + if module_names is None + else module_names + ) def fp16_auto_cast(self): return self._config.float16_config.auto_cast @@ -1263,45 +1565,58 @@ def get_data_types(self): return (model_dtype, grad_accum_dtype) def _optimizer_has_ckpt_event_prologue(self): - return self.optimizer is not None and hasattr(self.optimizer, 'checkpoint_event_prologue') + return self.optimizer is not None and hasattr( + self.optimizer, "checkpoint_event_prologue" + ) def _optimizer_has_ckpt_event_epilogue(self): - return self.optimizer is not None and hasattr(self.optimizer, 'checkpoint_event_epilogue') + return self.optimizer is not None and hasattr( + self.optimizer, "checkpoint_event_epilogue" + ) def _configure_lr_scheduler(self): if self.client_lr_scheduler: if isinstance(self.client_lr_scheduler, Callable): - log_dist('DeepSpeed using client callable to create LR scheduler', ranks=[0]) + log_dist( + "DeepSpeed using client callable to create LR scheduler", ranks=[0] + ) self.lr_scheduler = self.client_lr_scheduler(self.basic_optimizer) else: - log_dist('DeepSpeed using client LR scheduler', ranks=[0]) + log_dist("DeepSpeed using client LR scheduler", ranks=[0]) self.lr_scheduler = self.client_lr_scheduler else: # load lr scheduler from json configuration if lr scheduler is not defined and passed in lr_scheduler = self._scheduler_from_config(self.optimizer) - log_dist(f"DeepSpeed using configured LR scheduler = {self.scheduler_name()}", ranks=[0]) + log_dist( + f"DeepSpeed using configured LR scheduler = {self.scheduler_name()}", + ranks=[0], + ) self.lr_scheduler = lr_scheduler - log_dist(f'DeepSpeed LR Scheduler = {self.lr_scheduler}', ranks=[0]) + log_dist(f"DeepSpeed LR Scheduler = {self.lr_scheduler}", ranks=[0]) def _configure_checkpointing(self): # Enable optimization to parallelize checkpointing of DP state optimize_dp_state = not self.zero_optimization_partition_weights() - self.checkpoint_engine = create_checkpoint_engine(config_params=self._config, - groups=groups, - zero_stage=self.zero_optimization_stage(), - has_moe_layers=self.has_moe_layers, - optimize_dp_state=optimize_dp_state) + self.checkpoint_engine = create_checkpoint_engine( + config_params=self._config, + groups=groups, + zero_stage=self.zero_optimization_stage(), + has_moe_layers=self.has_moe_layers, + optimize_dp_state=optimize_dp_state, + ) dp_rank = groups._get_sequence_data_parallel_rank() rank = self.local_rank if self.use_node_local_storage() else dp_rank # Determine if this data parallel process needs to store the model checkpoint - if self.checkpoint_engine.is_data_parallel_writer(rank) \ - or (self.zero_optimization_partition_weights() and self.is_first_weights_partition_group()): + if self.checkpoint_engine.is_data_parallel_writer(rank) or ( + self.zero_optimization_partition_weights() + and self.is_first_weights_partition_group() + ): self.save_non_zero_checkpoint = True - if hasattr(self.optimizer, 'dp_process_group'): + if hasattr(self.optimizer, "dp_process_group"): param_rank = dist.get_rank(group=self.optimizer.dp_process_group) # Only the first parameter parallel process needs to store the @@ -1314,8 +1629,9 @@ def _scheduler_from_config(self, optimizer): if hasattr(lr_schedules, scheduler_name): scheduler = getattr(lr_schedules, scheduler_name) else: - assert hasattr(torch.optim.lr_scheduler, - scheduler_name), f"DeepSpeed does not recognize LR scheduler {scheduler_name}" + assert hasattr( + torch.optim.lr_scheduler, scheduler_name + ), f"DeepSpeed does not recognize LR scheduler {scheduler_name}" scheduler = getattr(torch.optim.lr_scheduler, scheduler_name) @@ -1326,7 +1642,11 @@ def _scheduler_from_config(self, optimizer): return None def _set_distributed_vars(self, args): - device_rank = args.device_rank if args is not None and hasattr(args, 'device_rank') else self.local_rank + device_rank = ( + args.device_rank + if args is not None and hasattr(args, "device_rank") + else self.local_rank + ) if device_rank >= 0: get_accelerator().set_device(device_rank) self.device = torch.device(get_accelerator().device_name(device_rank)) @@ -1346,24 +1666,31 @@ def _configure_with_arguments(self, args, mpu): if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ: ompi_local_rank = os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK") - local_rank = os.environ.get('LOCAL_RANK', ompi_local_rank) - assert ompi_local_rank == local_rank, f"LOCAL_RANK ({local_rank}) != OMPI_COMM_WORLD_LOCAL_RANK ({ompi_local_rank}), " \ + local_rank = os.environ.get("LOCAL_RANK", ompi_local_rank) + assert ompi_local_rank == local_rank, ( + f"LOCAL_RANK ({local_rank}) != OMPI_COMM_WORLD_LOCAL_RANK ({ompi_local_rank}), " "not sure how to proceed as we're seeing conflicting local rank info." - os.environ['LOCAL_RANK'] = local_rank + ) + os.environ["LOCAL_RANK"] = local_rank - self.local_rank = int(os.environ['LOCAL_RANK']) - if hasattr(args, 'local_rank'): + self.local_rank = int(os.environ["LOCAL_RANK"]) + if hasattr(args, "local_rank"): args.local_rank = self.local_rank # Validate command line arguments def _do_args_sanity_check(self, args): - assert "LOCAL_RANK" in os.environ or "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ, "DeepSpeed requires the LOCAL_RANK environment " \ - "variable, it is set by the deepspeed launcher, deepspeed.init_distributed, or the torch's launcher. If using a " \ + assert ( + "LOCAL_RANK" in os.environ or "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ + ), ( + "DeepSpeed requires the LOCAL_RANK environment " + "variable, it is set by the deepspeed launcher, deepspeed.init_distributed, or the torch's launcher. If using a " "different launcher please ensure LOCAL_RANK is set prior to initializing deepspeed." + ) - if hasattr(args, 'local_rank') and args.local_rank is not None: - assert isinstance(args.local_rank, - int), f"args.local_rank of {args.local_rank} is an unknown type {type(args.local_rank)}" + if hasattr(args, "local_rank") and args.local_rank is not None: + assert isinstance( + args.local_rank, int + ), f"args.local_rank of {args.local_rank} is an unknown type {type(args.local_rank)}" if args.local_rank >= 0: env_local_rank = int(os.environ.get("LOCAL_RANK")) assert ( @@ -1371,7 +1698,10 @@ def _do_args_sanity_check(self, args): ), f"Mismatch in local rank setting, args.local_rank={args.local_rank} but env['LOCAL_RANK']={env_local_rank}." def _is_supported_optimizer(self, optimizer_name): - return (optimizer_name in DEEPSPEED_OPTIMIZERS or getattr(torch.optim, optimizer_name, None) is not None) + return ( + optimizer_name in DEEPSPEED_OPTIMIZERS + or getattr(torch.optim, optimizer_name, None) is not None + ) def _supported_optims(self): FairseqOptimizer = None @@ -1396,22 +1726,33 @@ def _do_sanity_check(self): expected_optim_types = self._supported_optims() expected_optim_types += [type(None), Callable] - assert isinstance(self.client_optimizer, tuple(expected_optim_types)), \ - f'Client Optimizer is of unexpected type {type(self.client_optimizer)}' + assert isinstance( + self.client_optimizer, tuple(expected_optim_types) + ), f"Client Optimizer is of unexpected type {type(self.client_optimizer)}" if not self.client_optimizer: if self.optimizer_name() is not None: assert self._is_supported_optimizer( - self.optimizer_name()), "{} is not a supported DeepSpeed Optimizer".format(self.optimizer_name()) + self.optimizer_name() + ), "{} is not a supported DeepSpeed Optimizer".format( + self.optimizer_name() + ) - if (self.optimizer_name() == LAMB_OPTIMIZER or self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER): - assert (self.dynamic_loss_scale()), "DeepSpeed {} optimizer requires dynamic loss scaling".format( - self.optimizer_name()) + if ( + self.optimizer_name() == LAMB_OPTIMIZER + or self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER + ): + assert ( + self.dynamic_loss_scale() + ), "DeepSpeed {} optimizer requires dynamic loss scaling".format( + self.optimizer_name() + ) # Detect invalid combinations of client optimizer and client scheduler if isinstance(self.client_lr_scheduler, _LRScheduler): - assert isinstance(self.client_optimizer, Optimizer), \ - f'Client Optimizer (type = {type(self.client_optimizer)} is not instantiated but Client LR Scheduler is instantiated' + assert isinstance( + self.client_optimizer, Optimizer + ), f"Client Optimizer (type = {type(self.client_optimizer)} is not instantiated but Client LR Scheduler is instantiated" def _broadcast_model(self): if self.dist_backend is None: @@ -1420,7 +1761,7 @@ def _broadcast_model(self): def is_replicated(p): if hasattr(p, "ds_status") and p.ds_status is not ZeroParamStatus.AVAILABLE: return False - elif hasattr(p, 'ds_optim_param'): + elif hasattr(p, "ds_optim_param"): # do not broadcast OptimizedLinear parameters, they are unique per base weight shard return False return True @@ -1429,12 +1770,18 @@ def is_replicated(p): # Broadcast the model for different parameters if is_moe_param(p): if torch.is_tensor(p) and is_replicated(p): - dist.broadcast(p.data, - groups._get_expert_broadcast_src_rank(p.group_name), - group=self.expert_data_parallel_group[p.group_name]) + dist.broadcast( + p.data, + groups._get_expert_broadcast_src_rank(p.group_name), + group=self.expert_data_parallel_group[p.group_name], + ) else: if torch.is_tensor(p) and is_replicated(p): - dist.broadcast(p.data, groups._get_broadcast_src_rank(), group=self.seq_data_parallel_group) + dist.broadcast( + p.data, + groups._get_broadcast_src_rank(), + group=self.seq_data_parallel_group, + ) @staticmethod def __check_params(model: Module, dtype: torch.dtype) -> None: @@ -1442,16 +1789,19 @@ def __check_params(model: Module, dtype: torch.dtype) -> None: def _set_client_model(self, model): # register client model in _modules so that nn.module methods work correctly - modules = self.__dict__.get('_modules') - modules['module'] = model + modules = self.__dict__.get("_modules") + modules["module"] = model # register module attribute in engine but avoid getattr - self.__dict__['module'] = model + self.__dict__["module"] = model def _configure_distributed_model(self, model): self._set_client_model(model) - apply_zero_leaf_module_config(self.module, getattr(self._config.zero_config, "leaf_module", None)) + apply_zero_leaf_module_config( + self.module, getattr(self._config.zero_config, "leaf_module", None) + ) is_zero_init_model = self.zero_optimization_partition_weights() and any( - [hasattr(param, "ds_id") for param in self.module.parameters()]) + [hasattr(param, "ds_id") for param in self.module.parameters()] + ) if self.fp16_enabled(): if is_zero_init_model: @@ -1491,13 +1841,19 @@ def _configure_distributed_model(self, model): # Set deepspeed parallelism spec. for the model including expert parallelism for _, module in self.module.named_modules(): - if hasattr(module, 'set_deepspeed_parallelism'): - module.set_deepspeed_parallelism(self._config.use_data_before_expert_parallel_) + if hasattr(module, "set_deepspeed_parallelism"): + module.set_deepspeed_parallelism( + self._config.use_data_before_expert_parallel_ + ) # Query the groups module to get information about various parallel groups self.local_all_to_all_group = None if self.zero_quantized_gradients(): - message = "Using LoCo quantized gradients" if self.zeropp_loco_param() else "Using quantized gradients" + message = ( + "Using LoCo quantized gradients" + if self.zeropp_loco_param() + else "Using quantized gradients" + ) log_dist(message, ranks=[0]) self.local_all_to_all_group = groups._get_local_all_to_all_group() self.data_parallel_group = groups._get_data_parallel_group() @@ -1516,8 +1872,11 @@ def _configure_distributed_model(self, model): "rank indexing errors during backward pass when sp_size < world_size. " "Please use the weighted all-reduce workaround shown in the regression test " "(https://github.com/deepspeedai/DeepSpeed/blob/master/tests/unit/sequence_parallelism/test_ulysses.py) " - "or upgrade to PyTorch 2.3+.") - self.communication_data_type = self._config.seq_parallel_communication_data_type + "or upgrade to PyTorch 2.3+." + ) + self.communication_data_type = ( + self._config.seq_parallel_communication_data_type + ) self.seq_parallel_group = groups._get_sequence_parallel_group() if dist.get_rank() == 0: @@ -1540,19 +1899,27 @@ def _check_for_duplicates(self, optimizer): def ids_list(group): return [id(param) for param in group] - occurrence = sum([ - ids_list(group['params']).count(param_id) if param_id in ids_list(group['params']) else 0 - for group in optimizer.param_groups - ]) - assert occurrence <= 1, f"Parameter with name: {name} occurs multiple times in optimizer.param_groups. Make sure it only appears once to prevent undefined behavior." + occurrence = sum( + [ + ( + ids_list(group["params"]).count(param_id) + if param_id in ids_list(group["params"]) + else 0 + ) + for group in optimizer.param_groups + ] + ) + assert ( + occurrence <= 1 + ), f"Parameter with name: {name} occurs multiple times in optimizer.param_groups. Make sure it only appears once to prevent undefined behavior." def _do_optimizer_sanity_check(self, basic_optimizer): model_dtype, grad_accum_dtype = self.get_data_types() zero_enabled = self.zero_optimization() amp_enabled = self.amp_enabled() # config based assertions - assert ( - not (amp_enabled and zero_enabled) + assert not ( + amp_enabled and zero_enabled ), "Amp and ZeRO are not currently compatible, please use (legacy) fp16 mode which performs similar to amp opt_mode=O2" if zero_enabled: if not is_zero_supported_optimizer(basic_optimizer): @@ -1561,22 +1928,33 @@ def _do_optimizer_sanity_check(self, basic_optimizer): ), 'You are using an untested ZeRO Optimizer. Please add <"zero_allow_untested_optimizer": true> in the configuration file to use it.' if self.global_rank == 0: - logger.warning("**** You are using ZeRO with an untested optimizer, proceed with caution *****") - if model_dtype == torch.bfloat16 and grad_accum_dtype == torch.float32 and self.zero_optimization_stage( - ) == 1 and not self.zero_cpu_offload(): + logger.warning( + "**** You are using ZeRO with an untested optimizer, proceed with caution *****" + ) + if ( + model_dtype == torch.bfloat16 + and grad_accum_dtype == torch.float32 + and self.zero_optimization_stage() == 1 + and not self.zero_cpu_offload() + ): return BFLOAT16 return ZERO_OPTIMIZATION elif amp_enabled: if model_dtype != grad_accum_dtype: raise NotImplementedError( - "Model data type and gradient accumulation data type must be equal to use Amp") + "Model data type and gradient accumulation data type must be equal to use Amp" + ) if model_dtype == torch.bfloat16 or model_dtype == torch.float16: - raise NotImplementedError("Cannot enable both amp with (legacy) fp16 or bfloat16 mode") + raise NotImplementedError( + "Cannot enable both amp with (legacy) fp16 or bfloat16 mode" + ) try: logger.info("Initializing Apex amp from: {}".format(amp.__path__)) except NameError: # If apex/amp is available it will be imported above - raise RuntimeError("Unable to import apex/amp, please make sure it is installed") + raise RuntimeError( + "Unable to import apex/amp, please make sure it is installed" + ) return AMP # data type checks elif model_dtype == grad_accum_dtype: @@ -1589,7 +1967,9 @@ def _do_optimizer_sanity_check(self, basic_optimizer): return BFLOAT16 return FP16 if model_dtype == torch.float16 else DDP_BFLOAT16 else: - raise NotImplementedError(f"unsupported mix of {model_dtype=} and {grad_accum_dtype=}") + raise NotImplementedError( + f"unsupported mix of {model_dtype=} and {grad_accum_dtype=}" + ) return None @@ -1599,28 +1979,42 @@ def _configure_optimizer(self, client_optimizer, model_parameters): if self.has_moe_layers: model_parameters = configure_moe_param_groups(model_parameters) basic_optimizer = self._configure_basic_optimizer(model_parameters) - log_dist(f"Using DeepSpeed Optimizer param name {self.optimizer_name()} as basic optimizer", ranks=[0]) + log_dist( + f"Using DeepSpeed Optimizer param name {self.optimizer_name()} as basic optimizer", + ranks=[0], + ) else: if isinstance(client_optimizer, tuple(self._supported_optims())): basic_optimizer = client_optimizer - log_dist('Using client Optimizer as basic optimizer', ranks=[0]) + log_dist("Using client Optimizer as basic optimizer", ranks=[0]) else: basic_optimizer = client_optimizer(model_parameters) - log_dist('Using client callable to create basic optimizer', ranks=[0]) + log_dist("Using client callable to create basic optimizer", ranks=[0]) - if (self.zero_use_cpu_optimizer() and not isinstance(basic_optimizer, deepspeed.ops.adam.DeepSpeedCPUAdam) - and not isinstance(basic_optimizer, deepspeed.ops.lion.DeepSpeedCPULion)): + if ( + self.zero_use_cpu_optimizer() + and not isinstance(basic_optimizer, deepspeed.ops.adam.DeepSpeedCPUAdam) + and not isinstance(basic_optimizer, deepspeed.ops.lion.DeepSpeedCPULion) + ): if self.zero_force_ds_cpu_optimizer(): msg = f'You are using ZeRO-Offload with a client provided optimizer ({type(basic_optimizer)}) which in most cases will yield poor performance. Please either use deepspeed.ops.adam.DeepSpeedCPUAdam or set an optimizer in your ds-config (https://www.deepspeed.ai/docs/config-json/#optimizer-parameters). If you really want to use a custom optimizer w. ZeRO-Offload and understand the performance impacts you can also set <"zero_force_ds_cpu_optimizer": false> in your configuration file.' raise ZeRORuntimeException(msg) - basic_optimizer.param_groups[:] = [pg for pg in basic_optimizer.param_groups if len(pg["params"]) != 0] - log_dist("Removing param_group that has no 'params' in the basic Optimizer", ranks=[0]) + basic_optimizer.param_groups[:] = [ + pg for pg in basic_optimizer.param_groups if len(pg["params"]) != 0 + ] + log_dist( + "Removing param_group that has no 'params' in the basic Optimizer", + ranks=[0], + ) self._check_for_duplicates(basic_optimizer) self.basic_optimizer = basic_optimizer - log_dist(f"DeepSpeed Basic Optimizer = {basic_optimizer.__class__.__name__}", ranks=[0]) + log_dist( + f"DeepSpeed Basic Optimizer = {basic_optimizer.__class__.__name__}", + ranks=[0], + ) optimizer_wrapper = self._do_optimizer_sanity_check(basic_optimizer) @@ -1629,7 +2023,9 @@ def _configure_optimizer(self, client_optimizer, model_parameters): elif optimizer_wrapper == AMP: amp_params = self.amp_params() log_dist(f"Initializing AMP with these params: {amp_params}", ranks=[0]) - model, self.optimizer = amp.initialize(self.module, basic_optimizer, **amp_params) + model, self.optimizer = amp.initialize( + self.module, basic_optimizer, **amp_params + ) self._set_client_model(model) self._broadcast_model() # TODO: maybe need to broadcast experts differently? @@ -1641,7 +2037,10 @@ def _configure_optimizer(self, client_optimizer, model_parameters): else: self.optimizer = basic_optimizer - log_dist("DeepSpeed Final Optimizer = {}".format(self.optimizer.__class__.__name__), ranks=[0]) + log_dist( + "DeepSpeed Final Optimizer = {}".format(self.optimizer.__class__.__name__), + ranks=[0], + ) self.compression_scheduler = self._configure_compression_scheduler() self.quantizer = self._configure_quantization() @@ -1661,23 +2060,34 @@ def _configure_basic_optimizer(self, model_parameters): adam_w_mode = optimizer_parameters.pop(ADAM_W_MODE, ADAM_W_MODE_DEFAULT) # Optimizer name of Adam forces AdamW logic unless adam_w_mode is explicitly set - effective_adam_w_mode = self.optimizer_name() == ADAMW_OPTIMIZER or adam_w_mode + effective_adam_w_mode = ( + self.optimizer_name() == ADAMW_OPTIMIZER or adam_w_mode + ) if torch_adam: if not effective_adam_w_mode: - optimizer = torch.optim.Adam(model_parameters, **optimizer_parameters) + optimizer = torch.optim.Adam( + model_parameters, **optimizer_parameters + ) else: - optimizer = torch.optim.AdamW(model_parameters, **optimizer_parameters) + optimizer = torch.optim.AdamW( + model_parameters, **optimizer_parameters + ) else: if self.zero_use_cpu_optimizer(): from deepspeed.ops.adam import DeepSpeedCPUAdam, ZenFlowCPUAdam + CPUAdam = ZenFlowCPUAdam if self.zenflow else DeepSpeedCPUAdam - zenflow_kwargs = {'overlap_step': self.overlap_step} if self.zenflow else {} - optimizer = CPUAdam(model_parameters, - **optimizer_parameters, - adamw_mode=effective_adam_w_mode, - **zenflow_kwargs) + zenflow_kwargs = ( + {"overlap_step": self.overlap_step} if self.zenflow else {} + ) + optimizer = CPUAdam( + model_parameters, + **optimizer_parameters, + adamw_mode=effective_adam_w_mode, + **zenflow_kwargs, + ) else: from deepspeed.ops.adam import FusedAdam @@ -1690,9 +2100,14 @@ def _configure_basic_optimizer(self, model_parameters): elif self.optimizer_name() == ADAGRAD_OPTIMIZER: if self.zero_use_cpu_optimizer(): from deepspeed.ops.adagrad import DeepSpeedCPUAdagrad - optimizer = DeepSpeedCPUAdagrad(model_parameters, **optimizer_parameters) + + optimizer = DeepSpeedCPUAdagrad( + model_parameters, **optimizer_parameters + ) else: - optimizer = torch.optim.Adagrad(model_parameters, **optimizer_parameters) + optimizer = torch.optim.Adagrad( + model_parameters, **optimizer_parameters + ) elif self.optimizer_name() == LAMB_OPTIMIZER: from deepspeed.ops.lamb import FusedLamb @@ -1703,27 +2118,35 @@ def _configure_basic_optimizer(self, model_parameters): optimizer = OnebitAdam(model_parameters, self, **optimizer_parameters) if not self.fp16_enabled(): - logger.warning("Currently the convergence of 1-bit Adam is only verified under FP16") + logger.warning( + "Currently the convergence of 1-bit Adam is only verified under FP16" + ) elif self.optimizer_name() == ZERO_ONE_ADAM_OPTIMIZER: assert not self.zero_optimization(), "0/1 Adam is not compatible with ZeRO" from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam optimizer = ZeroOneAdam(model_parameters, self, **optimizer_parameters) if not self.fp16_enabled(): - logger.warning('Currently the convergence of 0/1 Adam is only verified under FP16') + logger.warning( + "Currently the convergence of 0/1 Adam is only verified under FP16" + ) elif self.optimizer_name() == ONEBIT_LAMB_OPTIMIZER: assert not self.zero_optimization(), "1bit-Lamb is not compatible with ZeRO" from deepspeed.runtime.fp16.onebit.lamb import OnebitLamb optimizer = OnebitLamb(model_parameters, self, **optimizer_parameters) if not self.fp16_enabled(): - logger.warning("Currently the convergence of 1-bit Lamb is only verified under FP16") + logger.warning( + "Currently the convergence of 1-bit Lamb is only verified under FP16" + ) elif self.optimizer_name() == LION_OPTIMIZER: if self.zero_use_cpu_optimizer(): from deepspeed.ops.lion import DeepSpeedCPULion + optimizer = DeepSpeedCPULion(model_parameters, **optimizer_parameters) else: from deepspeed.ops.lion import FusedLion + optimizer = FusedLion(model_parameters, **optimizer_parameters) elif self.optimizer_name() == MUADAM_OPTIMIZER: try: @@ -1745,31 +2168,41 @@ def _configure_basic_optimizer(self, model_parameters): optimizer = MuSGD(model_parameters, **optimizer_parameters) elif self.optimizer_name() == MUON_OPTIMIZER: zero_stage = self.zero_optimization_stage() - if not all([hasattr(p, 'use_muon') for p in model_parameters]): - msg = "Muon optimizer is used, but the use_muon attribute is NOT configured for some of the model parameters, " \ - "please set by `param.use_muon = True / False` for all params" + if not all([hasattr(p, "use_muon") for p in model_parameters]): + msg = ( + "Muon optimizer is used, but the use_muon attribute is NOT configured for some of the model parameters, " + "please set by `param.use_muon = True / False` for all params" + ) logger.error(msg) - muon_params = [p for p in model_parameters if p.use_muon and p.requires_grad] - non_muon_params = [p for p in model_parameters if (not p.use_muon) and p.requires_grad] + muon_params = [ + p for p in model_parameters if p.use_muon and p.requires_grad + ] + non_muon_params = [ + p for p in model_parameters if (not p.use_muon) and p.requires_grad + ] param_groups = [] if muon_params: accepted_parameters = dict() for key in ["lr", "momentum", "weight_decay", "muon_lr"]: if key in optimizer_parameters: if key == "muon_lr": # muon_lr will override lr - accepted_parameters['lr'] = optimizer_parameters[key] + accepted_parameters["lr"] = optimizer_parameters[key] else: accepted_parameters[key] = optimizer_parameters[key] - param_groups.append(dict(params=muon_params, use_muon=True, **accepted_parameters)) + param_groups.append( + dict(params=muon_params, use_muon=True, **accepted_parameters) + ) if non_muon_params: accepted_parameters = dict() for key in ["lr", "betas", "eps", "weight_decay", "adam_lr"]: if key in optimizer_parameters: if key == "adam_lr": # adam_lr will override lr - accepted_parameters['lr'] = optimizer_parameters[key] + accepted_parameters["lr"] = optimizer_parameters[key] else: accepted_parameters[key] = optimizer_parameters[key] - param_groups.append(dict(params=non_muon_params, use_muon=False, **accepted_parameters)) + param_groups.append( + dict(params=non_muon_params, use_muon=False, **accepted_parameters) + ) optimizer = MuonWithAuxAdam(param_groups) else: torch_optimizer = getattr(torch.optim, self.optimizer_name()) @@ -1795,7 +2228,8 @@ def _configure_quantization(self): use_quantizer_kernel, ) = self.quantize_training() if quantize_enabled and not quantize_weight_in_forward: - assert self.fp16_enabled( + assert ( + self.fp16_enabled() ), "MoQ (quantize in optimization step) weight quantization is only supported for FP16" quantizer = None if quantize_enabled and not quantize_weight_in_forward: @@ -1823,10 +2257,17 @@ def _configure_fp16_optimizer(self, optimizer, low_precision_dtype): else: fused_opts = FusedAdam - use_fused_optimizer = isinstance(optimizer, fused_opts) \ - or self.optimizer_name() in [ONEBIT_ADAM_OPTIMIZER, ZERO_ONE_ADAM_OPTIMIZER] - loss_scale_profile = LossScaleProfile.FUSED if use_fused_optimizer else LossScaleProfile.UNFUSED - initial_dynamic_scale = self.initial_dynamic_scale() if loss_scale_profile == LossScaleProfile.FUSED else None + use_fused_optimizer = isinstance( + optimizer, fused_opts + ) or self.optimizer_name() in [ONEBIT_ADAM_OPTIMIZER, ZERO_ONE_ADAM_OPTIMIZER] + loss_scale_profile = ( + LossScaleProfile.FUSED if use_fused_optimizer else LossScaleProfile.UNFUSED + ) + initial_dynamic_scale = ( + self.initial_dynamic_scale() + if loss_scale_profile == LossScaleProfile.FUSED + else None + ) loss_scale_config = LossScaleConfig( low_precision_dtype=low_precision_dtype, dynamic_loss_scale=self.dynamic_loss_scale(), @@ -1838,9 +2279,12 @@ def _configure_fp16_optimizer(self, optimizer, low_precision_dtype): if use_fused_optimizer: if loss_scale_config.dynamic_loss_scale: - log_dist('Creating fp16 optimizer with dynamic loss scale', ranks=[0]) + log_dist("Creating fp16 optimizer with dynamic loss scale", ranks=[0]) else: - log_dist(f'Creating fp16 optimizer with static loss scale: {loss_scale_config.cur_scale}', ranks=[0]) + log_dist( + f"Creating fp16 optimizer with static loss scale: {loss_scale_config.cur_scale}", + ranks=[0], + ) timers = self.timers if self.wall_clock_breakdown() else NoopTimer() optimizer = FP16_Optimizer( optimizer, @@ -1855,10 +2299,14 @@ def _configure_fp16_optimizer(self, optimizer, low_precision_dtype): ) else: if loss_scale_config.dynamic_loss_scale: - log_dist('Creating fp16 unfused optimizer with dynamic loss scale', ranks=[0]) + log_dist( + "Creating fp16 unfused optimizer with dynamic loss scale", ranks=[0] + ) else: - log_dist(f'Creating fp16 unfused optimizer with static loss scale: {loss_scale_config.cur_scale}', - ranks=[0]) + log_dist( + f"Creating fp16 unfused optimizer with static loss scale: {loss_scale_config.cur_scale}", + ranks=[0], + ) optimizer = FP16_UnfusedOptimizer( optimizer, deepspeed=self, @@ -1877,20 +2325,22 @@ def _configure_bf16_optimizer(self, optimizer): if optimizer is None: optimizer = DummyOptim(list(self.module.parameters())) - log_dist('Creating BF16 optimizer', ranks=[0]) + log_dist("Creating BF16 optimizer", ranks=[0]) timers = self.timers if self.wall_clock_breakdown() else NoopTimer() - optimizer = BF16_Optimizer(optimizer, - self.param_names, - bfloat16_config=self._config.bfloat16_config, - mpu=self.mpu, - clip_grad=clip_grad, - allgather_bucket_size=self.zero_allgather_bucket_size(), - dp_process_group=self.seq_data_parallel_group, - timers=timers, - grad_acc_dtype=self.get_data_types()[1], - graph_harvesting=self.graph_harvesting(), - has_moe_layers=self.has_moe_layers) + optimizer = BF16_Optimizer( + optimizer, + self.param_names, + bfloat16_config=self._config.bfloat16_config, + mpu=self.mpu, + clip_grad=clip_grad, + allgather_bucket_size=self.zero_allgather_bucket_size(), + dp_process_group=self.seq_data_parallel_group, + timers=timers, + grad_acc_dtype=self.get_data_types()[1], + graph_harvesting=self.graph_harvesting(), + has_moe_layers=self.has_moe_layers, + ) return optimizer @@ -1921,16 +2371,25 @@ def _configure_zero_optimizer(self, optimizer): overlap_comm = self.zero_overlap_comm() contiguous_gradients = self.zero_contiguous_gradients() round_robin_gradients = self.zero_round_robin_gradients() - assert not isinstance(optimizer, DummyOptim), "zero stage {} requires an optimizer".format(zero_stage) + assert not isinstance( + optimizer, DummyOptim + ), "zero stage {} requires an optimizer".format(zero_stage) - log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0]) + log_dist( + f"Creating {model_dtype} ZeRO stage {zero_stage} optimizer", ranks=[0] + ) if isinstance(self.module, PipelineModule): if overlap_comm: - logger.warning("Pipeline parallelism does not support overlapped communication, will be disabled.") + logger.warning( + "Pipeline parallelism does not support overlapped communication, will be disabled." + ) overlap_comm = False - Stage1And2ZeroOptimizer = DeepSpeedZeroOptimizer if not self.zenflow else ZenFlowZeroOptimizer.create( - zenflow_config=self.zenflow_config()) + Stage1And2ZeroOptimizer = ( + DeepSpeedZeroOptimizer + if not self.zenflow + else ZenFlowZeroOptimizer.create(zenflow_config=self.zenflow_config()) + ) optimizer = Stage1And2ZeroOptimizer( optimizer, @@ -1946,8 +2405,12 @@ def _configure_zero_optimizer(self, optimizer): use_multi_rank_bucket_allreduce=self.zero_multi_rank_bucket_allreduce(), allgather_bucket_size=self.zero_allgather_bucket_size(), dp_process_group=self.seq_data_parallel_group, - expert_parallel_group=self.expert_parallel_group if self.has_moe_layers else None, - expert_data_parallel_group=self.expert_data_parallel_group if self.has_moe_layers else None, + expert_parallel_group=( + self.expert_parallel_group if self.has_moe_layers else None + ), + expert_data_parallel_group=( + self.expert_data_parallel_group if self.has_moe_layers else None + ), reduce_scatter=self.zero_reduce_scatter(), overlap_comm=overlap_comm, offload_optimizer_config=self.zero_offload_optimizer(), @@ -1966,16 +2429,24 @@ def _configure_zero_optimizer(self, optimizer): gradient_accumulation_dtype=gradient_accumulation_dtype, communication_data_type=self.communication_data_type, elastic_checkpoint=self.zero_elastic_checkpoint(), - check_grad_overflow=check_grad_overflow) + check_grad_overflow=check_grad_overflow, + ) elif zero_stage == ZeroStageEnum.weights: assert not self.has_moe_layers, "MoE not supported with Stage 3" if isinstance(optimizer, DummyOptim): log_dist("Creating ZeRO Offload", ranks=[0]) - zero_param_parallel_group = groups._get_zero_param_intra_parallel_group() - if self.zero_hpz_partition_size() > 1 and zero_param_parallel_group is None: + zero_param_parallel_group = ( + groups._get_zero_param_intra_parallel_group() + ) + if ( + self.zero_hpz_partition_size() > 1 + and zero_param_parallel_group is None + ): self._set_zero_group_parallelism() - zero_param_parallel_group = groups._get_zero_param_intra_parallel_group() + zero_param_parallel_group = ( + groups._get_zero_param_intra_parallel_group() + ) optimizer = DeepSpeedZeRoOffload( self.module, timers=timers, @@ -1996,22 +2467,35 @@ def _configure_zero_optimizer(self, optimizer): ) else: log_dist( - f'Creating fp16 ZeRO stage {zero_stage} optimizer,' - f' MiCS is enabled {mics_shard_size>0},' - f' Hierarchical params gather {self._config.mics_hierarchial_params_gather}', - ranks=[0]) + f"Creating fp16 ZeRO stage {zero_stage} optimizer," + f" MiCS is enabled {mics_shard_size>0}," + f" Hierarchical params gather {self._config.mics_hierarchial_params_gather}", + ranks=[0], + ) if mics_shard_size > 0: return self._return_mics_optimizer(optimizer, timers) if self.zero_allgather_sequential(): - log_dist(f"If zero_allgather_sequential is True, set prefetch_bucket_size to 1", ranks=[0]) + log_dist( + "If zero_allgather_sequential is True, set prefetch_bucket_size to 1", + ranks=[0], + ) self._config.zero_config.prefetch_bucket_size = 1 - log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0]) + log_dist( + f"Creating {model_dtype} ZeRO stage {zero_stage} optimizer", + ranks=[0], + ) from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 - from deepspeed.runtime.superoffload.superoffload_stage3 import SuperOffloadOptimizer_Stage3 - Stage3ZeroOptimizer = DeepSpeedZeroOptimizer_Stage3 if not self.super_offload( - ) else SuperOffloadOptimizer_Stage3 + from deepspeed.runtime.superoffload.superoffload_stage3 import ( + SuperOffloadOptimizer_Stage3, + ) + + Stage3ZeroOptimizer = ( + DeepSpeedZeroOptimizer_Stage3 + if not self.super_offload() + else SuperOffloadOptimizer_Stage3 + ) optimizer = Stage3ZeroOptimizer( self.module, optimizer, @@ -2060,45 +2544,50 @@ def _configure_zero_optimizer(self, optimizer): ) else: - raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage)) + raise NotImplementedError( + "ZeRO stage {} not implemented".format(zero_stage) + ) return optimizer def _return_mics_optimizer(self, basic_optimizer, timers): from deepspeed.runtime.zero.mics import MiCS_Optimizer + model_dtype, gradient_accumulation_dtype = self.get_data_types() - optimizer = MiCS_Optimizer(self.module, - basic_optimizer, - self.param_names, - timers=timers, - ds_config=self.config, - static_loss_scale=self.loss_scale(), - dynamic_loss_scale=self.dynamic_loss_scale(), - dynamic_loss_args=self.dynamic_loss_scale_args(), - clip_grad=self.gradient_clipping(), - contiguous_gradients=self.zero_contiguous_gradients(), - reduce_bucket_size=self.zero_reduce_bucket_size(), - prefetch_bucket_size=self.zero_prefetch_bucket_size(), - max_reuse_distance=self.zero_max_reuse_distance(), - max_live_parameters=self.zero_max_live_parameters(), - param_persistence_threshold=self.zero_param_persistence_threshold(), - model_persistence_threshold=self.zero_model_persistence_threshold(), - dp_process_group=self.seq_data_parallel_group, - reduce_scatter=self.zero_reduce_scatter(), - overlap_comm=self.zero_overlap_comm(), - offload_optimizer_config=self.zero_offload_optimizer(), - offload_param_config=self.zero_offload_param(), - sub_group_size=self.zero_sub_group_size(), - mpu=self.mpu, - postscale_gradients=self.postscale_gradients(), - gradient_predivide_factor=self.gradient_predivide_factor(), - gradient_accumulation_steps=self.gradient_accumulation_steps(), - aio_config=self.aio_config(), - gradient_accumulation_dtype=gradient_accumulation_dtype, - communication_data_type=self.communication_data_type, - fp16_master_weights_and_gradients=self.fp16_master_weights_and_gradients(), - bf16_master_weights_and_gradients=self.bf16_master_weights_and_gradients(), - bf16_optimizer_states=self.bf16_optimizer_states()) + optimizer = MiCS_Optimizer( + self.module, + basic_optimizer, + self.param_names, + timers=timers, + ds_config=self.config, + static_loss_scale=self.loss_scale(), + dynamic_loss_scale=self.dynamic_loss_scale(), + dynamic_loss_args=self.dynamic_loss_scale_args(), + clip_grad=self.gradient_clipping(), + contiguous_gradients=self.zero_contiguous_gradients(), + reduce_bucket_size=self.zero_reduce_bucket_size(), + prefetch_bucket_size=self.zero_prefetch_bucket_size(), + max_reuse_distance=self.zero_max_reuse_distance(), + max_live_parameters=self.zero_max_live_parameters(), + param_persistence_threshold=self.zero_param_persistence_threshold(), + model_persistence_threshold=self.zero_model_persistence_threshold(), + dp_process_group=self.seq_data_parallel_group, + reduce_scatter=self.zero_reduce_scatter(), + overlap_comm=self.zero_overlap_comm(), + offload_optimizer_config=self.zero_offload_optimizer(), + offload_param_config=self.zero_offload_param(), + sub_group_size=self.zero_sub_group_size(), + mpu=self.mpu, + postscale_gradients=self.postscale_gradients(), + gradient_predivide_factor=self.gradient_predivide_factor(), + gradient_accumulation_steps=self.gradient_accumulation_steps(), + aio_config=self.aio_config(), + gradient_accumulation_dtype=gradient_accumulation_dtype, + communication_data_type=self.communication_data_type, + fp16_master_weights_and_gradients=self.fp16_master_weights_and_gradients(), + bf16_master_weights_and_gradients=self.bf16_master_weights_and_gradients(), + bf16_optimizer_states=self.bf16_optimizer_states(), + ) return optimizer def _configure_eigenvalue(self): @@ -2129,7 +2618,9 @@ def is_map_style_dataset(obj): @staticmethod def is_iterable_style_dataset(obj): - return isinstance(obj, torch.utils.data.IterableDataset) # hasattr(obj, "__iter__") should work as well + return isinstance( + obj, torch.utils.data.IterableDataset + ) # hasattr(obj, "__iter__") should work as well def dataloader_drop_last(self): return self._config.dataloader_drop_last @@ -2144,15 +2635,20 @@ def was_step_applied(self) -> bool: """ return self._step_applied - def deepspeed_io(self, - dataset, - batch_size=None, - route=ROUTE_TRAIN, - pin_memory=True, - data_sampler=None, - collate_fn=None, - num_local_io_workers=None): - if not (self.is_map_style_dataset(dataset) or self.is_iterable_style_dataset(dataset)): + def deepspeed_io( + self, + dataset, + batch_size=None, + route=ROUTE_TRAIN, + pin_memory=True, + data_sampler=None, + collate_fn=None, + num_local_io_workers=None, + ): + if not ( + self.is_map_style_dataset(dataset) + or self.is_iterable_style_dataset(dataset) + ): raise ValueError("Training data must be a torch Dataset") if batch_size is None: @@ -2189,20 +2685,24 @@ def deepspeed_io(self, DATA_PARALLEL_GROUP: self.data_parallel_group, GRADIENT_ACCUMULATION_STEPS: self.gradient_accumulation_steps(), GLOBAL_RANK: self.global_rank, - DATA_SAMPLING_NUM_WORKERS: self.data_sampling_config()[DATA_SAMPLING_NUM_WORKERS] + DATA_SAMPLING_NUM_WORKERS: self.data_sampling_config()[ + DATA_SAMPLING_NUM_WORKERS + ], } - return DeepSpeedDataLoader(dataset=dataset, - batch_size=batch_size, - pin_memory=pin_memory, - collate_fn=collate_fn, - local_rank=self.local_rank, - tput_timer=deepspeed_io_timer, - num_local_io_workers=num_local_io_workers, - data_sampler=data_sampler, - data_parallel_world_size=data_parallel_world_size, - data_parallel_rank=data_parallel_rank, - dataloader_drop_last=self.dataloader_drop_last(), - deepspeed_dataloader_config=deepspeed_dataloader_config) + return DeepSpeedDataLoader( + dataset=dataset, + batch_size=batch_size, + pin_memory=pin_memory, + collate_fn=collate_fn, + local_rank=self.local_rank, + tput_timer=deepspeed_io_timer, + num_local_io_workers=num_local_io_workers, + data_sampler=data_sampler, + data_parallel_world_size=data_parallel_world_size, + data_parallel_rank=data_parallel_rank, + dataloader_drop_last=self.dataloader_drop_last(), + deepspeed_dataloader_config=deepspeed_dataloader_config, + ) def train(self, mode=True): r"""""" @@ -2219,7 +2719,11 @@ def eval(self): def _scale_loss_by_gas(self, prescaled_loss, eval_micro_batches=None): # In pipeline evaluation, there is an option to use different micro-bs, which creates different number of # micro batches, thus the training gas, is not valid in this case. need to use the number of eval_micro_batches - scaling_factor = self.gradient_accumulation_steps() if eval_micro_batches is None else eval_micro_batches + scaling_factor = ( + self.gradient_accumulation_steps() + if eval_micro_batches is None + else eval_micro_batches + ) if isinstance(prescaled_loss, torch.Tensor): scaled_loss = prescaled_loss / scaling_factor elif isinstance(prescaled_loss, tuple) or isinstance(prescaled_loss, list): @@ -2232,7 +2736,9 @@ def _scale_loss_by_gas(self, prescaled_loss, eval_micro_batches=None): else: scaled_loss = prescaled_loss if self.warn_unscaled_loss: - logger.warning(f"DeepSpeed unable to scale loss because of type: {type(prescaled_loss)}") + logger.warning( + f"DeepSpeed unable to scale loss because of type: {type(prescaled_loss)}" + ) self.warn_unscaled_loss = False return scaled_loss @@ -2242,7 +2748,9 @@ def _create_module_forward_pre_hook(self): def _module_forward_pre_hook(module, inputs, kwargs): return self._forward_prologue(inputs, kwargs) - return self.module.register_forward_pre_hook(_module_forward_pre_hook, prepend=False, with_kwargs=True) + return self.module.register_forward_pre_hook( + _module_forward_pre_hook, prepend=False, with_kwargs=True + ) def _create_module_forward_post_hook(self): @@ -2257,15 +2765,21 @@ def _forward_prologue(self, inputs, kwargs): if not self.autotuning_profile_model_info(): see_memory_usage("Engine before forward", force=self.memory_breakdown()) - flops_profiler_active = (self.flops_profiler_enabled() - and self.global_steps == self.flops_profiler_profile_step() and self.global_rank == 0) + flops_profiler_active = ( + self.flops_profiler_enabled() + and self.global_steps == self.flops_profiler_profile_step() + and self.global_rank == 0 + ) # used to check quantization happens at step 0! if self.global_steps == 0 and hasattr(self, "compression_scheduler"): self.compression_scheduler.step(step_zero_check=True) if self.quantizer: - tensor_to_quantize = self.optimizer.bit16_groups if self.zero_optimization_stage( - ) == 2 else self.optimizer.fp16_groups + tensor_to_quantize = ( + self.optimizer.bit16_groups + if self.zero_optimization_stage() == 2 + else self.optimizer.fp16_groups + ) if self.compression_scheduler.weight_quantization_enabled: self.quantizer.quantize( tensor_to_quantize, @@ -2287,9 +2801,15 @@ def _forward_prologue(self, inputs, kwargs): # TODO: The above if condition is a HACK since for PipelineEngine # it's difficult to inject argument in forward pass. if self.module.training and self.curriculum_enabled_legacy(): - self.curriculum_scheduler_legacy.update_difficulty(self.global_steps + 1) + self.curriculum_scheduler_legacy.update_difficulty( + self.global_steps + 1 + ) if self.curriculum_params_legacy()["curriculum_type"] == "seqlen": - kwargs.update({"curriculum_seqlen": self.curriculum_scheduler_legacy.get_current_difficulty()}) + kwargs.update( + { + "curriculum_seqlen": self.curriculum_scheduler_legacy.get_current_difficulty() + } + ) return_modified = True if self.module.training and self.random_ltd_enabled(): @@ -2321,8 +2841,11 @@ def _forward_epilogue(self): self._stop_timers(self.engine_timers.forward_timers) - flops_profiler_active = (self.flops_profiler_enabled() - and self.global_steps == self.flops_profiler_profile_step() and self.global_rank == 0) + flops_profiler_active = ( + self.flops_profiler_enabled() + and self.global_steps == self.flops_profiler_profile_step() + and self.global_rank == 0 + ) if flops_profiler_active: self.flops_profiler.stop_profile() @@ -2345,10 +2868,15 @@ def forward(self, *inputs, **kwargs): if self.autotuning_profile_model_info(): ma = get_ma_status() - if self.is_deepcompile_enabled() and not self.is_deepcompile_active() and not self.is_compiled: + if ( + self.is_deepcompile_enabled() + and not self.is_deepcompile_active() + and not self.is_compiled + ): log_dist_once( "DeepCompile is enabled but engine.compile() has not been called; executing without DeepCompile until compile() runs.", - ranks=[0]) + ranks=[0], + ) if self.is_deepcompile_active() and hasattr(self, "launch_compile_passes"): # We can't have this in forward prologue as the compiler compiles hooks including the forward prologue. @@ -2360,14 +2888,18 @@ def forward(self, *inputs, **kwargs): # Register output backward hooks # preprocess_once_fn is called for preprocessing # preprocess_per_tensor_fn scales a tensor for gradient accumulation - register_output_backward_hooks(loss, - preprocess_once_fn=self._backward_prologue, - preprocess_per_tensor_fn=self._backward_prologue_per_tensor) + register_output_backward_hooks( + loss, + preprocess_once_fn=self._backward_prologue, + preprocess_per_tensor_fn=self._backward_prologue_per_tensor, + ) if self.autotuning_profile_model_info(): activation_mem = get_ma_status() - ma self.autotuning_model_info["activation_mem_per_gpu"] = activation_mem - print_json_dist(self.autotuning_model_info, [0], path=self.autotuning_model_info_path()) + print_json_dist( + self.autotuning_model_info, [0], path=self.autotuning_model_info_path() + ) exit() return loss @@ -2383,7 +2915,7 @@ def _cast_inputs_half(self, inputs): for k, v in inputs.items(): new_inputs[k] = self._cast_inputs_half(v) return new_inputs - elif hasattr(inputs, 'half') and inputs.is_floating_point(): + elif hasattr(inputs, "half") and inputs.is_floating_point(): return inputs.half() else: return inputs @@ -2395,11 +2927,11 @@ def print_forward_breakdown(self, fwd_time): salltoall = 0.0 for gate in self.gate_modules: - #logger.info(f"Individual TopK gate time: {gate.gate_time:.2f} ms") + # logger.info(f"Individual TopK gate time: {gate.gate_time:.2f} ms") gate_time += gate.gate_time for l in self.moe_layers: - #logger.info(f"MoE layer; total: {l.time_moe:.2f} ms, first alltoall: {l.time_falltoall:.2f}, second alltoall: {l.time_salltoall:.2f}") + # logger.info(f"MoE layer; total: {l.time_moe:.2f} ms, first alltoall: {l.time_falltoall:.2f}, second alltoall: {l.time_salltoall:.2f}") moe_time += l.time_moe falltoall += l.time_falltoall salltoall += l.time_salltoall @@ -2409,7 +2941,8 @@ def print_forward_breakdown(self, fwd_time): # if deepspeed.comm.get_rank() == 0: log_dist( f"time (ms) | fwd: {fwd_time:.2f} (fwd_moe: {moe_time:.2f}, 1st_a2a: {falltoall:.2f}, 2nd_a2a: {salltoall:.2f}, top_k: {gate_time:.2f})", - ranks=[0]) + ranks=[0], + ) @instrument_w_nvtx def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): @@ -2419,19 +2952,27 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): return # Pass (PP) gas boundary flag to optimizer (required for zero) - self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary() + self.optimizer.is_gradient_accumulation_boundary = ( + self.is_gradient_accumulation_boundary() + ) # ZeRO stage >= 2 communicates during non gradient accumulation boundaries as well if self.zero_optimization_partition_gradients(): self.optimizer.overlapping_partition_gradients_reduce_epilogue() # Communicate only at gradient accumulation boundaries elif self.is_gradient_accumulation_boundary(): - if self.zero_optimization_stage() == ZeroStageEnum.optimizer_states and hasattr( - self.optimizer, 'reduce_gradients'): - self.optimizer.reduce_gradients(pipeline_parallel=self.pipeline_parallelism) + if ( + self.zero_optimization_stage() == ZeroStageEnum.optimizer_states + and hasattr(self.optimizer, "reduce_gradients") + ): + self.optimizer.reduce_gradients( + pipeline_parallel=self.pipeline_parallelism + ) else: grads = None - self.buffered_allreduce_fallback(grads=grads, elements_per_buffer=bucket_size) + self.buffered_allreduce_fallback( + grads=grads, elements_per_buffer=bucket_size + ) elif self.zenflow: self.optimizer.reduce_gradients(pipeline_parallel=self.pipeline_parallelism) @@ -2441,15 +2982,21 @@ def _backward_prologue(self): # When necessary internal APIs are not available, we disable direct calls to tensor.backward() # and limit to engine.backward(loss) only. if not self._support_torch_style_backward and not self._running_engine_backward: - raise RuntimeError("Direct calls to tensor.backward() are not supported in this configuration. " - "This occurs when either: (1) your PyTorch version lacks required internal APIs, " - "or (2) using ZeRO stage 0. " - "Please use engine.backward(loss) instead.") + raise RuntimeError( + "Direct calls to tensor.backward() are not supported in this configuration. " + "This occurs when either: (1) your PyTorch version lacks required internal APIs, " + "or (2) using ZeRO stage 0. " + "Please use engine.backward(loss) instead." + ) see_memory_usage("Engine before backward", force=self.memory_breakdown()) - assert not self.eigenvalue_enabled(), "Eigenvalue is not supported with non-scalar backward" - assert not self.amp_enabled(), "Apex AMP is not supported with non-scalar backward" + assert ( + not self.eigenvalue_enabled() + ), "Eigenvalue is not supported with non-scalar backward" + assert ( + not self.amp_enabled() + ), "Apex AMP is not supported with non-scalar backward" if self.is_deepcompile_active(): deepcompile_backward_prologue(self.is_gradient_accumulation_boundary()) @@ -2463,7 +3010,9 @@ def _backward_prologue(self): self.optimizer.zenflow_state ^= 1 if self.zero_optimization(): - self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary() + self.optimizer.is_gradient_accumulation_boundary = ( + self.is_gradient_accumulation_boundary() + ) self._start_timers(self.engine_timers.backward_inner_timers) @@ -2501,12 +3050,16 @@ def _backward_post_hook(self): if needs_scaler and not self._manual_backward_expected: # User called backward() directly without using engine.scale() or engine.backward() - error_msg = ("Loss scaling is required for this configuration, but backward() was called " - "directly without scaling the loss. Please use one of the following:" - " 1. engine.backward(loss)" - " 2. engine.scale(loss).backward()") + error_msg = ( + "Loss scaling is required for this configuration, but backward() was called " + "directly without scaling the loss. Please use one of the following:" + " 1. engine.backward(loss)" + " 2. engine.scale(loss).backward()" + ) if self.amp_enabled(): - error_msg += " Note: AMP (NVIDIA Apex) only supports engine.backward(loss)." + error_msg += ( + " Note: AMP (NVIDIA Apex) only supports engine.backward(loss)." + ) raise RuntimeError(error_msg) # Clear the flag for next backward @@ -2517,16 +3070,19 @@ def _backward_post_hook(self): @contextmanager def no_sync(self): r""" - Context manager to disable gradient reduction during backward pass. - This context manager has the following effects on other DeepSpeed features: - 1. Incompatible with ZeRO stage 2/3 which rely on reduction for gradient partitioning. - 2. It is illegal to call engine.step() within the context manager. - 3. Tracking of gradient accumulation steps is disabled. + Context manager to disable gradient reduction during backward pass. + This context manager has the following effects on other DeepSpeed features: + 1. Incompatible with ZeRO stage 2/3 which rely on reduction for gradient partitioning. + 2. It is illegal to call engine.step() within the context manager. + 3. Tracking of gradient accumulation steps is disabled. """ - assert not self.zero_optimization_partition_gradients(), \ - f"no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage {self.zero_optimization_stage()}" + assert ( + not self.zero_optimization_partition_gradients() + ), f"no_sync context manager is incompatible with gradient partitioning logic of ZeRO stage {self.zero_optimization_stage()}" - assert not self.inside_no_sync_ctxt, "no_sync context manager reentry is unsupported" + assert ( + not self.inside_no_sync_ctxt + ), "no_sync context manager reentry is unsupported" self.inside_no_sync_ctxt = True try: @@ -2563,16 +3119,20 @@ def scale(self, loss): AssertionError: If loss is not a scalar tensor with grad_fn, or if no optimizer is configured. """ - assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \ - "must provide optimizer during init in order to use scale" - assert maybe_loss_for_backward(loss), \ - "loss must be a scalar tensor with grad_fn. For non-scalar tensors, use tensor.backward(grad)" + assert self.optimizer is not None and not isinstance( + self.optimizer, DummyOptim + ), "must provide optimizer during init in order to use scale" + assert maybe_loss_for_backward( + loss + ), "loss must be a scalar tensor with grad_fn. For non-scalar tensors, use tensor.backward(grad)" # AMP (NVIDIA Apex) uses a context manager that wraps both scaling and backward, # so it cannot be used with manual backward calls if self.amp_enabled(): - raise RuntimeError("engine.scale() is not compatible with AMP (NVIDIA Apex). " - "When using AMP, you must call engine.backward(loss) instead of manual backward.") + raise RuntimeError( + "engine.scale() is not compatible with AMP (NVIDIA Apex). " + "When using AMP, you must call engine.backward(loss) instead of manual backward." + ) # Apply loss scaler based on optimizer type scaled_loss = loss @@ -2596,10 +3156,12 @@ def backward(self, loss, retain_graph=False, scale_wrt_gas=True): scale_wrt_gas: bool, default: true whether to scale gradients and return value by gradient accumulation steps """ - assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \ - "must provide optimizer during init in order to use backward" + assert self.optimizer is not None and not isinstance( + self.optimizer, DummyOptim + ), "must provide optimizer during init in order to use backward" assert maybe_loss_for_backward( - loss), "loss must be a scalar tensor. If you need to pass output gradients, backward() of output tensors" + loss + ), "loss must be a scalar tensor. If you need to pass output gradients, backward() of output tensors" self._running_engine_backward = True # Store scale_wrt_gas so the hook can respect it @@ -2612,7 +3174,9 @@ def backward(self, loss, retain_graph=False, scale_wrt_gas=True): backward_kwargs["retain_graph"] = True # Used only for return value - gas_scaled_loss = loss / self.gradient_accumulation_steps() if scale_wrt_gas else loss + gas_scaled_loss = ( + loss / self.gradient_accumulation_steps() if scale_wrt_gas else loss + ) # TODO: handle these scaling with direct calls to loss.backward() if isinstance(self.optimizer, ZeROOptimizer): @@ -2620,14 +3184,18 @@ def backward(self, loss, retain_graph=False, scale_wrt_gas=True): elif self.torch_autocast_z0_gradscaler: loss = self.torch_autocast_z0_gradscaler.scale(loss) - with compiled_autograd(self._is_compiled_autograd_enabled, self._compile_kwargs): + with compiled_autograd( + self._is_compiled_autograd_enabled, self._compile_kwargs + ): if self.zero_optimization() or not self.amp_enabled(): loss.backward(**backward_kwargs) elif self.amp_enabled(): # AMP requires delaying unscale when inside gradient accumulation boundaries # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations delay_unscale = not self.is_gradient_accumulation_boundary() - with amp.scale_loss(loss, self.optimizer, delay_unscale=delay_unscale) as scaled_loss: + with amp.scale_loss( + loss, self.optimizer, delay_unscale=delay_unscale + ) as scaled_loss: scaled_loss.backward(**backward_kwargs) # backward_epilogue is not called in a hook when self._support_torch_style_backward is False @@ -2687,33 +3255,49 @@ def zero_grad(self): param.grad = None def clip_fp32_gradients(self): - clip_grad_norm_(parameters=self.module.parameters(), max_norm=self.gradient_clipping(), mpu=self.mpu) + clip_grad_norm_( + parameters=self.module.parameters(), + max_norm=self.gradient_clipping(), + mpu=self.mpu, + ) def _take_model_step(self, lr_kwargs, block_eigenvalue={}): if self.gradient_clipping() > 0.0: if self.torch_autocast_z0_gradscaler: # Unscale for gradient clipping self.torch_autocast_z0_gradscaler.unscale_(self.optimizer) - if not (self.fp16_enabled() or self.bfloat16_enabled() or self.amp_enabled() or self.zero_optimization()): + if not ( + self.fp16_enabled() + or self.bfloat16_enabled() + or self.amp_enabled() + or self.zero_optimization() + ): self.clip_fp32_gradients() elif self.amp_enabled(): # AMP's recommended way of doing clipping # https://nvidia.github.io/apex/advanced.html#gradient-clipping master_params = amp.master_params(self.optimizer) - clip_grad_norm_(parameters=master_params, max_norm=self.gradient_clipping(), mpu=self.mpu) + clip_grad_norm_( + parameters=master_params, + max_norm=self.gradient_clipping(), + mpu=self.mpu, + ) if self.torch_autocast_z0_gradscaler: self.torch_autocast_z0_gradscaler.step(self.optimizer) self.torch_autocast_z0_gradscaler.update() else: self.optimizer.step() - if hasattr(self.optimizer, '_global_grad_norm'): + if hasattr(self.optimizer, "_global_grad_norm"): self._global_grad_norm = self.optimizer._global_grad_norm # Quantize the updated parameter if there is no overflow if self.quantizer: - tensor_to_quantize = self.optimizer.bit16_groups if self.zero_optimization_stage( - ) == 2 else self.optimizer.fp16_groups + tensor_to_quantize = ( + self.optimizer.bit16_groups + if self.zero_optimization_stage() == 2 + else self.optimizer.fp16_groups + ) if self.compression_scheduler.weight_quantization_enabled: self.quantizer.quantize( tensor_to_quantize, @@ -2755,7 +3339,10 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}): if self.steps_per_print() is not None: report_progress = self.global_rank == 0 if self.global_rank else True - if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0: + if ( + report_progress + and (self.global_steps + 1) % self.steps_per_print() == 0 + ): self._report_progress(self.global_steps + 1) self.losses = None @@ -2766,20 +3353,25 @@ def step(self, lr_kwargs=None): r"""Execute the weight update step after forward and backward propagation on effective_train_batch. """ - assert not self.inside_no_sync_ctxt, \ - "It is illegal to call Engine.step() inside no_sync context manager" + assert ( + not self.inside_no_sync_ctxt + ), "It is illegal to call Engine.step() inside no_sync context manager" see_memory_usage("Engine before step", force=self.memory_breakdown()) # Check early because self.global_steps is incremented at some point here. # TODO: Delay self.global_steps increment until very end of this function. - flops_profiler_active = self.flops_profiler_enabled( - ) and self.global_steps == self.flops_profiler_profile_step() and self.global_rank == 0 + flops_profiler_active = ( + self.flops_profiler_enabled() + and self.global_steps == self.flops_profiler_profile_step() + and self.global_rank == 0 + ) self._start_timers(self.engine_timers.step_timers) - assert self.optimizer is not None and not isinstance(self.optimizer, DummyOptim), \ - "must provide optimizer during init in order to use step" + assert self.optimizer is not None and not isinstance( + self.optimizer, DummyOptim + ), "must provide optimizer during init in order to use step" report_progress = False @@ -2797,17 +3389,29 @@ def step(self, lr_kwargs=None): if self.checkpoint_engine.is_decoupled(): self._commit_decoupled_checkpoint() - if (self.eigenvalue_enabled() and (self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution() == 0) - and self.quantizer.any_precision_switch()): + if ( + self.eigenvalue_enabled() + and ( + self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution() + == 0 + ) + and self.quantizer.any_precision_switch() + ): log_dist("computing eigenvalue...", ranks=[0]) loss_scale = self._get_optimizer_loss_scale() or 1.0 - self.block_eigenvalue = self.eigenvalue.compute_eigenvalue(self.module, self.device, loss_scale) + self.block_eigenvalue = self.eigenvalue.compute_eigenvalue( + self.module, self.device, loss_scale + ) if self.progressive_layer_drop: self.progressive_layer_drop.update_state(self.global_steps) - if (self.eigenvalue_enabled() and not self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution() - and self.quantizer.any_precision_switch()): + if ( + self.eigenvalue_enabled() + and not self.gas_boundary_ctr + % self.eigenvalue_gas_boundary_resolution() + and self.quantizer.any_precision_switch() + ): self._take_model_step(lr_kwargs, self.block_eigenvalue) else: self._take_model_step(lr_kwargs) @@ -2817,7 +3421,10 @@ def step(self, lr_kwargs=None): if self.zenflow: self._zenflow_step(lr_kwargs) - self.tput_timer.stop(global_step=self.is_gradient_accumulation_boundary(), report_speed=report_progress) + self.tput_timer.stop( + global_step=self.is_gradient_accumulation_boundary(), + report_speed=report_progress, + ) self._stop_timers(self.engine_timers.step_timers) @@ -2825,25 +3432,38 @@ def step(self, lr_kwargs=None): if self.monitor.enabled: if self.is_gradient_accumulation_boundary(): if self.global_rank == 0: - self.summary_events = [("Train/Samples/lr", self.get_lr()[0], self.global_samples)] - - loss_scale = self._get_optimizer_loss_scale() if self.fp16_enabled() else None + self.summary_events = [ + ("Train/Samples/lr", self.get_lr()[0], self.global_samples) + ] + + loss_scale = ( + self._get_optimizer_loss_scale() + if self.fp16_enabled() + else None + ) if loss_scale is not None: - self.summary_events.append(( - "Train/Samples/loss_scale", - loss_scale, - self.global_samples, - )) - - if (self.eigenvalue_enabled() - and not self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution()): + self.summary_events.append( + ( + "Train/Samples/loss_scale", + loss_scale, + self.global_samples, + ) + ) + + if ( + self.eigenvalue_enabled() + and not self.gas_boundary_ctr + % self.eigenvalue_gas_boundary_resolution() + ): ev_values = self.block_eigenvalue.values() for i in range(len(ev_values)): - self.summary_events.append(( - f"Train/Eigenvalues/ModelBlockParam_{i}", - self.ev_values[i][0], - self.global_samples, - )) + self.summary_events.append( + ( + f"Train/Eigenvalues/ModelBlockParam_{i}", + self.ev_values[i][0], + self.global_samples, + ) + ) self.monitor.write_events(self.summary_events) # Check flops profiling @@ -2861,7 +3481,9 @@ def step(self, lr_kwargs=None): ) self.flops_profiler.end_profile() - if self.autotuning_enabled() and self.global_steps == (self.autotuning_end_profile_step() + 1): + if self.autotuning_enabled() and self.global_steps == ( + self.autotuning_end_profile_step() + 1 + ): self._autotuning_exit() if self.wall_clock_breakdown(): @@ -2869,7 +3491,10 @@ def step(self, lr_kwargs=None): self._update_wall_clock_timers() # Log micro timing and reset - self.timers.log(names=self.engine_timers.micro_timers, memory_breakdown=self.memory_breakdown()) + self.timers.log( + names=self.engine_timers.micro_timers, + memory_breakdown=self.memory_breakdown(), + ) if self.wall_clock_breakdown() or self.flops_profiler_enabled(): # Log global timing and reset @@ -2891,9 +3516,11 @@ def _start_timers(self, timer_names): self.timers(name).start() def _stop_timers(self, timer_names): - record = self.is_gradient_accumulation_boundary() and \ - self.flops_profiler_enabled() and \ - (self.global_steps >= self.flops_profiler_profile_step()) + record = ( + self.is_gradient_accumulation_boundary() + and self.flops_profiler_enabled() + and (self.global_steps >= self.flops_profiler_profile_step()) + ) for name in timer_names: self.timers(name).stop(record=record) @@ -2904,31 +3531,37 @@ def _update_wall_clock_timers(self): def get_wall_clock_timers(self): r""" - Return a dict snapshot of the Engine's wall clock timers. + Return a dict snapshot of the Engine's wall clock timers. """ return self.engine_timers_cache def _autotuning_exit(self): if self.global_rank == 0: - msg = self.timers.get_mean([ - FORWARD_GLOBAL_TIMER, - BACKWARD_GLOBAL_TIMER, - STEP_GLOBAL_TIMER, - ], reset=False) + msg = self.timers.get_mean( + [ + FORWARD_GLOBAL_TIMER, + BACKWARD_GLOBAL_TIMER, + STEP_GLOBAL_TIMER, + ], + reset=False, + ) titer = 0.0 titer += msg[FORWARD_GLOBAL_TIMER] if FORWARD_GLOBAL_TIMER in msg else 0 titer += msg[BACKWARD_GLOBAL_TIMER] if BACKWARD_GLOBAL_TIMER in msg else 0 titer += msg[STEP_GLOBAL_TIMER] if STEP_GLOBAL_TIMER in msg else 0 titer *= self.gradient_accumulation_steps() msg["latency"] = titer - msg["FLOPS_per_gpu"] = self.flops * 1_000_000 * self.gradient_accumulation_steps() / titer - msg["throughput"] = self.train_batch_size() * 1_000_000 / \ - msg["latency"] + msg["FLOPS_per_gpu"] = ( + self.flops * 1_000_000 * self.gradient_accumulation_steps() / titer + ) + msg["throughput"] = self.train_batch_size() * 1_000_000 / msg["latency"] print_json_dist(msg, [0], path=self.autotuning_metric_path()) log_dist( f"Wrote metrics to {self.autotuning_metric_path()}, {os.path.abspath(self.autotuning_metric_path())}", - ranks=[0]) + ranks=[0], + ) import atexit + atexit.register(print, "Autotuning: done with running current ds config.") exit() @@ -3002,7 +3635,9 @@ def get_pld_theta(self): def _report_progress(self, step): lr = self.get_lr() mom = self.get_mom() - log_dist(f"step={step}, skipped={self.skipped_steps}, lr={lr}, mom={mom}", ranks=[0]) + log_dist( + f"step={step}, skipped={self.skipped_steps}, lr={lr}, mom={mom}", ranks=[0] + ) def allreduce_bucket(self, bucket, dp_group, dp_world_size=None): tensor = self.flatten(bucket) @@ -3021,12 +3656,17 @@ def allreduce_bucket(self, bucket, dp_group, dp_world_size=None): dist.all_reduce(tensor_to_allreduce, group=dp_group) if self.gradient_average: if self.gradient_predivide_factor() != dp_world_size: - tensor_to_allreduce.mul_(self.gradient_predivide_factor() / dp_world_size) + tensor_to_allreduce.mul_( + self.gradient_predivide_factor() / dp_world_size + ) else: - tensor_to_allreduce.mul_(1. / dp_world_size) + tensor_to_allreduce.mul_(1.0 / dp_world_size) dist.all_reduce(tensor_to_allreduce, group=dp_group) - if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: + if ( + self.communication_data_type != tensor.dtype + and tensor is not tensor_to_allreduce + ): tensor.copy_(tensor_to_allreduce) return tensor @@ -3036,7 +3676,9 @@ def allreduce_and_copy(self, small_bucket, dp_group, dp_world_size=None): for buf, synced in zip(small_bucket, self.unflatten(allreduced, small_bucket)): buf.copy_(synced) - def allreduce_no_retain(self, bucket, dp_group, numel_per_bucket=500000000, dp_world_size=None): + def allreduce_no_retain( + self, bucket, dp_group, numel_per_bucket=500000000, dp_world_size=None + ): small_bucket = [] numel = 0 for tensor in bucket: @@ -3071,7 +3713,9 @@ def _get_gradients_for_reduction(self): # rank is reducing the same size. In some cases it may make # sense in the future to support the ability to average not # w.r.t. world size but with a different value. - param.grad = torch.zeros(param.size(), dtype=param.dtype, device=param.device) + param.grad = torch.zeros( + param.size(), dtype=param.dtype, device=param.device + ) grad_data = param.grad.data if param_name in self.sparse_tensor_module_names or grad_data.is_sparse: @@ -3086,25 +3730,33 @@ def _get_gradients_for_reduction(self): return non_expert_grads, expert_grads def _reduce_non_expert_gradients(self, grads, elements_per_buffer): - split_sparse_tensor_buckets, split_dense_tensor_buckets = split_half_float_double_sparse(grads) + split_sparse_tensor_buckets, split_dense_tensor_buckets = ( + split_half_float_double_sparse(grads) + ) if self.pipeline_parallelism: dp_group = self.mpu.get_data_parallel_group() dp_world_size = dist.get_world_size(dp_group) else: dp_group = groups._get_sequence_data_parallel_group() - dp_world_size = dist.get_world_size(dp_group) / float(self.sequence_parallel_size) + dp_world_size = dist.get_world_size(dp_group) / float( + self.sequence_parallel_size + ) for _, sparse_bucket_tuple in enumerate(split_sparse_tensor_buckets): if sparse_bucket_tuple: bucket_type, sparse_bucket = sparse_bucket_tuple - self.sparse_allreduce_no_retain(sparse_bucket, dp_group=dp_group, dp_world_size=dp_world_size) + self.sparse_allreduce_no_retain( + sparse_bucket, dp_group=dp_group, dp_world_size=dp_world_size + ) for _, dense_bucket_tuple in enumerate(split_dense_tensor_buckets): if dense_bucket_tuple: bucket_type, dense_bucket = dense_bucket_tuple - self.allreduce_no_retain(dense_bucket, - dp_group=dp_group, - numel_per_bucket=elements_per_buffer, - dp_world_size=dp_world_size) + self.allreduce_no_retain( + dense_bucket, + dp_group=dp_group, + numel_per_bucket=elements_per_buffer, + dp_world_size=dp_world_size, + ) def _reduce_expert_gradients(self, expert_grads, elements_per_buffer): # to maintain the gradients value unaffected by ep_size setting, @@ -3112,32 +3764,41 @@ def _reduce_expert_gradients(self, expert_grads, elements_per_buffer): dp_world_size = dist.get_world_size(groups._get_data_parallel_group()) for ep_name, expert_grads_group in expert_grads.items(): ep_dp_group = groups._get_expert_data_parallel_group(ep_name) - split_sparse_tensor_buckets, split_dense_tensor_buckets = split_half_float_double_sparse( - expert_grads_group) + split_sparse_tensor_buckets, split_dense_tensor_buckets = ( + split_half_float_double_sparse(expert_grads_group) + ) for _, sparse_bucket_tuple in enumerate(split_sparse_tensor_buckets): if sparse_bucket_tuple: bucket_type, sparse_bucket = sparse_bucket_tuple - self.sparse_allreduce_no_retain(sparse_bucket, dp_group=ep_dp_group, dp_world_size=dp_world_size) + self.sparse_allreduce_no_retain( + sparse_bucket, dp_group=ep_dp_group, dp_world_size=dp_world_size + ) for _, dense_bucket_tuple in enumerate(split_dense_tensor_buckets): if dense_bucket_tuple: bucket_type, dense_bucket = dense_bucket_tuple # Separate between diff groups - self.allreduce_no_retain(dense_bucket, - dp_group=ep_dp_group, - numel_per_bucket=elements_per_buffer, - dp_world_size=dp_world_size) + self.allreduce_no_retain( + dense_bucket, + dp_group=ep_dp_group, + numel_per_bucket=elements_per_buffer, + dp_world_size=dp_world_size, + ) def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000): if grads is None: if hasattr(self.optimizer, "get_grads_for_reduction"): # This is currently for BF16 optimizer - non_expert_grads, expert_grads = self.optimizer.get_grads_for_reduction() + non_expert_grads, expert_grads = ( + self.optimizer.get_grads_for_reduction() + ) else: non_expert_grads, expert_grads = self._get_gradients_for_reduction() else: - assert not self.has_moe_layers, "attempting to reduce grads in unsupported way w.r.t. MoE" + assert ( + not self.has_moe_layers + ), "attempting to reduce grads in unsupported way w.r.t. MoE" non_expert_grads = grads self._reduce_non_expert_gradients(non_expert_grads, elements_per_buffer) @@ -3146,7 +3807,9 @@ def buffered_allreduce_fallback(self, grads=None, elements_per_buffer=500000000) self._reduce_expert_gradients(expert_grads, elements_per_buffer) def sparse_allreduce_no_retain(self, bucket, dp_group, dp_world_size=None): - allreduced_sparses = self.sparse_allreduce_bucket(bucket, dp_group, dp_world_size) + allreduced_sparses = self.sparse_allreduce_bucket( + bucket, dp_group, dp_world_size + ) # Densify sparse tensor and copy back to original location for tensor in allreduced_sparses: if tensor.is_sparse: @@ -3178,7 +3841,7 @@ def sparse_allreduce(self, sparse, dp_group, dp_world_size=None): if self.gradient_average: values.mul_(self.gradient_predivide_factor() / (dp_world_size)) else: - values.mul_(1. / (dp_world_size)) + values.mul_(1.0 / (dp_world_size)) indices_device_list = self.sparse_all_gather(indices, dp_group) values_device_list = self.sparse_all_gather(values, dp_group) @@ -3197,30 +3860,48 @@ def sparse_all_gather(self, value, dp_group): if value.dim() == 1: if fill_size > 0: value = torch.cat([value, value.new_empty(fill_size)]) - tensor_list = [value.new_empty(max_size) for _ in range(dist.get_world_size(group=dp_group))] + tensor_list = [ + value.new_empty(max_size) + for _ in range(dist.get_world_size(group=dp_group)) + ] else: if fill_size > 0: value = torch.cat([value, value.new_empty(fill_size, value.size()[1])]) tensor_list = [ - value.new_empty(max_size, - value.size()[1]) for _ in range(dist.get_world_size(group=dp_group)) + value.new_empty(max_size, value.size()[1]) + for _ in range(dist.get_world_size(group=dp_group)) ] dist.all_gather(tensor_list, value, group=dp_group) tensors = [] for dev_idx, t in enumerate(tensor_list): size = all_sizes[dev_idx][0] - tensors.append(t.index_select(0, torch.arange(size, dtype=torch.long, device=self.device))) + tensors.append( + t.index_select( + 0, torch.arange(size, dtype=torch.long, device=self.device) + ) + ) return tensors def all_gather_scalar(self, value, dp_group): - tensor_list = [value.new_zeros(value.size()) for _ in range(dist.get_world_size(group=dp_group))] + tensor_list = [ + value.new_zeros(value.size()) + for _ in range(dist.get_world_size(group=dp_group)) + ] dist.all_gather(tensor_list, value, group=dp_group) return tensor_list - def module_state_dict(self, destination=None, prefix="", keep_vars=False, exclude_frozen_parameters=False): - sd = self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) + def module_state_dict( + self, + destination=None, + prefix="", + keep_vars=False, + exclude_frozen_parameters=False, + ): + sd = self.module.state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) # Remove frozen parameter weights from state_dict if specified if exclude_frozen_parameters: @@ -3233,19 +3914,26 @@ def module_state_dict(self, destination=None, prefix="", keep_vars=False, exclud return sd @staticmethod - def load_moe_state_dict(checkpoint_path, - tag, - state_dict, - old_moe_load, - model=None, - mpu=None, - num_experts=1, - checkpoint_engine=TorchCheckpointEngine()): + def load_moe_state_dict( + checkpoint_path, + tag, + state_dict, + old_moe_load, + model=None, + mpu=None, + num_experts=1, + checkpoint_engine=TorchCheckpointEngine(), + ): if old_moe_load: - expp_rank = groups._get_expert_data_parallel_rank(groups._get_max_expert_size_name()) + expp_rank = groups._get_expert_data_parallel_rank( + groups._get_max_expert_size_name() + ) - num_local_experts = max(num_experts) // groups._get_expert_parallel_world_size( - groups._get_max_expert_size_name()) + num_local_experts = max( + num_experts + ) // groups._get_expert_parallel_world_size( + groups._get_max_expert_size_name() + ) for local_expert_id in range(num_local_experts): global_expert_id = expp_rank * num_local_experts + local_expert_id expert_state_dict = checkpoint_engine.load( @@ -3254,14 +3942,18 @@ def load_moe_state_dict(checkpoint_path, -1, # -1 means ignore layer_id global_expert_id, tag, - mpu), - map_location=torch.device('cpu')) + mpu, + ), + map_location=torch.device("cpu"), + ) # Updating global -> local expert ids - moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.' + moe_str_prefix = ".deepspeed_moe.experts.deepspeed_experts." for key in list(expert_state_dict.keys()): - local_key = key.replace(f'{moe_str_prefix}{global_expert_id}', - f'{moe_str_prefix}{local_expert_id}') + local_key = key.replace( + f"{moe_str_prefix}{global_expert_id}", + f"{moe_str_prefix}{local_expert_id}", + ) expert_state_dict[local_key] = expert_state_dict.pop(key) state_dict.update(expert_state_dict) @@ -3274,37 +3966,49 @@ def load_moe_state_dict(checkpoint_path, expp_rank = groups._get_expert_parallel_rank(group_name) # loop all local_experts for local_expert_id in range(num_local_experts): - global_expert_id = expp_rank * num_local_experts + local_expert_id - expert_state_dict = checkpoint_engine.load(DeepSpeedEngine._get_expert_ckpt_name( - checkpoint_path, moe_layer_id, global_expert_id, tag, mpu), - map_location=torch.device('cpu')) + global_expert_id = ( + expp_rank * num_local_experts + local_expert_id + ) + expert_state_dict = checkpoint_engine.load( + DeepSpeedEngine._get_expert_ckpt_name( + checkpoint_path, + moe_layer_id, + global_expert_id, + tag, + mpu, + ), + map_location=torch.device("cpu"), + ) # print(expert_state_dict.keys()) # Updating global -> local expert ids - moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.' + moe_str_prefix = ".deepspeed_moe.experts.deepspeed_experts." for key in list(expert_state_dict.keys()): - local_key = key.replace(f'{moe_str_prefix}{global_expert_id}', - f'{moe_str_prefix}{local_expert_id}') + local_key = key.replace( + f"{moe_str_prefix}{global_expert_id}", + f"{moe_str_prefix}{local_expert_id}", + ) expert_state_dict[local_key] = expert_state_dict.pop(key) state_dict.update(expert_state_dict) moe_layer_id += 1 - def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None, fetch_z3_params=False): + def load_module_state_dict( + self, checkpoint, strict=True, custom_load_fn=None, fetch_z3_params=False + ): if fetch_z3_params: params_to_fetch = [ - p for p in self.module.parameters() - if hasattr(p, 'ds_id') and p.ds_status == ZeroParamStatus.NOT_AVAILABLE + p + for p in self.module.parameters() + if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE ] else: params_to_fetch = [] with deepspeed.zero.GatheredParameters(params_to_fetch, modifier_rank=0): - module_state_dict = checkpoint['module'] + module_state_dict = checkpoint["module"] if custom_load_fn: custom_load_fn(src=module_state_dict, dst=self.module) else: - self.module.load_state_dict( - module_state_dict, # TODO - strict=strict) + self.module.load_state_dict(module_state_dict, strict=strict) # TODO if checkpoint.get(FROZEN_PARAM_FRAGMENTS, None) is not None: saved_frozen_params = checkpoint[FROZEN_PARAM_FRAGMENTS] @@ -3314,7 +4018,7 @@ def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None, f if param not in self.param_names: raise ValueError(f"failed to find frozen {param} in named params") name = self.param_names[param] - if hasattr(param, 'ds_id'): + if hasattr(param, "ds_id"): param.ds_tensor.data.copy_(saved_frozen_params[name].data) else: param.data.copy_(saved_frozen_params[name].data) @@ -3322,7 +4026,9 @@ def load_module_state_dict(self, checkpoint, strict=True, custom_load_fn=None, f def _get_zero_ckpt_prefix(self, dp_rank, bf16_mode): return f'{"bf16_" if bf16_mode else ""}zero_pp_rank_{dp_rank}' - def _get_rank_zero_ckpt_name(self, checkpoints_path, tag, mp_rank, dp_rank, bf16_mode): + def _get_rank_zero_ckpt_name( + self, checkpoints_path, tag, mp_rank, dp_rank, bf16_mode + ): file_prefix = self._get_zero_ckpt_prefix(dp_rank, bf16_mode=bf16_mode) zero_ckpt_name = os.path.join( checkpoints_path, @@ -3335,9 +4041,13 @@ def _get_zero_ckpt_name(self, checkpoints_path, tag): mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() pp_rank = dist.get_rank(group=self.optimizer.dp_process_group) bf16_mode = self.bfloat16_enabled() - return self._get_rank_zero_ckpt_name(checkpoints_path, tag, mp_rank, pp_rank, bf16_mode) + return self._get_rank_zero_ckpt_name( + checkpoints_path, tag, mp_rank, pp_rank, bf16_mode + ) - def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None, pp_placeholder=None): + def _get_ckpt_name( + self, checkpoints_path, tag, mp_placeholder=None, pp_placeholder=None + ): if mp_placeholder is not None: mp_rank_str = mp_placeholder else: @@ -3366,8 +4076,11 @@ def _get_ckpt_name(self, checkpoints_path, tag, mp_placeholder=None, pp_placehol def _get_optimizer_ckpt_name(self, checkpoints_path, tag, expp_rank): mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() - ckpt_name = os.path.join(checkpoints_path, str(tag), - f'expp_rank_{expp_rank}_mp_rank_{mp_rank:02d}_optim_states.pt') + ckpt_name = os.path.join( + checkpoints_path, + str(tag), + f"expp_rank_{expp_rank}_mp_rank_{mp_rank:02d}_optim_states.pt", + ) return ckpt_name @staticmethod @@ -3375,34 +4088,44 @@ def _get_expert_ckpt_name(checkpoints_path, layer_id, expert_id, tag, mpu=None): mp_rank = 0 if mpu is None else mpu.get_model_parallel_rank() if layer_id <= -1: # Used to support old checkpoint loading - ckpt_name = os.path.join(checkpoints_path, '' if tag is None else str(tag), - f'expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt') + ckpt_name = os.path.join( + checkpoints_path, + "" if tag is None else str(tag), + f"expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt", + ) else: # Used to support new checkpoint loading - ckpt_name = os.path.join(checkpoints_path, '' if tag is None else str(tag), - f'layer_{layer_id}_expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt') + ckpt_name = os.path.join( + checkpoints_path, + "" if tag is None else str(tag), + f"layer_{layer_id}_expert_{expert_id}_mp_rank_{mp_rank:02d}_model_states.pt", + ) return ckpt_name def _get_all_ckpt_names(self, checkpoints_path, tag): # It is required that (checkpoints_path, tag) are consistent among all ranks. - ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, - tag, - mp_placeholder="*", - pp_placeholder="0" if self.load_universal_checkpoint() else None) + ckpt_file_pattern = self._get_ckpt_name( + checkpoints_path, + tag, + mp_placeholder="*", + pp_placeholder="0" if self.load_universal_checkpoint() else None, + ) import glob ckpt_files = glob.glob(ckpt_file_pattern) ckpt_files.sort() return ckpt_files - def load_checkpoint(self, - load_dir, - tag=None, - load_module_strict=True, - load_optimizer_states=True, - load_lr_scheduler_states=True, - load_module_only=False, - custom_load_fn=None): + def load_checkpoint( + self, + load_dir, + tag=None, + load_module_strict=True, + load_optimizer_states=True, + load_lr_scheduler_states=True, + load_module_only=False, + custom_load_fn=None, + ): """ Load training checkpoint @@ -3428,14 +4151,18 @@ def load_checkpoint(self, """ if tag is None: - latest_tag = "latest_universal" if self.load_universal_checkpoint() else "latest" + latest_tag = ( + "latest_universal" if self.load_universal_checkpoint() else "latest" + ) latest_path = os.path.join(load_dir, latest_tag) if os.path.isfile(latest_path): with open(latest_path, "r") as fd: tag = fd.read().strip() else: if self.load_universal_checkpoint(): - raise ValueError(f'Invalid for universal checkpoint: {latest_path} does not exist') + raise ValueError( + f"Invalid for universal checkpoint: {latest_path} does not exist" + ) else: logger.warning( f"Unable to find latest file at {latest_path}, if trying to load latest " @@ -3447,18 +4174,24 @@ def load_checkpoint(self, # Prepare for checkpoint load by ensuring all parameters are partitioned self.optimizer.checkpoint_event_prologue() - load_path, client_states = self._load_checkpoint(load_dir, - tag, - load_module_strict=load_module_strict, - load_optimizer_states=load_optimizer_states, - load_lr_scheduler_states=load_lr_scheduler_states, - load_module_only=load_module_only, - custom_load_fn=custom_load_fn) + load_path, client_states = self._load_checkpoint( + load_dir, + tag, + load_module_strict=load_module_strict, + load_optimizer_states=load_optimizer_states, + load_lr_scheduler_states=load_lr_scheduler_states, + load_module_only=load_module_only, + custom_load_fn=custom_load_fn, + ) load_zero_checkpoint = load_path is not None and self.zero_optimization() if load_zero_checkpoint and not self.zero_nvme_offload_optimizer(): - if (load_optimizer_states and not load_module_only) or self.load_universal_checkpoint(): - success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states) + if ( + load_optimizer_states and not load_module_only + ) or self.load_universal_checkpoint(): + success = self._load_zero_checkpoint( + load_dir, tag, load_optimizer_states=load_optimizer_states + ) else: success = False if not success: @@ -3466,52 +4199,68 @@ def load_checkpoint(self, if self.zero_nvme_offload_optimizer(): from shutil import copytree, disk_usage - rank = self.local_rank if self.use_node_local_storage() else self.global_rank + + rank = ( + self.local_rank if self.use_node_local_storage() else self.global_rank + ) rank_dir = "rank" + dp_index_to_str(rank) offload_dir = self.optimizer.optimizer_swapper.swap_folder - offload_ckpt_dir = os.path.join(load_dir, tag, "offloaded_tensors", rank_dir) + offload_ckpt_dir = os.path.join( + load_dir, tag, "offloaded_tensors", rank_dir + ) _, _, free = disk_usage(offload_dir) logger.info( f"Copying NVMe offload checkpoint from {offload_ckpt_dir} to {offload_dir}, {free / 1e9:,.2f} GB free on target filesystem..." ) copytree(offload_ckpt_dir, offload_dir, dirs_exist_ok=True) _, _, free = disk_usage(offload_dir) - logger.info(f"Copying complete! {free / 1e9:,.2f} GB free on target filesystem") + logger.info( + f"Copying complete! {free / 1e9:,.2f} GB free on target filesystem" + ) self.optimizer.reset_swap_buffers() if self._optimizer_has_ckpt_event_epilogue(): self.optimizer.checkpoint_event_epilogue() - if self.load_universal_checkpoint() and not self.zero_optimization_partition_weights(): + if ( + self.load_universal_checkpoint() + and not self.zero_optimization_partition_weights() + ): self.optimizer.update_lp_params() return load_path, client_states - def _load_checkpoint(self, - load_dir, - tag, - load_module_strict=True, - load_optimizer_states=True, - load_lr_scheduler_states=True, - load_module_only=False, - custom_load_fn=None): + def _load_checkpoint( + self, + load_dir, + tag, + load_module_strict=True, + load_optimizer_states=True, + load_lr_scheduler_states=True, + load_module_only=False, + custom_load_fn=None, + ): from deepspeed.runtime.state_dict_factory import SDLoaderFactory ckpt_list = self._get_all_ckpt_names(load_dir, tag) - sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, checkpoint_engine=self.checkpoint_engine) + sd_loader = SDLoaderFactory.get_sd_loader( + ckpt_list, checkpoint_engine=self.checkpoint_engine + ) is_pipe_parallel = isinstance(self.module, PipelineModule) mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() - load_path, checkpoint, _ = sd_loader.load(self.mp_world_size, mp_rank, is_pipe_parallel=is_pipe_parallel) + load_path, checkpoint, _ = sd_loader.load( + self.mp_world_size, mp_rank, is_pipe_parallel=is_pipe_parallel + ) if checkpoint is None: return None, None fetch_z3_params = False if self.zero_optimization_partition_weights() and not load_optimizer_states: - checkpoint['module'] = get_fp32_state_dict_from_zero_checkpoint(load_dir) + checkpoint["module"] = get_fp32_state_dict_from_zero_checkpoint(load_dir) fetch_z3_params = True if is_pipe_parallel: @@ -3521,59 +4270,86 @@ def _load_checkpoint(self, if self.has_moe_layers: # print(checkpoint.keys()) old_moe_load = False - if not isinstance(checkpoint['num_experts'], list): + if not isinstance(checkpoint["num_experts"], list): old_moe_load = True - DeepSpeedEngine.load_moe_state_dict(load_dir, - tag, - state_dict=checkpoint['module'], - old_moe_load=old_moe_load, - model=self.module, - mpu=self.mpu, - num_experts=self.num_experts, - checkpoint_engine=self.checkpoint_engine) + DeepSpeedEngine.load_moe_state_dict( + load_dir, + tag, + state_dict=checkpoint["module"], + old_moe_load=old_moe_load, + model=self.module, + mpu=self.mpu, + num_experts=self.num_experts, + checkpoint_engine=self.checkpoint_engine, + ) if not self.load_universal_checkpoint(): - self.load_module_state_dict(checkpoint=checkpoint, - strict=load_module_strict, - custom_load_fn=custom_load_fn, - fetch_z3_params=fetch_z3_params) + self.load_module_state_dict( + checkpoint=checkpoint, + strict=load_module_strict, + custom_load_fn=custom_load_fn, + fetch_z3_params=fetch_z3_params, + ) - self.loaded_checkpoint_dp_world_size = checkpoint['dp_world_size'] + self.loaded_checkpoint_dp_world_size = checkpoint["dp_world_size"] optim_checkpoint = None if load_module_only: - deepspeed_states = ['module'] - if self.optimizer is not None and hasattr(self.optimizer, 'refresh_fp32_params'): + deepspeed_states = ["module"] + if self.optimizer is not None and hasattr( + self.optimizer, "refresh_fp32_params" + ): self.optimizer.refresh_fp32_params() else: has_zero_optimizer_state = self.zero_optimization() - if load_optimizer_states and self.optimizer is not None and not has_zero_optimizer_state: + if ( + load_optimizer_states + and self.optimizer is not None + and not has_zero_optimizer_state + ): if self.has_moe_layers: largest_group_name = groups._get_max_expert_size_name() expp_rank = groups._get_expert_parallel_rank(largest_group_name) - optim_load_path = self._get_optimizer_ckpt_name(load_dir, tag, expp_rank) - optim_checkpoint = self.checkpoint_engine.load(optim_load_path, map_location=torch.device('cpu')) + optim_load_path = self._get_optimizer_ckpt_name( + load_dir, tag, expp_rank + ) + optim_checkpoint = self.checkpoint_engine.load( + optim_load_path, map_location=torch.device("cpu") + ) else: optim_checkpoint = checkpoint if self.fp16_enabled() or self.bfloat16_enabled(): - self.optimizer.load_state_dict(optim_checkpoint['optimizer'], - load_optimizer_states=load_optimizer_states) + self.optimizer.load_state_dict( + optim_checkpoint["optimizer"], + load_optimizer_states=load_optimizer_states, + ) else: optim_checkpoint = checkpoint - self.optimizer.load_state_dict(optim_checkpoint['optimizer']) + self.optimizer.load_state_dict(optim_checkpoint["optimizer"]) if load_lr_scheduler_states and self.lr_scheduler is not None: - self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) - - if self.random_ltd_enabled() and self.random_ltd_scheduler is not None and 'random_ltd' in checkpoint: - self.random_ltd_scheduler.load_state_dict(checkpoint['random_ltd']) - - if self.training_dataloader is not None and self.curriculum_learning_enabled( - ) and 'data_sampler' in checkpoint: - self.training_dataloader.data_sampler.load_state_dict(checkpoint['data_sampler']) + self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + + if ( + self.random_ltd_enabled() + and self.random_ltd_scheduler is not None + and "random_ltd" in checkpoint + ): + self.random_ltd_scheduler.load_state_dict(checkpoint["random_ltd"]) + + if ( + self.training_dataloader is not None + and self.curriculum_learning_enabled() + and "data_sampler" in checkpoint + ): + self.training_dataloader.data_sampler.load_state_dict( + checkpoint["data_sampler"] + ) - def get_sparse_tensor_module_names(original_set, loaded_set, original_parameters, loaded_parameters): + def get_sparse_tensor_module_names( + original_set, loaded_set, original_parameters, loaded_parameters + ): result = set() for name in original_set: @@ -3583,14 +4359,16 @@ def get_sparse_tensor_module_names(original_set, loaded_set, original_parameters for name in loaded_set: if name in original_parameters: - result.add(name) # parameter exists in both configs and it was sparse + result.add( + name + ) # parameter exists in both configs and it was sparse return result - if 'sparse_tensor_module_names' in checkpoint: - sparse_tensor_module_names = checkpoint['sparse_tensor_module_names'] - elif 'csr_tensor_module_names' in checkpoint: - sparse_tensor_module_names = checkpoint['csr_tensor_module_names'] + if "sparse_tensor_module_names" in checkpoint: + sparse_tensor_module_names = checkpoint["sparse_tensor_module_names"] + elif "csr_tensor_module_names" in checkpoint: + sparse_tensor_module_names = checkpoint["csr_tensor_module_names"] else: sparse_tensor_module_names = None if sparse_tensor_module_names is not None: @@ -3598,28 +4376,43 @@ def get_sparse_tensor_module_names(original_set, loaded_set, original_parameters self.sparse_tensor_module_names = sparse_tensor_module_names else: self.sparse_tensor_module_names = get_sparse_tensor_module_names( - self.sparse_tensor_module_names, sparse_tensor_module_names, - dict(self.module.named_parameters()), checkpoint["module"]) + self.sparse_tensor_module_names, + sparse_tensor_module_names, + dict(self.module.named_parameters()), + checkpoint["module"], + ) - self.global_steps = checkpoint['global_steps'] - self.global_samples = checkpoint.get('global_samples', self.global_steps * self.train_batch_size()) - self.skipped_steps = checkpoint['skipped_steps'] - self.loaded_checkpoint_mp_world_size = checkpoint['mp_world_size'] + self.global_steps = checkpoint["global_steps"] + self.global_samples = checkpoint.get( + "global_samples", self.global_steps * self.train_batch_size() + ) + self.skipped_steps = checkpoint["skipped_steps"] + self.loaded_checkpoint_mp_world_size = checkpoint["mp_world_size"] deepspeed_states = [ - 'module', 'sparse_tensor_module_names', 'skipped_steps', 'global_steps', 'dp_world_size', - 'mp_world_size', 'data_sampler', 'random_ltd' + "module", + "sparse_tensor_module_names", + "skipped_steps", + "global_steps", + "dp_world_size", + "mp_world_size", + "data_sampler", + "random_ltd", ] client_state = {} if load_lr_scheduler_states: - deepspeed_states.append('lr_scheduler') + deepspeed_states.append("lr_scheduler") if load_optimizer_states: - deepspeed_states.append('optimizer') + deepspeed_states.append("optimizer") - client_state = {key: value for key, value in checkpoint.items() if key not in deepspeed_states} + client_state = { + key: value + for key, value in checkpoint.items() + if key not in deepspeed_states + } if optim_checkpoint is not None: - client_state['optimizer'] = optim_checkpoint['optimizer'] + client_state["optimizer"] = optim_checkpoint["optimizer"] return load_path, client_state @@ -3629,64 +4422,84 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True): # When use loading checkpoint serial, checkpoint loading start from local rank 0, # all other local rank would be paused, waiting for its rank-1 peer ready and its notification. if self._config.zero_config.pipeline_loading_checkpoint: - assert self.zero_optimization_stage( - ) == ZeroStageEnum.weights, "Only stage3 support for pipeline checkpoint loading" + assert ( + self.zero_optimization_stage() == ZeroStageEnum.weights + ), "Only stage3 support for pipeline checkpoint loading" load_serial = torch.zeros(1).to(self.device) if dist.get_local_rank() != 0: dist.recv(tensor=load_serial, src=dist.get_rank() - 1) if self.load_universal_checkpoint(): zero_sd_list = None - checkpoint_folder = f'{os.path.join(load_dir, tag)}' + checkpoint_folder = f"{os.path.join(load_dir, tag)}" else: - if load_optimizer_states and self.seq_dp_world_size != self.loaded_checkpoint_dp_world_size: - raise ZeRORuntimeException("The checkpoint being loaded used a DP " \ - f"world size of {self.loaded_checkpoint_dp_world_size} but the " \ - f"current world size is {self.seq_dp_world_size}. Automatic adjustment " \ - "of ZeRO's optimizer state partitioning with a new world size is not " \ - "currently supported.") + if ( + load_optimizer_states + and self.seq_dp_world_size != self.loaded_checkpoint_dp_world_size + ): + raise ZeRORuntimeException( + "The checkpoint being loaded used a DP " + f"world size of {self.loaded_checkpoint_dp_world_size} but the " + f"current world size is {self.seq_dp_world_size}. Automatic adjustment " + "of ZeRO's optimizer state partitioning with a new world size is not " + "currently supported." + ) checkpoint_folder = None zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag) if zero_sd_list is None: return False param_shapes = self._get_zero_param_shapes() - self.optimizer.load_state_dict(state_dict_list=zero_sd_list, - load_optimizer_states=load_optimizer_states, - load_from_fp32_weights=self.zero_load_from_fp32_weights(), - checkpoint_folder=checkpoint_folder, - load_serial=load_serial, - param_shapes=param_shapes) + self.optimizer.load_state_dict( + state_dict_list=zero_sd_list, + load_optimizer_states=load_optimizer_states, + load_from_fp32_weights=self.zero_load_from_fp32_weights(), + checkpoint_folder=checkpoint_folder, + load_serial=load_serial, + param_shapes=param_shapes, + ) if self.load_universal_checkpoint(): - logger.info(f'loaded universal zero checkpoints from {checkpoint_folder} for rank {self.global_rank}') + logger.info( + f"loaded universal zero checkpoints from {checkpoint_folder} for rank {self.global_rank}" + ) else: - logger.info(f"loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}") + logger.info( + f"loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}" + ) return True - def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size, bf16_mode): + def _get_mp_rank_zero_checkpoint_names( + self, load_dir, tag, mp_rank, dp_world_size, bf16_mode + ): zero_ckpt_names = [] for dp_rank in range(dp_world_size): - ckpt_name = self._get_rank_zero_ckpt_name(checkpoints_path=load_dir, - tag=tag, - mp_rank=mp_rank, - dp_rank=dp_rank, - bf16_mode=bf16_mode) + ckpt_name = self._get_rank_zero_ckpt_name( + checkpoints_path=load_dir, + tag=tag, + mp_rank=mp_rank, + dp_rank=dp_rank, + bf16_mode=bf16_mode, + ) zero_ckpt_names.append(ckpt_name) return zero_ckpt_names def _get_all_zero_checkpoint_names(self, load_dir, tag, bf16_mode): mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() - zero_ckpt_names = self._get_mp_rank_zero_checkpoint_names(load_dir=load_dir, - tag=tag, - mp_rank=mp_rank, - dp_world_size=self.loaded_checkpoint_dp_world_size, - bf16_mode=bf16_mode) + zero_ckpt_names = self._get_mp_rank_zero_checkpoint_names( + load_dir=load_dir, + tag=tag, + mp_rank=mp_rank, + dp_world_size=self.loaded_checkpoint_dp_world_size, + bf16_mode=bf16_mode, + ) for i, ckpt_name in enumerate(zero_ckpt_names): if not os.path.exists(ckpt_name): # transparently handle the old file pattern for optim_states if "optim_states.pt" in ckpt_name: - ckpt_name_try = ckpt_name.replace("_optim_states.pt", "optim_states.pt") + ckpt_name_try = ckpt_name.replace( + "_optim_states.pt", "optim_states.pt" + ) if os.path.exists(ckpt_name_try): zero_ckpt_names[i] = ckpt_name_try continue @@ -3700,28 +4513,37 @@ def _get_all_zero_checkpoint_state_dicts(self, zero_ckpt_names): if ckpt_name is None: _state = {OPTIMIZER_STATE_DICT: None} # Fully load state for current rank - elif self.zero_elastic_checkpoint() or dist.get_rank(group=self.optimizer.dp_process_group) == i: + elif ( + self.zero_elastic_checkpoint() + or dist.get_rank(group=self.optimizer.dp_process_group) == i + ): _state = self.checkpoint_engine.load( ckpt_name, - map_location='cpu', + map_location="cpu", ) else: _state = {OPTIMIZER_STATE_DICT: None} zero_sd_list.append(_state) zero_optimizer_sd = [sd[OPTIMIZER_STATE_DICT] for sd in zero_sd_list] - logger.info(f"successfully read {len(zero_optimizer_sd)} ZeRO state_dicts for rank {self.global_rank}") + logger.info( + f"successfully read {len(zero_optimizer_sd)} ZeRO state_dicts for rank {self.global_rank}" + ) return zero_optimizer_sd def _get_all_zero_checkpoints(self, load_dir, tag): for bf16_mode in [self.bfloat16_enabled(), not self.bfloat16_enabled()]: - zero_ckpt_names = self._get_all_zero_checkpoint_names(load_dir, tag, bf16_mode) + zero_ckpt_names = self._get_all_zero_checkpoint_names( + load_dir, tag, bf16_mode + ) if zero_ckpt_names is not None: # Warn if loading checkpoint of different bit16 type if bf16_mode is not self.bfloat16_enabled(): checkpoint_bit16 = BFLOAT16 if bf16_mode else FP16 engine_bit16 = BFLOAT16 if self.bfloat16_enabled() else FP16 - logger.warning(f'Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine') + logger.warning( + f"Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine" + ) return self._get_all_zero_checkpoint_state_dicts(zero_ckpt_names) return None @@ -3735,15 +4557,24 @@ def _checkpoint_tag_validation(self, tag): dist.all_reduce(max_bhash, op=dist.ReduceOp.MAX) dist.all_reduce(min_bhash, op=dist.ReduceOp.MIN) valid = all(min_bhash == bhash) and all(max_bhash == bhash) - msg = (f"[rank={dist.get_rank()}] The checkpoint tag name '{tag}' is not consistent across " - "all ranks. Including rank unique information in checkpoint tag could cause issues when " - "restoring with different world sizes.") + msg = ( + f"[rank={dist.get_rank()}] The checkpoint tag name '{tag}' is not consistent across " + "all ranks. Including rank unique information in checkpoint tag could cause issues when " + "restoring with different world sizes." + ) if self.checkpoint_tag_validation_fail(): assert valid, msg elif not valid: logger.warning(msg) - def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True, exclude_frozen_parameters=False): + def save_checkpoint( + self, + save_dir, + tag=None, + client_state={}, + save_latest=True, + exclude_frozen_parameters=False, + ): """Save training checkpoint Arguments: @@ -3778,7 +4609,9 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True, # Ensure tag is a string tag = str(tag) - commit_info = CheckpointCommitInfo(tag=tag, save_dir=save_dir, save_latest=save_latest) + commit_info = CheckpointCommitInfo( + tag=tag, save_dir=save_dir, save_latest=save_latest + ) self.checkpoint_engine.create(commit_info) @@ -3788,10 +4621,12 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True, if self.has_moe_layers: self.save_non_zero_checkpoint = False self._create_checkpoint_file(save_dir, tag, False) - self._save_moe_checkpoint(save_dir, - tag, - client_state=client_state, - exclude_frozen_parameters=exclude_frozen_parameters) + self._save_moe_checkpoint( + save_dir, + tag, + client_state=client_state, + exclude_frozen_parameters=exclude_frozen_parameters, + ) # We distribute the task of saving layer checkpoint files among # data parallel instances, so all procs should call _save_checkpoint. @@ -3799,10 +4634,12 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True, # parallel rank 0 save the general model params. if not self.has_moe_layers: self._create_checkpoint_file(save_dir, tag, False) - self._save_checkpoint(save_dir, - tag, - client_state=client_state, - exclude_frozen_parameters=exclude_frozen_parameters) + self._save_checkpoint( + save_dir, + tag, + client_state=client_state, + exclude_frozen_parameters=exclude_frozen_parameters, + ) if self.save_zero_checkpoint: self._create_zero_checkpoint_files(save_dir, tag) @@ -3810,29 +4647,40 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True, if self.zero_nvme_offload_optimizer(): from shutil import copytree, disk_usage + rank_dir = "rank" + dp_index_to_str(rank) offload_dir = self.optimizer.optimizer_swapper.swap_folder - offload_ckpt_dir = os.path.join(save_dir, tag, "offloaded_tensors", rank_dir) + offload_ckpt_dir = os.path.join( + save_dir, tag, "offloaded_tensors", rank_dir + ) _, _, free = disk_usage(save_dir) logger.info( f"Copying NVMe offload files from {offload_dir} to {offload_ckpt_dir}, {free / 1e9:,.2f} GB free on target filesystem..." ) - copytree(offload_dir, - offload_ckpt_dir, - ignore=lambda _, dir_list: list(filter(lambda x: 'gradient' in x, dir_list)), - dirs_exist_ok=False) + copytree( + offload_dir, + offload_ckpt_dir, + ignore=lambda _, dir_list: list( + filter(lambda x: "gradient" in x, dir_list) + ), + dirs_exist_ok=False, + ) _, _, free = disk_usage(save_dir) - logger.info(f"Copying complete! {free / 1e9:,.2f} GB free on target filesystem") + logger.info( + f"Copying complete! {free / 1e9:,.2f} GB free on target filesystem" + ) if self._optimizer_has_ckpt_event_epilogue(): self.optimizer.checkpoint_event_epilogue() # Save latest checkpoint tag if not self.checkpoint_engine.is_decoupled(): - commit_info = CheckpointCommitInfo(tag=tag, save_dir=save_dir, save_latest=save_latest) + commit_info = CheckpointCommitInfo( + tag=tag, save_dir=save_dir, save_latest=save_latest + ) self.checkpoint_engine.commit(commit_info) if save_latest and self.global_rank == 0: - with open(os.path.join(save_dir, 'latest'), 'w') as fd: + with open(os.path.join(save_dir, "latest"), "w") as fd: fd.write(tag) dist.barrier() @@ -3840,8 +4688,9 @@ def save_checkpoint(self, save_dir, tag=None, client_state={}, save_latest=True, return True def _commit_decoupled_checkpoint(self): - assert self.checkpoint_engine.is_decoupled(), \ - f'{self.checkpoint_engine} is not a Decoupled Checkpoint Engine' + assert ( + self.checkpoint_engine.is_decoupled() + ), f"{self.checkpoint_engine} is not a Decoupled Checkpoint Engine" commit_info = self.checkpoint_engine.get_commit_info() if commit_info is None: @@ -3850,22 +4699,24 @@ def _commit_decoupled_checkpoint(self): self.checkpoint_engine.commit(commit_info) if self.global_rank == 0 and commit_info.save_latest: - with open(os.path.join(commit_info.save_dir, 'latest'), 'w') as fd: + with open(os.path.join(commit_info.save_dir, "latest"), "w") as fd: fd.write(commit_info.tag) dist.barrier() def _get_non_moe_state_dict(self, full_state_dict): """ - Get the state dict of the non-moe layers + Get the state dict of the non-moe layers """ for key in list(full_state_dict.keys()): - if 'expert' in key and 'moe.gate.wg.weight' not in key: + if "expert" in key and "moe.gate.wg.weight" not in key: full_state_dict.pop(key) return full_state_dict - def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_parameters=False): + def _save_moe_checkpoint( + self, save_dir, tag, client_state={}, exclude_frozen_parameters=False + ): save_path = self._get_ckpt_name(save_dir, tag) # A hack to save the checkpointing directory. Pipeline parallelism overrides @@ -3889,9 +4740,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa # get all moe parameters moe_state_dict = {} for n, p in module.state_dict().items(): - if 'expert' in n and 'moe.gate.wg.weight' not in n: - moe_state_dict[n_module + '.' + n] = p - moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.' + if "expert" in n and "moe.gate.wg.weight" not in n: + moe_state_dict[n_module + "." + n] = p + moe_str_prefix = ".deepspeed_moe.experts.deepspeed_experts." # print(moe_state_dict.keys()) # until now, everything is fine. So the bug happens at next few lines # Reorder the moe name rank, so that each checkpoint only has one expert experts_state_dict = defaultdict(dict) @@ -3900,14 +4751,17 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa local_expert_id = None if not m: - logger.warning(f'No expert found in key {key}.') + logger.warning(f"No expert found in key {key}.") else: local_expert_id = m.group(1) - global_expert_id = expp_rank * \ - num_local_experts + int(local_expert_id) - expert_key = key.replace(f'{moe_str_prefix}{local_expert_id}', - f'{moe_str_prefix}{global_expert_id}') + global_expert_id = expp_rank * num_local_experts + int( + local_expert_id + ) + expert_key = key.replace( + f"{moe_str_prefix}{local_expert_id}", + f"{moe_str_prefix}{global_expert_id}", + ) # truncating extra tensor (shared) storage truncated = moe_state_dict.pop(key).clone().detach() experts_state_dict[str(global_expert_id)][expert_key] = truncated @@ -3915,12 +4769,18 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa # let save the moe parameters for global_expert_id, expert_state_dict in experts_state_dict.items(): # save the moe parameters - moe_save_path = self._get_expert_ckpt_name(save_dir, moe_layer_id, global_expert_id, tag, self.mpu) + moe_save_path = self._get_expert_ckpt_name( + save_dir, moe_layer_id, global_expert_id, tag, self.mpu + ) if self.random_ltd_enabled(): - expert_state_dict = remove_random_ltd_state_dict(expert_state_dict) + expert_state_dict = remove_random_ltd_state_dict( + expert_state_dict + ) saveable_state_dict = expert_state_dict if self.checkpoint_engine.preserves_storage_sharing(): - saveable_state_dict = clone_tensors_for_torch_save(expert_state_dict) + saveable_state_dict = clone_tensors_for_torch_save( + expert_state_dict + ) self.checkpoint_engine.save(saveable_state_dict, moe_save_path) moe_layer_id += 1 @@ -3938,7 +4798,11 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa # Save optimizer states. They are different across each exp parallel rank. optimizer_state = { - 'optimizer': self.optimizer.state_dict() if self.optimizer and not self.zero_optimization() else None + "optimizer": ( + self.optimizer.state_dict() + if self.optimizer and not self.zero_optimization() + else None + ) } # TODO: why use BufferedWriter not the path file_path = self._get_optimizer_ckpt_name(save_dir, tag, expp_rank) @@ -3954,43 +4818,51 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa # DeepSpeedEngine returns the state dict, where PipelineEngine saves the state dict and returns None. # We need to get the state dict, therefore, call to DeepSpeedEngine (base class for PipelineEngine) model_state_dict = self._get_non_moe_state_dict( - DeepSpeedEngine.module_state_dict(self, exclude_frozen_parameters=exclude_frozen_parameters)) + DeepSpeedEngine.module_state_dict( + self, exclude_frozen_parameters=exclude_frozen_parameters + ) + ) # TODO: update num experts info,.. in checkpoint state = { - 'module': - model_state_dict, - 'lr_scheduler': - self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, - 'data_sampler': - self.training_dataloader.data_sampler.state_dict() if - (self.training_dataloader is not None and self.curriculum_learning_enabled()) else None, - 'random_ltd': - self.random_ltd_scheduler.state_dict() if self.random_ltd_enabled() else None, - 'sparse_tensor_module_names': - self.sparse_tensor_module_names, - 'skipped_steps': - self.skipped_steps, - 'global_steps': - self.global_steps, - 'global_samples': - self.global_samples, - 'dp_world_size': - self.seq_dp_world_size, - 'mp_world_size': - self.mp_world_size, - 'num_experts': - self.num_experts + "module": model_state_dict, + "lr_scheduler": ( + self.lr_scheduler.state_dict() + if self.lr_scheduler is not None + else None + ), + "data_sampler": ( + self.training_dataloader.data_sampler.state_dict() + if ( + self.training_dataloader is not None + and self.curriculum_learning_enabled() + ) + else None + ), + "random_ltd": ( + self.random_ltd_scheduler.state_dict() + if self.random_ltd_enabled() + else None + ), + "sparse_tensor_module_names": self.sparse_tensor_module_names, + "skipped_steps": self.skipped_steps, + "global_steps": self.global_steps, + "global_samples": self.global_samples, + "dp_world_size": self.seq_dp_world_size, + "mp_world_size": self.mp_world_size, + "num_experts": self.num_experts, } state.update(client_state) - logger.info(f'Saving model checkpoint: {save_path}') + logger.info(f"Saving model checkpoint: {save_path}") saveable_state_dict = state if self.checkpoint_engine.preserves_storage_sharing(): saveable_state_dict = clone_tensors_for_torch_save(state) self.checkpoint_engine.save(saveable_state_dict, save_path) def _create_checkpoint_file(self, save_dir, tag, zero_checkpoint): - name_function = (self._get_zero_ckpt_name if zero_checkpoint else self._get_ckpt_name) + name_function = ( + self._get_zero_ckpt_name if zero_checkpoint else self._get_ckpt_name + ) try: checkpoint_name = name_function(save_dir, tag) path = os.path.dirname(checkpoint_name) @@ -4010,48 +4882,89 @@ def _create_zero_checkpoint_files(self, save_dir, tag): return success - def _save_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_parameters=False): + def _save_checkpoint( + self, save_dir, tag, client_state={}, exclude_frozen_parameters=False + ): save_path = self._get_ckpt_name(save_dir, tag) zero_optimizer_state = self.zero_optimization() - save_frozen_param = self.zero_optimization_partition_gradients() and not exclude_frozen_parameters + save_frozen_param = ( + self.zero_optimization_partition_gradients() + and not exclude_frozen_parameters + ) # A hack to save the checkpointing directory. Pipeline parallelism overrides # module_state_dict() and uses this path to save the model. module_state_dict() # then instead just returns None. The module_state_dict() implementation in # PipelineEngine expects the save path to be set in self._curr_ckpt_path. self._curr_ckpt_path = os.path.join(save_dir, tag) - module = self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters) + module = self.module_state_dict( + exclude_frozen_parameters=exclude_frozen_parameters + ) self._curr_ckpt_path = None - state = dict(module=module, - buffer_names=self._get_buffer_names(), - optimizer=self.optimizer.state_dict() if self.optimizer and not zero_optimizer_state else None, - param_shapes=self._get_zero_param_shapes() if self.optimizer and zero_optimizer_state else None, - frozen_param_shapes=self._get_zero_frozen_param_attributes(self._get_param_shape_func) - if save_frozen_param else None, - shared_params=self._get_shared_params() if self.optimizer and zero_optimizer_state else None, - frozen_param_fragments=self._get_zero_frozen_param_attributes(self._get_param_fragment_func) - if save_frozen_param else None, - lr_scheduler=self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, - data_sampler=self.training_dataloader.data_sampler.state_dict() if - (self.training_dataloader is not None and self.curriculum_learning_enabled()) else None, - random_ltd=self.random_ltd_scheduler.state_dict() if self.random_ltd_enabled() else None, - sparse_tensor_module_names=self.sparse_tensor_module_names, - skipped_steps=self.skipped_steps, - global_steps=self.global_steps, - global_samples=self.global_samples, - dp_world_size=self.seq_dp_world_size, - mp_world_size=self.mp_world_size, - ds_config=self.config, - ds_version=version) + state = dict( + module=module, + buffer_names=self._get_buffer_names(), + optimizer=( + self.optimizer.state_dict() + if self.optimizer and not zero_optimizer_state + else None + ), + param_shapes=( + self._get_zero_param_shapes() + if self.optimizer and zero_optimizer_state + else None + ), + frozen_param_shapes=( + self._get_zero_frozen_param_attributes(self._get_param_shape_func) + if save_frozen_param + else None + ), + shared_params=( + self._get_shared_params() + if self.optimizer and zero_optimizer_state + else None + ), + frozen_param_fragments=( + self._get_zero_frozen_param_attributes(self._get_param_fragment_func) + if save_frozen_param + else None + ), + lr_scheduler=( + self.lr_scheduler.state_dict() + if self.lr_scheduler is not None + else None + ), + data_sampler=( + self.training_dataloader.data_sampler.state_dict() + if ( + self.training_dataloader is not None + and self.curriculum_learning_enabled() + ) + else None + ), + random_ltd=( + self.random_ltd_scheduler.state_dict() + if self.random_ltd_enabled() + else None + ), + sparse_tensor_module_names=self.sparse_tensor_module_names, + skipped_steps=self.skipped_steps, + global_steps=self.global_steps, + global_samples=self.global_samples, + dp_world_size=self.seq_dp_world_size, + mp_world_size=self.mp_world_size, + ds_config=self.config, + ds_version=version, + ) autotp_uc_info = getattr(self.module, UNIVERSAL_CHECKPOINT_INFO, None) if autotp_uc_info is not None: state[UNIVERSAL_CHECKPOINT_INFO] = autotp_uc_info state.update(client_state) - log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0]) + log_dist(message=f"Saving model checkpoint: {save_path}", ranks=[0]) if self.save_non_zero_checkpoint: self.checkpoint_engine.save(state_dict=state, path=save_path) @@ -4078,10 +4991,14 @@ def get_layer_named_buffers(module, prefix=""): return buffer_names def _get_param_shape_func(self, param): - return param.ds_shape if hasattr(param, 'ds_id') else param.shape + return param.ds_shape if hasattr(param, "ds_id") else param.shape def _get_param_fragment_func(self, param): - return param.ds_tensor.detach().cpu() if hasattr(param, 'ds_id') else param.detach().cpu() + return ( + param.ds_tensor.detach().cpu() + if hasattr(param, "ds_id") + else param.detach().cpu() + ) def _get_zero_frozen_param_attributes(self, attr_func): frozen_param_fragments = OrderedDict() @@ -4117,8 +5034,11 @@ def _get_zero_param_shapes(self): elif self.bfloat16_enabled() and hasattr(self.optimizer, "bf16_groups"): bit16_groups = self.optimizer.bf16_groups else: - bit16_groups = self.optimizer.bit16_groups if self.zero_optimization_stage( - ) == 2 else self.optimizer.fp16_groups + bit16_groups = ( + self.optimizer.bit16_groups + if self.zero_optimization_stage() == 2 + else self.optimizer.fp16_groups + ) for bit16_group in bit16_groups: param_shapes = OrderedDict() @@ -4147,8 +5067,9 @@ def _get_shared_params(self): shared_index = {} shared_params_by_full_name = {} - is_zero3_model = (self.zero_optimization_partition_weights() - and any(hasattr(param, "ds_id") for param in self.module.parameters())) + is_zero3_model = self.zero_optimization_partition_weights() and any( + hasattr(param, "ds_id") for param in self.module.parameters() + ) def get_layer_state_dict(module, prefix=""): # handle params @@ -4164,7 +5085,7 @@ def get_layer_state_dict(module, prefix=""): if param_id in shared_index: # shared weights - #print(f"`{key}` is shared with `{shared_index[param_id]}`") + # print(f"`{key}` is shared with `{shared_index[param_id]}`") shared_params_by_full_name[key] = shared_index[param_id] else: shared_index[param_id] = key @@ -4183,7 +5104,7 @@ def _copy_recovery_script(self, save_path): script = "zero_to_fp32.py" src = os.path.join(base_dir, "utils", script) dst = os.path.join(save_path, script) - #logger.info(f"creating recovery script {dst}") + # logger.info(f"creating recovery script {dst}") copyfile(src, dst) self._change_recovery_script_permissions(dst) @@ -4192,20 +5113,24 @@ def _change_recovery_script_permissions(self, dst): try: os.chmod(dst, os.stat(dst).st_mode | stat.S_IEXEC) except (FileNotFoundError, PermissionError) as e: - #this message is used in unit test TestZeRONonDistributed + # this message is used in unit test TestZeRONonDistributed logger.info( - f'Warning: Could not change permissions for {dst} due to error: {e}. Continuing without changing permissions.' + f"Warning: Could not change permissions for {dst} due to error: {e}. Continuing without changing permissions." ) def _save_zero_checkpoint(self, save_path, tag): zero_checkpoint_name = self._get_zero_ckpt_name(save_path, tag) - zero_sd = dict(optimizer_state_dict=self.optimizer.state_dict(), ds_config=self.config, ds_version=version) + zero_sd = dict( + optimizer_state_dict=self.optimizer.state_dict(), + ds_config=self.config, + ds_version=version, + ) self.checkpoint_engine.save(zero_sd, zero_checkpoint_name) if self.global_rank == 0: self._copy_recovery_script(save_path) - ckpt_type = 'zero' if self.zero_optimization() else 'bf16_zero' - #logger.info(f'{ckpt_type} checkpoint saved {zero_checkpoint_name}') + ckpt_type = "zero" if self.zero_optimization() else "bf16_zero" + # logger.info(f'{ckpt_type} checkpoint saved {zero_checkpoint_name}') def _replace_module_consolidated_state_dict(self): """ @@ -4217,17 +5142,19 @@ def _replace_module_consolidated_state_dict(self): Returns: OrderedDict: The consolidated state dictionary if the current process rank is 0, otherwise None. """ - #TODO: If we use both Zero3 and tensor parallel simultaneously + # TODO: If we use both Zero3 and tensor parallel simultaneously # we need to consolidate the gather mechanisms of both. state_dict = OrderedDict() if dist.get_rank() == 0 else None def get_layer_state_dict(module, prefix=""): - with GatherReplacedLayerParams(list(module.parameters(recurse=False)), module, enabled=True): + with GatherReplacedLayerParams( + list(module.parameters(recurse=False)), module, enabled=True + ): for name, param in module.named_parameters(recurse=False): if param is None: continue key = prefix + name - if (dist.get_rank() == 0): + if dist.get_rank() == 0: state_dict[key] = param.detach().cpu() # print(key,module, param.detach().cpu().shape) @@ -4250,8 +5177,10 @@ def _consolidated_16bit_state_dict(self, exclude_frozen_parameters=False): elif self.autotp_size() > 1: return self._replace_module_consolidated_state_dict() - raise ValueError("consolidated_16bit_state_dict is only applicable to cases where weights are partitioned, " - "including Zero Stage 3 and tensor parallelism.") + raise ValueError( + "consolidated_16bit_state_dict is only applicable to cases where weights are partitioned, " + "including Zero Stage 3 and tensor parallelism." + ) def _zero3_consolidated_16bit_state_dict(self, exclude_frozen_parameters=False): """ @@ -4274,12 +5203,16 @@ def _zero3_consolidated_16bit_state_dict(self, exclude_frozen_parameters=False): def get_layer_state_dict(module, prefix=""): # gather one layer at a time to be memory-efficient # must use modifier_rank=0 to release GPU memory after each layer gathered - #see_memory_usage("before GatheredParameters", force=True) - with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): + # see_memory_usage("before GatheredParameters", force=True) + with deepspeed.zero.GatheredParameters( + list(module.parameters(recurse=False)), modifier_rank=0 + ): if dist.get_rank() == 0: # handle params for name, param in module.named_parameters(recurse=False): - if param is None or (exclude_frozen_parameters and not param.requires_grad): + if param is None or ( + exclude_frozen_parameters and not param.requires_grad + ): continue key = prefix + name # can't rely on param.data_ptr() as it will be reused as weights gets @@ -4287,18 +5220,21 @@ def get_layer_state_dict(module, prefix=""): # (and shared params will have the same param.ds_id) if param.ds_id in shared_params: # shared weights - #print(f"`{key}` is shared with `{shared_params[param.ds_id]}`") + # print(f"`{key}` is shared with `{shared_params[param.ds_id]}`") state_dict[key] = state_dict[shared_params[param.ds_id]] else: state_dict[key] = param.detach().cpu() shared_params[param.ds_id] = key - #print(f"param {param.ds_id} {param.shape} {key} ") + # print(f"param {param.ds_id} {param.shape} {key} ") # now buffers - not sure if need to take care of potentially shared weights here for name, buf in module.named_buffers(recurse=False): - if (buf is not None and name not in module._non_persistent_buffers_set): + if ( + buf is not None + and name not in module._non_persistent_buffers_set + ): state_dict[prefix + name] = buf.detach().cpu() - #see_memory_usage("after GatheredParameters", force=True) + # see_memory_usage("after GatheredParameters", force=True) for name, child in module.named_children(): if child is not None: @@ -4322,7 +5258,12 @@ def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"): compatibility""" return self.save_16bit_model(save_dir, save_filename) - def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin", exclude_frozen_parameters=False): + def save_16bit_model( + self, + save_dir, + save_filename="pytorch_model.bin", + exclude_frozen_parameters=False, + ): """ Save 16bit model weights @@ -4349,18 +5290,24 @@ def save_16bit_model(self, save_dir, save_filename="pytorch_model.bin", exclude_ if self.zero_gather_16bit_weights_on_model_save(): # consolidation is expensive in time and memory and therefore isn't a default state_dict = self._zero3_consolidated_16bit_state_dict( - exclude_frozen_parameters=exclude_frozen_parameters) + exclude_frozen_parameters=exclude_frozen_parameters + ) else: # the model will be bogus if not consolidated so don't confuse the user by saving it logger.info( - f"Did not save the model {path} because stage3_gather_16bit_weights_on_model_save is False") + f"Did not save the model {path} because stage3_gather_16bit_weights_on_model_save is False" + ) return False else: - state_dict = self.module_state_dict(exclude_frozen_parameters=exclude_frozen_parameters) + state_dict = self.module_state_dict( + exclude_frozen_parameters=exclude_frozen_parameters + ) tag = f"global_step{self.global_steps}" tag = str(tag) - commit_info = CheckpointCommitInfo(tag=tag, save_dir=save_dir, save_latest=False) + commit_info = CheckpointCommitInfo( + tag=tag, save_dir=save_dir, save_latest=False + ) self.checkpoint_engine.create(commit_info) if dist.get_rank() == 0: @@ -4376,11 +5323,20 @@ def empty_partition_cache(self): """ Release GPU memory consumed by offloaded model parameters. """ - if hasattr(self.optimizer, 'empty_partition_cache'): + if hasattr(self.optimizer, "empty_partition_cache"): self.optimizer.empty_partition_cache() gc.collect() get_accelerator().empty_cache() + + def compile( + self, + backend=get_accelerator().get_compile_backend(), + compile_kwargs={}, + schedule=None, + compiled_autograd_enabled=False, + ) -> None: + def get_autosp_backend(self, compile_kwargs): if self.compile_autosp() and self.zero_optimization_stage() not in [ ZeroStageEnum.disabled, ZeroStageEnum.optimizer_states @@ -4446,6 +5402,7 @@ def compile(self, compile_kwargs={}, schedule=None, compiled_autograd_enabled=False) -> None: + """Compile the module using the specified backend and kwargs. If a compiler_fn is set, it will be used instead of torch.compile(). """ @@ -4458,10 +5415,74 @@ def compile(self, if self.is_compiled: return - if 'backend' in compile_kwargs: - logger.warning("The `backend` in `compile_kwargs` will be overridden. Use the `backend` argument instead.") + if "backend" in compile_kwargs: + logger.warning( + "The `backend` in `compile_kwargs` will be overridden. Use the `backend` argument instead." + ) + + logger.info( + f"Compiling deepcompile={self.is_deepcompile_enabled()} backend={backend}" + ) - logger.info(f"Compiling deepcompile={self.is_deepcompile_enabled()} backend={backend}") + + enable_deepcompile = self.is_deepcompile_enabled() + if ( + enable_deepcompile + and self.zero_optimization_stage() != ZeroStageEnum.optimizer_states + and self.zero_optimization_stage() != ZeroStageEnum.weights + and self.zero_optimization_stage() != ZeroStageEnum.gradients + ): + logger.info( + f"Currently DeepCompile supports ZeRO stage 1, 2, or 3 only, but ZeRO stage is set to {self.zero_optimization_stage()}. Falling back to the torch compiler." + ) + enable_deepcompile = False + + if enable_deepcompile: + + if schedule is not None: + + def passes_name_to_fn(passes): + for p in passes: + assert callable(p) or p in opt_passes, f"Unknown pass {p}" + return [p if callable(p) else opt_passes[p] for p in passes] + + schedule = [ + (step, passes_name_to_fn(passes)) for step, passes in schedule + ] + + assert backend in [ + "inductor", + "eager", + ], f"Backend {backend} is not supported for DeepCompile." + + compile_config = self._config.compile_config + if ( + ( + "zero_optimization" in self.config + and "offload_optimizer" in self.config["zero_optimization"] + and "offload_param" in self.config["zero_optimization"] + ) + and self._config.zero_config.offload_param.device == "cpu" + and self._config.zero_config.offload_optimizer.device == "cpu" + ): + compile_config.offload_parameters = True + if self.zero_optimization_stage() == ZeroStageEnum.optimizer_states: + backend = init_z1( + self, backend, compile_config, compile_kwargs, schedule + ) + elif self.zero_optimization_stage() == ZeroStageEnum.gradients: + backend = init_z1( + self, backend, compile_config, compile_kwargs, schedule, use_z2=True + ) + elif self.zero_optimization_stage() == ZeroStageEnum.weights: + if required_torch_version(min_version=2.9): + raise RuntimeError( + "DeepCompile with ZeRO stage 3 is not currently supported on PyTorch >= 2.9. " + "Please use ZeRO stage 1 or 2 with DeepCompile, or disable DeepCompile for ZeRO stage 3." + ) + backend = init_z3( + self, backend, compile_config, compile_kwargs, schedule + ) resolved_backend = None if self.is_deepcompile_enabled(): @@ -4472,12 +5493,13 @@ def compile(self, # default to torch.compiler backend if deepspeed config validation fails backend = resolved_backend or backend + # Hook state must align with whether DeepCompile is active. self._set_deepcompile_active(is_deepspeed_compile_backend) # create new dict to avoid modifying original dict try: - self.module.compile(**{**compile_kwargs, 'backend': backend}) + self.module.compile(**{**compile_kwargs, "backend": backend}) except Exception: if is_deepspeed_compile_backend: # Restore default hooks if compilation fails before completing. @@ -4490,7 +5512,9 @@ def compile(self, if not self._deepcompile_active: self._is_compiled_autograd_enabled = compiled_autograd_enabled else: - logger.warning("Compiled autograd is not compatible with DeepCompile, disabling compiled autograd.") + logger.warning( + "Compiled autograd is not compatible with DeepCompile, disabling compiled autograd." + ) self._is_compiled_autograd_enabled = False def _set_deepcompile_active(self, active: bool) -> None: @@ -4515,6 +5539,7 @@ def _set_deepcompile_active(self, active: bool) -> None: def get_compile_time(self): from deepspeed.compile.backend import opt_pass_times + return opt_pass_times def register_compile_pass(self, pass_name: str, pass_fn: Callable) -> None: @@ -4530,11 +5555,13 @@ def is_deepcompile_active(self) -> bool: def is_compiled(self) -> bool: return self._is_compiled - def offload_states(self, - include: Container[OffloadStateTypeEnum] = None, - device: OffloadDeviceEnum = OffloadDeviceEnum.cpu, - pin_memory: bool = True, - non_blocking: bool = False) -> None: + def offload_states( + self, + include: Container[OffloadStateTypeEnum] = None, + device: OffloadDeviceEnum = OffloadDeviceEnum.cpu, + pin_memory: bool = True, + non_blocking: bool = False, + ) -> None: """Offload the engine's states to the specified device. Arguments: @@ -4544,13 +5571,19 @@ def offload_states(self, non_blocking: Optional. Whether to offload the states asynchronously. """ opt_offload_config = self.zero_offload_optimizer() - assert opt_offload_config is None or opt_offload_config.device == OffloadDeviceEnum.none, "Moving states across devices is not supported for offloaded optimizer states." + assert ( + opt_offload_config is None + or opt_offload_config.device == OffloadDeviceEnum.none + ), "Moving states across devices is not supported for offloaded optimizer states." param_offload_config = self.zero_offload_param() - assert param_offload_config is None or param_offload_config.device == OffloadDeviceEnum.none, "Moving states across devices is not supported for offloaded parameters." + assert ( + param_offload_config is None + or param_offload_config.device == OffloadDeviceEnum.none + ), "Moving states across devices is not supported for offloaded parameters." assert not isinstance( - self.optimizer, - DeepSpeedZeRoOffload), "Moving states across devices is not supported without an optimizer." + self.optimizer, DeepSpeedZeRoOffload + ), "Moving states across devices is not supported without an optimizer." if device == OffloadDeviceEnum.none: logger.warning("No device specified for offloading states.") @@ -4559,7 +5592,12 @@ def offload_states(self, if device == OffloadDeviceEnum.nvme: raise ValueError("NVMe offload is not supported for offloading states.") - self.optimizer.offload_states(include=include, device=device, pin_memory=pin_memory, non_blocking=non_blocking) + self.optimizer.offload_states( + include=include, + device=device, + pin_memory=pin_memory, + non_blocking=non_blocking, + ) def reload_states(self, non_blocking: bool = False) -> None: """Reload the engine states to the original device. @@ -4568,7 +5606,7 @@ def reload_states(self, non_blocking: bool = False) -> None: non_blocking: Optional. Whether to offload the states asynchronously. """ assert not isinstance( - self.optimizer, - DeepSpeedZeRoOffload), "Moving states across devices is not supported without an optimizer." + self.optimizer, DeepSpeedZeRoOffload + ), "Moving states across devices is not supported without an optimizer." self.optimizer.reload_states(non_blocking=non_blocking)