Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
33db7c4
fix: fix LinearFunctionForZeroStage3 to support torch.func transforms
roycho96 Mar 21, 2026
39b1755
fix: always pass bias arg in zero3_linear_wrap to avoid setup_context…
roycho96 Mar 21, 2026
6df37af
fix: remove @autocast_custom_fwd from forward, move autocast state to…
roycho96 Mar 22, 2026
c0b9694
fix(zero3): replace custom_bwd with explicit autocast for functorch-s…
zhangj1an Mar 22, 2026
5e83d05
fix(zero): use setup_context for offload pre/post backward Functions
zhangj1an Mar 22, 2026
7483701
Merge branch 'master' into fix/support-func-torch
zhangj1an Mar 24, 2026
a1e798d
run pre-commit checks
zhangj1an Mar 25, 2026
8762d00
update unit tests to reproduce main branch error
zhangj1an Mar 25, 2026
dd037da
add reproduce scripts
zhangj1an Mar 25, 2026
01ee5a6
Merge branch 'master' into fix/support-func-torch
zhangj1an Mar 25, 2026
f69c1f1
update reproduce script
zhangj1an Mar 25, 2026
e58ac18
update reproduce script to skip repeated env setup
zhangj1an Mar 25, 2026
3121a7f
update reproduce script to remove duplicated code
zhangj1an Mar 25, 2026
60d20da
update reproduce script to print test env
zhangj1an Mar 25, 2026
bb245b2
drop PyTorch < 2.0 support and fix autocast backward in ZeRO linear
roycho96 Mar 29, 2026
04c456f
change PyTorch version in README
roycho96 Mar 29, 2026
703aad3
resolve conflict with master
tohtana Mar 29, 2026
8468149
Merge pull request #1 from tohtana/tohtana/pr7916-merge-master-resolve
zhangj1an Mar 30, 2026
e309a6f
remove repro scripts
zhangj1an Mar 30, 2026
e425569
update unit test
zhangj1an Mar 30, 2026
39f7e3c
drop support for pytorch<2.0
zhangj1an Mar 30, 2026
39c9a73
Merge branch 'master' into fix/support-func-torch
zhangj1an Mar 31, 2026
a5aa09a
Merge branch 'master' into fix/support-func-torch
roycho96 Apr 4, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ dynamically link them at runtime.

## Requirements
* [PyTorch](https://pytorch.org/) must be installed _before_ installing DeepSpeed.
* For full feature support we recommend a version of PyTorch that is >= 1.9 and ideally the latest PyTorch stable release.
* For full feature support we recommend a version of PyTorch that is >= 2.0 and ideally the latest PyTorch stable release.
* A CUDA or ROCm compiler such as [nvcc](https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/#introduction) or [hipcc](https://github.com/ROCm-Developer-Tools/HIPCC) used to compile C++/CUDA/HIP extensions.
* Specific GPUs we develop and test against are listed below, this doesn't mean your GPU will not work if it doesn't fall into this category it's just DeepSpeed is most well tested on the following:
* NVIDIA: Pascal, Volta, Ampere, and Hopper architectures
Expand Down
101 changes: 54 additions & 47 deletions deepspeed/runtime/zero/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,31 @@ def _get_autocast_decorators():
autocast_custom_fwd, autocast_custom_bwd = _get_autocast_decorators()


def _is_autocast_enabled(device_type):
try:
return torch.is_autocast_enabled(device_type)
except TypeError:
legacy_getter = getattr(torch, f'is_autocast_{device_type}_enabled', None)
if legacy_getter is not None:
return legacy_getter()
return torch.is_autocast_enabled()


def _get_autocast_dtype(device_type):
try:
return torch.get_autocast_dtype(device_type)
except TypeError:
legacy_getter = getattr(torch, f'get_autocast_{device_type}_dtype', None)
if legacy_getter is not None:
return legacy_getter()
return None


class LinearFunctionForZeroStage3(torch.autograd.Function):

# Note that both forward and backward are @staticmethods
@staticmethod
@autocast_custom_fwd
# bias is an optional argument
def forward(ctx, input, weight, bias=None):

ctx.save_for_backward(input, weight, bias)
def forward(input, weight, bias=None):

if input.dim() == 2 and bias is not None:
# fused op is marginally faster
Expand All @@ -77,54 +93,45 @@ def forward(ctx, input, weight, bias=None):

return ret

@staticmethod
def setup_context(ctx, inputs, output):
device_type = get_accelerator().device_name()
ctx._dtype = _get_autocast_dtype(device_type)
ctx._fwd_used_autocast = _is_autocast_enabled(device_type)
input, weight, bias = inputs[0], inputs[1], inputs[2] if len(inputs) > 2 else None
ctx.save_for_backward(input, weight, bias)

# This function has only a single output, so it gets only one gradient
@staticmethod
@autocast_custom_bwd
def backward(ctx, grad_output):
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
input, weight, bias = ctx.saved_tensors

grad_input = grad_weight = grad_bias = None

#print(f"backward shaped grad_output {grad_output.shape}, input {input.shape}, weight {weight.shape} and bias {bias.shape if bias is not None else None}")
# These needs_input_grad checks are optional and there only to
# improve efficiency. If you want to make your code simpler, you can
# skip them. Returning gradients for inputs that don't require it is
# not an error.
dim = grad_output.dim()
if ctx.needs_input_grad[0]:
#print(f"Computing grad input weight {weight.shape} grad_output {grad_output.shape}")
grad_input = grad_output.matmul(weight)
#print(f"Computed grad input {grad_input.shape}")
if ctx.needs_input_grad[1]:
#print("Computing grad weight")
if dim > 2:
grad_weight = grad_output.reshape(-1,
grad_output.shape[-1]).t().matmul(input.reshape(-1, input.shape[-1]))
else:
grad_weight = grad_output.t().matmul(input)
#print(f"Computed grad weight grad_weight {grad_weight.shape}")
if bias is not None and ctx.needs_input_grad[2]:
#print("Computing grad bias")
if dim > 2:
grad_bias = grad_output.sum([i for i in range(dim - 1)])
else:
grad_bias = grad_output.sum(0)
#print("Done computing grad bias")
#print("needs bias")
#print(f"backward shaped grad_input {grad_input.shape}, grad_weight {grad_weight.shape}, grad_bias {grad_bias.shape if grad_bias is not None else None}")
return grad_input, grad_weight, grad_bias
# Match @custom_bwd semantics: always run backward under the same
# autocast state as forward — including explicitly disabling autocast
# when forward did not use it, to guard against outer autocast regions.
device_type = get_accelerator().device_name()
with torch.autocast(device_type=device_type, enabled=ctx._fwd_used_autocast, dtype=ctx._dtype):
input, weight, bias = ctx.saved_tensors

grad_input = grad_weight = grad_bias = None

dim = grad_output.dim()
if ctx.needs_input_grad[0]:
grad_input = grad_output.matmul(weight)
if ctx.needs_input_grad[1]:
if dim > 2:
grad_weight = grad_output.reshape(-1, grad_output.shape[-1]).t().matmul(
input.reshape(-1, input.shape[-1]))
else:
grad_weight = grad_output.t().matmul(input)
if bias is not None and ctx.needs_input_grad[2]:
if dim > 2:
grad_bias = grad_output.sum([i for i in range(dim - 1)])
else:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias


def zero3_linear_wrap(input, weight, bias=None):
if bias is None:
return LinearFunctionForZeroStage3.apply(input, weight)
else:
return LinearFunctionForZeroStage3.apply(input, weight, bias)
return LinearFunctionForZeroStage3.apply(input, weight, bias)


class LinearModuleForZeroStage3(Module):
Expand Down
20 changes: 12 additions & 8 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,15 +404,16 @@ def _run_before_backward_function(sub_module):
class PreBackwardFunctionForModule(torch.autograd.Function):

@staticmethod
def forward(ctx, outputs):
# Capture `module` and _run_before_backward_function
def forward(outputs):
return outputs.detach()

@staticmethod
def setup_context(ctx, inputs, output):
ctx.module = module
ctx.pre_backward_function = _run_before_backward_function
if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"):
ctx.module.applied_pre_backward_ref_cnt = 0
ctx.module.applied_pre_backward_ref_cnt += 1
outputs = outputs.detach()
return outputs

@staticmethod
def backward(ctx, *args):
Expand All @@ -434,9 +435,14 @@ def _run_after_backward_function(sub_module):
class PostBackwardFunctionModule(torch.autograd.Function):

@staticmethod
def forward(ctx, output):
def forward(output):
return output.detach()

@staticmethod
def setup_context(ctx, inputs, output):
(output_in, ) = inputs
ctx.module = module
if output.requires_grad:
if output_in.requires_grad:
#TODO SOME TIMES post backward does not seem to be triggered debug in detail
#Should only cause increase in memory not correctness issue
#if output.grad_fn.__class__.__name__ == 'ViewBackward':
Expand All @@ -447,8 +453,6 @@ def forward(ctx, output):
# print(f"Before Forward: {ctx.module.__class__.__name__}")
module.ds_grads_remaining += 1
ctx.post_backward_function = _run_after_backward_function
output = output.detach()
return output

@staticmethod
def backward(ctx, *args):
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-readthedocs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ py-cpuinfo
pydantic>=2.0.0
recommonmark
sphinx_rtd_theme
torch
torch>=2.0.0
tqdm
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ packaging>=20.0
psutil
py-cpuinfo
pydantic>=2.0.0
torch
torch>=2.0.0
tqdm
203 changes: 203 additions & 0 deletions tests/unit/v1/zero/test_zero_functorch_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
"""Regression: ZeRO-3 linear autograd.Function must work with torch.func transforms.

ZeRO Stage 3 uses ``LinearFunctionForZeroStage3`` (via ``zero3_linear_wrap``) as
the memory-efficient linear path. After ``deepspeed.initialize``, global
``torch.nn.functional.linear`` is often the built-in again, so tests call
``zero3_linear_wrap`` directly-the same ``autograd.Function`` as when the patch
is active. Legacy ``forward(ctx, ...)`` + ``ctx.save_for_backward`` in forward
raises on strict functorch builds::

RuntimeError: In order to use an autograd.Function with functorch
transforms ... it must override the setup_context staticmethod.
"""

import pytest
import torch
import torch.nn as nn

import deepspeed
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.zero.linear import zero3_linear_wrap

from unit.common import DistributedTest


def _zero3_functorch_config():
config = {
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"steps_per_print": 2147483647,
"zero_optimization": {
"stage": 3,
"stage3_param_persistence_threshold": 0,
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-3
},
},
}
acc = get_accelerator()
if acc.is_bf16_supported():
config["bf16"] = {"enabled": True}
elif acc.is_fp16_supported():
config["fp16"] = {"enabled": True, "initial_scale_power": 8}
return config


class TestZeroFunctorchLinearRegression(DistributedTest):
"""``torch.func.grad_and_value`` over ``zero3_linear_wrap`` / LinearFunctionForZeroStage3."""

world_size = 1

def test_grad_and_value_over_patched_functional_linear(self):
if not hasattr(torch, "func"):
pytest.skip("torch.func not available")

model = nn.Linear(8, 8, bias=True)
engine, _, _, _ = deepspeed.initialize(
model=model,
config=_zero3_functorch_config(),
model_parameters=model.parameters(),
)

device = engine.device
dtype = engine.module.weight.dtype
weight = torch.randn(8, 8, device=device, dtype=dtype, requires_grad=True)
inp = torch.randn(2, 8, device=device, dtype=dtype, requires_grad=True)

with torch.enable_grad():
probe = zero3_linear_wrap(inp, weight, None)
assert "LinearFunctionForZeroStage3" in type(probe.grad_fn).__name__

def loss_fn(w, x):
return zero3_linear_wrap(x, w, None).sum()

grads, value = torch.func.grad_and_value(loss_fn, argnums=(0, 1))(weight, inp)
assert torch.isfinite(value)
assert grads[0] is not None and torch.isfinite(grads[0]).all()
assert grads[1] is not None and torch.isfinite(grads[1]).all()


class TestZeroLinearAutocast(DistributedTest):
"""Verify autocast state is correctly propagated through forward and backward."""

world_size = 1

def _run_forward_backward(self, device, use_autocast, dtype=None):
"""Run zero3_linear_wrap forward+backward, optionally inside autocast."""
weight = torch.randn(4, 4, device=device, dtype=torch.float32, requires_grad=True)
inp = torch.randn(2, 4, device=device, dtype=torch.float32, requires_grad=True)
bias = torch.randn(4, device=device, dtype=torch.float32, requires_grad=True)

if use_autocast:
with torch.amp.autocast(device_type=device.type, dtype=dtype):
out = zero3_linear_wrap(inp, weight, bias)
else:
out = zero3_linear_wrap(inp, weight, bias)

loss = out.sum()
loss.backward()
return out, weight.grad, inp.grad, bias.grad

def test_backward_without_autocast(self):
"""Backward without autocast should produce float32 gradients."""
model = nn.Linear(4, 4)
engine, _, _, _ = deepspeed.initialize(
model=model,
config=_zero3_functorch_config(),
model_parameters=model.parameters(),
)
device = engine.device

out, w_grad, i_grad, b_grad = self._run_forward_backward(device, use_autocast=False)
assert out.dtype == torch.float32
assert w_grad.dtype == torch.float32
assert i_grad.dtype == torch.float32
assert b_grad.dtype == torch.float32

def test_backward_with_autocast(self):
"""Backward with autocast should produce float32 gradients (autocast only affects forward)."""
acc = get_accelerator()
if acc.is_bf16_supported():
amp_dtype = torch.bfloat16
elif acc.is_fp16_supported():
amp_dtype = torch.float16
else:
pytest.skip("No half-precision support")

model = nn.Linear(4, 4)
engine, _, _, _ = deepspeed.initialize(
model=model,
config=_zero3_functorch_config(),
model_parameters=model.parameters(),
)
device = engine.device

out, w_grad, i_grad, b_grad = self._run_forward_backward(device, use_autocast=True, dtype=amp_dtype)
# Forward output should be in reduced precision
assert out.dtype == amp_dtype
# Gradients accumulate in float32 (master weights)
assert w_grad.dtype == torch.float32
assert i_grad.dtype == torch.float32
assert b_grad.dtype == torch.float32

def test_no_autocast_leak_into_backward(self):
"""When forward runs without autocast, an outer autocast during backward must not affect gradient dtype."""
model = nn.Linear(4, 4)
engine, _, _, _ = deepspeed.initialize(
model=model,
config=_zero3_functorch_config(),
model_parameters=model.parameters(),
)
device = engine.device

acc = get_accelerator()
if acc.is_bf16_supported():
amp_dtype = torch.bfloat16
elif acc.is_fp16_supported():
amp_dtype = torch.float16
else:
pytest.skip("No half-precision support")

weight = torch.randn(4, 4, device=device, dtype=torch.float32, requires_grad=True)
inp = torch.randn(2, 4, device=device, dtype=torch.float32, requires_grad=True)

# Forward WITHOUT autocast
out = zero3_linear_wrap(inp, weight, None)
assert out.dtype == torch.float32

# Backward WITH an outer autocast region -- should NOT affect gradient computation
# because setup_context captured _fwd_used_autocast=False
with torch.amp.autocast(device_type=device.type, dtype=amp_dtype):
out.sum().backward()

assert weight.grad.dtype == torch.float32
assert inp.grad.dtype == torch.float32

def test_setup_context_stores_autocast_attrs(self):
"""setup_context must store _fwd_used_autocast and _dtype on ctx."""
model = nn.Linear(4, 4)
engine, _, _, _ = deepspeed.initialize(
model=model,
config=_zero3_functorch_config(),
model_parameters=model.parameters(),
)
device = engine.device

weight = torch.randn(4, 4, device=device, dtype=torch.float32, requires_grad=True)
inp = torch.randn(2, 4, device=device, dtype=torch.float32, requires_grad=True)

# Without autocast: setup_context must record that forward did not use autocast
out = zero3_linear_wrap(inp, weight, None)
grad_fn = out.grad_fn
assert hasattr(grad_fn, "_fwd_used_autocast")
assert grad_fn._fwd_used_autocast is False
assert hasattr(grad_fn, "_dtype")
out.sum().backward()
assert torch.isfinite(weight.grad).all()
Loading