diff --git a/op_builder/builder.py b/op_builder/builder.py index 308f1822a58f..b2e42d4bd3a8 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -563,14 +563,8 @@ def jit_load(self, verbose=True): sources = [os.path.abspath(self.deepspeed_src_path(path)) for path in self.sources()] extra_include_paths = [os.path.abspath(self.deepspeed_src_path(path)) for path in self.include_paths()] - # Torch will try and apply whatever CCs are in the arch list at compile time, - # we have already set the intended targets ourselves we know that will be - # needed at runtime. This prevents CC collisions such as multiple __half - # implementations. Stash arch list to reset after build. - torch_arch_list = None - if "TORCH_CUDA_ARCH_LIST" in os.environ: - torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST") - os.environ["TORCH_CUDA_ARCH_LIST"] = "" + # Stash TORCH_CUDA_ARCH_LIST to restore after build. + torch_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST") nvcc_args = self.strip_empty_entries(self.nvcc_args()) cxx_args = self.strip_empty_entries(self.cxx_args()) @@ -603,9 +597,11 @@ def jit_load(self, verbose=True): if verbose: print(f"Time to load {self.name} op: {build_duration} seconds") - # Reset arch list so we are not silently removing it for other possible use cases - if torch_arch_list: + # Restore TORCH_CUDA_ARCH_LIST to its original state. + if torch_arch_list is not None: os.environ["TORCH_CUDA_ARCH_LIST"] = torch_arch_list + elif "TORCH_CUDA_ARCH_LIST" in os.environ: + del os.environ["TORCH_CUDA_ARCH_LIST"] __class__._loaded_ops[self.name] = op_module @@ -618,18 +614,22 @@ def compute_capability_args(self, cross_compile_archs=None): """ Returns nvcc compute capability compile flags. - 1. `TORCH_CUDA_ARCH_LIST` takes priority over `cross_compile_archs`. - 2. If neither is set default compute capabilities will be used - 3. Under `jit_mode` compute capabilities of all visible cards will be used plus PTX + 1. Under ``jit_mode`` the visible-card architectures are detected, + ``TORCH_CUDA_ARCH_LIST`` is set accordingly, and an **empty list** + is returned so that PyTorch generates the ``-gencode`` flags + itself (avoiding duplicates). See + https://github.com/deepspeedai/DeepSpeed/issues/7972 + 2. ``TORCH_CUDA_ARCH_LIST`` takes priority over ``cross_compile_archs``. + 3. If neither is set default compute capabilities will be used. Format: - - `TORCH_CUDA_ARCH_LIST` may use ; or whitespace separators. Examples: + - ``TORCH_CUDA_ARCH_LIST`` may use ; or whitespace separators. Examples: TORCH_CUDA_ARCH_LIST="6.1;7.5;8.6;9.0;10.0" pip install ... TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0 7.5 8.0 8.6 9.0 10.0+PTX" pip install ... - - `cross_compile_archs` uses ; separator. + - ``cross_compile_archs`` uses ; separator. """ ccs = [] @@ -662,17 +662,28 @@ def compute_capability_args(self, cross_compile_archs=None): raise RuntimeError( f"Unable to load {self.name} op due to no compute capabilities remaining after filtering") - args = [] self.enable_bf16 = True + for cc in ccs: + if int(cc[0]) <= 7: + self.enable_bf16 = False + + # Keep TORCH_CUDA_ARCH_LIST in sync with the filtered arch list so + # PyTorch does not re-add archs that filter_ccs() removed. + arch_list = ";".join(f"{cc[0]}.{cc[1]}" for cc in ccs) + os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list + + if self.jit_mode: + # Let PyTorch generate -gencode flags from the env var. + return [] + + # Non-JIT: return explicit flags per builder for extra_compile_args. + args = [] for cc in ccs: num = cc[0] + cc[1].split('+')[0] args.append(f'-gencode=arch=compute_{num},code=sm_{num}') if cc[1].endswith('+PTX'): args.append(f'-gencode=arch=compute_{num},code=compute_{num}') - if int(cc[0]) <= 7: - self.enable_bf16 = False - return args def filter_ccs(self, ccs: List[str]):