diff --git a/CMakeLists.txt b/CMakeLists.txt index 49154e0b0..3f9bf8e07 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -206,6 +206,7 @@ if(MSCCLPP_USE_CUDA) else() set(GPU_LIBRARIES CUDA::cudart CUDA::cuda_driver) endif() + list(APPEND GPU_LIBRARIES CUDA::nvml) else() set(CMAKE_HIP_STANDARD 17) set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Wall -Wextra") diff --git a/examples/customized-collective-algorithm/customized_allgather.cu b/examples/customized-collective-algorithm/customized_allgather.cu index 02df36854..13802f80d 100644 --- a/examples/customized-collective-algorithm/customized_allgather.cu +++ b/examples/customized-collective-algorithm/customized_allgather.cu @@ -79,7 +79,7 @@ __global__ void __launch_bounds__(1024) struct Context { int rank; - int workSize; + int worldSize; int nRanksPerNode; std::vector registeredMemories; @@ -140,7 +140,7 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder { size_t inputSize, cudaStream_t stream) { auto algoCtx = std::static_pointer_cast(ctx); int rank = algoCtx->rank; - int worldSize = algoCtx->workSize; + int worldSize = algoCtx->worldSize; int nThreadsPerBlock = (worldSize - 1) * WARP_SIZE; allgather<<<1, nThreadsPerBlock, 0, stream>>>(algoCtx->portChannelDeviceHandles.get(), rank, inputSize); @@ -154,16 +154,16 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder { void* output, size_t inputSize, mscclpp::DataType dtype) { auto ctx = std::make_shared(); ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); + ctx->worldSize = comm->bootstrap()->getNranks(); ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); // register memories mscclpp::RegisteredMemory inputBufRegMem = comm->registerMemory((void*)input, inputSize, mscclpp::Transport::CudaIpc); mscclpp::RegisteredMemory outputBufRegMem = - comm->registerMemory(output, inputSize * ctx->workSize, mscclpp::Transport::CudaIpc); + comm->registerMemory(output, inputSize * ctx->worldSize, mscclpp::Transport::CudaIpc); std::vector> remoteRegMemories; - for (int i = 0; i < ctx->workSize; i++) { + for (int i = 0; i < ctx->worldSize; i++) { if (i == ctx->rank) continue; comm->sendMemory(outputBufRegMem, i, 0); remoteRegMemories.push_back(comm->recvMemory(i, 0)); diff --git a/examples/torch-integration/customized_allgather.cu b/examples/torch-integration/customized_allgather.cu index 907b3adab..5ba2935fc 100644 --- a/examples/torch-integration/customized_allgather.cu +++ b/examples/torch-integration/customized_allgather.cu @@ -47,7 +47,7 @@ __global__ void __launch_bounds__(1024) struct Context { int rank; - int workSize; + int worldSize; int nRanksPerNode; std::vector registeredMemories; @@ -108,7 +108,7 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder { cudaStream_t stream) { auto algoCtx = std::static_pointer_cast(ctx); int rank = algoCtx->rank; - int worldSize = algoCtx->workSize; + int worldSize = algoCtx->worldSize; int nThreadsPerBlock = (worldSize - 1) * WARP_SIZE; allgather<<<1, nThreadsPerBlock, 0, stream>>>(algoCtx->portChannelDeviceHandles.get(), rank, inputBytes); @@ -122,16 +122,16 @@ class AllgatherAlgoBuilder : public mscclpp::AlgorithmBuilder { void* output, size_t inputBytes, mscclpp::DataType dtype) { auto ctx = std::make_shared(); ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); + ctx->worldSize = comm->bootstrap()->getNranks(); ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); // register memories mscclpp::RegisteredMemory inputBufRegMem = comm->registerMemory((void*)input, inputBytes, mscclpp::Transport::CudaIpc); mscclpp::RegisteredMemory outputBufRegMem = - comm->registerMemory(output, inputBytes * ctx->workSize, mscclpp::Transport::CudaIpc); + comm->registerMemory(output, inputBytes * ctx->worldSize, mscclpp::Transport::CudaIpc); std::vector> remoteRegMemories; - for (int i = 0; i < ctx->workSize; i++) { + for (int i = 0; i < ctx->worldSize; i++) { if (i == ctx->rank) continue; comm->sendMemory(outputBufRegMem, i, 0); remoteRegMemories.push_back(comm->recvMemory(i, 0)); diff --git a/examples/torch-integration/customized_comm_with_tuning.py b/examples/torch-integration/customized_comm_with_tuning.py index b96087c2e..cf475cdfc 100644 --- a/examples/torch-integration/customized_comm_with_tuning.py +++ b/examples/torch-integration/customized_comm_with_tuning.py @@ -1,12 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# torchrun --nnodes=1 --nproc_per_node=8 examples/torch-integration/customized_comm_with_tuning.py +# mpirun -np 8 python3 examples/torch-integration/customized_comm_with_tuning.py +# mpirun -np 16 --hostfile python3 examples/torch-integration/customized_comm_with_tuning.py import os -import ipaddress -import netifaces as ni +from mpi4py import MPI + import torch import mscclpp import mscclpp.ext @@ -35,17 +36,6 @@ def _load_algorithms(scratch: torch.Tensor, rank: int): ) -def _interfaces_for_ip(ip: str): - target = ipaddress.ip_address(ip) - for iface in ni.interfaces(): - addrs = ni.ifaddresses(iface) - if ni.AF_INET in addrs: - for link in addrs[ni.AF_INET]: - if "addr" in link and ipaddress.ip_address(link["addr"]) == target: - return iface - return None - - def _to_mscclpp_op(op) -> mscclpp.ReduceOp: if op == torch.distributed.ReduceOp.SUM: return mscclpp.ReduceOp.SUM @@ -68,8 +58,8 @@ class CustomizedComm: """Exposes all_reduce, all_gather, barrier with lazy per-size tuning.""" _TUNE_N_WARMUP = 5 - _TUNE_N_GRAPH_LAUNCHES = 10 - _TUNE_N_OPS_PER_GRAPH = 100 + _TUNE_N_GRAPH_LAUNCHES = 5 + _TUNE_N_OPS_PER_GRAPH = 20 _CANDIDATE_NBLOCKS = [4, 8, 16, 24, 32, 48, 56, 64, 128] _CANDIDATE_NTHREADS = [512, 768, 1024] _NBLOCKS_LIMIT = { @@ -79,11 +69,33 @@ class CustomizedComm: "default_allreduce_fullmesh": 64, "default_allgather_fullmesh2": 32, } + # (algo_name, min_size, max_size, predicate) + # Boundaries are inclusive on both ends. max_size=None means unbounded. + # predicate=None means always applicable; otherwise a callable taking `self`. + _AR_CANDIDATES_MNNVL = [ + ("default_allreduce_allpair_packet", 0, 128 << 10, None), + ("default_allreduce_nvls_packet", 0, 64 << 10, lambda c: c._nvls), + ("default_allreduce_packet", 128 << 10, 512 << 10, None), + ("default_allreduce_nvls_zero_copy", 512 << 10, None, lambda c: c._nvls and c.symmetric_memory), + ("default_allreduce_rsag_zero_copy", 512 << 10, None, None), + ("default_allreduce_rsag", 512 << 10, None, None), + ] + _AR_CANDIDATES_SINGLE = [ + ("default_allreduce_packet", 0, 4 << 20, None), + ("default_allreduce_allpair_packet", 0, 512 << 10, None), + ("default_allreduce_nvls_packet", 0, 512 << 10, lambda c: c._nvls), + ("default_allreduce_rsag_zero_copy", 512 << 10, None, None), + ("default_allreduce_nvls_zero_copy", 512 << 10, None, lambda c: c._nvls and c.symmetric_memory), + ("default_allreduce_fullmesh", 0, None, lambda c: torch.version.hip is not None), + ] def __init__(self, comm: mscclpp.CommGroup, symmetric_memory: bool = False): self.comm = comm self.rank = comm.my_rank self.world_size = comm.nranks + self.nranks_per_node = comm.nranks_per_node + self.ipc_domain_n_ranks = comm.ipc_domain_n_ranks + self.multi_host_mnnvl = self.ipc_domain_n_ranks >= self.world_size and self.world_size > self.nranks_per_node self.symmetric_memory = symmetric_memory self._nvls = mscclpp.is_nvls_supported() @@ -99,13 +111,12 @@ def __init__(self, comm: mscclpp.CommGroup, symmetric_memory: bool = False): self._time_buf = None def _algo(self, collective: str, name: str): - return self._algos.get((collective, name)) + return self._algos[(collective, name)] def _default_ar_config(self): """Fallback allreduce config for barrier / timing sync.""" - pkt = self._algo("allreduce", "default_allreduce_nvls_packet") - if self._nvls and pkt: - return (pkt, 0, 0) + if self._nvls: + return (self._algo("allreduce", "default_allreduce_nvls_packet"), 0, 0) return (self._algo("allreduce", "default_allreduce_packet"), 0, 0) # -- low-level execute -- @@ -165,33 +176,17 @@ def _ensure_tune_bufs(self): return self._tune_buf def _ar_candidates(self, size: int): - out = [] - if size <= 4 << 20: - a = self._algo("allreduce", "default_allreduce_nvls_packet") - if self._nvls and a: - out.append(a) - a = self._algo("allreduce", "default_allreduce_packet") - if a: - out.append(a) - a = self._algo("allreduce", "default_allreduce_allpair_packet") - if a: - out.append(a) - if size >= 512 << 10: - a = self._algo("allreduce", "default_allreduce_nvls_zero_copy") - if self._nvls and self.symmetric_memory and a: - out.append(a) - a = self._algo("allreduce", "default_allreduce_rsag_zero_copy") - if a: - out.append(a) - if torch.version.hip is not None: - a = self._algo("allreduce", "default_allreduce_fullmesh") - if a: - out.append(a) - return out + table = self._AR_CANDIDATES_MNNVL if self.multi_host_mnnvl else self._AR_CANDIDATES_SINGLE + return [ + self._algo("allreduce", name) + for name, lo, hi, pred in table + if size >= lo and (hi is None or size <= hi) and (pred is None or pred(self)) + ] def _ag_candidates(self): - a = self._algo("allgather", "default_allgather_fullmesh2") - return [a] if a else [] + if self.multi_host_mnnvl: + return [] + return [self._algo("allgather", "default_allgather_fullmesh2")] def _run_tune(self, collective, algo, buf, size, nb, nt): """Single tune invocation for either collective.""" @@ -207,7 +202,7 @@ def _run_tune(self, collective, algo, buf, size, nb, nt): stream=torch.cuda.current_stream().cuda_stream, nblocks=nb, nthreads_per_block=nt, - symmetric_memory=True, + symmetric_memory=self.symmetric_memory, ) else: total = size * self.world_size @@ -245,7 +240,7 @@ def _tune_size(self, collective: str, target_size: int): ret = run(algo, nb, nt) torch.cuda.synchronize() self._time_buf[0] = float(ret) - self._exec_ar(self._time_buf[:1], *self._default_ar_config(), sym=True) + self._exec_ar(self._time_buf[:1], *self._default_ar_config(), sym=self.symmetric_memory) if self._time_buf[0].item() != 0: continue used.add(algo) @@ -274,7 +269,7 @@ def _tune_size(self, collective: str, target_size: int): # Cross-rank timing sync self._time_buf.fill_(elapsed) torch.cuda.current_stream().wait_stream(cs) - self._exec_ar(self._time_buf, *self._default_ar_config(), sym=True) + self._exec_ar(self._time_buf, *self._default_ar_config(), sym=self.symmetric_memory) avg = self._time_buf[self.rank].item() / self.world_size if avg < best_time: @@ -314,6 +309,8 @@ def all_reduce(self, tensor, op=torch.distributed.ReduceOp.SUM, stream=None, acc ) def all_gather(self, output_tensor, input_tensor, stream=None): + if self.multi_host_mnnvl: + raise RuntimeError("all_gather in this example currently supports only single-node runs") sz = _round_pow2(input_tensor.nbytes) if sz not in self._tune_cache["allgather"]: self._tune_size("allgather", sz) @@ -341,7 +338,7 @@ def _bench_sizes(low=5 * 1024, high=80 << 20): def benchmark_allreduce( - comm: CustomizedComm, dtype=torch.float16, accum_dtype=None, n_warmup=10, n_graph_launches=10, n_iter=100 + comm: CustomizedComm, dtype=torch.float16, accum_dtype=None, n_warmup=5, n_graph_launches=5, n_iter=50 ): sizes = _bench_sizes() if comm.rank == 0: @@ -382,7 +379,7 @@ def benchmark_allreduce( print(f"{nelems:<18} {size:<18} {ms*1000:<18.2f} {size/(ms*1e-3)/1e9:<18.2f}") -def benchmark_allgather(comm: CustomizedComm, dtype=torch.float16, n_warmup=10, n_graph_launches=10, n_iter=100): +def benchmark_allgather(comm: CustomizedComm, dtype=torch.float16, n_warmup=5, n_graph_launches=5, n_iter=50): sizes = _bench_sizes() if comm.rank == 0: print(f"\n{'='*60}\nAllgather Benchmark\n{'='*60}") @@ -432,22 +429,11 @@ def benchmark_allgather(comm: CustomizedComm, dtype=torch.float16, n_warmup=10, def init_dist() -> mscclpp.CommGroup: - addr = os.environ.get("MSCCLPP_MASTER_ADDR") - if addr: - rank, world = int(os.environ["RANK"]), int(os.environ["WORLD_SIZE"]) - port = os.environ["MSCCLPP_MASTER_PORT"] - iface = _interfaces_for_ip(addr) - if not iface: - raise ValueError(f"No interface for {addr}") - return mscclpp.CommGroup(interfaceIpPortTrio=f"{iface}:{addr}:{port}", rank=rank, size=world) - import torch.distributed as dist - - dist.init_process_group(backend="gloo") - return mscclpp.CommGroup(torch_group=dist.group.WORLD) + return mscclpp.CommGroup(mpi_comm=MPI.COMM_WORLD) def main(): - local = int(os.environ["LOCAL_RANK"]) + local = MPI.COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED).Get_rank() torch.cuda.set_device(local) dtype_str = os.environ.get("DTYPE", "float16") @@ -455,18 +441,26 @@ def main(): accum_map = {"float32": mscclpp.DataType.float32, "float16": mscclpp.DataType.float16} accum_str = os.environ.get("ACCUM_DTYPE") accum_dtype = accum_map.get(accum_str) if accum_str else None + symmetric_memory = os.environ.get("SYMMETRIC_MEMORY", "1") == "1" comm_group = init_dist() - cc = CustomizedComm(comm_group) + cc = CustomizedComm(comm_group, symmetric_memory=symmetric_memory) - print(f"rank {local} starting benchmarks with dtype={dtype} accum_dtype={accum_dtype}...") + print( + f"rank {local} starting benchmarks with dtype={dtype} " + f"accum_dtype={accum_dtype} symmetric_memory={symmetric_memory}..." + ) benchmark_allreduce(cc, dtype=dtype, accum_dtype=accum_dtype) cc.barrier() torch.cuda.synchronize() - benchmark_allgather(cc, dtype=dtype) - cc.barrier() - torch.cuda.synchronize() + if cc.multi_host_mnnvl: + if cc.rank == 0: + print("Skipping allgather benchmark on multi-node: this example's allgather path is single-node only.") + else: + benchmark_allgather(cc, dtype=dtype) + cc.barrier() + torch.cuda.synchronize() cc.destroy() print(f"rank {local} completed successfully.") diff --git a/include/mscclpp/core.hpp b/include/mscclpp/core.hpp index 45b56bcc0..4c14f1eec 100644 --- a/include/mscclpp/core.hpp +++ b/include/mscclpp/core.hpp @@ -46,6 +46,10 @@ class Bootstrap { /// @return The total number of ranks per node. virtual int getNranksPerNode() const = 0; + /// Return the number of ranks in this rank's GPU IPC domain. + /// @return The number of ranks in the GPU IPC domain. + virtual int getNranksPerIpcDomain() const; + /// Send arbitrary data to another process. /// /// Data sent via `send(senderBuff, size, receiverRank, tag)` can be received via `recv(receiverBuff, size, @@ -144,6 +148,9 @@ class TcpBootstrap : public Bootstrap { /// Return the total number of ranks per node. int getNranksPerNode() const override; + /// Return the number of ranks in this rank's GPU IPC domain. + int getNranksPerIpcDomain() const override; + /// Send arbitrary data to another process. /// /// Data sent via `send(senderBuff, size, receiverRank, tag)` can be received via `recv(receiverBuff, size, diff --git a/include/mscclpp/gpu.hpp b/include/mscclpp/gpu.hpp index b8d096e2b..b289bd4d3 100644 --- a/include/mscclpp/gpu.hpp +++ b/include/mscclpp/gpu.hpp @@ -31,6 +31,7 @@ using CUmemorytype = hipMemoryType; constexpr auto cudaErrorPeerAccessAlreadyEnabled = hipErrorPeerAccessAlreadyEnabled; constexpr auto cudaErrorContextIsDestroyed = hipErrorContextIsDestroyed; constexpr auto cudaErrorInvalidDevice = hipErrorInvalidDevice; +constexpr auto cudaErrorInvalidValue = hipErrorInvalidValue; constexpr auto cudaSuccess = hipSuccess; constexpr auto cudaErrorNotSupported = hipErrorNotSupported; constexpr auto cudaStreamNonBlocking = hipStreamNonBlocking; diff --git a/include/mscclpp/gpu_utils.hpp b/include/mscclpp/gpu_utils.hpp index b079e0fd9..ed5f9f63b 100644 --- a/include/mscclpp/gpu_utils.hpp +++ b/include/mscclpp/gpu_utils.hpp @@ -342,7 +342,8 @@ class GpuBuffer { MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId_)); #if (CUDA_NVLS_API_AVAILABLE) if (isNvlsSupported()) { - size_t gran = detail::getMulticastGranularity(nelems * sizeof(T), CU_MULTICAST_GRANULARITY_RECOMMENDED); + // TODO: pass granularity from the caller instead of using the minimum granularity. + size_t gran = detail::getMulticastGranularity(nelems * sizeof(T), CU_MULTICAST_GRANULARITY_MINIMUM); bytes_ = (nelems * sizeof(T) + gran - 1) / gran * gran / sizeof(T) * sizeof(T); memory_ = detail::gpuCallocPhysicalShared(nelems, gran); return; diff --git a/include/mscclpp/switch_channel_device.hpp b/include/mscclpp/switch_channel_device.hpp index b52b65723..fcdd7fddb 100644 --- a/include/mscclpp/switch_channel_device.hpp +++ b/include/mscclpp/switch_channel_device.hpp @@ -37,7 +37,10 @@ struct SwitchChannelDeviceHandle { SwitchChannelDeviceHandle::multimemStore(val, reinterpret_cast(mcPtr) + index); } - template + /// Vectorized multimem load+reduce. The optional `AccumT` template parameter selects the + /// accumulator: when `AccumT == __half` and `VectorType` is an FP8 vector type, the + /// `.acc::f16` variant of the instruction is used. For all other types `AccumT` is ignored. + template MSCCLPP_DEVICE_INLINE static VectorType multimemLoadReduce(VectorType* ptr) { VectorType val; if constexpr (std::is_same_v) { @@ -80,32 +83,78 @@ struct SwitchChannelDeviceHandle { : "=r"(val.words[0]), "=r"(val.words[1]), "=r"(val.words[2]), "=r"(val.words[3]) : "l"(ptr) : "memory"); - } else if constexpr (std::is_same_v) { - asm("multimem.ld_reduce.relaxed.sys.global.add.e4m3x4 %0, [%1];" : "=r"(val.words[0]) : "l"(ptr) : "memory"); + } +#if (defined(__CUDA_ARCH_SPECIFIC__) || defined(__CUDA_ARCH_FAMILY_SPECIFIC__)) && (__CUDA_ARCH__ >= 1000) + else if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { + asm("multimem.ld_reduce.relaxed.sys.global.add.acc::f16.e4m3x4 %0, [%1];" + : "=r"(val.words[0]) + : "l"(ptr) + : "memory"); + } else { + asm("multimem.ld_reduce.relaxed.sys.global.add.e4m3x4 %0, [%1];" : "=r"(val.words[0]) : "l"(ptr) : "memory"); + } } else if constexpr (std::is_same_v) { - asm("multimem.ld_reduce.relaxed.sys.global.add.v2.e4m3x4 {%0,%1}, [%2];" - : "=r"(val.words[0]), "=r"(val.words[1]) - : "l"(ptr) - : "memory"); + if constexpr (std::is_same_v) { + asm("multimem.ld_reduce.relaxed.sys.global.add.acc::f16.v2.e4m3x4 {%0,%1}, [%2];" + : "=r"(val.words[0]), "=r"(val.words[1]) + : "l"(ptr) + : "memory"); + } else { + asm("multimem.ld_reduce.relaxed.sys.global.add.v2.e4m3x4 {%0,%1}, [%2];" + : "=r"(val.words[0]), "=r"(val.words[1]) + : "l"(ptr) + : "memory"); + } } else if constexpr (std::is_same_v) { - asm("multimem.ld_reduce.relaxed.sys.global.add.v4.e4m3x4 {%0,%1,%2,%3}, [%4];" - : "=r"(val.words[0]), "=r"(val.words[1]), "=r"(val.words[2]), "=r"(val.words[3]) - : "l"(ptr) - : "memory"); + if constexpr (std::is_same_v) { + asm("multimem.ld_reduce.relaxed.sys.global.add.acc::f16.v4.e4m3x4 {%0,%1,%2,%3}, [%4];" + : "=r"(val.words[0]), "=r"(val.words[1]), "=r"(val.words[2]), "=r"(val.words[3]) + : "l"(ptr) + : "memory"); + } else { + asm("multimem.ld_reduce.relaxed.sys.global.add.v4.e4m3x4 {%0,%1,%2,%3}, [%4];" + : "=r"(val.words[0]), "=r"(val.words[1]), "=r"(val.words[2]), "=r"(val.words[3]) + : "l"(ptr) + : "memory"); + } } else if constexpr (std::is_same_v) { - asm("multimem.ld_reduce.relaxed.sys.global.add.e5m2x4 %0, [%1];" : "=r"(val.words[0]) : "l"(ptr) : "memory"); + if constexpr (std::is_same_v) { + asm("multimem.ld_reduce.relaxed.sys.global.add.acc::f16.e5m2x4 %0, [%1];" + : "=r"(val.words[0]) + : "l"(ptr) + : "memory"); + } else { + asm("multimem.ld_reduce.relaxed.sys.global.add.e5m2x4 %0, [%1];" : "=r"(val.words[0]) : "l"(ptr) : "memory"); + } } else if constexpr (std::is_same_v) { - asm("multimem.ld_reduce.relaxed.sys.global.add.v2.e5m2x4 {%0,%1}, [%2];" - : "=r"(val.words[0]), "=r"(val.words[1]) - : "l"(ptr) - : "memory"); + if constexpr (std::is_same_v) { + asm("multimem.ld_reduce.relaxed.sys.global.add.acc::f16.v2.e5m2x4 {%0,%1}, [%2];" + : "=r"(val.words[0]), "=r"(val.words[1]) + : "l"(ptr) + : "memory"); + } else { + asm("multimem.ld_reduce.relaxed.sys.global.add.v2.e5m2x4 {%0,%1}, [%2];" + : "=r"(val.words[0]), "=r"(val.words[1]) + : "l"(ptr) + : "memory"); + } } else if constexpr (std::is_same_v) { - asm("multimem.ld_reduce.relaxed.sys.global.add.v4.e5m2x4 {%0,%1,%2,%3}, [%4];" - : "=r"(val.words[0]), "=r"(val.words[1]), "=r"(val.words[2]), "=r"(val.words[3]) - : "l"(ptr) - : "memory"); - } else { - static_assert(dependentFalse, "Not supported type"); + if constexpr (std::is_same_v) { + asm("multimem.ld_reduce.relaxed.sys.global.add.acc::f16.v4.e5m2x4 {%0,%1,%2,%3}, [%4];" + : "=r"(val.words[0]), "=r"(val.words[1]), "=r"(val.words[2]), "=r"(val.words[3]) + : "l"(ptr) + : "memory"); + } else { + asm("multimem.ld_reduce.relaxed.sys.global.add.v4.e5m2x4 {%0,%1,%2,%3}, [%4];" + : "=r"(val.words[0]), "=r"(val.words[1]), "=r"(val.words[2]), "=r"(val.words[3]) + : "l"(ptr) + : "memory"); + } + } +#endif + else { + static_assert(dependentFalse, "Unsupported vector type for multimemLoadReduce"); } return val; }; @@ -148,7 +197,9 @@ struct SwitchChannelDeviceHandle { asm volatile("multimem.st.relaxed.sys.global.v4.bf16x2 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.words[0]), "r"(val.words[1]), "r"(val.words[2]), "r"(val.words[3]) : "memory"); - } else if constexpr (std::is_same_v) { + } +#if (defined(__CUDA_ARCH_SPECIFIC__) || defined(__CUDA_ARCH_FAMILY_SPECIFIC__)) && (__CUDA_ARCH__ >= 1000) + else if constexpr (std::is_same_v) { asm volatile("multimem.st.relaxed.sys.global.e4m3x4 [%0], %1;" ::"l"(ptr), "r"(val.words[0]) : "memory"); } else if constexpr (std::is_same_v) { asm volatile("multimem.st.relaxed.sys.global.v2.e4m3x4 [%0], {%1,%2};" ::"l"(ptr), "r"(val.words[0]), @@ -168,8 +219,10 @@ struct SwitchChannelDeviceHandle { asm volatile("multimem.st.relaxed.sys.global.v4.e5m2x4 [%0], {%1,%2,%3,%4};" ::"l"(ptr), "r"(val.words[0]), "r"(val.words[1]), "r"(val.words[2]), "r"(val.words[3]) : "memory"); - } else { - static_assert(dependentFalse, "Not supported type"); + } +#endif + else { + static_assert(dependentFalse, "Unsupported vector type for multimemStore"); } }; @@ -194,7 +247,7 @@ struct SwitchChannelDeviceHandle { } else if constexpr (std::is_same_v && std::is_same_v) { asm volatile("multimem.red.relaxed.sys.global.add.f16x2 [%0], {%1};" ::"l"(ptr), "r"(val.x) : "memory"); } else { - static_assert(dependentFalse, "Not supported type"); + static_assert(dependentFalse, "Unsupported vector type for multimemStoreReduce"); } }; #endif // defined(MSCCLPP_DEVICE_CUDA) diff --git a/python/csrc/core_py.cpp b/python/csrc/core_py.cpp index a94f9863a..7e9af6c1f 100644 --- a/python/csrc/core_py.cpp +++ b/python/csrc/core_py.cpp @@ -56,6 +56,7 @@ void register_core(nb::module_& m) { .def("get_rank", &Bootstrap::getRank) .def("get_n_ranks", &Bootstrap::getNranks) .def("get_n_ranks_per_node", &Bootstrap::getNranksPerNode) + .def("get_n_ranks_per_ipc_domain", &Bootstrap::getNranksPerIpcDomain) .def( "send", [](Bootstrap* self, uintptr_t ptr, size_t size, int peer, int tag) { diff --git a/python/mscclpp/_core/comm.py b/python/mscclpp/_core/comm.py index d42349ddb..875e07f18 100644 --- a/python/mscclpp/_core/comm.py +++ b/python/mscclpp/_core/comm.py @@ -73,6 +73,7 @@ def __init__( self.my_rank = self.bootstrap.get_rank() self.nranks = self.bootstrap.get_n_ranks() self.nranks_per_node = self.bootstrap.get_n_ranks_per_node() + self.ipc_domain_n_ranks = self.bootstrap.get_n_ranks_per_ipc_domain() def barrier(self): self.bootstrap.barrier() diff --git a/src/core/bootstrap/bootstrap.cc b/src/core/bootstrap/bootstrap.cc index b3032e502..ffdd9c1cc 100644 --- a/src/core/bootstrap/bootstrap.cc +++ b/src/core/bootstrap/bootstrap.cc @@ -50,6 +50,8 @@ MSCCLPP_API_CPP void Bootstrap::groupBarrier(const std::vector& ranks) { } } +MSCCLPP_API_CPP int Bootstrap::getNranksPerIpcDomain() const { return getNranksPerNode(); } + MSCCLPP_API_CPP void Bootstrap::send(const std::vector& data, int peer, int tag) { size_t size = data.size(); send((void*)&size, sizeof(size_t), peer, tag); @@ -83,6 +85,7 @@ class TcpBootstrap::Impl { int getRank(); int getNranks(); int getNranksPerNode(); + int getNranksPerIpcDomain(); void allGather(void* allData, int size); void broadcast(void* data, int size, int root); void send(void* data, int size, int peer, int tag); @@ -95,6 +98,7 @@ class TcpBootstrap::Impl { int rank_; int nRanks_; int nRanksPerNode_; + int nRanksPerIpcDomain_; bool netInitialized; std::unique_ptr listenSockRoot_; std::unique_ptr listenSock_; @@ -148,6 +152,7 @@ TcpBootstrap::Impl::Impl(int rank, int nRanks) : rank_(rank), nRanks_(nRanks), nRanksPerNode_(0), + nRanksPerIpcDomain_(0), netInitialized(false), peerCommAddresses_(nRanks, SocketAddress()), barrierArr_(nRanks, 0), @@ -451,6 +456,24 @@ int TcpBootstrap::Impl::getNranksPerNode() { return nRanksPerNode_; } +int TcpBootstrap::Impl::getNranksPerIpcDomain() { + if (nRanksPerIpcDomain_ > 0) return nRanksPerIpcDomain_; + std::vector ipcDomainHashes(nRanks_); + ipcDomainHashes[rank_] = getIpcDomainHash(); + allGather(ipcDomainHashes.data(), sizeof(uint64_t)); + + int nRanksPerIpcDomain = 0; + for (int i = 0; i < nRanks_; ++i) { + if (ipcDomainHashes[i] == ipcDomainHashes[rank_]) { + ++nRanksPerIpcDomain; + } + } + INFO(MSCCLPP_INIT, "rank %d IPC domain fabric hash 0x%016llx nRanksPerIpcDomain %d", rank_, + static_cast(ipcDomainHashes[rank_]), nRanksPerIpcDomain); + nRanksPerIpcDomain_ = nRanksPerIpcDomain; + return nRanksPerIpcDomain_; +} + void TcpBootstrap::Impl::allGather(void* allData, int size) { char* data = static_cast(allData); int rank = rank_; @@ -592,6 +615,8 @@ MSCCLPP_API_CPP int TcpBootstrap::getNranks() const { return pimpl_->getNranks() MSCCLPP_API_CPP int TcpBootstrap::getNranksPerNode() const { return pimpl_->getNranksPerNode(); } +MSCCLPP_API_CPP int TcpBootstrap::getNranksPerIpcDomain() const { return pimpl_->getNranksPerIpcDomain(); } + MSCCLPP_API_CPP void TcpBootstrap::send(void* data, int size, int peer, int tag) { pimpl_->send(data, size, peer, tag); } diff --git a/src/core/executor/executor.cc b/src/core/executor/executor.cc index fcecc4ddf..15c6af4e6 100644 --- a/src/core/executor/executor.cc +++ b/src/core/executor/executor.cc @@ -389,6 +389,7 @@ struct Executor::Impl { nvlsConnection->bindAllocatedMemory((CUdeviceptr)bufferInfo.first, bufferInfo.second); context.nvlsChannels.push_back(switchChannel); } + this->comm->bootstrap()->barrier(); } void setupSemaphores(ExecutionContext& context, const ExecutionPlan& plan) { diff --git a/src/core/include/execution_kernel.hpp b/src/core/include/execution_kernel.hpp index cb808bc8c..e9095ada6 100644 --- a/src/core/include/execution_kernel.hpp +++ b/src/core/include/execution_kernel.hpp @@ -525,7 +525,15 @@ MSCCLPP_DEVICE_INLINE void handleMultiLoadReduceStore(const Operation& op, uint3 if constexpr (std::is_same_v) { assert(false && "MULTI_LOAD_REDUCE_STORE is not supported for uint8_t data type"); return; - } else { + } +#if defined(__FP8_TYPES_EXIST__) && \ + (!(defined(__CUDA_ARCH_SPECIFIC__) || defined(__CUDA_ARCH_FAMILY_SPECIFIC__)) || (__CUDA_ARCH__ < 1000)) + else if constexpr (std::is_same_v || std::is_same_v) { + assert(false && "FP8 MULTI_LOAD_REDUCE_STORE requires sm_100a or newer"); + return; + } +#endif + else { static_assert(sizeof(T) <= 8, "Only support type with size <= 8 bytes"); const uint32_t size = min(op.inputBufferSizes[0] - offset, unitSize); if (size <= 0) { diff --git a/src/core/include/utils_internal.hpp b/src/core/include/utils_internal.hpp index c5c67e26c..c6934194d 100644 --- a/src/core/include/utils_internal.hpp +++ b/src/core/include/utils_internal.hpp @@ -37,6 +37,7 @@ int64_t busIdToInt64(const std::string busId); uint64_t getHash(const char* string, int n); uint64_t getHostHash(); uint64_t getPidHash(); +uint64_t getIpcDomainHash(); void getRandomData(void* buffer, size_t bytes); struct netIf { diff --git a/src/core/utils_internal.cc b/src/core/utils_internal.cc index 8cc554301..adbf8e5b7 100644 --- a/src/core/utils_internal.cc +++ b/src/core/utils_internal.cc @@ -6,6 +6,10 @@ #include #include +#if defined(MSCCLPP_USE_CUDA) +#include +#endif + #include #include #include @@ -175,6 +179,67 @@ uint64_t getPidHash(void) { return *pidHash; } +#if defined(MSCCLPP_USE_CUDA) && defined(NVML_GPU_FABRIC_UUID_LEN) +namespace { + +class NvmlState { + public: + NvmlState() : initialized_(nvmlInit_v2() == NVML_SUCCESS) {} + + ~NvmlState() { + if (initialized_) { + (void)nvmlShutdown(); + } + } + + bool isInitialized() const { return initialized_; } + + private: + bool initialized_ = false; +}; + +uint64_t getFabricHash(const nvmlGpuFabricInfo_t& fabricInfo) { + char hashData[NVML_GPU_FABRIC_UUID_LEN + sizeof(fabricInfo.cliqueId)]; + std::memcpy(hashData, fabricInfo.clusterUuid, NVML_GPU_FABRIC_UUID_LEN); + std::memcpy(hashData + NVML_GPU_FABRIC_UUID_LEN, &fabricInfo.cliqueId, sizeof(fabricInfo.cliqueId)); + return getHash(hashData, sizeof(hashData)); +} + +bool tryGetNvmlIpcDomainHash(uint64_t& ipcDomainHash) { + // Use the current CUDA device; callers must set the rank's device before querying. + int deviceId; + char pciBusId[] = "00000000:00:00.0"; + if (cudaGetDevice(&deviceId) != cudaSuccess || + cudaDeviceGetPCIBusId(pciBusId, sizeof(pciBusId), deviceId) != cudaSuccess) { + return false; + } + + static NvmlState nvml; + nvmlDevice_t nvmlDevice; + nvmlGpuFabricInfo_t fabricInfo = {}; + if (!nvml.isInitialized() || nvmlDeviceGetHandleByPciBusId_v2(pciBusId, &nvmlDevice) != NVML_SUCCESS || + nvmlDeviceGetGpuFabricInfo(nvmlDevice, &fabricInfo) != NVML_SUCCESS || + fabricInfo.state != NVML_GPU_FABRIC_STATE_COMPLETED || fabricInfo.status != NVML_SUCCESS) { + return false; + } + + ipcDomainHash = getFabricHash(fabricInfo); + return true; +} + +} // namespace +#endif + +uint64_t getIpcDomainHash(void) { +#if defined(MSCCLPP_USE_CUDA) && defined(NVML_GPU_FABRIC_UUID_LEN) + uint64_t ipcDomainHash; + if (tryGetNvmlIpcDomainHash(ipcDomainHash)) { + return ipcDomainHash; + } +#endif + return getHostHash(); +} + int parseStringList(const char* string, netIf* ifList, int maxList) { if (!string) return 0; diff --git a/src/ext/collectives/allgather/allgather_fullmesh.cu b/src/ext/collectives/allgather/allgather_fullmesh.cu index d1b4e7315..49688f473 100644 --- a/src/ext/collectives/allgather/allgather_fullmesh.cu +++ b/src/ext/collectives/allgather/allgather_fullmesh.cu @@ -16,8 +16,8 @@ constexpr int kMaxThreadsPerBlock = 1024; template __global__ void __launch_bounds__(1024, 1) allgatherFullmesh(void* buff, void* scratch, void* resultBuff, DeviceHandle* memoryChannels, - int rank, int nRanksPerNode, [[maybe_unused]] int worldSize, size_t nelems) { - const int nPeer = nRanksPerNode - 1; + int rank, int nRanksPerIpcDomain, [[maybe_unused]] int worldSize, size_t nelems) { + const int nPeer = nRanksPerIpcDomain - 1; const size_t chanOffset = nPeer * blockIdx.x; // assume (nelems * sizeof(T)) is divisible by 16 const size_t nInt4 = nelems * sizeof(int) / sizeof(int4); @@ -33,10 +33,11 @@ __global__ void __launch_bounds__(1024, 1) const size_t restNInt4 = nInt4 % nInt4PerChunk; const size_t scratchChunkRankOffset = nInt4PerChunk * rank; - __shared__ DeviceHandle channels[MAX_NRANKS_PER_NODE - 1]; + __shared__ DeviceHandle channels[MAX_IPC_DOMAIN_NRANKS - 1]; const int lid = threadIdx.x % WARP_SIZE; - if (lid < nPeer) { - channels[lid] = memoryChans[lid]; + // Peer count may exceed WARP_SIZE on MNNVL. + for (int i = lid; i < nPeer; i += WARP_SIZE) { + channels[i] = memoryChans[i]; } __syncwarp(); const int tid = threadIdx.x + blockIdx.x * blockDim.x; @@ -138,11 +139,11 @@ CommResult AllgatherFullmesh::allgatherKernelFunc(const std::shared_ptr ct if ((char*)input == (char*)output + rank * inputSize) { allgatherFullmesh<<>>( (void*)input, this->scratchBuffer_, (void*)output, ctx->memoryChannelDeviceHandles.get(), rank, - ctx->nRanksPerNode, ctx->workSize, nElem); + ctx->nRanksPerIpcDomain, ctx->worldSize, nElem); } else { allgatherFullmesh<<>>( (void*)input, this->scratchBuffer_, (void*)output, ctx->memoryChannelDeviceHandles.get(), rank, - ctx->nRanksPerNode, ctx->workSize, nElem); + ctx->nRanksPerIpcDomain, ctx->worldSize, nElem); } cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { @@ -156,8 +157,8 @@ std::shared_ptr AllgatherFullmesh::initAllgatherContext(std::shared_ptr(); ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); - ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); + ctx->worldSize = comm->bootstrap()->getNranks(); + ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); // setup semaphores ctx->memorySemaphores = setupMemorySemaphores(comm, this->conns_, kMaxBlocks); diff --git a/src/ext/collectives/allgather/allgather_fullmesh_2.cu b/src/ext/collectives/allgather/allgather_fullmesh_2.cu index 895818228..8436532e1 100644 --- a/src/ext/collectives/allgather/allgather_fullmesh_2.cu +++ b/src/ext/collectives/allgather/allgather_fullmesh_2.cu @@ -12,8 +12,8 @@ __device__ DeviceSyncer deviceSyncer; template __global__ void __launch_bounds__(1024, 1) allgatherFullmesh2(void* sendbuff, mscclpp::DeviceHandle* memoryChannels, - size_t channelOutOffset, size_t rank, [[maybe_unused]] size_t worldSize, size_t nRanksPerNode, - size_t nelemsPerGPU) { + size_t channelOutOffset, size_t rank, [[maybe_unused]] size_t worldSize, + size_t nRanksPerIpcDomain, size_t nelemsPerGPU) { const size_t tid = threadIdx.x + blockIdx.x * blockDim.x; const size_t lid = tid % WARP_SIZE; const size_t wid = tid / WARP_SIZE; @@ -24,7 +24,7 @@ __global__ void __launch_bounds__(1024, 1) return; } const size_t nWarp = nThread / WARP_SIZE; - const size_t nPeer = nRanksPerNode - 1; + const size_t nPeer = nRanksPerIpcDomain - 1; const size_t chanOffset = nPeer * blockIdx.x; auto memChans = memoryChannels + chanOffset; @@ -161,12 +161,12 @@ CommResult AllgatherFullmesh2::allgatherKernelFunc(const std::shared_ptr c size_t channelOutOffset = *static_cast(ctx->extras["channel_out_offset"].get()); if ((char*)input == (char*)output + rank * inputSize) { allgatherFullmesh2<<>>( - (void*)input, ctx->memoryChannelDeviceHandles.get(), channelOutOffset, ctx->rank, ctx->workSize, - ctx->nRanksPerNode, nElem); + (void*)input, ctx->memoryChannelDeviceHandles.get(), channelOutOffset, ctx->rank, ctx->worldSize, + ctx->nRanksPerIpcDomain, nElem); } else { allgatherFullmesh2<<>>( - (void*)input, ctx->memoryChannelDeviceHandles.get(), channelOutOffset, ctx->rank, ctx->workSize, - ctx->nRanksPerNode, nElem); + (void*)input, ctx->memoryChannelDeviceHandles.get(), channelOutOffset, ctx->rank, ctx->worldSize, + ctx->nRanksPerIpcDomain, nElem); } cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { @@ -180,8 +180,8 @@ std::shared_ptr AllgatherFullmesh2::initAllgatherContext(std::shared_ptr(); ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); - ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); + ctx->worldSize = comm->bootstrap()->getNranks(); + ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); // setup semaphores ctx->memorySemaphores = this->memorySemaphores_; diff --git a/src/ext/collectives/allreduce/allreduce_allpair_packet.cu b/src/ext/collectives/allreduce/allreduce_allpair_packet.cu index 49058f59a..0e34be718 100644 --- a/src/ext/collectives/allreduce/allreduce_allpair_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_allpair_packet.cu @@ -14,14 +14,11 @@ namespace collective { template __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, DeviceHandle* memoryChannels, - size_t channelDataOffset, size_t scratchBufferSize, int rank, int nRanksPerNode, + size_t channelDataOffset, size_t scratchBufferSize, int rank, int nRanksPerIpcDomain, int worldSize, size_t nelems, uint32_t numScratchBuff, void* flags, uint32_t flagSize) { - // This version of allreduce only works for single nodes - if (worldSize != nRanksPerNode) return; - if (sizeof(T) == 2 || sizeof(T) == 1) nelems = (nelems * sizeof(T) + sizeof(T)) / sizeof(int); - const int nPeers = nRanksPerNode - 1; + const int nPeers = nRanksPerIpcDomain - 1; uint32_t flag = ((uint32_t*)flags)[blockIdx.x]; size_t scratchBaseOffset = (flag % numScratchBuff) ? (scratchBufferSize / numScratchBuff) : 0; @@ -71,25 +68,25 @@ __global__ void allreduceAllPairs(T* buff, T* scratch, T* resultBuff, DeviceHand } } -inline std::pair getDefaultBlockNumAndThreadNum(size_t inputSize, int worldSize) { - if (inputSize < worldSize * sizeof(int)) { - return {worldSize - 1, (worldSize - 1) * WARP_SIZE}; +inline std::pair getDefaultBlockNumAndThreadNum(size_t inputSize, int nRanksPerIpcDomain) { + if (inputSize < nRanksPerIpcDomain * sizeof(int)) { + return {nRanksPerIpcDomain - 1, (nRanksPerIpcDomain - 1) * WARP_SIZE}; } - return {(worldSize - 1) * 4, 512}; + return {(nRanksPerIpcDomain - 1) * 4, 512}; } template struct AllpairAdapter { static cudaError_t call(const void* buff, void* scratch, void* resultBuff, void* memoryChannels, void*, DeviceHandle*, DeviceHandle*, size_t channelInOffset, size_t, - size_t scratchBufferSize, int rank, int nRanksPerNode, int worldSize, size_t inputSize, + size_t scratchBufferSize, int rank, int nRanksPerIpcDomain, int worldSize, size_t inputSize, cudaStream_t stream, void* flags, uint32_t flagSize, uint32_t numScratchBuff, int nBlocks = 0, int nThreadsPerBlock = 0) { using ChannelType = DeviceHandle; const size_t nelems = inputSize / sizeof(T); allreduceAllPairs<<>>( (T*)buff, (T*)scratch, (T*)resultBuff, (ChannelType*)memoryChannels, channelInOffset, scratchBufferSize, rank, - nRanksPerNode, worldSize, nelems, numScratchBuff, flags, flagSize); + nRanksPerIpcDomain, worldSize, nelems, numScratchBuff, flags, flagSize); return cudaGetLastError(); } }; @@ -108,16 +105,22 @@ CommResult AllreduceAllpairPacket::allreduceKernelFunc(const std::shared_ptr&, DataType accumDtype) { auto algoCtx = std::static_pointer_cast(ctx); + if (algoCtx->worldSize != algoCtx->nRanksPerIpcDomain) { + WARN(ALGO, + "AllreduceAllpairPacket requires worldSize to match nRanksPerIpcDomain, got worldSize=", algoCtx->worldSize, + ", nRanksPerIpcDomain=", algoCtx->nRanksPerIpcDomain); + return CommResult::CommInvalidArgument; + } std::pair blockAndThreadNum{nBlocks, nThreadsPerBlock}; if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) { - blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize, algoCtx->workSize); + blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize, algoCtx->nRanksPerIpcDomain); } if (blockAndThreadNum.first > maxBlockNum_) { WARN(ALGO, "Requested block number ", blockAndThreadNum.first, " exceeds the maximum supported block number ", maxBlockNum_, "."); return CommResult::CommInvalidArgument; } - const int nPeers = algoCtx->nRanksPerNode - 1; + const int nPeers = algoCtx->nRanksPerIpcDomain - 1; // The kernel maps peer sends by warpId, so every peer needs a full warp. if (blockAndThreadNum.second % WARP_SIZE != 0 || blockAndThreadNum.second / WARP_SIZE < nPeers) { WARN(ALGO, @@ -138,8 +141,8 @@ CommResult AllreduceAllpairPacket::allreduceKernelFunc(const std::shared_ptrscratchBuffer_, output, algoCtx->memoryChannelDeviceHandles.get(), nullptr, nullptr, - nullptr, channelInOffset, 0, this->scratchBufferSize_, algoCtx->rank, algoCtx->nRanksPerNode, - algoCtx->workSize, inputSize, stream, (void*)flagBuffer_, (uint32_t)flagBufferSize_, + nullptr, channelInOffset, 0, this->scratchBufferSize_, algoCtx->rank, algoCtx->nRanksPerIpcDomain, + algoCtx->worldSize, inputSize, stream, (void*)flagBuffer_, (uint32_t)flagBufferSize_, this->nSegmentsForScratchBuffer_, blockAndThreadNum.first, blockAndThreadNum.second); if (error != cudaSuccess) { WARN(ALGO, "AllreducePacket failed with error: ", cudaGetErrorString(error)); @@ -153,8 +156,8 @@ std::shared_ptr AllreduceAllpairPacket::initAllreduceContext(std::shared_p auto ctx = std::make_shared(); const int nChannelsPerConnection = maxBlockNum_; ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); - ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); + ctx->worldSize = comm->bootstrap()->getNranks(); + ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); ctx->memorySemaphores = this->memorySemaphores_; ctx->registeredMemories = this->registeredMemories_; ctx->registeredMemories.pop_back(); // remove the local memory from previous context diff --git a/src/ext/collectives/allreduce/allreduce_fullmesh.cu b/src/ext/collectives/allreduce/allreduce_fullmesh.cu index 24d2a31c2..eb8726245 100644 --- a/src/ext/collectives/allreduce/allreduce_fullmesh.cu +++ b/src/ext/collectives/allreduce/allreduce_fullmesh.cu @@ -9,12 +9,23 @@ namespace mscclpp { namespace collective { +namespace { +// Per-context cache of input-side MemoryChannels keyed by input pointer. +// Lifetime is tied to AlgorithmCtx, so entries are released when the ctx is +// evicted from the framework's context cache (avoids unbounded growth across +// allreduce calls that pass different input buffers). +using InputChannelsCache = + std::unordered_map, std::shared_ptr>>>; +constexpr const char* kInputChannelsExtraKey = "inputChannels"; +} // namespace + template __global__ void __launch_bounds__(512, 1) allreduceFullmesh(T* buff, T* scratch, T* resultBuff, DeviceHandle* memoryChannels, DeviceHandle* memoryOutChannels, size_t channelOutDataOffset, int rank, - int nRanksPerNode, int worldSize, size_t nelems) { - const int nPeer = nRanksPerNode - 1; + int nRanksPerIpcDomain, int worldSize, size_t nelems) { + const int nPeer = nRanksPerIpcDomain - 1; const size_t chanOffset = nPeer * blockIdx.x; // assume (nelems * sizeof(T)) is divisible by (16 * worldSize) const size_t nInt4 = nelems * sizeof(T) / sizeof(int4); @@ -49,12 +60,13 @@ __global__ void __launch_bounds__(512, 1) const size_t blockOffset = nInt4PerChunk * blockIdx.x; const size_t scratchChunkRankOffset = chunkSizePerRank * rank; - __shared__ DeviceHandle channels[MAX_NRANKS_PER_NODE - 1]; - __shared__ DeviceHandle outChannels[MAX_NRANKS_PER_NODE - 1]; + __shared__ DeviceHandle channels[MAX_IPC_DOMAIN_NRANKS - 1]; + __shared__ DeviceHandle outChannels[MAX_IPC_DOMAIN_NRANKS - 1]; const int lid = threadIdx.x % WARP_SIZE; - if (lid < nPeer) { - channels[lid] = memoryChans[lid]; - outChannels[lid] = memoryOutChans[lid]; + // Peer count may exceed WARP_SIZE on MNNVL. + for (int i = lid; i < nPeer; i += WARP_SIZE) { + channels[i] = memoryChans[i]; + outChannels[i] = memoryOutChans[i]; } __syncwarp(); @@ -156,7 +168,7 @@ template struct AllreduceAllconnectAdapter { static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* memoryOutChannels, DeviceHandle*, DeviceHandle*, size_t, - size_t channelOutDataOffset, size_t, int rank, int nRanksPerNode, int worldSize, + size_t channelOutDataOffset, size_t, int rank, int nRanksPerIpcDomain, int worldSize, size_t inputSize, cudaStream_t stream, void*, uint32_t, uint32_t, int nBlocks, int nThreadsPerBlock) { using ChannelType = DeviceHandle; @@ -165,7 +177,7 @@ struct AllreduceAllconnectAdapter { if (nThreadsPerBlock == 0) nThreadsPerBlock = 512; allreduceFullmesh<<>>( (T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels, (ChannelType*)memoryOutChannels, - channelOutDataOffset, rank, nRanksPerNode, worldSize, nelems); + channelOutDataOffset, rank, nRanksPerIpcDomain, worldSize, nelems); return cudaGetLastError(); } }; @@ -194,17 +206,17 @@ CommResult AllreduceFullmesh::allreduceKernelFunc( MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)output)); channelOutOffset = (char*)output - (char*)recvBasePtr; } - std::shared_ptr> inputChannelHandles; - if (this->memoryChannelsMap_.find(input) != this->memoryChannelsMap_.end()) { - inputChannelHandles = this->memoryChannelsMap_[input].second; - } else { + auto& inputChannelsCache = *static_cast(ctx->extras.at(kInputChannelsExtraKey).get()); + auto it = inputChannelsCache.find(input); + if (it == inputChannelsCache.end()) { RegisteredMemory localMemory = comm_->registerMemory(const_cast(input), inputSize, Transport::CudaIpc); std::vector channels = setupMemoryChannels(this->conns_, this->inputScratchSemaphores_, this->remoteScratchMemories_, localMemory, nChannelsPerConnection_); - this->memoryChannelsMap_[input] = std::make_pair(channels, setupMemoryChannelDeviceHandles(channels)); + auto handles = setupMemoryChannelDeviceHandles(channels); + it = inputChannelsCache.emplace(input, std::make_pair(std::move(channels), std::move(handles))).first; } - inputChannelHandles = this->memoryChannelsMap_[input].second; + std::shared_ptr> inputChannelHandles = it->second.second; AllreduceFunc allreduce = dispatch(op, dtype, accumDtype); if (!allreduce) { @@ -222,7 +234,7 @@ CommResult AllreduceFullmesh::allreduceKernelFunc( } cudaError_t error = allreduce(input, this->scratchBuffer_, output, inputChannelHandles.get(), ctx->memoryChannelDeviceHandles.get(), - nullptr, nullptr, 0, channelOutOffset, 0, ctx->rank, ctx->nRanksPerNode, ctx->workSize, inputSize, + nullptr, nullptr, 0, channelOutOffset, 0, ctx->rank, ctx->nRanksPerIpcDomain, ctx->worldSize, inputSize, stream, nullptr, 0, 0, numBlocksAndThreads.first, numBlocksAndThreads.second); if (error != cudaSuccess) { WARN("AllreduceAllconnect failed with error: %s", cudaGetErrorString(error)); @@ -248,8 +260,8 @@ std::shared_ptr AllreduceFullmesh::initAllreduceContext(std::shared_ptr(); ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); - ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); + ctx->worldSize = comm->bootstrap()->getNranks(); + ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); // setup semaphores ctx->memorySemaphores = this->outputSemaphores_; @@ -266,6 +278,7 @@ std::shared_ptr AllreduceFullmesh::initAllreduceContext(std::shared_ptrmemoryChannels = setupMemoryChannels(this->conns_, ctx->memorySemaphores, ctx->registeredMemories, localMemory, nChannelsPerConnection_); ctx->memoryChannelDeviceHandles = setupMemoryChannelDeviceHandles(ctx->memoryChannels); + ctx->extras.insert({kInputChannelsExtraKey, std::make_shared()}); return ctx; } diff --git a/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu b/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu index 2d71cd638..1edbc0118 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_block_pipeline.cu @@ -20,15 +20,15 @@ __global__ void __launch_bounds__(1024, 1) [[maybe_unused]] DeviceHandle* memoryChannels, [[maybe_unused]] DeviceHandle* switchChannels, [[maybe_unused]] size_t size, [[maybe_unused]] size_t scratchBufferSize, - [[maybe_unused]] int rank, [[maybe_unused]] int nRanksPerNode) { + [[maybe_unused]] int rank, [[maybe_unused]] int nRanksPerIpcDomain) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 constexpr int alignment = 16; - int nPeers = nRanksPerNode - 1; - int nBlocksForCopy = nRanksPerNode * 2; - int nBlocksForReduce = nRanksPerNode; + int nPeers = nRanksPerIpcDomain - 1; + int nBlocksForCopy = nRanksPerIpcDomain * 2; + int nBlocksForReduce = nRanksPerIpcDomain; int copyReduceRatio = nBlocksForCopy / nBlocksForReduce; - size_t scratchSizePerRank = scratchBufferSize / nRanksPerNode; - size_t sizePerRank = size / nRanksPerNode; + size_t scratchSizePerRank = scratchBufferSize / nRanksPerIpcDomain; + size_t sizePerRank = size / nRanksPerIpcDomain; assert(sizePerRank % alignment == 0); uint32_t sizePerBlock = ((sizePerRank + (nBlocksForCopy - 1)) / nBlocksForCopy + alignment - 1) / alignment * alignment; @@ -68,7 +68,7 @@ __global__ void __launch_bounds__(1024, 1) deviceSemaphore[bid + 2 * nBlocksForCopy].acquire(); } __syncthreads(); - for (int i = 0; i < nRanksPerNode; i++) { + for (int i = 0; i < nRanksPerIpcDomain; i++) { size_t blockOffset = it * unitSize + bid * sizePerBlock + i * sizePerRank; uint32_t scratchOffset = scratchIt * unitSize + bid * scratchSizePerBlock + i * scratchSizePerRank; char* srcData = (char*)src + blockOffset; @@ -125,7 +125,7 @@ __global__ void __launch_bounds__(1024, 1) channels->wait(); } __syncthreads(); - for (int i = 0; i < nRanksPerNode; i++) { + for (int i = 0; i < nRanksPerIpcDomain; i++) { size_t blockOffset = it * unitSize + (bid - nBlocksForCopy - nBlocksForReduce) * sizePerBlock + i * sizePerRank; uint32_t scratchOffset = scratchIt * unitSize + (bid - nBlocksForCopy - nBlocksForReduce) * scratchSizePerBlock + @@ -150,7 +150,7 @@ template struct NvlsBlockPipelineAdapter { static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void*, DeviceHandle* nvlsChannels, DeviceHandle*, size_t, size_t, - size_t scratchBufferSize, int rank, int nRanksPerNode, int, size_t inputSize, + size_t scratchBufferSize, int rank, int nRanksPerIpcDomain, int, size_t inputSize, cudaStream_t stream, void*, uint32_t, uint32_t, int nBlocks, int nThreadsPerBlock) { // uint8_t is not supported for NVLS (no hardware support for byte-level reduction) if constexpr (std::is_same_v) { @@ -166,9 +166,9 @@ struct NvlsBlockPipelineAdapter { #endif { using ChannelType = DeviceHandle; - allreduceNvlsBlockPipeline - <<>>(input, scratch, output, (ChannelType*)memoryChannels, - nvlsChannels, inputSize, scratchBufferSize, rank, nRanksPerNode); + allreduceNvlsBlockPipeline<<>>( + input, scratch, output, (ChannelType*)memoryChannels, nvlsChannels, inputSize, scratchBufferSize, rank, + nRanksPerIpcDomain); return cudaGetLastError(); } } @@ -176,7 +176,9 @@ struct NvlsBlockPipelineAdapter { void AllreduceNvlsBlockPipeline::initialize(std::shared_ptr comm) { nSwitchChannels_ = 8; - int nBaseChannels = 64; + int nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); + // Per-peer channel allocation must hold up to 4 * nRanksPerIpcDomain entries (see kernel). + int nBaseChannels = std::max(64, 4 * nRanksPerIpcDomain); this->conns_ = setupConnections(comm); // setup semaphores std::vector> memorySemaphores = @@ -187,11 +189,10 @@ void AllreduceNvlsBlockPipeline::initialize(std::shared_ptr comm) this->nvlsConnections_ = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels_); } -CommResult AllreduceNvlsBlockPipeline::allreduceKernelFunc(const std::shared_ptr ctx_void, const void* input, - void* output, size_t inputSize, DataType dtype, ReduceOp op, - cudaStream_t stream, int nBlocks, int nThreadsPerBlock, - const std::unordered_map& extras, - DataType accumDtype) { +CommResult AllreduceNvlsBlockPipeline::allreduceKernelFunc( + const std::shared_ptr ctx_void, const void* input, void* output, size_t inputSize, DataType dtype, + ReduceOp op, cudaStream_t stream, int nBlocks, int nThreadsPerBlock, + [[maybe_unused]] const std::unordered_map& extras, DataType accumDtype) { auto ctx = std::static_pointer_cast(ctx_void); AllreduceFunc allreduce = dispatch(op, dtype, accumDtype); if (!allreduce) { @@ -200,11 +201,11 @@ CommResult AllreduceNvlsBlockPipeline::allreduceKernelFunc(const std::shared_ptr } std::pair blockAndThreadNum = {nBlocks, nThreadsPerBlock}; if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) { - blockAndThreadNum = {ctx->nRanksPerNode * 5, 1024}; + blockAndThreadNum = {ctx->nRanksPerIpcDomain * 5, 1024}; } cudaError_t error = allreduce(input, this->scratchBuffer_, output, this->memoryChannelsDeviceHandle_.get(), nullptr, ctx->switchChannelDeviceHandles.get(), nullptr, 0, 0, this->scratchBufferSize_, - ctx->rank, ctx->nRanksPerNode, ctx->workSize, inputSize, stream, nullptr, 0, 0, + ctx->rank, ctx->nRanksPerIpcDomain, ctx->worldSize, inputSize, stream, nullptr, 0, 0, blockAndThreadNum.first, blockAndThreadNum.second); if (error != cudaSuccess) { WARN("AllreduceNvlsBlockPipeline failed with error: %s", cudaGetErrorString(error)); @@ -221,12 +222,12 @@ std::shared_ptr AllreduceNvlsBlockPipeline::initAllreduceContext(std::shar void*, size_t, DataType) { auto ctx = std::make_shared(); ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); - ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); + ctx->worldSize = comm->bootstrap()->getNranks(); + ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); // setup channels ctx->switchChannels = - setupNvlsChannels(this->nvlsConnections_, this->scratchBuffer_, scratchBufferSize_, nSwitchChannels_); + setupNvlsChannels(comm, this->nvlsConnections_, this->scratchBuffer_, scratchBufferSize_, nSwitchChannels_); ctx->switchChannelDeviceHandles = setupNvlsChannelDeviceHandles(ctx->switchChannels); return ctx; } diff --git a/src/ext/collectives/allreduce/allreduce_nvls_packet.cu b/src/ext/collectives/allreduce/allreduce_nvls_packet.cu index a616485e1..98d9e1a39 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_packet.cu @@ -82,7 +82,7 @@ void AllreduceNvlsPacket::initialize(std::shared_ptr comm) { int nSwitchChannels = 1; this->nvlsConnections_ = setupNvlsConnections(comm, nvlsBufferSize_, nSwitchChannels); this->switchChannels_ = - setupNvlsChannels(this->nvlsConnections_, this->scratchBuffer_, this->scratchBufferSize_, nSwitchChannels); + setupNvlsChannels(comm, this->nvlsConnections_, this->scratchBuffer_, this->scratchBufferSize_, nSwitchChannels); } AlgorithmCtxKey AllreduceNvlsPacket::generateAllreduceContextKey(const void*, void*, size_t, DataType, bool) { @@ -93,8 +93,8 @@ std::shared_ptr AllreduceNvlsPacket::initAllreduceContext(std::shared_ptr< size_t, DataType) { auto ctx = std::make_shared(); ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); - ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); + ctx->worldSize = comm->bootstrap()->getNranks(); + ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); // setup channels ctx->switchChannels = this->switchChannels_; @@ -123,7 +123,7 @@ CommResult AllreduceNvlsPacket::allreduceKernelFunc(const std::shared_ptr } cudaError_t error = allreduce(input, this->scratchBuffer_, output, nullptr, nullptr, ctx->switchChannelDeviceHandles.get(), nullptr, - 0, 0, this->scratchBufferSize_, ctx->rank, ctx->nRanksPerNode, ctx->workSize, inputSize, stream, + 0, 0, this->scratchBufferSize_, ctx->rank, ctx->nRanksPerIpcDomain, ctx->worldSize, inputSize, stream, (void*)flagBuffer_, (uint32_t)flagBufferSize_, 0, blockAndThreadNum.first, blockAndThreadNum.second); if (error != cudaSuccess) { WARN(ALGO, "AllreduceNvlsPacket failed with error: ", cudaGetErrorString(error)); @@ -154,4 +154,4 @@ std::shared_ptr AllreduceNvlsPacket::build() { }); } } // namespace collective -} // namespace mscclpp \ No newline at end of file +} // namespace mscclpp diff --git a/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu b/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu index 3bb054dae..d4492ed5d 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_warp_pipeline.cu @@ -18,15 +18,15 @@ __global__ void __launch_bounds__(1024, 1) [[maybe_unused]] DeviceHandle* memoryChannels, [[maybe_unused]] DeviceHandle* multicast, [[maybe_unused]] size_t size, [[maybe_unused]] size_t scratchBufferSize, [[maybe_unused]] int rank, - [[maybe_unused]] int nRanksPerNode) { + [[maybe_unused]] int nRanksPerIpcDomain) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 constexpr int alignment = 16; - int nPeers = nRanksPerNode - 1; + int nPeers = nRanksPerIpcDomain - 1; int nBlocks = gridDim.x; int nBlocksPerNvlsConn = nBlocks / NUM_NVLS_CONNECTION; int bid = blockIdx.x; - size_t sizePerRank = size / nRanksPerNode; - size_t scratchSizePerRank = scratchBufferSize / nRanksPerNode; + size_t sizePerRank = size / nRanksPerIpcDomain; + size_t scratchSizePerRank = scratchBufferSize / nRanksPerIpcDomain; const size_t maxSizePerBlock = ((sizePerRank + nBlocks - 1) / nBlocks + alignment - 1) / alignment * alignment; size_t start = bid * maxSizePerBlock; size_t end = min(start + maxSizePerBlock, sizePerRank); @@ -53,19 +53,20 @@ __global__ void __launch_bounds__(1024, 1) lastIterSize = sizePerBlock % copyPerIter; } - const size_t chanOffset = (nRanksPerNode - 1) * blockIdx.x * 2; + const size_t chanOffset = (nRanksPerIpcDomain - 1) * blockIdx.x * 2; auto memoryChans = memoryChannels + chanOffset; - __shared__ DeviceHandle channels[(MAX_NRANKS_PER_NODE - 1) * 2]; + __shared__ DeviceHandle channels[(MAX_IPC_DOMAIN_NRANKS - 1) * 2]; const int lid = threadIdx.x % WARP_SIZE; - if (lid < nPeers * 2) { - channels[lid] = memoryChans[lid]; + // Peer count may exceed WARP_SIZE on MNNVL. + for (int i = lid; i < nPeers * 2; i += WARP_SIZE) { + channels[i] = memoryChans[i]; } __syncwarp(); for (int it = 0; it < nIter; it++) { const size_t iterSize = (it == nIter - 1) ? lastIterSize : copyPerIter; if (warpId < endCopyWid) { int tidInCopy = threadIdx.x; - for (int i = 0; i < nRanksPerNode; i++) { + for (int i = 0; i < nRanksPerIpcDomain; i++) { size_t offset = i * sizePerRank + maxSizePerBlock * bid + it * copyPerIter; size_t offsetScratch = i * scratchSizePerRank + scratchSizePerBlock * bid + (it * copyPerIter) % scratchSizePerBlock; @@ -96,7 +97,7 @@ __global__ void __launch_bounds__(1024, 1) channels[tidInRecvCopy + nPeers].wait(); } asm volatile("bar.sync %0, %1;" ::"r"(3), "r"((NRECV_COPY_WARPS)*WARP_SIZE) : "memory"); - for (int i = 0; i < nRanksPerNode; i++) { + for (int i = 0; i < nRanksPerIpcDomain; i++) { size_t offset = i * sizePerRank + maxSizePerBlock * bid + it * copyPerIter; size_t offsetScratch = i * scratchSizePerRank + scratchSizePerBlock * bid + (it * copyPerIter) % scratchSizePerBlock; @@ -113,7 +114,7 @@ template struct NvlsWarpPipelineAdapter { static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void*, DeviceHandle* nvlsChannels, DeviceHandle*, size_t, size_t, - size_t scratchBufferSize, int rank, int nRanksPerNode, int, size_t inputSize, + size_t scratchBufferSize, int rank, int nRanksPerIpcDomain, int, size_t inputSize, cudaStream_t stream, void*, uint32_t, uint32_t, int nBlocks, int nThreadsPerBlock) { // uint8_t is not supported for NVLS (no hardware support for byte-level reduction) if constexpr (std::is_same_v) { @@ -129,17 +130,19 @@ struct NvlsWarpPipelineAdapter { #endif { using ChannelType = DeviceHandle; - allreduceNvlsWarpPipeline - <<>>(input, scratch, output, (ChannelType*)memoryChannels, - nvlsChannels, inputSize, scratchBufferSize, rank, nRanksPerNode); + allreduceNvlsWarpPipeline<<>>( + input, scratch, output, (ChannelType*)memoryChannels, nvlsChannels, inputSize, scratchBufferSize, rank, + nRanksPerIpcDomain); return cudaGetLastError(); } } }; void AllreduceNvlsWarpPipeline::initialize(std::shared_ptr comm) { - nSwitchChannels_ = 8; - int nBaseChannels = 64; + nSwitchChannels_ = NUM_NVLS_CONNECTION; + int nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); + // Per-peer channel allocation must hold 2 * nBlocks entries; default nBlocks = 4 * nRanksPerIpcDomain. + int nBaseChannels = std::max(64, 8 * nRanksPerIpcDomain); this->conns_ = setupConnections(comm); // setup semaphores std::vector> memorySemaphores = @@ -162,11 +165,11 @@ CommResult AllreduceNvlsWarpPipeline::allreduceKernelFunc( } std::pair blockAndThreadNum = {nBlocks, nThreadsPerBlock}; if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) { - blockAndThreadNum = {ctx->nRanksPerNode * 4, 1024}; + blockAndThreadNum = {ctx->nRanksPerIpcDomain * 4, 1024}; } cudaError_t error = allreduce(input, this->scratchBuffer_, output, this->memoryChannelsDeviceHandle_.get(), nullptr, ctx->switchChannelDeviceHandles.get(), nullptr, 0, 0, this->scratchBufferSize_, - ctx->rank, ctx->nRanksPerNode, ctx->workSize, inputSize, stream, nullptr, 0, 0, + ctx->rank, ctx->nRanksPerIpcDomain, ctx->worldSize, inputSize, stream, nullptr, 0, 0, blockAndThreadNum.first, blockAndThreadNum.second); if (error != cudaSuccess) { WARN("AllreduceNvlsWarpPipeline failed with error: %s", cudaGetErrorString(error)); @@ -183,12 +186,12 @@ std::shared_ptr AllreduceNvlsWarpPipeline::initAllreduceContext(std::share void*, size_t, DataType) { auto ctx = std::make_shared(); ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); - ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); + ctx->worldSize = comm->bootstrap()->getNranks(); + ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); // setup channels ctx->switchChannels = - setupNvlsChannels(this->nvlsConnections_, this->scratchBuffer_, scratchBufferSize_, nSwitchChannels_); + setupNvlsChannels(comm, this->nvlsConnections_, this->scratchBuffer_, scratchBufferSize_, nSwitchChannels_); ctx->switchChannelDeviceHandles = setupNvlsChannelDeviceHandles(ctx->switchChannels); return ctx; } diff --git a/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu b/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu index e7f2028fa..f76dd079b 100644 --- a/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu +++ b/src/ext/collectives/allreduce/allreduce_nvls_zero_copy.cu @@ -13,18 +13,18 @@ namespace collective { constexpr int MAX_NBLOCKS = 32; -template +template __global__ void __launch_bounds__(1024, 1) allreduceNvls([[maybe_unused]] mscclpp::DeviceHandle* memoryChannels, [[maybe_unused]] mscclpp::DeviceHandle* multicast, [[maybe_unused]] mscclpp::DeviceHandle* multicastOut, [[maybe_unused]] size_t channelInOffset, [[maybe_unused]] size_t channelOutOffset, - [[maybe_unused]] size_t size, [[maybe_unused]] int rank, [[maybe_unused]] int nRanksPerNode) { + [[maybe_unused]] size_t size, [[maybe_unused]] int rank, [[maybe_unused]] int nRanksPerIpcDomain) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 - int nPeers = nRanksPerNode - 1; + int nPeers = nRanksPerIpcDomain - 1; int nBlocks = gridDim.x; int bid = blockIdx.x; - size_t sizePerRank = size / nRanksPerNode; + size_t sizePerRank = size / nRanksPerIpcDomain; const size_t minAlign = 16; // Align sizePerBlock to 16 bytes to ensure aligned vector access in handleMultiLoadReduceStore size_t sizePerBlock = (sizePerRank + nBlocks - 1) / nBlocks; @@ -40,12 +40,13 @@ __global__ void __launch_bounds__(1024, 1) mscclpp::DeviceHandle* multicastPtr = multicast + bid; mscclpp::DeviceHandle* multicastOutPtr = multicastOut + bid; - const size_t chanOffset = (nRanksPerNode - 1) * blockIdx.x; + const size_t chanOffset = (nRanksPerIpcDomain - 1) * blockIdx.x; auto memoryChans = memoryChannels + chanOffset; - __shared__ mscclpp::DeviceHandle channels[MAX_NRANKS_PER_NODE - 1]; + __shared__ mscclpp::DeviceHandle channels[MAX_IPC_DOMAIN_NRANKS - 1]; const int lid = threadIdx.x % WARP_SIZE; - if (lid < nRanksPerNode - 1) { - channels[lid] = memoryChans[lid]; + // Peer count may exceed WARP_SIZE on MNNVL. + for (int i = lid; i < nRanksPerIpcDomain - 1; i += WARP_SIZE) { + channels[i] = memoryChans[i]; } __syncwarp(); if (threadIdx.x < nPeers) { @@ -56,8 +57,8 @@ __global__ void __launch_bounds__(1024, 1) T* src = (T*)multicastPtr->mcPtr; T* dst = (T*)multicastOutPtr->mcPtr; if (curBlockSize > 0) { - handleMultiLoadReduceStore(src, dst, blockOffset + channelInOffset, blockOffset + channelOutOffset, curBlockSize, - threadIdx.x, blockDim.x); + handleMultiLoadReduceStore(src, dst, blockOffset + channelInOffset, blockOffset + channelOutOffset, + curBlockSize, threadIdx.x, blockDim.x); } __syncthreads(); if (threadIdx.x < nPeers) { @@ -72,7 +73,7 @@ struct NvlsAdapter { static cudaError_t call(const void*, void*, void*, void* memoryChannels, void*, mscclpp::DeviceHandle* nvlsChannels, mscclpp::DeviceHandle* nvlsOutChannels, size_t channelInOffset, - size_t channelOutOffset, size_t, int rank, int nRanksPerNode, int, size_t inputSize, + size_t channelOutOffset, size_t, int rank, int nRanksPerIpcDomain, int, size_t inputSize, cudaStream_t stream, void*, uint32_t, uint32_t, int nBlocks, int nThreadsPerBlock) { // uint8_t is not supported for NVLS (no hardware support for byte-level reduction) if constexpr (std::is_same_v) { @@ -80,17 +81,11 @@ struct NvlsAdapter { } else if constexpr (std::is_same_v) { // fp8_e4m3b15 is a software-only type with no hardware NVLS support. return cudaErrorNotSupported; - } else -#if (!defined(__CUDA_ARCH_SPECIFIC__) && !defined(__CUDA_ARCH_FAMILY_SPECIFIC__)) || (__CUDA_ARCH__ < 1000) - if constexpr (std::is_same_v || std::is_same_v) { - return cudaErrorNotSupported; - } else -#endif - { + } else { using ChannelType = DeviceHandle; - allreduceNvls<<>>((ChannelType*)memoryChannels, nvlsChannels, - nvlsOutChannels, channelInOffset, channelOutOffset, - inputSize, rank, nRanksPerNode); + allreduceNvls<<>>( + (ChannelType*)memoryChannels, nvlsChannels, nvlsOutChannels, channelInOffset, channelOutOffset, inputSize, + rank, nRanksPerIpcDomain); return cudaGetLastError(); } } @@ -124,6 +119,13 @@ CommResult AllreduceNvls::allreduceKernelFunc(const std::shared_ptr ctx_vo return CommResult::CommInvalidArgument; } auto ctx = std::static_pointer_cast(ctx_void); +#if defined(__FP8_TYPES_EXIST__) + bool isFp8Dtype = dtype == mscclpp::DataType::FLOAT8_E4M3FN || dtype == mscclpp::DataType::FLOAT8_E5M2; + if (isFp8Dtype && computeCapabilityMajor_ < 10) { + WARN("FP8 NVLS allreduce requires compute capability 10.x or newer."); + return CommResult::CommInvalidArgument; + } +#endif AllreduceFunc allreduce = dispatch(op, dtype, accumDtype); if (!allreduce) { WARN("Unsupported operation or data type for allreduce, dtype=%d", static_cast(dtype)); @@ -142,23 +144,27 @@ CommResult AllreduceNvls::allreduceKernelFunc(const std::shared_ptr ctx_vo } std::pair numBlocksAndThreads = {nBlocks, nThreadsPerBlock}; if (numBlocksAndThreads.first == 0 || numBlocksAndThreads.second == 0) { - numBlocksAndThreads = {::min(ctx->nRanksPerNode, MAX_NBLOCKS), 1024}; + numBlocksAndThreads = {::min(ctx->nRanksPerIpcDomain, MAX_NBLOCKS), 1024}; // For GB200 devices with MNNVLS (Multi-Node NVLink Sharp), scale the number of blocks inversely with // the number of GPUs. Empirically, 32 blocks works well for 4 GPUs and 16 for 8 GPUs, which // follows the formula 128 / nGPUs, clamped to [1, MAX_NBLOCKS]. if (computeCapabilityMajor_ == 10) { - numBlocksAndThreads.first = ::max(1, ::min(128 / ctx->workSize, MAX_NBLOCKS)); + numBlocksAndThreads.first = ::max(1, ::min(128 / ctx->worldSize, MAX_NBLOCKS)); } } if (numBlocksAndThreads.first > MAX_NBLOCKS) { WARN("Number of blocks exceeds maximum supported value of %d", MAX_NBLOCKS); return CommResult::CommInvalidArgument; } - cudaError_t error = - allreduce(nullptr, nullptr, nullptr, this->memoryChannelsDeviceHandle_.get(), nullptr, nvlsChannels, - nvlsOutChannels, channelInOffset, channelOutOffset, 0, ctx->rank, ctx->nRanksPerNode, ctx->workSize, - inputSize, stream, nullptr, 0, 0, numBlocksAndThreads.first, numBlocksAndThreads.second); + cudaError_t error = allreduce(nullptr, nullptr, nullptr, this->memoryChannelsDeviceHandle_.get(), nullptr, + nvlsChannels, nvlsOutChannels, channelInOffset, channelOutOffset, 0, ctx->rank, + ctx->nRanksPerIpcDomain, ctx->worldSize, inputSize, stream, nullptr, 0, 0, + numBlocksAndThreads.first, numBlocksAndThreads.second); if (error != cudaSuccess) { + if (error == cudaErrorNotSupported) { + WARN("AllreduceNvls does not support the requested data type."); + return CommResult::CommInvalidArgument; + } WARN("AllreduceNvls failed with error: %s", cudaGetErrorString(error)); return CommResult::CommUnhandledCudaError; } @@ -179,8 +185,8 @@ std::shared_ptr AllreduceNvls::initAllreduceContext(std::shared_ptr(); ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); - ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); + ctx->worldSize = comm->bootstrap()->getNranks(); + ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); size_t sendBytes, recvBytes; CUdeviceptr sendBasePtr, recvBasePtr; @@ -188,11 +194,12 @@ std::shared_ptr AllreduceNvls::initAllreduceContext(std::shared_ptrswitchChannels = setupNvlsChannels(this->nvlsConnections_, (void*)sendBasePtr, sendBytes, nSwitchChannels_); + ctx->switchChannels = + setupNvlsChannels(comm, this->nvlsConnections_, (void*)sendBasePtr, sendBytes, nSwitchChannels_); if (input != output) { auto nvlsOutConnections = this->nvlsOutConnections_; std::vector outChannels = - setupNvlsChannels(this->nvlsOutConnections_, (void*)recvBasePtr, recvBytes, nSwitchChannels_); + setupNvlsChannels(comm, this->nvlsOutConnections_, (void*)recvBasePtr, recvBytes, nSwitchChannels_); ctx->switchChannels.insert(ctx->switchChannels.end(), outChannels.begin(), outChannels.end()); } diff --git a/src/ext/collectives/allreduce/allreduce_packet.cu b/src/ext/collectives/allreduce/allreduce_packet.cu index 3c75a746d..8591c9834 100644 --- a/src/ext/collectives/allreduce/allreduce_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_packet.cu @@ -15,7 +15,7 @@ namespace collective { template __global__ void __launch_bounds__(1024, 1) allreducePacket(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle* memoryChannels, - size_t channelDataOffset, size_t scratchBufferSize, int rank, int nRanksPerNode, int worldSize, + size_t channelDataOffset, size_t scratchBufferSize, int rank, int nRanksPerIpcDomain, int worldSize, size_t nelems, void* flags, uint32_t flagBufferSize, uint32_t numScratchBuff #if defined(ENABLE_NPKIT) , @@ -23,9 +23,6 @@ __global__ void __launch_bounds__(1024, 1) #else ) { #endif - // This version of allreduce only works for single nodes - if (worldSize != nRanksPerNode) return; - #if defined(ENABLE_NPKIT) extern __shared__ int4 NpkitSharedMem[]; NpKitEvent* event_buffer = (NpKitEvent*)((char*)NpkitSharedMem); @@ -56,7 +53,7 @@ __global__ void __launch_bounds__(1024, 1) else nelems = nelems / (sizeof(int) / sizeof(T)); - const int nPeers = nRanksPerNode - 1; + const int nPeers = nRanksPerIpcDomain - 1; const size_t nPkts = nelems / 2; uint32_t flag = ((uint32_t*)flags)[blockIdx.x]; @@ -81,10 +78,11 @@ __global__ void __launch_bounds__(1024, 1) uint2* dst = (uint2*)((char*)resultBuff + rank * nelemsPerRank * sizeof(int)); // Put channels into shared memory, read channel info from global memory is unexpectable slow. - __shared__ mscclpp::DeviceHandle channels[MAX_NRANKS_PER_NODE - 1]; + __shared__ mscclpp::DeviceHandle channels[MAX_IPC_DOMAIN_NRANKS - 1]; const int lid = tid % WARP_SIZE; - if (lid < nPeers) { - channels[lid] = memoryChannels[lid]; + // Peer count may exceed WARP_SIZE on MNNVL. + for (int i = lid; i < nPeers; i += WARP_SIZE) { + channels[i] = memoryChannels[i]; } __syncwarp(); // step 1: write to scratch buffer @@ -156,31 +154,32 @@ template struct PacketAdapter { static cudaError_t call(const void* buff, void* scratch, void* resultBuff, void* memoryChannels, void*, DeviceHandle*, DeviceHandle*, size_t channelInOffset, size_t, - size_t scratchBufferSize, int rank, int nRanksPerNode, int worldSize, size_t inputSize, + size_t scratchBufferSize, int rank, int nRanksPerIpcDomain, int worldSize, size_t inputSize, cudaStream_t stream, void* flags, uint32_t flagBufferSize, uint32_t numScratchBuff, int nBlocks = 0, int nThreadsPerBlock = 0) { using ChannelType = DeviceHandle; const size_t nelems = inputSize / sizeof(T); - // Optimize the number of blocks to be multiple of (worldSize - 1) - nBlocks = nBlocks / (worldSize - 1) * (worldSize - 1); + // Optimize the number of blocks to be multiple of the IPC-domain peer count. + const int nPeers = nRanksPerIpcDomain - 1; + nBlocks = nBlocks / nPeers * nPeers; #if defined(ENABLE_NPKIT) size_t sharedMemSize = sizeof(NpKitEvent) * NPKIT_SHM_NUM_EVENTS; allreducePacket<<>>( (T*)buff, (T*)scratch, (T*)resultBuff, (ChannelType*)memoryChannels, channelInOffset, scratchBufferSize, rank, - nRanksPerNode, worldSize, nelems, flags, flagBufferSize, numScratchBuff, NpKit::GetGpuEventCollectContexts(), - NpKit::GetCpuTimestamp()); + nRanksPerIpcDomain, worldSize, nelems, flags, flagBufferSize, numScratchBuff, + NpKit::GetGpuEventCollectContexts(), NpKit::GetCpuTimestamp()); #else allreducePacket<<>>( (T*)buff, (T*)scratch, (T*)resultBuff, (ChannelType*)memoryChannels, channelInOffset, scratchBufferSize, rank, - nRanksPerNode, worldSize, nelems, flags, flagBufferSize, numScratchBuff); + nRanksPerIpcDomain, worldSize, nelems, flags, flagBufferSize, numScratchBuff); #endif return cudaGetLastError(); } }; -inline std::pair getDefaultBlockNumAndThreadNum(size_t inputSize, int nRanksPerNode, int worldSize, +inline std::pair getDefaultBlockNumAndThreadNum(size_t inputSize, int nRanksPerIpcDomain, int worldSize, [[maybe_unused]] DataType dtype) { - int nBlocks = (nRanksPerNode - 1) * 4; + int nBlocks = (nRanksPerIpcDomain - 1) * 4; int nThreadsPerBlock = 1024; if (inputSize >= 32768) { nBlocks = (worldSize - 1) * 8; @@ -231,9 +230,19 @@ CommResult AllreducePacket::allreduceKernelFunc(const std::shared_ptr ctx_ const std::unordered_map&, DataType accumDtype) { auto ctx = std::static_pointer_cast(ctx_void); + if (ctx->worldSize != ctx->nRanksPerIpcDomain) { + WARN(ALGO, "AllreducePacket requires worldSize to match nRanksPerIpcDomain, got worldSize=", ctx->worldSize, + ", nRanksPerIpcDomain=", ctx->nRanksPerIpcDomain); + return CommResult::CommInvalidArgument; + } std::pair blockAndThreadNum = {nBlocks, nThreadsPerBlock}; if (blockAndThreadNum.first == 0 || blockAndThreadNum.second == 0) { - blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize, ctx->workSize, ctx->nRanksPerNode, dtype); + blockAndThreadNum = getDefaultBlockNumAndThreadNum(inputSize, ctx->nRanksPerIpcDomain, ctx->worldSize, dtype); + } else { + const int nPeers = ctx->nRanksPerIpcDomain - 1; + if (blockAndThreadNum.first < nPeers) { + return CommResult::CommInvalidArgument; + } } if (blockAndThreadNum.first > maxBlockNum_) { WARN(ALGO, "Requested block number ", blockAndThreadNum.first, " exceeds the maximum supported block number ", @@ -254,8 +263,8 @@ CommResult AllreducePacket::allreduceKernelFunc(const std::shared_ptr ctx_ } cudaError_t error = allreduce(input, this->scratchBuffer_, output, ctx->memoryChannelDeviceHandles.get(), nullptr, nullptr, nullptr, - channelInOffset, 0, this->scratchBufferSize_, ctx->rank, ctx->nRanksPerNode, ctx->workSize, inputSize, - stream, (void*)flagBuffer_, (uint32_t)flagBufferSize_, this->nSegmentsForScratchBuffer_, + channelInOffset, 0, this->scratchBufferSize_, ctx->rank, ctx->nRanksPerIpcDomain, ctx->worldSize, + inputSize, stream, (void*)flagBuffer_, (uint32_t)flagBufferSize_, this->nSegmentsForScratchBuffer_, blockAndThreadNum.first, blockAndThreadNum.second); if (error != cudaSuccess) { WARN(ALGO, "AllreducePacket failed with error: ", cudaGetErrorString(error)); @@ -269,8 +278,8 @@ std::shared_ptr AllreducePacket::initAllreduceContext(std::shared_ptr(); const int nChannelsPerConnection = maxBlockNum_; ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); - ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); + ctx->worldSize = comm->bootstrap()->getNranks(); + ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); ctx->memorySemaphores = this->memorySemaphores_; ctx->registeredMemories = this->registeredMemories_; ctx->registeredMemories.pop_back(); // remove the local memory from previous context diff --git a/src/ext/collectives/allreduce/allreduce_rsag.cu b/src/ext/collectives/allreduce/allreduce_rsag.cu index db471b932..1f5d3e5d8 100644 --- a/src/ext/collectives/allreduce/allreduce_rsag.cu +++ b/src/ext/collectives/allreduce/allreduce_rsag.cu @@ -31,18 +31,18 @@ namespace collective { template __global__ void __launch_bounds__(1024, 1) allreduceRsAg(T* buff, T* scratch, T* resultBuff, DeviceHandle* memoryChannels, - DeviceHandle* switchChannels, void* remoteMemories, int rank, int nRanksPerNode, + DeviceHandle* switchChannels, void* remoteMemories, int rank, int nRanksPerIpcDomain, int worldSize, size_t nelems) { int blockId = blockIdx.x; - uint32_t nPeers = nRanksPerNode - 1; + uint32_t nPeers = nRanksPerIpcDomain - 1; assert((uintptr_t)buff % sizeof(int4) == 0); assert((uintptr_t)resultBuff % sizeof(int4) == 0); constexpr uint32_t nelemsPerInt4 = sizeof(int4) / sizeof(T); - uint32_t alignedNelems = ((nelems + nRanksPerNode - 1) / nRanksPerNode + nelemsPerInt4 - 1) / nelemsPerInt4 * - nelemsPerInt4 * nRanksPerNode; - uint32_t nelemsPerRank = alignedNelems / nRanksPerNode; + uint32_t alignedNelems = ((nelems + nRanksPerIpcDomain - 1) / nRanksPerIpcDomain + nelemsPerInt4 - 1) / + nelemsPerInt4 * nelemsPerInt4 * nRanksPerIpcDomain; + uint32_t nelemsPerRank = alignedNelems / nRanksPerIpcDomain; uint32_t nInt4PerRank = nelemsPerRank / nelemsPerInt4; uint32_t lastInt4Index = nelems / nelemsPerInt4; uint32_t remainder = nelems % nelemsPerInt4; @@ -59,7 +59,7 @@ __global__ void __launch_bounds__(1024, 1) nInt4PerBlock += remainderForBlock; } if (nInt4PerBlock == 0) return; - uint32_t nInt4ForCopy = nInt4PerBlock * nRanksPerNode; + uint32_t nInt4ForCopy = nInt4PerBlock * nRanksPerIpcDomain; for (uint32_t idx = threadIdx.x; idx < nInt4ForCopy; idx += blockDim.x) { int rankIdx = idx / nInt4PerBlock; @@ -84,13 +84,13 @@ __global__ void __launch_bounds__(1024, 1) if (offset > lastInt4Index) continue; int4 tmp = scratch4[offset]; for (uint32_t i = 0; i < nPeers; i++) { - int rankIdx = (rank + i + 1) % nRanksPerNode; + int rankIdx = (rank + i + 1) % nRanksPerIpcDomain; int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1; int4 data = mscclpp::read(((void**)remoteMemories)[peerIdx], offset); tmp = calVector(data, tmp); } for (uint32_t i = 0; i < nPeers; i++) { - int rankIdx = (rank + i + 1) % nRanksPerNode; + int rankIdx = (rank + i + 1) % nRanksPerIpcDomain; int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1; mscclpp::write(((void**)remoteMemories)[peerIdx], offset, tmp); } @@ -127,8 +127,8 @@ template struct AllreduceRsAgAdapter { static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* remoteMemories, DeviceHandle* switchChannel, DeviceHandle*, size_t, size_t, - size_t, int rank, int nRanksPerNode, int worldSize, size_t inputSize, cudaStream_t stream, - void*, uint32_t, uint32_t, int nBlocks, int nThreadsPerBlock) { + size_t, int rank, int nRanksPerIpcDomain, int worldSize, size_t inputSize, + cudaStream_t stream, void*, uint32_t, uint32_t, int nBlocks, int nThreadsPerBlock) { using ChannelType = DeviceHandle; size_t nelems = inputSize / sizeof(T); if (nBlocks == 0 || nThreadsPerBlock == 0) { @@ -137,7 +137,7 @@ struct AllreduceRsAgAdapter { } allreduceRsAg<<>>( (T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels, switchChannel, remoteMemories, rank, - nRanksPerNode, worldSize, nelems); + nRanksPerIpcDomain, worldSize, nelems); return cudaGetLastError(); } }; @@ -179,9 +179,13 @@ CommResult AllreduceRsAg::allreduceKernelFunc(const std::shared_ptr ctx, c return CommResult::CommInvalidArgument; } std::pair numBlocksAndThreads = {nBlocks, nThreadsPerBlock}; + if (numBlocksAndThreads.first > nChannelsPerConnection_) { + WARN(ALGO, "Block number ", numBlocksAndThreads.first, " exceeds the maximum limit ", nChannelsPerConnection_); + return CommResult::CommInvalidArgument; + } cudaError_t error = allreduce(input, this->scratchBuffer_, output, this->baseMemoryChannelHandles_.get(), this->remoteMemoryHandles_.get(), nullptr, nullptr, 0, 0, 0, algoCtx->rank, - algoCtx->nRanksPerNode, algoCtx->workSize, inputSize, stream, nullptr, 0, 0, + algoCtx->nRanksPerIpcDomain, algoCtx->worldSize, inputSize, stream, nullptr, 0, 0, numBlocksAndThreads.first, numBlocksAndThreads.second); if (error != cudaSuccess) { WARN(ALGO, "Allreduce kernel launch failed with error: ", cudaGetErrorString(error)); @@ -198,8 +202,8 @@ std::shared_ptr AllreduceRsAg::initAllreduceContext(std::shared_ptr(); ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); - ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); + ctx->worldSize = comm->bootstrap()->getNranks(); + ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); ctx->memorySemaphores = this->scratchSemaphores_; ctx->registeredMemories = this->remoteScratchMemories_; diff --git a/src/ext/collectives/allreduce/allreduce_rsag_pipeline.cu b/src/ext/collectives/allreduce/allreduce_rsag_pipeline.cu index eabe3dc53..4b2434444 100644 --- a/src/ext/collectives/allreduce/allreduce_rsag_pipeline.cu +++ b/src/ext/collectives/allreduce/allreduce_rsag_pipeline.cu @@ -86,15 +86,15 @@ template __global__ void __launch_bounds__(1024, 1) allreduceRsAgPipeline(T* buff, T* scratch, T* resultBuff, DeviceHandle* memoryChannels, DeviceHandle* switchChannels, void* remoteMemories, int rank, - int nRanksPerNode, int worldSize, size_t nelems, size_t scratchSize, uint32_t nblocksForPut, - uint32_t nblocksForReduce, uint32_t nblocksForRecv) { + int nRanksPerIpcDomain, int worldSize, size_t nelems, size_t scratchSize, + uint32_t nblocksForPut, uint32_t nblocksForReduce, uint32_t nblocksForRecv) { uint32_t bid = blockIdx.x; constexpr uint32_t nStepsPerIter = 4; uint32_t nInt4 = (nelems * sizeof(T) + sizeof(int4) - 1) / sizeof(int4); uint32_t nInt4PerIter = nblocksForReduce * blockDim.x * nStepsPerIter; const uint32_t chunkSize = nInt4PerIter * worldSize; uint32_t nIters = (nInt4 + chunkSize - 1) / chunkSize; - uint32_t nPeers = nRanksPerNode - 1; + uint32_t nPeers = nRanksPerIpcDomain - 1; int4* scratch4 = reinterpret_cast((char*)scratch); const uint32_t scratchIterStride = 2 * chunkSize; // one for AS, one for AG const uint32_t pipelineDepth = scratchSize / sizeof(int4) / scratchIterStride; @@ -111,7 +111,7 @@ __global__ void __launch_bounds__(1024, 1) __syncthreads(); uint32_t threadIdInPut = bid * blockDim.x + threadIdx.x; for (uint32_t peer = 0; peer < nPeers; peer++) { - int remoteRankId = (rank + peer + 1) % nRanksPerNode; + int remoteRankId = (rank + peer + 1) % nRanksPerIpcDomain; int peerId = remoteRankId < rank ? remoteRankId : remoteRankId - 1; // Read chunk[remoteRankId] from local buff, write to peer's scratch[rank] (sender's slot) uint32_t srcOffset = iter * chunkSize + remoteRankId * nInt4PerIter; @@ -164,7 +164,7 @@ __global__ void __launch_bounds__(1024, 1) int4 tmp = loadVec(buff, myChunkOffset, nelems); // Add data from each peer's slot in scratch (peer sent their chunk[rank] to our scratch[peer]) for (uint32_t peer = 0; peer < nPeers; peer++) { - int remoteRankId = (rank + peer + 1) % nRanksPerNode; + int remoteRankId = (rank + peer + 1) % nRanksPerIpcDomain; uint32_t peerSlotOffset = baseOffset + remoteRankId * nInt4PerIter + threadIdInPut + putStep * blockDim.x * nblocksForPut; int4 data = scratch4[peerSlotOffset]; @@ -175,7 +175,7 @@ __global__ void __launch_bounds__(1024, 1) uint32_t dstOffset = baseOffset + chunkSize + rank * nInt4PerIter + threadIdInPut + putStep * blockDim.x * nblocksForPut; for (uint32_t i = 0; i < nPeers; i++) { - int peerIdx = (rank + i + 1) % nRanksPerNode; + int peerIdx = (rank + i + 1) % nRanksPerIpcDomain; int index = peerIdx < rank ? peerIdx : peerIdx - 1; mscclpp::write(((void**)remoteMemories)[index], dstOffset, tmp); } @@ -203,7 +203,7 @@ __global__ void __launch_bounds__(1024, 1) __syncthreads(); // Copy other ranks' reduced chunks from scratch to result for (uint32_t peer = 0; peer < nPeers; peer++) { - int remoteRankId = (rank + peer + 1) % nRanksPerNode; + int remoteRankId = (rank + peer + 1) % nRanksPerIpcDomain; for (uint32_t step = 0; step < nStepsPerIter * REDUCE_COPY_RATIO; step++) { uint32_t offset = baseOffset + chunkSize + remoteRankId * nInt4PerIter + threadIdInRecv + step * blockDim.x * nblocksForRecv; @@ -224,7 +224,7 @@ template struct AllreduceRsAgPipelineAdapter { static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* remoteMemories, DeviceHandle* switchChannel, DeviceHandle*, size_t, size_t, - size_t scratchSize, int rank, int nRanksPerNode, int worldSize, size_t inputSize, + size_t scratchSize, int rank, int nRanksPerIpcDomain, int worldSize, size_t inputSize, cudaStream_t stream, void*, uint32_t, uint32_t, int nBlocks, int nThreadsPerBlock) { using ChannelType = DeviceHandle; size_t nelems = inputSize / sizeof(T); @@ -248,7 +248,7 @@ struct AllreduceRsAgPipelineAdapter { } allreduceRsAgPipeline<<>>( (T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels, switchChannel, remoteMemories, rank, - nRanksPerNode, worldSize, nelems, scratchSize, nblocksForPut, nblocksForReduce, nblocksForRecv); + nRanksPerIpcDomain, worldSize, nelems, scratchSize, nblocksForPut, nblocksForReduce, nblocksForRecv); return cudaGetLastError(); } }; @@ -288,8 +288,8 @@ CommResult AllreduceRsAgPipeline::allreduceKernelFunc( std::pair numBlocksAndThreads = {nBlocks, nThreadsPerBlock}; cudaError_t error = allreduce(input, this->scratchBuffer_, output, this->baseMemoryChannelHandles_.get(), this->remoteMemoryHandles_.get(), nullptr, nullptr, 0, 0, this->scratchBufferSize_, - algoCtx->rank, algoCtx->nRanksPerNode, algoCtx->workSize, inputSize, stream, nullptr, 0, - 0, numBlocksAndThreads.first, numBlocksAndThreads.second); + algoCtx->rank, algoCtx->nRanksPerIpcDomain, algoCtx->worldSize, inputSize, stream, + nullptr, 0, 0, numBlocksAndThreads.first, numBlocksAndThreads.second); if (error != cudaSuccess) { WARN(ALGO, "Allreduce kernel launch failed with error: ", cudaGetErrorString(error)); return CommResult::CommUnhandledCudaError; @@ -305,8 +305,8 @@ std::shared_ptr AllreduceRsAgPipeline::initAllreduceContext(std::shared_pt void*, size_t, DataType) { auto ctx = std::make_shared(); ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); - ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); + ctx->worldSize = comm->bootstrap()->getNranks(); + ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); ctx->memorySemaphores = this->scratchSemaphores_; ctx->registeredMemories = this->remoteScratchMemories_; diff --git a/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu b/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu index f95ba7e33..e7ed0cabe 100644 --- a/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu +++ b/src/ext/collectives/allreduce/allreduce_rsag_zero_copy.cu @@ -35,10 +35,10 @@ __device__ mscclpp::DeviceSyncer globalSyncer; // // This approach requires registering both input and output buffers as remote // memories (2 * nPeers handles), but avoids scratch buffer allocation and -// the extra copy steps of the standard RSAG. The NRanksPerNode template +// the extra copy steps of the standard RSAG. The NRanks template // parameter enables compile-time unrolling of peer loops (supports 4 or 8). -template +template __global__ void __launch_bounds__(1024, 1) allreduceRsAgZeroCopy(T* buff, T* scratch, T* resultBuff, DeviceHandle* memoryChannels, DeviceHandle* switchChannels, void* remoteMemories, int rank, int worldSize, @@ -48,12 +48,12 @@ __global__ void __launch_bounds__(1024, 1) assert((uintptr_t)buff % sizeof(int4) == 0); assert((uintptr_t)resultBuff % sizeof(int4) == 0); - constexpr int NPeers = NRanksPerNode - 1; + constexpr int NPeers = NRanks - 1; constexpr uint32_t nelemsPerInt4 = sizeof(int4) / sizeof(T); - const uint32_t outputRemoteBufferOffset = NRanksPerNode - 1; - uint32_t alignedNelems = ((nelems + NRanksPerNode - 1) / NRanksPerNode + nelemsPerInt4 - 1) / nelemsPerInt4 * - nelemsPerInt4 * NRanksPerNode; - uint32_t nelemsPerRank = alignedNelems / NRanksPerNode; + constexpr uint32_t outputRemoteBufferOffset = NPeers; + uint32_t alignedNelems = + ((nelems + NRanks - 1) / NRanks + nelemsPerInt4 - 1) / nelemsPerInt4 * nelemsPerInt4 * NRanks; + uint32_t nelemsPerRank = alignedNelems / NRanks; uint32_t nInt4PerRank = nelemsPerRank / nelemsPerInt4; uint32_t nInt4Total = (nelems + nelemsPerInt4 - 1) / nelemsPerInt4; @@ -69,7 +69,7 @@ __global__ void __launch_bounds__(1024, 1) } if (nInt4PerBlock == 0) return; - if (threadIdx.x < NPeers) { + if ((int)threadIdx.x < NPeers) { memoryChannelsLocal[threadIdx.x].relaxedSignal(); memoryChannelsLocal[threadIdx.x].relaxedWait(); } @@ -86,18 +86,19 @@ __global__ void __launch_bounds__(1024, 1) int4 tmp_raw = buff4[offset]; #pragma unroll for (int i = 0; i < NPeers; i++) { - int rankIdx = (rank + i + 1) % NRanksPerNode; + int rankIdx = (rank + i + 1) % NRanks; int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1; data[i] = mscclpp::read(((void**)remoteMemories)[peerIdx], offset); } AccumVec acc = mscclpp::upcastVector(tmp_raw); +#pragma unroll for (int i = 0; i < NPeers; i++) { acc = mscclpp::calVectorAccum(acc, data[i]); } int4 tmp = mscclpp::downcastVector(acc); #pragma unroll for (int i = 0; i < NPeers; i++) { - int rankIdx = (rank + i + 1) % NRanksPerNode; + int rankIdx = (rank + i + 1) % NRanks; int peerIdx = rankIdx < rank ? rankIdx : rankIdx - 1; mscclpp::write(((void**)remoteMemories)[outputRemoteBufferOffset + peerIdx], offset, tmp); } @@ -105,7 +106,7 @@ __global__ void __launch_bounds__(1024, 1) } // Use device barrier gives better performance here. globalSyncer.sync(gridDim.x); - if (blockIdx.x == 0 && threadIdx.x < NPeers) { + if (blockIdx.x == 0 && (int)threadIdx.x < NPeers) { memoryChannelsLocal[threadIdx.x].signal(); memoryChannelsLocal[threadIdx.x].wait(); } @@ -115,8 +116,8 @@ template struct AllreduceRsAgZeroCopyAdapter { static cudaError_t call(const void* input, void* scratch, void* output, void* memoryChannels, void* remoteMemories, DeviceHandle* switchChannel, DeviceHandle*, size_t, size_t, - size_t, int rank, int nRanksPerNode, int worldSize, size_t inputSize, cudaStream_t stream, - void*, uint32_t, uint32_t, int nBlocks, int nThreadsPerBlock) { + size_t, int rank, int nRanksPerIpcDomain, int worldSize, size_t inputSize, + cudaStream_t stream, void*, uint32_t, uint32_t, int nBlocks, int nThreadsPerBlock) { using ChannelType = DeviceHandle; size_t nelems = inputSize / sizeof(T); if (nBlocks == 0 || nThreadsPerBlock == 0) { @@ -126,16 +127,17 @@ struct AllreduceRsAgZeroCopyAdapter { nBlocks = 128; } } - if (nRanksPerNode == 4) { + if (nRanksPerIpcDomain == 4) { allreduceRsAgZeroCopy<4, OpType, T, AccumT> <<>>((T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels, switchChannel, remoteMemories, rank, worldSize, nelems); - } else if (nRanksPerNode == 8) { + } else if (nRanksPerIpcDomain == 8) { allreduceRsAgZeroCopy<8, OpType, T, AccumT> <<>>((T*)input, (T*)scratch, (T*)output, (ChannelType*)memoryChannels, switchChannel, remoteMemories, rank, worldSize, nelems); } else { - THROW(ALGO, Error, ErrorCode::InvalidUsage, "Unsupported number of ranks per node: ", nRanksPerNode); + WARN(ALGO, "AllreduceRsAgZeroCopy only supports nRanksPerIpcDomain of 4 or 8, got: ", nRanksPerIpcDomain); + return cudaErrorInvalidValue; } return cudaGetLastError(); } @@ -164,11 +166,19 @@ CommResult AllreduceRsAgZeroCopy::allreduceKernelFunc(const std::shared_ptr numBlocksAndThreads = {nBlocks, nThreadsPerBlock}; + if (numBlocksAndThreads.first > nChannelsPerConnection_) { + WARN(ALGO, "Block number ", numBlocksAndThreads.first, " exceeds the maximum limit ", nChannelsPerConnection_); + return CommResult::CommInvalidArgument; + } cudaError_t error = allreduce(input, nullptr, output, this->baseMemoryChannelHandles_.get(), algoCtx->remoteMemoryHandles.get(), - nullptr, nullptr, 0, 0, 0, algoCtx->rank, algoCtx->nRanksPerNode, algoCtx->workSize, inputSize, stream, - nullptr, 0, 0, numBlocksAndThreads.first, numBlocksAndThreads.second); + nullptr, nullptr, 0, 0, 0, algoCtx->rank, algoCtx->nRanksPerIpcDomain, algoCtx->worldSize, inputSize, + stream, nullptr, 0, 0, numBlocksAndThreads.first, numBlocksAndThreads.second); if (error != cudaSuccess) { + if (error == cudaErrorInvalidValue) { + WARN(ALGO, "AllreduceRsAgZeroCopy received invalid launch arguments: ", cudaGetErrorString(error)); + return CommResult::CommInvalidArgument; + } WARN(ALGO, "Allreduce kernel launch failed with error: ", cudaGetErrorString(error)); return CommResult::CommUnhandledCudaError; } @@ -193,16 +203,14 @@ std::shared_ptr AllreduceRsAgZeroCopy::initAllreduceContext(std::shared_pt void* output, size_t size, DataType) { auto ctx = std::make_shared(); ctx->rank = comm->bootstrap()->getRank(); - ctx->workSize = comm->bootstrap()->getNranks(); - ctx->nRanksPerNode = comm->bootstrap()->getNranksPerNode(); + ctx->worldSize = comm->bootstrap()->getNranks(); + ctx->nRanksPerIpcDomain = comm->bootstrap()->getNranksPerIpcDomain(); ctx->memorySemaphores = this->semaphores_; // register input and output memories RegisteredMemory inputMemory = comm->registerMemory((void*)input, size, Transport::CudaIpc); RegisteredMemory outputMemory = comm->registerMemory(output, size, Transport::CudaIpc); - this->inputMemories_.push_back(inputMemory); - this->outputMemories_.push_back(outputMemory); auto remoteInputMemories = setupRemoteMemories(comm, ctx->rank, inputMemory); auto remoteOutputMemories = setupRemoteMemories(comm, ctx->rank, outputMemory); diff --git a/src/ext/collectives/collective_utils.cc b/src/ext/collectives/collective_utils.cc index 016c4a5cc..5d038afae 100644 --- a/src/ext/collectives/collective_utils.cc +++ b/src/ext/collectives/collective_utils.cc @@ -98,7 +98,8 @@ std::vector> setupNvlsConnections(std:: return nvlsConnections; } -std::vector setupNvlsChannels(std::vector> conns, +std::vector setupNvlsChannels(std::shared_ptr comm, + std::vector> conns, void* buffer, size_t bufferSize, int nSwitchChannels) { std::vector channels; @@ -107,6 +108,8 @@ std::vector setupNvlsChannels(std::vectorbindAllocatedMemory((CUdeviceptr)buffer, bufferSize); channels.push_back(switchChannel); } + // Synchronize to make sure all ranks have their NVLS channels set up before any rank starts using them. + comm->bootstrap()->barrier(); return channels; } @@ -153,4 +156,4 @@ std::shared_ptr> setupBaseMemo } // namespace collective -} // namespace mscclpp \ No newline at end of file +} // namespace mscclpp diff --git a/src/ext/collectives/include/allreduce/allreduce_allpair_packet.hpp b/src/ext/collectives/include/allreduce/allreduce_allpair_packet.hpp index 64f5ec544..bba82ee50 100644 --- a/src/ext/collectives/include/allreduce/allreduce_allpair_packet.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_allpair_packet.hpp @@ -4,6 +4,7 @@ #include #include "allreduce/common.hpp" +#include "collective_utils.hpp" namespace mscclpp { namespace collective { diff --git a/src/ext/collectives/include/allreduce/allreduce_fullmesh.hpp b/src/ext/collectives/include/allreduce/allreduce_fullmesh.hpp index a54352b3f..e0c63a3d3 100644 --- a/src/ext/collectives/include/allreduce/allreduce_fullmesh.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_fullmesh.hpp @@ -30,8 +30,6 @@ class AllreduceFullmesh : public mscclpp::AlgorithmBuilder { std::vector> inputScratchSemaphores_; std::vector remoteScratchMemories_; RegisteredMemory localScratchMemory_; - std::unordered_map, std::shared_ptr>>> - memoryChannelsMap_; bool symmetricMemory_ = false; }; } // namespace collective diff --git a/src/ext/collectives/include/allreduce/allreduce_nvls_zero_copy.hpp b/src/ext/collectives/include/allreduce/allreduce_nvls_zero_copy.hpp index d53ea180b..c40bd2cda 100644 --- a/src/ext/collectives/include/allreduce/allreduce_nvls_zero_copy.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_nvls_zero_copy.hpp @@ -41,4 +41,4 @@ class AllreduceNvls : public AlgorithmBuilder { } // namespace collective } // namespace mscclpp -#endif // MSCCLPP_ALLREDUCE_NVLS_ZERO_COPY_HPP_ \ No newline at end of file +#endif // MSCCLPP_ALLREDUCE_NVLS_ZERO_COPY_HPP_ diff --git a/src/ext/collectives/include/allreduce/allreduce_packet.hpp b/src/ext/collectives/include/allreduce/allreduce_packet.hpp index de7ca4719..771126c96 100644 --- a/src/ext/collectives/include/allreduce/allreduce_packet.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_packet.hpp @@ -29,7 +29,7 @@ class AllreducePacket : public AlgorithmBuilder { void* scratchBuffer_; size_t scratchBufferSize_; const int nSegmentsForScratchBuffer_ = 2; - const int maxBlockNum_ = 56; + const int maxBlockNum_ = 112; std::vector conns_; uintptr_t flagBuffer_; size_t flagBufferSize_; @@ -37,4 +37,4 @@ class AllreducePacket : public AlgorithmBuilder { std::vector registeredMemories_; }; } // namespace collective -} // namespace mscclpp \ No newline at end of file +} // namespace mscclpp diff --git a/src/ext/collectives/include/allreduce/allreduce_rsag_zero_copy.hpp b/src/ext/collectives/include/allreduce/allreduce_rsag_zero_copy.hpp index 05bf2ef3c..528d9708b 100644 --- a/src/ext/collectives/include/allreduce/allreduce_rsag_zero_copy.hpp +++ b/src/ext/collectives/include/allreduce/allreduce_rsag_zero_copy.hpp @@ -27,8 +27,6 @@ class AllreduceRsAgZeroCopy : public mscclpp::AlgorithmBuilder { int nChannelsPerConnection_; std::vector conns_; std::vector> semaphores_; - std::vector inputMemories_; - std::vector outputMemories_; std::vector baseChannels_; std::shared_ptr> baseMemoryChannelHandles_; diff --git a/src/ext/collectives/include/allreduce/common.hpp b/src/ext/collectives/include/allreduce/common.hpp index 93b18e262..5d593449c 100644 --- a/src/ext/collectives/include/allreduce/common.hpp +++ b/src/ext/collectives/include/allreduce/common.hpp @@ -36,36 +36,46 @@ MSCCLPP_DEVICE_INLINE constexpr std::size_t calcVectorSize() { } } -template +template MSCCLPP_DEVICE_INLINE void handleMultiLoadReduceStore(T* src, T* dst, size_t srcOffset, size_t dstOffset, size_t size, int tid, int nThreads) { - // nvls can only handle 4 bytes alignment - MSCCLPP_ASSERT_DEVICE(size % 4 == 0, "size must be 4 bytes aligned"); - constexpr size_t nElem = calcVectorSize(); - // For integer types, use 1-element vectors since multimem doesn't support vectorized integer operations - constexpr size_t vecSize = (std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v) - ? 1 - : nElem; - using vectorType = mscclpp::VectorType; - const size_t nVec = size / sizeof(vectorType); - const size_t srcOffset4 = srcOffset / sizeof(vectorType); - const size_t dstOffset4 = dstOffset / sizeof(vectorType); - vectorType* src4 = (vectorType*)src; - vectorType* dst4 = (vectorType*)dst; - for (size_t idx = tid; idx < nVec; idx += nThreads) { - auto val = mscclpp::SwitchChannelDeviceHandle::multimemLoadReduce(src4 + srcOffset4 + idx); - mscclpp::SwitchChannelDeviceHandle::multimemStore(val, dst4 + dstOffset4 + idx); - } - // handle rest of data - size_t processed = nVec * sizeof(vectorType); - constexpr size_t nRestElem = 4 / sizeof(T); - using restVectorType = mscclpp::VectorType; - const size_t startIdx = (srcOffset + processed) / sizeof(restVectorType); - const size_t endIdx = (srcOffset + size) / sizeof(restVectorType); - for (size_t idx = tid + startIdx; idx < endIdx; idx += nThreads) { - auto val = mscclpp::SwitchChannelDeviceHandle::multimemLoadReduce((restVectorType*)src + idx); - mscclpp::SwitchChannelDeviceHandle::multimemStore(val, (restVectorType*)dst + idx); +#if defined(__FP8_TYPES_EXIST__) && \ + (!(defined(__CUDA_ARCH_SPECIFIC__) || defined(__CUDA_ARCH_FAMILY_SPECIFIC__)) || (__CUDA_ARCH__ < 1000)) + if constexpr (std::is_same_v || std::is_same_v) { + assert(false && "FP8 NVLS multimem requires sm_100a or newer"); + return; + } else +#endif + { + // nvls can only handle 4 bytes alignment + MSCCLPP_ASSERT_DEVICE(size % 4 == 0, "size must be 4 bytes aligned"); + constexpr size_t nElem = calcVectorSize(); + // For integer types, use 1-element vectors since multimem doesn't support vectorized integer operations + constexpr size_t vecSize = (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + ? 1 + : nElem; + using vectorType = mscclpp::VectorType; + const size_t nVec = size / sizeof(vectorType); + const size_t srcOffset4 = srcOffset / sizeof(vectorType); + const size_t dstOffset4 = dstOffset / sizeof(vectorType); + vectorType* src4 = (vectorType*)src; + vectorType* dst4 = (vectorType*)dst; + for (size_t idx = tid; idx < nVec; idx += nThreads) { + auto val = mscclpp::SwitchChannelDeviceHandle::multimemLoadReduce(src4 + srcOffset4 + idx); + mscclpp::SwitchChannelDeviceHandle::multimemStore(val, dst4 + dstOffset4 + idx); + } + // handle rest of data + size_t processed = nVec * sizeof(vectorType); + constexpr size_t nRestElem = 4 / sizeof(T); + using restVectorType = mscclpp::VectorType; + const size_t startIdx = (srcOffset + processed) / sizeof(restVectorType); + const size_t endIdx = (srcOffset + size) / sizeof(restVectorType); + for (size_t idx = tid + startIdx; idx < endIdx; idx += nThreads) { + auto val = + mscclpp::SwitchChannelDeviceHandle::multimemLoadReduce((restVectorType*)src + idx); + mscclpp::SwitchChannelDeviceHandle::multimemStore(val, (restVectorType*)dst + idx); + } } } #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 diff --git a/src/ext/collectives/include/collective_utils.hpp b/src/ext/collectives/include/collective_utils.hpp index f705a9d1d..c2bcd87e3 100644 --- a/src/ext/collectives/include/collective_utils.hpp +++ b/src/ext/collectives/include/collective_utils.hpp @@ -26,11 +26,18 @@ namespace mscclpp { namespace collective { constexpr int NUM_NVLS_CONNECTION = 8; -constexpr int NUM_SEMAPHORES = 64; +// Sized to cover MAX_IPC_DOMAIN_NRANKS-scale allreduce algos whose device-side +// semaphore indices grow as O(nRanksPerIpcDomain) (e.g. nvls_block_pipeline uses +// up to ~5 * nRanksPerIpcDomain entries). +constexpr int NUM_SEMAPHORES = 512; -constexpr int MAX_NRANKS_PER_NODE = 8; +// Upper bound on the number of NVLink-reachable ranks that participate in a +// single collective. Sized to cover Multi-Node NVLink (MNNVL) domains up to +// GB200 NVL72 (72 GPUs sharing one NVLink fabric). Drives compile-time sizing +// of shared-memory channel arrays in the allreduce/allgather kernels. +constexpr int MAX_IPC_DOMAIN_NRANKS = 72; -constexpr int SCRATCH_SIZE = 2 * 1024 * 1024 * 70; // double buffer * 35 thread-blocks * 8 ranks * 256KB = 70MB +constexpr int SCRATCH_SIZE = 2 * 1024 * 1024 * 70; // Two 70 MiB buffers for double-buffered packet scratch space. std::vector setupRemoteMemories(std::shared_ptr comm, int rank, RegisteredMemory localMemory); @@ -50,7 +57,8 @@ std::shared_ptr> setupMemoryChannelDeviceHandles( std::vector> setupNvlsConnections(std::shared_ptr comm, size_t size, int numConnections); -std::vector setupNvlsChannels(std::vector> conns, void* buffer, +std::vector setupNvlsChannels(std::shared_ptr comm, + std::vector> conns, void* buffer, size_t bufferSize, int nSwitchChannels); std::shared_ptr> setupNvlsChannelDeviceHandles( @@ -71,8 +79,9 @@ std::shared_ptr> setupBaseMemoryChannelDeviceHan class AlgorithmCtx { public: int rank; - int workSize; + int worldSize; int nRanksPerNode; + int nRanksPerIpcDomain; std::vector registeredMemories; std::vector memoryChannels; @@ -89,4 +98,4 @@ class AlgorithmCtx { } // namespace collective } // namespace mscclpp -#endif // MSCCLPP_EXT_COLLECTIVE_UTILS_HPP_ \ No newline at end of file +#endif // MSCCLPP_EXT_COLLECTIVE_UTILS_HPP_ diff --git a/test/mp_unit/bootstrap_tests.cc b/test/mp_unit/bootstrap_tests.cc index c28087a45..eb6985a8e 100644 --- a/test/mp_unit/bootstrap_tests.cc +++ b/test/mp_unit/bootstrap_tests.cc @@ -127,6 +127,7 @@ class MPIBootstrap : public mscclpp::Bootstrap { MPI_Comm_size(shmcomm, &shmrank); return shmrank; } + int getNranksPerIpcDomain() const override { return getNranksPerNode(); } void allGather(void* sendbuf, int size) override { MPI_Allgather(MPI_IN_PLACE, 0, MPI_BYTE, sendbuf, size, MPI_BYTE, MPI_COMM_WORLD); }