diff --git a/README.md b/README.md index 18185bb2a33c..ce56b5915fdc 100755 --- a/README.md +++ b/README.md @@ -118,7 +118,7 @@ dynamically link them at runtime. ## Requirements * [PyTorch](https://pytorch.org/) must be installed _before_ installing DeepSpeed. -* For full feature support we recommend a version of PyTorch that is >= 1.9 and ideally the latest PyTorch stable release. +* For full feature support we recommend a version of PyTorch that is >= 2.0 and ideally the latest PyTorch stable release. * A CUDA or ROCm compiler such as [nvcc](https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/#introduction) or [hipcc](https://github.com/ROCm-Developer-Tools/HIPCC) used to compile C++/CUDA/HIP extensions. * Specific GPUs we develop and test against are listed below, this doesn't mean your GPU will not work if it doesn't fall into this category it's just DeepSpeed is most well tested on the following: * NVIDIA: Pascal, Volta, Ampere, and Hopper architectures diff --git a/deepspeed/runtime/zero/linear.py b/deepspeed/runtime/zero/linear.py index 1227f4ee356b..7421fd10c5ef 100644 --- a/deepspeed/runtime/zero/linear.py +++ b/deepspeed/runtime/zero/linear.py @@ -56,15 +56,31 @@ def _get_autocast_decorators(): autocast_custom_fwd, autocast_custom_bwd = _get_autocast_decorators() +def _is_autocast_enabled(device_type): + try: + return torch.is_autocast_enabled(device_type) + except TypeError: + legacy_getter = getattr(torch, f'is_autocast_{device_type}_enabled', None) + if legacy_getter is not None: + return legacy_getter() + return torch.is_autocast_enabled() + + +def _get_autocast_dtype(device_type): + try: + return torch.get_autocast_dtype(device_type) + except TypeError: + legacy_getter = getattr(torch, f'get_autocast_{device_type}_dtype', None) + if legacy_getter is not None: + return legacy_getter() + return None + + class LinearFunctionForZeroStage3(torch.autograd.Function): - # Note that both forward and backward are @staticmethods @staticmethod - @autocast_custom_fwd # bias is an optional argument - def forward(ctx, input, weight, bias=None): - - ctx.save_for_backward(input, weight, bias) + def forward(input, weight, bias=None): if input.dim() == 2 and bias is not None: # fused op is marginally faster @@ -77,54 +93,45 @@ def forward(ctx, input, weight, bias=None): return ret + @staticmethod + def setup_context(ctx, inputs, output): + device_type = get_accelerator().device_name() + ctx._dtype = _get_autocast_dtype(device_type) + ctx._fwd_used_autocast = _is_autocast_enabled(device_type) + input, weight, bias = inputs[0], inputs[1], inputs[2] if len(inputs) > 2 else None + ctx.save_for_backward(input, weight, bias) + # This function has only a single output, so it gets only one gradient @staticmethod - @autocast_custom_bwd def backward(ctx, grad_output): - # This is a pattern that is very convenient - at the top of backward - # unpack saved_tensors and initialize all gradients w.r.t. inputs to - # None. Thanks to the fact that additional trailing Nones are - # ignored, the return statement is simple even when the function has - # optional inputs. - input, weight, bias = ctx.saved_tensors - - grad_input = grad_weight = grad_bias = None - - #print(f"backward shaped grad_output {grad_output.shape}, input {input.shape}, weight {weight.shape} and bias {bias.shape if bias is not None else None}") - # These needs_input_grad checks are optional and there only to - # improve efficiency. If you want to make your code simpler, you can - # skip them. Returning gradients for inputs that don't require it is - # not an error. - dim = grad_output.dim() - if ctx.needs_input_grad[0]: - #print(f"Computing grad input weight {weight.shape} grad_output {grad_output.shape}") - grad_input = grad_output.matmul(weight) - #print(f"Computed grad input {grad_input.shape}") - if ctx.needs_input_grad[1]: - #print("Computing grad weight") - if dim > 2: - grad_weight = grad_output.reshape(-1, - grad_output.shape[-1]).t().matmul(input.reshape(-1, input.shape[-1])) - else: - grad_weight = grad_output.t().matmul(input) - #print(f"Computed grad weight grad_weight {grad_weight.shape}") - if bias is not None and ctx.needs_input_grad[2]: - #print("Computing grad bias") - if dim > 2: - grad_bias = grad_output.sum([i for i in range(dim - 1)]) - else: - grad_bias = grad_output.sum(0) - #print("Done computing grad bias") - #print("needs bias") - #print(f"backward shaped grad_input {grad_input.shape}, grad_weight {grad_weight.shape}, grad_bias {grad_bias.shape if grad_bias is not None else None}") - return grad_input, grad_weight, grad_bias + # Match @custom_bwd semantics: always run backward under the same + # autocast state as forward — including explicitly disabling autocast + # when forward did not use it, to guard against outer autocast regions. + device_type = get_accelerator().device_name() + with torch.autocast(device_type=device_type, enabled=ctx._fwd_used_autocast, dtype=ctx._dtype): + input, weight, bias = ctx.saved_tensors + + grad_input = grad_weight = grad_bias = None + + dim = grad_output.dim() + if ctx.needs_input_grad[0]: + grad_input = grad_output.matmul(weight) + if ctx.needs_input_grad[1]: + if dim > 2: + grad_weight = grad_output.reshape(-1, grad_output.shape[-1]).t().matmul( + input.reshape(-1, input.shape[-1])) + else: + grad_weight = grad_output.t().matmul(input) + if bias is not None and ctx.needs_input_grad[2]: + if dim > 2: + grad_bias = grad_output.sum([i for i in range(dim - 1)]) + else: + grad_bias = grad_output.sum(0) + return grad_input, grad_weight, grad_bias def zero3_linear_wrap(input, weight, bias=None): - if bias is None: - return LinearFunctionForZeroStage3.apply(input, weight) - else: - return LinearFunctionForZeroStage3.apply(input, weight, bias) + return LinearFunctionForZeroStage3.apply(input, weight, bias) class LinearModuleForZeroStage3(Module): diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index c434ff738933..b42b3c8e263e 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -404,15 +404,16 @@ def _run_before_backward_function(sub_module): class PreBackwardFunctionForModule(torch.autograd.Function): @staticmethod - def forward(ctx, outputs): - # Capture `module` and _run_before_backward_function + def forward(outputs): + return outputs.detach() + + @staticmethod + def setup_context(ctx, inputs, output): ctx.module = module ctx.pre_backward_function = _run_before_backward_function if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"): ctx.module.applied_pre_backward_ref_cnt = 0 ctx.module.applied_pre_backward_ref_cnt += 1 - outputs = outputs.detach() - return outputs @staticmethod def backward(ctx, *args): @@ -434,9 +435,14 @@ def _run_after_backward_function(sub_module): class PostBackwardFunctionModule(torch.autograd.Function): @staticmethod - def forward(ctx, output): + def forward(output): + return output.detach() + + @staticmethod + def setup_context(ctx, inputs, output): + (output_in, ) = inputs ctx.module = module - if output.requires_grad: + if output_in.requires_grad: #TODO SOME TIMES post backward does not seem to be triggered debug in detail #Should only cause increase in memory not correctness issue #if output.grad_fn.__class__.__name__ == 'ViewBackward': @@ -447,8 +453,6 @@ def forward(ctx, output): # print(f"Before Forward: {ctx.module.__class__.__name__}") module.ds_grads_remaining += 1 ctx.post_backward_function = _run_after_backward_function - output = output.detach() - return output @staticmethod def backward(ctx, *args): diff --git a/requirements/requirements-readthedocs.txt b/requirements/requirements-readthedocs.txt index a48a47e4428d..aaac814354c4 100644 --- a/requirements/requirements-readthedocs.txt +++ b/requirements/requirements-readthedocs.txt @@ -7,5 +7,5 @@ py-cpuinfo pydantic>=2.0.0 recommonmark sphinx_rtd_theme -torch +torch>=2.0.0 tqdm diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 1af4c69c5807..1bbd21dd5e32 100755 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -7,5 +7,5 @@ packaging>=20.0 psutil py-cpuinfo pydantic>=2.0.0 -torch +torch>=2.0.0 tqdm diff --git a/tests/unit/v1/zero/test_zero_functorch_linear.py b/tests/unit/v1/zero/test_zero_functorch_linear.py new file mode 100644 index 000000000000..e56c214d997a --- /dev/null +++ b/tests/unit/v1/zero/test_zero_functorch_linear.py @@ -0,0 +1,203 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team +"""Regression: ZeRO-3 linear autograd.Function must work with torch.func transforms. + +ZeRO Stage 3 uses ``LinearFunctionForZeroStage3`` (via ``zero3_linear_wrap``) as +the memory-efficient linear path. After ``deepspeed.initialize``, global +``torch.nn.functional.linear`` is often the built-in again, so tests call +``zero3_linear_wrap`` directly-the same ``autograd.Function`` as when the patch +is active. Legacy ``forward(ctx, ...)`` + ``ctx.save_for_backward`` in forward +raises on strict functorch builds:: + + RuntimeError: In order to use an autograd.Function with functorch + transforms ... it must override the setup_context staticmethod. +""" + +import pytest +import torch +import torch.nn as nn + +import deepspeed +from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.zero.linear import zero3_linear_wrap + +from unit.common import DistributedTest + + +def _zero3_functorch_config(): + config = { + "train_micro_batch_size_per_gpu": 1, + "gradient_accumulation_steps": 1, + "steps_per_print": 2147483647, + "zero_optimization": { + "stage": 3, + "stage3_param_persistence_threshold": 0, + }, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-3 + }, + }, + } + acc = get_accelerator() + if acc.is_bf16_supported(): + config["bf16"] = {"enabled": True} + elif acc.is_fp16_supported(): + config["fp16"] = {"enabled": True, "initial_scale_power": 8} + return config + + +class TestZeroFunctorchLinearRegression(DistributedTest): + """``torch.func.grad_and_value`` over ``zero3_linear_wrap`` / LinearFunctionForZeroStage3.""" + + world_size = 1 + + def test_grad_and_value_over_patched_functional_linear(self): + if not hasattr(torch, "func"): + pytest.skip("torch.func not available") + + model = nn.Linear(8, 8, bias=True) + engine, _, _, _ = deepspeed.initialize( + model=model, + config=_zero3_functorch_config(), + model_parameters=model.parameters(), + ) + + device = engine.device + dtype = engine.module.weight.dtype + weight = torch.randn(8, 8, device=device, dtype=dtype, requires_grad=True) + inp = torch.randn(2, 8, device=device, dtype=dtype, requires_grad=True) + + with torch.enable_grad(): + probe = zero3_linear_wrap(inp, weight, None) + assert "LinearFunctionForZeroStage3" in type(probe.grad_fn).__name__ + + def loss_fn(w, x): + return zero3_linear_wrap(x, w, None).sum() + + grads, value = torch.func.grad_and_value(loss_fn, argnums=(0, 1))(weight, inp) + assert torch.isfinite(value) + assert grads[0] is not None and torch.isfinite(grads[0]).all() + assert grads[1] is not None and torch.isfinite(grads[1]).all() + + +class TestZeroLinearAutocast(DistributedTest): + """Verify autocast state is correctly propagated through forward and backward.""" + + world_size = 1 + + def _run_forward_backward(self, device, use_autocast, dtype=None): + """Run zero3_linear_wrap forward+backward, optionally inside autocast.""" + weight = torch.randn(4, 4, device=device, dtype=torch.float32, requires_grad=True) + inp = torch.randn(2, 4, device=device, dtype=torch.float32, requires_grad=True) + bias = torch.randn(4, device=device, dtype=torch.float32, requires_grad=True) + + if use_autocast: + with torch.amp.autocast(device_type=device.type, dtype=dtype): + out = zero3_linear_wrap(inp, weight, bias) + else: + out = zero3_linear_wrap(inp, weight, bias) + + loss = out.sum() + loss.backward() + return out, weight.grad, inp.grad, bias.grad + + def test_backward_without_autocast(self): + """Backward without autocast should produce float32 gradients.""" + model = nn.Linear(4, 4) + engine, _, _, _ = deepspeed.initialize( + model=model, + config=_zero3_functorch_config(), + model_parameters=model.parameters(), + ) + device = engine.device + + out, w_grad, i_grad, b_grad = self._run_forward_backward(device, use_autocast=False) + assert out.dtype == torch.float32 + assert w_grad.dtype == torch.float32 + assert i_grad.dtype == torch.float32 + assert b_grad.dtype == torch.float32 + + def test_backward_with_autocast(self): + """Backward with autocast should produce float32 gradients (autocast only affects forward).""" + acc = get_accelerator() + if acc.is_bf16_supported(): + amp_dtype = torch.bfloat16 + elif acc.is_fp16_supported(): + amp_dtype = torch.float16 + else: + pytest.skip("No half-precision support") + + model = nn.Linear(4, 4) + engine, _, _, _ = deepspeed.initialize( + model=model, + config=_zero3_functorch_config(), + model_parameters=model.parameters(), + ) + device = engine.device + + out, w_grad, i_grad, b_grad = self._run_forward_backward(device, use_autocast=True, dtype=amp_dtype) + # Forward output should be in reduced precision + assert out.dtype == amp_dtype + # Gradients accumulate in float32 (master weights) + assert w_grad.dtype == torch.float32 + assert i_grad.dtype == torch.float32 + assert b_grad.dtype == torch.float32 + + def test_no_autocast_leak_into_backward(self): + """When forward runs without autocast, an outer autocast during backward must not affect gradient dtype.""" + model = nn.Linear(4, 4) + engine, _, _, _ = deepspeed.initialize( + model=model, + config=_zero3_functorch_config(), + model_parameters=model.parameters(), + ) + device = engine.device + + acc = get_accelerator() + if acc.is_bf16_supported(): + amp_dtype = torch.bfloat16 + elif acc.is_fp16_supported(): + amp_dtype = torch.float16 + else: + pytest.skip("No half-precision support") + + weight = torch.randn(4, 4, device=device, dtype=torch.float32, requires_grad=True) + inp = torch.randn(2, 4, device=device, dtype=torch.float32, requires_grad=True) + + # Forward WITHOUT autocast + out = zero3_linear_wrap(inp, weight, None) + assert out.dtype == torch.float32 + + # Backward WITH an outer autocast region -- should NOT affect gradient computation + # because setup_context captured _fwd_used_autocast=False + with torch.amp.autocast(device_type=device.type, dtype=amp_dtype): + out.sum().backward() + + assert weight.grad.dtype == torch.float32 + assert inp.grad.dtype == torch.float32 + + def test_setup_context_stores_autocast_attrs(self): + """setup_context must store _fwd_used_autocast and _dtype on ctx.""" + model = nn.Linear(4, 4) + engine, _, _, _ = deepspeed.initialize( + model=model, + config=_zero3_functorch_config(), + model_parameters=model.parameters(), + ) + device = engine.device + + weight = torch.randn(4, 4, device=device, dtype=torch.float32, requires_grad=True) + inp = torch.randn(2, 4, device=device, dtype=torch.float32, requires_grad=True) + + # Without autocast: setup_context must record that forward did not use autocast + out = zero3_linear_wrap(inp, weight, None) + grad_fn = out.grad_fn + assert hasattr(grad_fn, "_fwd_used_autocast") + assert grad_fn._fwd_used_autocast is False + assert hasattr(grad_fn, "_dtype") + out.sum().backward() + assert torch.isfinite(weight.grad).all()