Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 30 additions & 19 deletions op_builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,14 +563,8 @@ def jit_load(self, verbose=True):
sources = [os.path.abspath(self.deepspeed_src_path(path)) for path in self.sources()]
extra_include_paths = [os.path.abspath(self.deepspeed_src_path(path)) for path in self.include_paths()]

# Torch will try and apply whatever CCs are in the arch list at compile time,
# we have already set the intended targets ourselves we know that will be
# needed at runtime. This prevents CC collisions such as multiple __half
# implementations. Stash arch list to reset after build.
torch_arch_list = None
if "TORCH_CUDA_ARCH_LIST" in os.environ:
torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST")
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
# Stash TORCH_CUDA_ARCH_LIST to restore after build.
torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST")

nvcc_args = self.strip_empty_entries(self.nvcc_args())
cxx_args = self.strip_empty_entries(self.cxx_args())
Expand Down Expand Up @@ -603,9 +597,11 @@ def jit_load(self, verbose=True):
if verbose:
print(f"Time to load {self.name} op: {build_duration} seconds")

# Reset arch list so we are not silently removing it for other possible use cases
if torch_arch_list:
# Restore TORCH_CUDA_ARCH_LIST to its original state.
if torch_arch_list is not None:
os.environ["TORCH_CUDA_ARCH_LIST"] = torch_arch_list
elif "TORCH_CUDA_ARCH_LIST" in os.environ:
del os.environ["TORCH_CUDA_ARCH_LIST"]

__class__._loaded_ops[self.name] = op_module

Expand All @@ -618,18 +614,22 @@ def compute_capability_args(self, cross_compile_archs=None):
"""
Returns nvcc compute capability compile flags.

1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`.
2. If neither is set default compute capabilities will be used
3. Under `jit_mode` compute capabilities of all visible cards will be used plus PTX
1. Under ``jit_mode`` the visible-card architectures are detected,
``TORCH_CUDA_ARCH_LIST`` is set accordingly, and an **empty list**
is returned so that PyTorch generates the ``-gencode`` flags
itself (avoiding duplicates). See
https://github.com/deepspeedai/DeepSpeed/issues/7972
2. ``TORCH_CUDA_ARCH_LIST`` takes priority over ``cross_compile_archs``.
3. If neither is set default compute capabilities will be used.

Format:

- `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples:
- ``TORCH_CUDA_ARCH_LIST`` may use ; or whitespace separators. Examples:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- ``TORCH_CUDA_ARCH_LIST`` may use ; or whitespace separators. Examples:
- `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples:


TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6;9.0;10.0" pip install ...
TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6 9.0 10.0+PTX" pip install ...

- `cross_compile_archs` uses ; separator.
- ``cross_compile_archs`` uses ; separator.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- ``cross_compile_archs`` uses ; separator.
- `cross_compile_archs` uses ; separator.


"""
ccs = []
Expand Down Expand Up @@ -662,17 +662,28 @@ def compute_capability_args(self, cross_compile_archs=None):
raise RuntimeError(
f"Unable to load {self.name} op due to no compute capabilities remaining after filtering")

args = []
self.enable_bf16 = True
for cc in ccs:
if int(cc[0]) <= 7:
self.enable_bf16 = False
Comment on lines 665 to +668
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using any?

Suggested change
self.enable_bf16 = True
for cc in ccs:
if int(cc[0]) <= 7:
self.enable_bf16 = False
self.enable_bf16 = not any(int(cc[0]) <= 7 for cc in ccs)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is better this way.


# Keep TORCH_CUDA_ARCH_LIST in sync with the filtered arch list so
# PyTorch does not re-add archs that filter_ccs() removed.
arch_list = ";".join(f"{cc[0]}.{cc[1]}" for cc in ccs)
os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list

if self.jit_mode:
# Let PyTorch generate -gencode flags from the env var.
return []

# Non-JIT: return explicit flags per builder for extra_compile_args.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this cause duplicate, possibly wrong flags again? If TORCH_CUDA_ARCH_LIST is used for all extensions and one allows a CC another doesn't you'll still get it for both

I guess for non-JIT mode you'd want to set TORCH_CUDA_ARCH_LIST to a single CC out of the intersection of the CCs allowed by all extensions if that is the case to at least not add the wrong one if you can't avoid the duplication.

Or if more high level is possible:

  • Determine CCs to add (from GPU arch and/or TORCH_CUDA_ARCH_LIST)
  • Filter them per extension
  • if the intersection is equal to the full list set TORCH_CUDA_ARCH_LIST to all of them, else to one of them or error if empty
  • add extension specific flags for each non-filtered arch of that extension that is not in TORCH_CUDA_ARCH_LIST, this might be none

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I totally agree with your intersection-based approach, but implementing it would likely require changes to setup.py (splitting the builder loop into two passes), and I'm concerned the scope of that refactor could be too large for this PR.

args = []
for cc in ccs:
num = cc[0] + cc[1].split('+')[0]
args.append(f'-gencode=arch=compute_{num},code=sm_{num}')
if cc[1].endswith('+PTX'):
args.append(f'-gencode=arch=compute_{num},code=compute_{num}')

if int(cc[0]) <= 7:
self.enable_bf16 = False

return args

def filter_ccs(self, ccs: List[str]):
Expand Down