-
Notifications
You must be signed in to change notification settings - Fork 4.8k
fix: add setup_context for torch.func compatibility #7916
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
roycho96
wants to merge
23
commits into
deepspeedai:master
Choose a base branch
from
roycho96:fix/support-func-torch
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 16 commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
33db7c4
fix: fix LinearFunctionForZeroStage3 to support torch.func transforms
roycho96 39b1755
fix: always pass bias arg in zero3_linear_wrap to avoid setup_context…
roycho96 6df37af
fix: remove @autocast_custom_fwd from forward, move autocast state to…
roycho96 c0b9694
fix(zero3): replace custom_bwd with explicit autocast for functorch-s…
zhangj1an 5e83d05
fix(zero): use setup_context for offload pre/post backward Functions
zhangj1an 7483701
Merge branch 'master' into fix/support-func-torch
zhangj1an a1e798d
run pre-commit checks
zhangj1an 8762d00
update unit tests to reproduce main branch error
zhangj1an dd037da
add reproduce scripts
zhangj1an 01ee5a6
Merge branch 'master' into fix/support-func-torch
zhangj1an f69c1f1
update reproduce script
zhangj1an e58ac18
update reproduce script to skip repeated env setup
zhangj1an 3121a7f
update reproduce script to remove duplicated code
zhangj1an 60d20da
update reproduce script to print test env
zhangj1an bb245b2
drop PyTorch < 2.0 support and fix autocast backward in ZeRO linear
roycho96 04c456f
change PyTorch version in README
roycho96 703aad3
resolve conflict with master
tohtana 8468149
Merge pull request #1 from tohtana/tohtana/pr7916-merge-master-resolve
zhangj1an e309a6f
remove repro scripts
zhangj1an e425569
update unit test
zhangj1an 39f7e3c
drop support for pytorch<2.0
zhangj1an 39c9a73
Merge branch 'master' into fix/support-func-torch
zhangj1an a5aa09a
Merge branch 'master' into fix/support-func-torch
roycho96 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| #!/usr/bin/env python3 | ||
tohtana marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # Copyright (c) Microsoft Corporation. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Repro: functorch over ZeRO-3 memory-efficient linear (LinearFunctionForZeroStage3). | ||
| # | ||
| # Legacy autograd.Function.forward(ctx, ...) + ctx.save_for_backward in that class | ||
| # triggers (PyTorch builds that enforce functorch custom-Function rules, e.g. 2.8+): | ||
| # | ||
| # RuntimeError: In order to use an autograd.Function with functorch transforms | ||
| # (vmap, grad, jvp, jacrev, ...), it must override the setup_context staticmethod. | ||
| # | ||
| # Why we call zero3_linear_wrap() instead of torch.nn.functional.linear: | ||
| # After deepspeed.initialize(), the global ZeRO Init context has usually ended, so | ||
| # torch.nn.functional.linear is often restored to PyTorch's built-in. That means | ||
| # F.linear in a post-init script does NOT hit LinearFunctionForZeroStage3. The | ||
| # Stage-3 patch uses zero3_linear_wrap (see partition_parameters.py); it is the | ||
| # same autograd.Function — calling it here reliably reproduces the bug on unfixed | ||
| # trees and validates the fix on fixed trees. | ||
| # | ||
| # Regression coverage: tests/unit/v1/zero/test_zero_functorch_linear.py | ||
| # | ||
| # Run from the DeepSpeed repo root (single GPU), after scripts/setup_pr7916.sh (or manually): | ||
| # torchrun --standalone --nproc_per_node=1 scripts/repro_pr7916.py | ||
| # | ||
| # To test an unfixed DeepSpeed tree without importing another checkout by mistake, | ||
| # copy this file outside the repo (e.g. /tmp) and set PYTHONPATH to that tree: | ||
| # cp scripts/repro_zero3_functorch_linear.py /tmp/ && cd /tmp && \ | ||
| # PYTHONPATH=/path/to/deepspeed-checkout torchrun --standalone --nproc_per_node=1 repro_zero3_functorch_linear.py | ||
| # | ||
| # Requires: PyTorch with torch.func and strict custom-Function checks (e.g. 2.8+), | ||
| # DeepSpeed ZeRO-3, CUDA (typical setup). | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| import deepspeed | ||
| from deepspeed.accelerator import get_accelerator | ||
| from deepspeed.runtime.zero.linear import zero3_linear_wrap | ||
|
|
||
|
|
||
| def _assert_hits_zero3_linear(weight, inp): | ||
| """Sanity check: we are exercising LinearFunctionForZeroStage3, not built-in linear.""" | ||
| with torch.enable_grad(): | ||
| y = zero3_linear_wrap(inp, weight, None) | ||
| name = type(y.grad_fn).__name__ | ||
| assert "LinearFunctionForZeroStage3" in name, ( | ||
| f"Expected LinearFunctionForZeroStage3 in grad_fn, got {name!r}. " | ||
| "Repro would not test the intended autograd.Function.") | ||
|
|
||
|
|
||
| def main(): | ||
| if not hasattr(torch, "func"): | ||
| raise SystemExit("This repro requires torch.func (PyTorch 2.0+).") | ||
| if not hasattr(torch.autograd.Function, "setup_context"): | ||
| raise SystemExit("This repro requires autograd.Function.setup_context (PyTorch 2.0+).") | ||
|
|
||
| deepspeed.init_distributed() | ||
| acc = get_accelerator() | ||
| device = acc.device_name() + ":" + str(acc.current_device()) | ||
|
|
||
| model = nn.Linear(8, 8, bias=True).to(device) | ||
|
|
||
| config = { | ||
| "train_micro_batch_size_per_gpu": 1, | ||
| "steps_per_print": 2147483647, | ||
| "zero_optimization": { | ||
| "stage": 3, | ||
| "stage3_param_persistence_threshold": 0, | ||
| }, | ||
| "optimizer": {"type": "Adam", "params": {"lr": 1e-3}}, | ||
| } | ||
| if acc.is_bf16_supported(): | ||
| config["bf16"] = {"enabled": True} | ||
| elif acc.is_fp16_supported(): | ||
| config["fp16"] = {"enabled": True, "initial_scale_power": 8} | ||
|
|
||
| _, _, _, _ = deepspeed.initialize( | ||
| model=model, | ||
| config=config, | ||
| model_parameters=model.parameters(), | ||
| ) | ||
|
|
||
| weight = torch.randn(8, 8, device=device, dtype=model.weight.dtype, requires_grad=True) | ||
| inp = torch.randn(2, 8, device=device, dtype=model.weight.dtype, requires_grad=True) | ||
|
|
||
| if deepspeed.comm.get_rank() == 0: | ||
| _assert_hits_zero3_linear(weight, inp) | ||
|
|
||
| def loss_fn(w, x): | ||
| # Same op as ZeRO-3's F.linear replacement when the patch is active. | ||
| return zero3_linear_wrap(x, w, None).sum() | ||
|
|
||
| torch.func.grad_and_value(loss_fn, argnums=(0, 1))(weight, inp) | ||
| if deepspeed.comm.get_rank() == 0: | ||
| print("repro: grad_and_value over zero3_linear_wrap (LinearFunctionForZeroStage3) OK.") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.