Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
33db7c4
fix: fix LinearFunctionForZeroStage3 to support torch.func transforms
roycho96 Mar 21, 2026
39b1755
fix: always pass bias arg in zero3_linear_wrap to avoid setup_context…
roycho96 Mar 21, 2026
6df37af
fix: remove @autocast_custom_fwd from forward, move autocast state to…
roycho96 Mar 22, 2026
c0b9694
fix(zero3): replace custom_bwd with explicit autocast for functorch-s…
zhangj1an Mar 22, 2026
5e83d05
fix(zero): use setup_context for offload pre/post backward Functions
zhangj1an Mar 22, 2026
7483701
Merge branch 'master' into fix/support-func-torch
zhangj1an Mar 24, 2026
a1e798d
run pre-commit checks
zhangj1an Mar 25, 2026
8762d00
update unit tests to reproduce main branch error
zhangj1an Mar 25, 2026
dd037da
add reproduce scripts
zhangj1an Mar 25, 2026
01ee5a6
Merge branch 'master' into fix/support-func-torch
zhangj1an Mar 25, 2026
f69c1f1
update reproduce script
zhangj1an Mar 25, 2026
e58ac18
update reproduce script to skip repeated env setup
zhangj1an Mar 25, 2026
3121a7f
update reproduce script to remove duplicated code
zhangj1an Mar 25, 2026
60d20da
update reproduce script to print test env
zhangj1an Mar 25, 2026
bb245b2
drop PyTorch < 2.0 support and fix autocast backward in ZeRO linear
roycho96 Mar 29, 2026
04c456f
change PyTorch version in README
roycho96 Mar 29, 2026
703aad3
resolve conflict with master
tohtana Mar 29, 2026
8468149
Merge pull request #1 from tohtana/tohtana/pr7916-merge-master-resolve
zhangj1an Mar 30, 2026
e309a6f
remove repro scripts
zhangj1an Mar 30, 2026
e425569
update unit test
zhangj1an Mar 30, 2026
39f7e3c
drop support for pytorch<2.0
zhangj1an Mar 30, 2026
39c9a73
Merge branch 'master' into fix/support-func-torch
zhangj1an Mar 31, 2026
a5aa09a
Merge branch 'master' into fix/support-func-torch
roycho96 Apr 4, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ dynamically link them at runtime.

## Requirements
* [PyTorch](https://pytorch.org/) must be installed _before_ installing DeepSpeed.
* For full feature support we recommend a version of PyTorch that is >= 1.9 and ideally the latest PyTorch stable release.
* For full feature support we recommend a version of PyTorch that is >= 2.0 and ideally the latest PyTorch stable release.
* A CUDA or ROCm compiler such as [nvcc](https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/#introduction) or [hipcc](https://github.com/ROCm-Developer-Tools/HIPCC) used to compile C++/CUDA/HIP extensions.
* Specific GPUs we develop and test against are listed below, this doesn't mean your GPU will not work if it doesn't fall into this category it's just DeepSpeed is most well tested on the following:
* NVIDIA: Pascal, Volta, Ampere, and Hopper architectures
Expand Down
86 changes: 34 additions & 52 deletions deepspeed/runtime/zero/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#when implemented outside of torch.autograd.Function

import math
import functools

import torch
from torch import Tensor
Expand All @@ -32,19 +31,11 @@ def print_rank_0(message, debug=False, force=False):
print(message)


autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=get_accelerator().device_name())
autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=get_accelerator().device_name())


class LinearFunctionForZeroStage3(torch.autograd.Function):

# Note that both forward and backward are @staticmethods
@staticmethod
@autocast_custom_fwd
# bias is an optional argument
def forward(ctx, input, weight, bias=None):

ctx.save_for_backward(input, weight, bias)
def forward(input, weight, bias=None):

if input.dim() == 2 and bias is not None:
# fused op is marginally faster
Expand All @@ -57,54 +48,45 @@ def forward(ctx, input, weight, bias=None):

return ret

@staticmethod
def setup_context(ctx, inputs, output):
device_type = get_accelerator().device_name()
ctx._dtype = torch.get_autocast_dtype(device_type)
ctx._fwd_used_autocast = torch.is_autocast_enabled(device_type)
input, weight, bias = inputs[0], inputs[1], inputs[2] if len(inputs) > 2 else None
ctx.save_for_backward(input, weight, bias)

# This function has only a single output, so it gets only one gradient
@staticmethod
@autocast_custom_bwd
def backward(ctx, grad_output):
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
input, weight, bias = ctx.saved_tensors

grad_input = grad_weight = grad_bias = None

#print(f"backward shaped grad_output {grad_output.shape}, input {input.shape}, weight {weight.shape} and bias {bias.shape if bias is not None else None}")
# These needs_input_grad checks are optional and there only to
# improve efficiency. If you want to make your code simpler, you can
# skip them. Returning gradients for inputs that don't require it is
# not an error.
dim = grad_output.dim()
if ctx.needs_input_grad[0]:
#print(f"Computing grad input weight {weight.shape} grad_output {grad_output.shape}")
grad_input = grad_output.matmul(weight)
#print(f"Computed grad input {grad_input.shape}")
if ctx.needs_input_grad[1]:
#print("Computing grad weight")
if dim > 2:
grad_weight = grad_output.reshape(-1,
grad_output.shape[-1]).t().matmul(input.reshape(-1, input.shape[-1]))
else:
grad_weight = grad_output.t().matmul(input)
#print(f"Computed grad weight grad_weight {grad_weight.shape}")
if bias is not None and ctx.needs_input_grad[2]:
#print("Computing grad bias")
if dim > 2:
grad_bias = grad_output.sum([i for i in range(dim - 1)])
else:
grad_bias = grad_output.sum(0)
#print("Done computing grad bias")
#print("needs bias")
#print(f"backward shaped grad_input {grad_input.shape}, grad_weight {grad_weight.shape}, grad_bias {grad_bias.shape if grad_bias is not None else None}")
return grad_input, grad_weight, grad_bias
# Match @custom_bwd semantics: always run backward under the same
# autocast state as forward — including explicitly disabling autocast
# when forward did not use it, to guard against outer autocast regions.
device_type = get_accelerator().device_name()
with torch.amp.autocast(device_type=device_type, enabled=ctx._fwd_used_autocast, dtype=ctx._dtype):
input, weight, bias = ctx.saved_tensors

grad_input = grad_weight = grad_bias = None

dim = grad_output.dim()
if ctx.needs_input_grad[0]:
grad_input = grad_output.matmul(weight)
if ctx.needs_input_grad[1]:
if dim > 2:
grad_weight = grad_output.reshape(-1, grad_output.shape[-1]).t().matmul(
input.reshape(-1, input.shape[-1]))
else:
grad_weight = grad_output.t().matmul(input)
if bias is not None and ctx.needs_input_grad[2]:
if dim > 2:
grad_bias = grad_output.sum([i for i in range(dim - 1)])
else:
grad_bias = grad_output.sum(0)
return grad_input, grad_weight, grad_bias


def zero3_linear_wrap(input, weight, bias=None):
if bias is None:
return LinearFunctionForZeroStage3.apply(input, weight)
else:
return LinearFunctionForZeroStage3.apply(input, weight, bias)
return LinearFunctionForZeroStage3.apply(input, weight, bias)


class LinearModuleForZeroStage3(Module):
Expand Down
20 changes: 12 additions & 8 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,15 +404,16 @@ def _run_before_backward_function(sub_module):
class PreBackwardFunctionForModule(torch.autograd.Function):

@staticmethod
def forward(ctx, outputs):
# Capture `module` and _run_before_backward_function
def forward(outputs):
return outputs.detach()

@staticmethod
def setup_context(ctx, inputs, output):
ctx.module = module
ctx.pre_backward_function = _run_before_backward_function
if not hasattr(ctx.module, "applied_pre_backward_ref_cnt"):
ctx.module.applied_pre_backward_ref_cnt = 0
ctx.module.applied_pre_backward_ref_cnt += 1
outputs = outputs.detach()
return outputs

@staticmethod
def backward(ctx, *args):
Expand All @@ -434,9 +435,14 @@ def _run_after_backward_function(sub_module):
class PostBackwardFunctionModule(torch.autograd.Function):

@staticmethod
def forward(ctx, output):
def forward(output):
return output.detach()

@staticmethod
def setup_context(ctx, inputs, output):
(output_in, ) = inputs
ctx.module = module
if output.requires_grad:
if output_in.requires_grad:
#TODO SOME TIMES post backward does not seem to be triggered debug in detail
#Should only cause increase in memory not correctness issue
#if output.grad_fn.__class__.__name__ == 'ViewBackward':
Expand All @@ -447,8 +453,6 @@ def forward(ctx, output):
# print(f"Before Forward: {ctx.module.__class__.__name__}")
module.ds_grads_remaining += 1
ctx.post_backward_function = _run_after_backward_function
output = output.detach()
return output

@staticmethod
def backward(ctx, *args):
Expand Down
100 changes: 100 additions & 0 deletions scripts/repro_pr7916.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#!/usr/bin/env python3
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
#
# Repro: functorch over ZeRO-3 memory-efficient linear (LinearFunctionForZeroStage3).
#
# Legacy autograd.Function.forward(ctx, ...) + ctx.save_for_backward in that class
# triggers (PyTorch builds that enforce functorch custom-Function rules, e.g. 2.8+):
#
# RuntimeError: In order to use an autograd.Function with functorch transforms
# (vmap, grad, jvp, jacrev, ...), it must override the setup_context staticmethod.
#
# Why we call zero3_linear_wrap() instead of torch.nn.functional.linear:
# After deepspeed.initialize(), the global ZeRO Init context has usually ended, so
# torch.nn.functional.linear is often restored to PyTorch's built-in. That means
# F.linear in a post-init script does NOT hit LinearFunctionForZeroStage3. The
# Stage-3 patch uses zero3_linear_wrap (see partition_parameters.py); it is the
# same autograd.Function — calling it here reliably reproduces the bug on unfixed
# trees and validates the fix on fixed trees.
#
# Regression coverage: tests/unit/v1/zero/test_zero_functorch_linear.py
#
# Run from the DeepSpeed repo root (single GPU), after scripts/setup_pr7916.sh (or manually):
# torchrun --standalone --nproc_per_node=1 scripts/repro_pr7916.py
#
# To test an unfixed DeepSpeed tree without importing another checkout by mistake,
# copy this file outside the repo (e.g. /tmp) and set PYTHONPATH to that tree:
# cp scripts/repro_zero3_functorch_linear.py /tmp/ && cd /tmp && \
# PYTHONPATH=/path/to/deepspeed-checkout torchrun --standalone --nproc_per_node=1 repro_zero3_functorch_linear.py
#
# Requires: PyTorch with torch.func and strict custom-Function checks (e.g. 2.8+),
# DeepSpeed ZeRO-3, CUDA (typical setup).

import torch
import torch.nn as nn

import deepspeed
from deepspeed.accelerator import get_accelerator
from deepspeed.runtime.zero.linear import zero3_linear_wrap


def _assert_hits_zero3_linear(weight, inp):
"""Sanity check: we are exercising LinearFunctionForZeroStage3, not built-in linear."""
with torch.enable_grad():
y = zero3_linear_wrap(inp, weight, None)
name = type(y.grad_fn).__name__
assert "LinearFunctionForZeroStage3" in name, (
f"Expected LinearFunctionForZeroStage3 in grad_fn, got {name!r}. "
"Repro would not test the intended autograd.Function.")


def main():
if not hasattr(torch, "func"):
raise SystemExit("This repro requires torch.func (PyTorch 2.0+).")
if not hasattr(torch.autograd.Function, "setup_context"):
raise SystemExit("This repro requires autograd.Function.setup_context (PyTorch 2.0+).")

deepspeed.init_distributed()
acc = get_accelerator()
device = acc.device_name() + ":" + str(acc.current_device())

model = nn.Linear(8, 8, bias=True).to(device)

config = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 2147483647,
"zero_optimization": {
"stage": 3,
"stage3_param_persistence_threshold": 0,
},
"optimizer": {"type": "Adam", "params": {"lr": 1e-3}},
}
if acc.is_bf16_supported():
config["bf16"] = {"enabled": True}
elif acc.is_fp16_supported():
config["fp16"] = {"enabled": True, "initial_scale_power": 8}

_, _, _, _ = deepspeed.initialize(
model=model,
config=config,
model_parameters=model.parameters(),
)

weight = torch.randn(8, 8, device=device, dtype=model.weight.dtype, requires_grad=True)
inp = torch.randn(2, 8, device=device, dtype=model.weight.dtype, requires_grad=True)

if deepspeed.comm.get_rank() == 0:
_assert_hits_zero3_linear(weight, inp)

def loss_fn(w, x):
# Same op as ZeRO-3's F.linear replacement when the patch is active.
return zero3_linear_wrap(x, w, None).sum()

torch.func.grad_and_value(loss_fn, argnums=(0, 1))(weight, inp)
if deepspeed.comm.get_rank() == 0:
print("repro: grad_and_value over zero3_linear_wrap (LinearFunctionForZeroStage3) OK.")


if __name__ == "__main__":
main()
Loading
Loading