-
Notifications
You must be signed in to change notification settings - Fork 4.8k
fix(op_builder): avoid duplicate/wrong -gencode flags #7974
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
e21277e
392d038
b5cdc89
1f32f02
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -563,14 +563,8 @@ def jit_load(self, verbose=True): | |||||||||||
| sources = [os.path.abspath(self.deepspeed_src_path(path)) for path in self.sources()] | ||||||||||||
| extra_include_paths = [os.path.abspath(self.deepspeed_src_path(path)) for path in self.include_paths()] | ||||||||||||
|
|
||||||||||||
| # Torch will try and apply whatever CCs are in the arch list at compile time, | ||||||||||||
| # we have already set the intended targets ourselves we know that will be | ||||||||||||
| # needed at runtime. This prevents CC collisions such as multiple __half | ||||||||||||
| # implementations. Stash arch list to reset after build. | ||||||||||||
| torch_arch_list = None | ||||||||||||
| if "TORCH_CUDA_ARCH_LIST" in os.environ: | ||||||||||||
| torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST") | ||||||||||||
| os.environ["TORCH_CUDA_ARCH_LIST"] = "" | ||||||||||||
| # Stash TORCH_CUDA_ARCH_LIST to restore after build. | ||||||||||||
| torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST") | ||||||||||||
|
|
||||||||||||
| nvcc_args = self.strip_empty_entries(self.nvcc_args()) | ||||||||||||
| cxx_args = self.strip_empty_entries(self.cxx_args()) | ||||||||||||
|
|
@@ -603,9 +597,11 @@ def jit_load(self, verbose=True): | |||||||||||
| if verbose: | ||||||||||||
| print(f"Time to load {self.name} op: {build_duration} seconds") | ||||||||||||
|
|
||||||||||||
| # Reset arch list so we are not silently removing it for other possible use cases | ||||||||||||
| if torch_arch_list: | ||||||||||||
| # Restore TORCH_CUDA_ARCH_LIST to its original state. | ||||||||||||
| if torch_arch_list is not None: | ||||||||||||
| os.environ["TORCH_CUDA_ARCH_LIST"] = torch_arch_list | ||||||||||||
| elif "TORCH_CUDA_ARCH_LIST" in os.environ: | ||||||||||||
| del os.environ["TORCH_CUDA_ARCH_LIST"] | ||||||||||||
|
|
||||||||||||
| __class__._loaded_ops[self.name] = op_module | ||||||||||||
|
|
||||||||||||
|
|
@@ -618,18 +614,22 @@ def compute_capability_args(self, cross_compile_archs=None): | |||||||||||
| """ | ||||||||||||
| Returns nvcc compute capability compile flags. | ||||||||||||
|
|
||||||||||||
| 1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`. | ||||||||||||
| 2. If neither is set default compute capabilities will be used | ||||||||||||
| 3. Under `jit_mode` compute capabilities of all visible cards will be used plus PTX | ||||||||||||
| 1. Under ``jit_mode`` the visible-card architectures are detected, | ||||||||||||
| ``TORCH_CUDA_ARCH_LIST`` is set accordingly, and an **empty list** | ||||||||||||
| is returned so that PyTorch generates the ``-gencode`` flags | ||||||||||||
| itself (avoiding duplicates). See | ||||||||||||
| https://github.com/deepspeedai/DeepSpeed/issues/7972 | ||||||||||||
| 2. ``TORCH_CUDA_ARCH_LIST`` takes priority over ``cross_compile_archs``. | ||||||||||||
| 3. If neither is set default compute capabilities will be used. | ||||||||||||
|
|
||||||||||||
| Format: | ||||||||||||
|
|
||||||||||||
| - `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples: | ||||||||||||
| - ``TORCH_CUDA_ARCH_LIST`` may use ; or whitespace separators. Examples: | ||||||||||||
|
|
||||||||||||
| TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6;9.0;10.0" pip install ... | ||||||||||||
| TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6 9.0 10.0+PTX" pip install ... | ||||||||||||
|
|
||||||||||||
| - `cross_compile_archs` uses ; separator. | ||||||||||||
| - ``cross_compile_archs`` uses ; separator. | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
|
|
||||||||||||
| """ | ||||||||||||
| ccs = [] | ||||||||||||
|
|
@@ -662,17 +662,28 @@ def compute_capability_args(self, cross_compile_archs=None): | |||||||||||
| raise RuntimeError( | ||||||||||||
| f"Unable to load {self.name} op due to no compute capabilities remaining after filtering") | ||||||||||||
|
|
||||||||||||
| args = [] | ||||||||||||
| self.enable_bf16 = True | ||||||||||||
| for cc in ccs: | ||||||||||||
| if int(cc[0]) <= 7: | ||||||||||||
| self.enable_bf16 = False | ||||||||||||
|
Comment on lines
665
to
+668
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about using
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it is better this way. |
||||||||||||
|
|
||||||||||||
| # Keep TORCH_CUDA_ARCH_LIST in sync with the filtered arch list so | ||||||||||||
| # PyTorch does not re-add archs that filter_ccs() removed. | ||||||||||||
| arch_list = ";".join(f"{cc[0]}.{cc[1]}" for cc in ccs) | ||||||||||||
| os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list | ||||||||||||
|
|
||||||||||||
| if self.jit_mode: | ||||||||||||
| # Let PyTorch generate -gencode flags from the env var. | ||||||||||||
| return [] | ||||||||||||
|
|
||||||||||||
| # Non-JIT: return explicit flags per builder for extra_compile_args. | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Won't this cause duplicate, possibly wrong flags again? If I guess for non-JIT mode you'd want to set TORCH_CUDA_ARCH_LIST to a single CC out of the intersection of the CCs allowed by all extensions if that is the case to at least not add the wrong one if you can't avoid the duplication. Or if more high level is possible:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I totally agree with your intersection-based approach, but implementing it would likely require changes to setup.py (splitting the builder loop into two passes), and I'm concerned the scope of that refactor could be too large for this PR. |
||||||||||||
| args = [] | ||||||||||||
| for cc in ccs: | ||||||||||||
| num = cc[0] + cc[1].split('+')[0] | ||||||||||||
| args.append(f'-gencode=arch=compute_{num},code=sm_{num}') | ||||||||||||
| if cc[1].endswith('+PTX'): | ||||||||||||
| args.append(f'-gencode=arch=compute_{num},code=compute_{num}') | ||||||||||||
|
|
||||||||||||
| if int(cc[0]) <= 7: | ||||||||||||
| self.enable_bf16 = False | ||||||||||||
|
|
||||||||||||
| return args | ||||||||||||
|
|
||||||||||||
| def filter_ccs(self, ccs: List[str]): | ||||||||||||
|
|
||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.