diff --git a/CMakeLists.txt b/CMakeLists.txt index ef8b785a5..841a32276 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,6 +51,7 @@ option(MSCCLPP_BUILD_TESTS "Build tests" OFF) option(MSCCLPP_BUILD_PYTHON_BINDINGS "Build Python bindings" ON) option(MSCCLPP_BUILD_EXT_NCCL "Build NCCL interfaces" ON) option(MSCCLPP_BUILD_EXT_COLLECTIVES "Build collective algorithms" ON) +option(MSCCLPP_BUILD_EXT_TORCHCOMMS "Build TorchComms MSCCL++ backend" OFF) option(MSCCLPP_USE_CUDA "Use NVIDIA/CUDA." OFF) option(MSCCLPP_USE_ROCM "Use AMD/ROCm." OFF) option(MSCCLPP_USE_IB "Use InfiniBand." ON) @@ -272,3 +273,8 @@ endif() if(MSCCLPP_BUILD_PYTHON_BINDINGS) add_subdirectory(python) endif() + +# TorchComms MSCCL++ backend +if(MSCCLPP_BUILD_EXT_TORCHCOMMS) + add_subdirectory(python/mscclpp_torchcomms) +endif() diff --git a/docs/index.rst b/docs/index.rst index 23b444021..9a0267db0 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -10,6 +10,7 @@ You can find the followings from this documentation. - **Overview:** An overview of MSCCL++ and its features. :doc:`🔗 ` - **Quick Start:** A guide to build, install, and run MSCCL++. :doc:`🔗 ` +- **TorchComms:** Using MSCCL++ as a TorchComms backend for PyTorch training. :doc:`🔗 ` - **MSCCL++ DSL:** A guide to get started with the MSCCL++ DSL. :doc:`🔗 ` - **Tutorials:** A step-by-step guide for GPU communication using MSCCL++. :doc:`🔗 ` - **Programming Guide:** Advanced topics and best practices for using MSCCL++. :doc:`🔗 ` @@ -22,6 +23,7 @@ You can find the followings from this documentation. overview quickstart + torchcomms dsl tutorials programming_guide diff --git a/docs/quickstart.md b/docs/quickstart.md index 83a08d6aa..95cc2d546 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -95,6 +95,7 @@ There are a few optional CMake options you can set: - `-DMSCCLPP_BUILD_PYTHON_BINDINGS=OFF`: Don't build the Python module. - `-DMSCCLPP_BUILD_TESTS=OFF`: Don't build the tests. - `-DMSCCLPP_BUILD_APPS_NCCL=OFF`: Don't build the NCCL API. +- `-DMSCCLPP_BUILD_EXT_TORCHCOMMS=ON`: Build [TorchComms](https://github.com/meta-pytorch/torchcomms) support for MSCCL++ (off by default). Requires PyTorch and pybind11. ``` (install-from-source-python-module)= @@ -226,6 +227,26 @@ export LD_LIBRARY_PATH=$MSCCLPP_INSTALL_DIR:$LD_LIBRARY_PATH torchrun --nnodes=1 --nproc_per_node=8 your_script.py ``` +(torchcomms-support)= +### TorchComms Support + +MSCCL++ integrates with [TorchComms](https://github.com/meta-pytorch/torchcomms), enabling PyTorch users to use MSCCL++ collectives through the TorchComms API. This is the recommended way to use MSCCL++ in PyTorch training for mixed-backend setups (e.g., MSCCL++ for allreduce, NCCL for broadcast/barrier). + +```bash +$ python -m pip install ./python/mscclpp_torchcomms +``` + +```python +import torchcomms +import mscclpp_torchcomms # auto-registers the backend + +comm = torchcomms.new_comm("mscclpp", device, name="my_comm") +comm.all_reduce(tensor, torchcomms.ReduceOp.SUM, False) +comm.finalize() +``` + +See [TorchComms Integration](torchcomms.md) for full documentation including architecture, algorithm selection, user-defined algorithms, testing, benchmarks, and troubleshooting. + ## Version Tracking The MSCCL++ Python package includes comprehensive version tracking that captures git repository information at build time. This feature allows users to identify the exact source code version of their installed package. diff --git a/docs/torchcomms.md b/docs/torchcomms.md new file mode 100644 index 000000000..da86476f1 --- /dev/null +++ b/docs/torchcomms.md @@ -0,0 +1,318 @@ +(torchcomms)= +# TorchComms Integration + +MSCCL++ integrates with [TorchComms](https://github.com/meta-pytorch/torchcomms), enabling PyTorch users to use MSCCL++ collectives through a standard API. This is the recommended way to use MSCCL++ in PyTorch training — particularly for mixed-backend setups where you want MSCCL++ for the hot-path collectives (allreduce, allgather) and NCCL/RCCL for everything else. + +```python +import torch +import torchcomms +import mscclpp_torchcomms # auto-registers the backend + +comm = torchcomms.new_comm("mscclpp", torch.device(f"cuda:{local_rank}"), name="grad_sync") +comm.all_reduce(tensor, torchcomms.ReduceOp.SUM, False) +comm.finalize() +``` + +## Why TorchComms + +MSCCL++ provides GPU-driven collectives that are faster than NCCL for many workloads (especially allreduce on NVSwitch/H100 systems), but using them directly requires custom CUDA kernels and manual connection setup. The existing NCCL compatibility shim (`LD_PRELOAD`) works but prevents mixed-backend usage and masks MSCCL++'s identity. + +TorchComms solves this: + +- **Mixed-backend training**: Use MSCCL++ for gradient allreduce (~90% of communication time) and NCCL for broadcast, barrier, send/recv — no code changes. +- **Clean integration**: Training frameworks using TorchComms (torchtitan, FSDP2, etc.) swap in MSCCL++ with one line. +- **Proper identity**: MSCCL++ appears as its own backend, not masquerading as NCCL. This matters for debugging, profiling, and configuration. +- **Automatic algorithm selection**: The backend automatically selects the best algorithm (NVLS warp pipeline, packet, fullmesh, RS+AG, etc.) based on message size, topology, and hardware. + +## Installation + +### Prerequisites + +| Dependency | Tested Version | Notes | +|---|---|---| +| PyTorch | 2.10.0+cu128 | Other versions with TorchComms support should work | +| torchcomms | 0.2.0 | `pip install --pre torchcomms` | +| pybind11 | 3.0.2 | Build dependency | +| glog | (any recent) | Build dependency | + +**GPU support:** Tested on NVIDIA GPUs with CUDA 12.8. AMD ROCm GPUs are supported at the build level (MSCCL++ uses a CUDA/HIP translation layer), but the TorchComms backend has not been validated on ROCm yet. + +### pip install (recommended) + +```bash +$ python -m pip install ./python/mscclpp_torchcomms +``` + +This builds and installs the `mscclpp-torchcomms` package. The backend `.so` is automatically discovered — no environment variable needed. + +### CMake build + +For development or integration into an existing build: + +```bash +$ mkdir -p build && cd build +$ cmake -DCMAKE_BUILD_TYPE=Release -DMSCCLPP_BUILD_EXT_TORCHCOMMS=ON .. +$ make -j$(nproc) +$ cd .. +``` + +When using the CMake build path (without pip install), set the environment variable so TorchComms can discover the backend: + +```bash +$ export TORCHCOMMS_BACKEND_LIB_PATH_MSCCLPP=$PWD/build/lib/_comms_mscclpp.cpython-*.so +``` + +## Usage + +```bash +$ torchrun --nproc_per_node=8 your_script.py +``` + +```python +import torch +import torchcomms +import mscclpp_torchcomms # auto-registers the backend .so path + +local_rank = int(os.environ["LOCAL_RANK"]) +device = torch.device(f"cuda:{local_rank}") + +# Create an MSCCL++ communicator +comm = torchcomms.new_comm("mscclpp", device, name="my_comm") + +# AllReduce — MSCCL++ automatically selects the best algorithm +tensor = torch.randn(1024 * 1024, device=device) +comm.all_reduce(tensor, torchcomms.ReduceOp.SUM, False) + +# AllGather +input_chunk = torch.randn(1024, device=device) +output = torch.empty(1024 * world_size, device=device) +comm.all_gather_single(output, input_chunk, False) + +# Cleanup +comm.finalize() +``` + +Alternatively, if you prefer not to add the `mscclpp_torchcomms` import, set the environment variable directly: + +```bash +$ export TORCHCOMMS_BACKEND_LIB_PATH_MSCCLPP=/path/to/_comms_mscclpp.cpython-*.so +``` + +### Mixed-Backend Training + +Use MSCCL++ for high-performance collectives and NCCL for everything else: + +```python +import torch +import torchcomms +import mscclpp_torchcomms + +local_rank = int(os.environ["LOCAL_RANK"]) +device = torch.device(f"cuda:{local_rank}") + +# MSCCL++ for gradient sync (hot path) +mscclpp_comm = torchcomms.new_comm("mscclpp", device, name="grad_sync") + +# NCCL for everything else (broadcast, barrier, etc.) +nccl_comm = torchcomms.new_comm("nccl", device, name="control") + +for epoch in range(num_epochs): + loss = model(data) + loss.backward() + + # Fast gradient allreduce via MSCCL++ + for param in model.parameters(): + mscclpp_comm.all_reduce(param.grad, torchcomms.ReduceOp.SUM, False) + + optimizer.step() + +mscclpp_comm.finalize() +nccl_comm.finalize() +``` + +## Architecture + +### What Happens When You Create a Communicator + +When `torchcomms.new_comm("mscclpp", device)` is called, TorchComms dlopen's the `_comms_mscclpp.*.so` module and invokes `init()`, which: + +1. **Bootstrap** — discovers rank/world_size from the torchrun environment, exchanges a `UniqueId` through `c10d::Store` (rank 0 generates, others read), creates the MSCCL++ `Communicator` with a `TcpBootstrap`. +2. **Scratch buffer** — allocates 128MB via `GpuBuffer` (`cuMemMap`) for native algorithms that need intermediate storage. +3. **Executor** — creates the DSL plan executor (used by DSL algorithms, ignored by native ones). +4. **Algorithm collection** — calls `AlgorithmCollectionBuilder::buildDefaultAlgorithms()` which registers native algorithms + DSL plans, then wires up the topology-aware algorithm selector. +5. **Event pool** — pre-allocates a pool of 256 reusable CUDA events for async work tracking. + +### What Happens When You Call a Collective + +``` +comm.all_reduce(tensor, torchcomms.ReduceOp.SUM, False) + │ + ▾ +TorchCommMSCCLPP::all_reduce() + │ validates reduce op (SUM, MIN) + │ checks tensor is contiguous + │ + ▾ +TorchCommMSCCLPP::executeCollective("allreduce", ...) + │ + │ 1. Builds a CollectiveRequest with world_size, nRanksPerNode, + │ rank, buffer pointers, message size, stream, dtype + │ + │ 2. Calls algorithmCollection_.selectAlgorithm(request) + │ → considers message size, NVLS support, compute capability, + │ symmetric memory, CUDA graph capture mode + │ → returns the best algorithm + │ + │ 3. Creates TorchWorkMSCCLPP handle, records start GPU event + │ + │ 4. Calls algo->execute(...) + │ → native algorithms launch a CUDA kernel directly + │ → DSL algorithms use the executor to interpret a JSON plan + │ + │ 5. Records end GPU event, returns the work handle + │ + ▾ +TorchWorkMSCCLPP (returned to caller) + │ wait() → cudaStreamWaitEvent on caller's stream (GPU-side, no CPU block) + │ checkStatus() → polls GPU events for completion/timeout +``` + +### Component Diagram + +``` +torchcomms.new_comm("mscclpp", device) + │ + ▾ +TorchCommMSCCLPPPy.cpp ← pybind11 module + dynamic loader interface + │ + ▾ +TorchCommMSCCLPP.cpp/hpp ← backend class (init, finalize, collective dispatch) + │ + ├── TorchCommMSCCLPPBootstrap ← rank discovery via c10d::Store + ├── TorchWorkMSCCLPP ← GPU event pool + async work tracking + │ + ▾ +AlgorithmCollection::selectAlgorithm() ← MSCCL++ algorithm selection + │ + ▾ +Algorithm::execute() ← GPU kernel launch (native or DSL) +``` + +## Supported Collectives + +| Collective | Status | Algorithms | Notes | +|---|---|---|---| +| AllReduce | Supported | allpair_packet, nvls_packet, packet, nvls_zero_copy, nvls_warp_pipeline, nvls_block_pipeline, fullmesh, rsag, rsag_pipeline, rsag_zero_copy | SUM, MIN. Auto-selected by message size + topology. | +| AllGather | Supported | fullmesh, fullmesh2 | Auto-selected by message size. | +| ReduceScatter | Supported (with custom algorithm) | — | No default algorithms ship. Requires registering a DSL or native algorithm via `AlgorithmCollectionBuilder`. | +| AllToAll | Supported (with custom algorithm) | — | No default algorithms ship. Requires registering a DSL or native algorithm via `AlgorithmCollectionBuilder`. | +| Broadcast | Not supported | — | Use a separate NCCL/RCCL communicator. | +| Reduce | Not supported | — | Use a separate NCCL/RCCL communicator. | +| Send/Recv | Not supported | — | Use a separate NCCL/RCCL communicator. | +| Barrier | Not supported | — | Use a separate NCCL/RCCL communicator. | +| Scatter/Gather | Not supported | — | Use a separate NCCL/RCCL communicator. | + +Unsupported collectives throw a `RuntimeError` with an explicit message naming the operation and suggesting the caller use a separate NCCL/RCCL communicator. + +## Algorithm Selection + +The backend uses the same topology-aware algorithm selector as the NCCL compatibility extension. Selection considers: + +- **Message size**: Small messages (â‰Ī1MB) use packet-based algorithms for lower latency. Large messages use non-packet algorithms for higher bandwidth. +- **NVLS support**: On NVSwitch-connected systems (H100, etc.), NVLS algorithms (warp pipeline, block pipeline) are preferred for large allreduce. +- **Compute capability**: Some algorithms require SM 9.0+ (Hopper). +- **Buffer allocation**: Zero-copy NVLS algorithms require `cuMemMap`-allocated buffers. +- **CUDA graph capture**: Some algorithms are compatible with CUDA graph capture mode. + +The selector picks the best algorithm automatically. Users do not need to configure algorithm selection for default usage. + +## User-Defined Algorithms + +Custom algorithms (DSL or native) can be registered via the `AlgorithmCollectionBuilder` singleton **before** creating the TorchComms communicator. The backend picks them up during `init()`. + +```python +import mscclpp +from mscclpp.language.collectives import AllReduce +import torchcomms +import mscclpp_torchcomms + +# 1. Configure algorithms on the builder singleton +builder = mscclpp.AlgorithmCollectionBuilder() +spec = mscclpp.AlgoSpec(name="my_allreduce", collective=AllReduce(8, 1, True)) +algo = mscclpp.compile(algo=my_allreduce_fn, algo_spec=spec, rank=rank) +builder.add_algorithm_builder(algo) +builder.set_algorithm_selector(my_selector) + +# 2. Create comm — init() picks up everything from the builder +comm = torchcomms.new_comm("mscclpp", device, name="custom") + +# 3. Collectives use the configured algorithms automatically +comm.all_reduce(tensor, torchcomms.ReduceOp.SUM, False) +``` + +## Environment Variables + +| Variable | Description | +|---|---| +| `TORCHCOMMS_BACKEND_LIB_PATH_MSCCLPP` | Path to the `_comms_mscclpp.*.so` module. **Automatically set** when `mscclpp-torchcomms` is pip-installed. Only needed for CMake-only builds. | + +## Testing + +All tests are launched via `torchrun`: + +```bash +# Collective correctness (allreduce, allgather, reducescatter) +$ torchrun --nproc_per_node=2 test/torchcomms/test_correctness.py --all + +# With size/dtype sweep (exercises both packet and non-packet algorithm paths) +$ torchrun --nproc_per_node=2 test/torchcomms/test_correctness.py --all --sweep + +# Message size sweep (1 to 32MB) +$ torchrun --nproc_per_node=2 test/torchcomms/test_sizes.py + +# Error handling (unsupported ops, invalid reduce ops) +$ torchrun --nproc_per_node=2 test/torchcomms/test_error_handling.py + +# Simulated training loop +$ torchrun --nproc_per_node=2 test/torchcomms/test_training_loop.py + +# User-defined algorithm registration +$ torchrun --nproc_per_node=2 test/torchcomms/test_user_algorithms.py +``` + +## Benchmarks + +```bash +$ torchrun --nproc_per_node=8 test/torchcomms/bench_torchcomms.py --collective allreduce --warmup 100 --iters 200 +$ torchrun --nproc_per_node=8 test/torchcomms/bench_torchcomms.py --collective allgather --warmup 100 --iters 200 +``` + +Generate a report from benchmark output: + +```bash +$ python test/torchcomms/bench_report.py --input bench_results/torchcomms_raw.json +``` + +## Limitations + +- **Single-tensor variants only.** MSCCL++'s `Algorithm::execute()` operates on contiguous buffers, so the backend implements `all_gather_single` and `reduce_scatter_single` but not the tensor-list variants. The tensor-list variants throw with guidance to use the single-tensor variant. +- **Contiguous tensors required.** All input and output tensors must be contiguous. Non-contiguous tensors raise a `RuntimeError`. +- **Unsupported collectives throw at runtime.** Broadcast, reduce, send/recv, barrier, scatter, and gather throw a `RuntimeError` with guidance to use NCCL/RCCL. + +## Troubleshooting + +### "Backend mscclpp specified, but TORCHCOMMS_BACKEND_LIB_PATH_MSCCLPP not set" + +The test or script is not importing `mscclpp_torchcomms`. Add `import mscclpp_torchcomms` before `torchcomms.new_comm()`, or set the environment variable manually if not using pip install. + +### "Requested fd not found, size of fdSet_ is 0" + +The scratch buffer was allocated with `cudaMalloc` instead of `GpuBuffer` (`cuMemMap`). This means POSIX file descriptors were not registered in the unix socket server for cross-rank IPC sharing. This is a build issue — ensure the backend is built against the correct MSCCL++ version. + +### "No algorithm registered for collective X" + +The algorithm selector found no matching algorithm for the given collective, message size, and topology. For ReduceScatter and AllToAll, you need to register a DSL algorithm via `AlgorithmCollectionBuilder` before creating the MSCCL++ communicator or use a different communicator. + +### CUDA device mismatch errors + +The backend uses `CudaDeviceGuard` to restore the CUDA device after `init()`. If you see device mismatch errors, ensure the `device` argument to `torchcomms.new_comm()` matches `LOCAL_RANK`. diff --git a/python/mscclpp_torchcomms/CMakeLists.txt b/python/mscclpp_torchcomms/CMakeLists.txt new file mode 100644 index 000000000..6c113150e --- /dev/null +++ b/python/mscclpp_torchcomms/CMakeLists.txt @@ -0,0 +1,111 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +include(FetchContent) + +# Fetch torchcomms headers (header-only dependency — we only need the interface headers) +FetchContent_Declare(torchcomms + GIT_REPOSITORY https://github.com/meta-pytorch/torchcomms.git + GIT_TAG v0.2.0-rc2 +) +FetchContent_GetProperties(torchcomms) +if(NOT torchcomms_POPULATED) + FetchContent_Populate(torchcomms) +endif() + +# Find PyTorch (provides Torch libraries and Python development headers) +find_package(Torch REQUIRED) +find_package(Python 3.8 COMPONENTS Interpreter Development.Module REQUIRED) + +# Locate pybind11 via Python package +execute_process( + COMMAND "${Python_EXECUTABLE}" -c "import pybind11; print(pybind11.get_cmake_dir())" + OUTPUT_VARIABLE PYBIND11_CMAKE_DIR + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE PYBIND11_FIND_RESULT +) +if(PYBIND11_FIND_RESULT EQUAL 0 AND PYBIND11_CMAKE_DIR) + list(APPEND CMAKE_PREFIX_PATH "${PYBIND11_CMAKE_DIR}") +endif() +find_package(pybind11 REQUIRED) + +# Gather our C++ sources +file(GLOB_RECURSE TORCHCOMM_SOURCES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp) + +# Torchcomms framework sources we need to compile in directly. +# Our module inherits from TorchWork, TorchCommBackend, and registers with +# TorchCommFactory — these symbols must be in our .so since torchcomms doesn't +# export them from a shared lib we can link against. +set(TORCHCOMMS_FRAMEWORK_SOURCES + ${torchcomms_SOURCE_DIR}/comms/torchcomms/TorchWork.cpp + ${torchcomms_SOURCE_DIR}/comms/torchcomms/TorchCommFactory.cpp + ${torchcomms_SOURCE_DIR}/comms/torchcomms/TorchCommOptions.cpp + ${torchcomms_SOURCE_DIR}/comms/torchcomms/TorchCommTypes.cpp + ${torchcomms_SOURCE_DIR}/comms/torchcomms/utils/Utils.cpp + ${torchcomms_SOURCE_DIR}/comms/torchcomms/utils/StoreManager.cpp +) + +# MSCCL++ algorithm selector (same one used by the NCCL extension) +set(MSCCLPP_ALGO_SELECTOR_SOURCES + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ext/nccl/algorithm_selector.cc +) + +# Build pybind11 module +pybind11_add_module(_comms_mscclpp ${TORCHCOMM_SOURCES} ${TORCHCOMMS_FRAMEWORK_SOURCES} ${MSCCLPP_ALGO_SELECTOR_SOURCES}) + +# Find glog (required by torchcomms framework sources via Logging.hpp). +# Add the conda/venv prefix to CMAKE_PREFIX_PATH so find_package can locate +# glog's CMake config installed alongside the Python environment. +get_filename_component(_PYTHON_ENV_PREFIX "${Python_EXECUTABLE}" DIRECTORY) +get_filename_component(_PYTHON_ENV_PREFIX "${_PYTHON_ENV_PREFIX}" DIRECTORY) +list(APPEND CMAKE_PREFIX_PATH "${_PYTHON_ENV_PREFIX}") +find_package(glog REQUIRED) + +target_include_directories(_comms_mscclpp SYSTEM PRIVATE + # torchcomms headers: resolves #include + ${torchcomms_SOURCE_DIR} + ${GPU_INCLUDE_DIRS} +) +# MSCCL++ internal headers (for algorithm_selector.hpp and debug.h) +target_include_directories(_comms_mscclpp PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ext/nccl + ${CMAKE_CURRENT_SOURCE_DIR}/../../src/core/include +) + +target_link_libraries(_comms_mscclpp PRIVATE + # MUST use the shared library (not mscclpp_static) to avoid dual-singleton: + # mscclpp_collectives.so links against libmscclpp.so, so if we statically + # link mscclpp into our module, there are two copies of singletons like + # UnixSocketServer::instance(). The bootstrap starts server #1 (static), + # but the collectives code registers fds in server #2 (shared), causing + # "Requested fd not found, size of fdSet_ is 0" crashes. + mscclpp + mscclpp_collectives + ${TORCH_LIBRARIES} + ${GPU_LIBRARIES} + glog::glog +) + +# Propagate USE_ROCM define for mscclpp/gpu.hpp portability +target_compile_definitions(_comms_mscclpp PRIVATE + $<$:USE_ROCM> +) + +target_compile_features(_comms_mscclpp PRIVATE cxx_std_17) + +# Set the torch_python library path for linking +set(TORCH_PYTHON_LIB "${TORCH_INSTALL_PREFIX}/lib/libtorch_python.so") +if(EXISTS "${TORCH_PYTHON_LIB}") + target_link_libraries(_comms_mscclpp PRIVATE "${TORCH_PYTHON_LIB}") +endif() + +# Install the module into the package directory for pip/scikit-build-core +install(TARGETS _comms_mscclpp LIBRARY DESTINATION mscclpp_torchcomms COMPONENT torchcomm) + +# Copy built module to source tree for easy import during development +add_custom_target(torchcomm_lib_copy ALL + COMMAND ${CMAKE_COMMAND} -E copy_if_different + ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/_comms_mscclpp*.so + ${CMAKE_CURRENT_SOURCE_DIR} + DEPENDS _comms_mscclpp +) diff --git a/python/mscclpp_torchcomms/__init__.py b/python/mscclpp_torchcomms/__init__.py new file mode 100644 index 000000000..ba9ceb070 --- /dev/null +++ b/python/mscclpp_torchcomms/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import glob +import os + +# Auto-register the MSCCL++ TorchComms backend .so path. +# TorchComms discovers backends via TORCHCOMMS_BACKEND_LIB_PATH_MSCCLPP. +# When this package is pip-installed, the .so lives alongside this __init__.py. +if "TORCHCOMMS_BACKEND_LIB_PATH_MSCCLPP" not in os.environ: + _pkg_dir = os.path.dirname(os.path.abspath(__file__)) + _candidates = glob.glob(os.path.join(_pkg_dir, "_comms_mscclpp*.so")) + if _candidates: + os.environ["TORCHCOMMS_BACKEND_LIB_PATH_MSCCLPP"] = _candidates[0] + + +def get_lib_path(): + """Return the path to the _comms_mscclpp shared library, or None if not found.""" + return os.environ.get("TORCHCOMMS_BACKEND_LIB_PATH_MSCCLPP") diff --git a/python/mscclpp_torchcomms/csrc/NcclFallback.cpp b/python/mscclpp_torchcomms/csrc/NcclFallback.cpp new file mode 100644 index 000000000..8b82e0c63 --- /dev/null +++ b/python/mscclpp_torchcomms/csrc/NcclFallback.cpp @@ -0,0 +1,275 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "NcclFallback.hpp" + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace torch::comms { + +// --- NCCL C-API ABI mirror --- +// +// We use dlsym for runtime binding (no link-time NCCL dependency), but include +// nccl.h so enum/type names stay source-of-truth (no hardcoded numeric values). +namespace { +bool torchcommTraceEnabled() { + const char* value = std::getenv("MSCCLPP_TORCHCOMMS_TRACE"); + return value != nullptr && value[0] != '\0' && std::string(value) != "0"; +} + +// Function pointer types (signatures from nccl.h). +using GetUniqueIdFn = ncclResult_t (*)(ncclUniqueId*); +using CommInitRankFn = ncclResult_t (*)(ncclComm_t*, int, ncclUniqueId, int); +using CommDestroyFn = ncclResult_t (*)(ncclComm_t); +using ReduceScatterFn = ncclResult_t (*)(const void*, void*, size_t, ncclDataType_t, ncclRedOp_t, ncclComm_t, + cudaStream_t); +using BroadcastFn = ncclResult_t (*)(const void*, void*, size_t, ncclDataType_t, int, ncclComm_t, cudaStream_t); +using AllReduceFn = ncclResult_t (*)(const void*, void*, size_t, ncclDataType_t, ncclRedOp_t, ncclComm_t, cudaStream_t); +using ReduceFn = + ncclResult_t (*)(const void*, void*, size_t, ncclDataType_t, ncclRedOp_t, int, ncclComm_t, cudaStream_t); +using SendFn = ncclResult_t (*)(const void*, size_t, ncclDataType_t, int, ncclComm_t, cudaStream_t); +using RecvFn = ncclResult_t (*)(void*, size_t, ncclDataType_t, int, ncclComm_t, cudaStream_t); +using GroupFn = ncclResult_t (*)(); + +ncclDataType_t torchDtypeToNccl(at::ScalarType dtype) { + switch (dtype) { + case at::kFloat: + return ncclFloat32; + case at::kHalf: + return ncclFloat16; + case at::kBFloat16: + return ncclBfloat16; + case at::kInt: + return ncclInt32; + case at::kUInt32: + return ncclUint32; + default: + throw std::runtime_error("[NcclFallback] unsupported dtype " + std::string(at::toString(dtype))); + } +} + +ncclRedOp_t torchReduceOpToNccl(const ReduceOp& op) { + using RedOpType = ReduceOp::RedOpType; + switch (op.type()) { + case RedOpType::SUM: + case RedOpType::PREMUL_SUM: // caller has already applied the scaling + return ncclSum; + case RedOpType::AVG: + return ncclAvg; + case RedOpType::MIN: + return ncclMin; + case RedOpType::MAX: + return ncclMax; + case RedOpType::PRODUCT: + return ncclProd; + default: + throw std::runtime_error("[NcclFallback] unsupported reduce op type " + + std::to_string(static_cast(op.type()))); + } +} +} // namespace + +// --- Lifecycle --- + +std::unique_ptr NcclFallback::tryCreate(const std::shared_ptr& comm, int rank, + int worldSize) { + std::unique_ptr fb(new NcclFallback()); + + // Search candidates for libnccl. MSCCLPP_NCCL_LIB_PATH matches the + // existing src/ext/nccl behavior; bare "libnccl.so.2" picks up PyTorch's + // bundled copy via the Python rpath. + std::vector candidates; + if (const char* envPath = std::getenv("MSCCLPP_NCCL_LIB_PATH"); envPath && envPath[0]) { + candidates.emplace_back(envPath); + } + candidates.emplace_back("libnccl.so.2"); + candidates.emplace_back("libnccl.so"); + + for (const auto& path : candidates) { + fb->dlHandle_ = dlopen(path.c_str(), RTLD_LAZY | RTLD_NODELETE); + if (fb->dlHandle_) break; + } + if (!fb->dlHandle_) { + if (rank == 0) { + const char* err = dlerror(); + std::cerr << "[NcclFallback] could not dlopen libnccl.so.2; fallback disabled. dlerror=" << (err ? err : "(null)") + << std::endl; + } + return nullptr; + } + + auto sym = [&](const char* name) -> void* { + void* p = dlsym(fb->dlHandle_, name); + if (!p && rank == 0) { + std::cerr << "[NcclFallback] dlsym(" << name << ") failed: " << dlerror() << std::endl; + } + return p; + }; + + fb->getUniqueIdFn_ = sym("ncclGetUniqueId"); + fb->commInitRankFn_ = sym("ncclCommInitRank"); + fb->commDestroyFn_ = sym("ncclCommDestroy"); + fb->reduceScatterFn_ = sym("ncclReduceScatter"); + fb->broadcastFn_ = sym("ncclBroadcast"); + fb->allReduceFn_ = sym("ncclAllReduce"); + fb->reduceFn_ = sym("ncclReduce"); + fb->sendFn_ = sym("ncclSend"); + fb->recvFn_ = sym("ncclRecv"); + fb->groupStartFn_ = sym("ncclGroupStart"); + fb->groupEndFn_ = sym("ncclGroupEnd"); + if (!fb->getUniqueIdFn_ || !fb->commInitRankFn_ || !fb->commDestroyFn_ || !fb->reduceScatterFn_ || + !fb->broadcastFn_ || !fb->allReduceFn_ || !fb->reduceFn_ || !fb->sendFn_ || !fb->recvFn_ || + !fb->groupStartFn_ || !fb->groupEndFn_) { + return nullptr; // dtor cleans up dlHandle_ + } + fb->worldSize_ = worldSize; + + // Distribute the ncclUniqueId via the MSCCL++ bootstrap. The base Bootstrap + // interface exposes allGather() but not broadcast(), so we use allGather: + // rank 0 fills its own slot, others zero theirs, then everyone reads slot 0. + std::vector allIds(worldSize); + if (rank == 0) { + int rc = reinterpret_cast(fb->getUniqueIdFn_)(&allIds[0]); + if (rc != 0) { + std::cerr << "[NcclFallback] ncclGetUniqueId failed: rc=" << rc << std::endl; + return nullptr; + } + } else { + std::memset(&allIds[rank], 0, sizeof(ncclUniqueId)); + } + comm->bootstrap()->allGather(allIds.data(), sizeof(ncclUniqueId)); + + int rc = reinterpret_cast(fb->commInitRankFn_)(reinterpret_cast(&fb->ncclComm_), + worldSize, allIds[0], rank); + if (rc != 0) { + std::cerr << "[NcclFallback] ncclCommInitRank failed: rc=" << rc << std::endl; + return nullptr; + } + + // Persistent 4-byte device buffer for barrier-as-allreduce. This is local, + // never shared cross-rank, so plain cudaMalloc/cudaFree is appropriate. + MSCCLPP_CUDATHROW(cudaMalloc(&fb->barrierBuf_, sizeof(int))); + MSCCLPP_CUDATHROW(cudaMemset(fb->barrierBuf_, 0, sizeof(int))); + + if (rank == 0) { + std::cerr << "[NcclFallback] enabled (libnccl.so.2 dlopened)." << std::endl; + } + return fb; +} + +NcclFallback::~NcclFallback() { + if (barrierBuf_) cudaFree(barrierBuf_); + if (ncclComm_ && commDestroyFn_) { + reinterpret_cast(commDestroyFn_)(reinterpret_cast(ncclComm_)); + } + if (dlHandle_) dlclose(dlHandle_); +} + +// --- Dispatchers --- +// Keep fallback narrowly scoped to collectives that currently need it. +// Unsupported collectives remain explicit in TorchCommMSCCLPP to preserve +// TorchComm API semantics and clear error messaging. + +// Tag-and-call helpers. NCCL_TRACE compiles to a runtime-gated cerr line; +// NCCL_CALL invokes a dlsym'd function pointer and throws on nonzero rc. +#define NCCL_TRACE(tag, fields) \ + do { \ + if (torchcommTraceEnabled()) std::cerr << "[NcclFallback] " tag " -> NCCL " << fields << std::endl; \ + } while (0) +#define NCCL_CALL(label, fn_t, fn_ptr, ...) \ + do { \ + int _rc = reinterpret_cast(fn_ptr)(__VA_ARGS__); \ + if (_rc != 0) throw std::runtime_error("[NcclFallback] " label " rc=" + std::to_string(_rc)); \ + } while (0) + +void NcclFallback::reduceScatter(const void* sendbuf, void* recvbuf, size_t recvCount, at::ScalarType dtype, + const ReduceOp& op, cudaStream_t stream) { + NCCL_TRACE("reduce_scatter", "recvCount=" << recvCount << " dtype=" << static_cast(torchDtypeToNccl(dtype)) + << " op=" << static_cast(torchReduceOpToNccl(op))); + NCCL_CALL("ncclReduceScatter", ReduceScatterFn, reduceScatterFn_, sendbuf, recvbuf, recvCount, + torchDtypeToNccl(dtype), torchReduceOpToNccl(op), reinterpret_cast(ncclComm_), stream); +} + +void NcclFallback::broadcast(const void* sendbuf, void* recvbuf, size_t count, at::ScalarType dtype, int root, + cudaStream_t stream) { + NCCL_TRACE("broadcast", + "count=" << count << " dtype=" << static_cast(torchDtypeToNccl(dtype)) << " root=" << root); + NCCL_CALL("ncclBroadcast", BroadcastFn, broadcastFn_, sendbuf, recvbuf, count, torchDtypeToNccl(dtype), root, + reinterpret_cast(ncclComm_), stream); +} + +void NcclFallback::barrier(cudaStream_t stream) { + NCCL_TRACE("barrier", "(allreduce on persistent 4-byte buffer)"); + NCCL_CALL("ncclAllReduce(barrier)", AllReduceFn, allReduceFn_, barrierBuf_, barrierBuf_, 1, ncclInt32, ncclSum, + reinterpret_cast(ncclComm_), stream); +} + +void NcclFallback::allReduce(const void* sendbuf, void* recvbuf, size_t count, at::ScalarType dtype, + const ReduceOp& op, cudaStream_t stream) { + NCCL_TRACE("all_reduce", "count=" << count << " dtype=" << static_cast(torchDtypeToNccl(dtype)) + << " op=" << static_cast(torchReduceOpToNccl(op))); + NCCL_CALL("ncclAllReduce", AllReduceFn, allReduceFn_, sendbuf, recvbuf, count, torchDtypeToNccl(dtype), + torchReduceOpToNccl(op), reinterpret_cast(ncclComm_), stream); +} + +void NcclFallback::reduce(const void* sendbuf, void* recvbuf, size_t count, at::ScalarType dtype, const ReduceOp& op, + int root, cudaStream_t stream) { + NCCL_TRACE("reduce", "count=" << count << " root=" << root); + NCCL_CALL("ncclReduce", ReduceFn, reduceFn_, sendbuf, recvbuf, count, torchDtypeToNccl(dtype), + torchReduceOpToNccl(op), root, reinterpret_cast(ncclComm_), stream); +} + +void NcclFallback::send(const void* sendbuf, size_t count, at::ScalarType dtype, int peer, cudaStream_t stream) { + NCCL_TRACE("send", "peer=" << peer << " count=" << count); + NCCL_CALL("ncclSend", SendFn, sendFn_, sendbuf, count, torchDtypeToNccl(dtype), peer, + reinterpret_cast(ncclComm_), stream); +} + +void NcclFallback::recv(void* recvbuf, size_t count, at::ScalarType dtype, int peer, cudaStream_t stream) { + NCCL_TRACE("recv", "peer=" << peer << " count=" << count); + NCCL_CALL("ncclRecv", RecvFn, recvFn_, recvbuf, count, torchDtypeToNccl(dtype), peer, + reinterpret_cast(ncclComm_), stream); +} + +void NcclFallback::allToAllV(const void* sendbuf, void* recvbuf, const std::vector& sendCounts, + const std::vector& recvCounts, const std::vector& sendOffsets, + const std::vector& recvOffsets, at::ScalarType dtype, cudaStream_t stream) { + if (sendCounts.size() != static_cast(worldSize_) || + recvCounts.size() != static_cast(worldSize_) || + sendOffsets.size() != static_cast(worldSize_) || + recvOffsets.size() != static_cast(worldSize_)) { + throw std::runtime_error("[NcclFallback] all_to_all_v counts/offsets must be length worldSize"); + } + const ncclDataType_t ncclDtype = torchDtypeToNccl(dtype); + const size_t elemBytes = c10::elementSize(dtype); + const auto* sendBytes = static_cast(sendbuf); + auto* recvBytes = static_cast(recvbuf); + ncclComm_t nccl = reinterpret_cast(ncclComm_); + + NCCL_TRACE("all_to_all_v", "group send/recv worldSize=" << worldSize_); + NCCL_CALL("ncclGroupStart", GroupFn, groupStartFn_); + for (int peer = 0; peer < worldSize_; ++peer) { + if (sendCounts[peer] > 0) { + NCCL_CALL("ncclSend", SendFn, sendFn_, sendBytes + sendOffsets[peer] * elemBytes, sendCounts[peer], ncclDtype, + peer, nccl, stream); + } + if (recvCounts[peer] > 0) { + NCCL_CALL("ncclRecv", RecvFn, recvFn_, recvBytes + recvOffsets[peer] * elemBytes, recvCounts[peer], ncclDtype, + peer, nccl, stream); + } + } + NCCL_CALL("ncclGroupEnd", GroupFn, groupEndFn_); +} + +#undef NCCL_CALL +#undef NCCL_TRACE + +} // namespace torch::comms diff --git a/python/mscclpp_torchcomms/csrc/NcclFallback.hpp b/python/mscclpp_torchcomms/csrc/NcclFallback.hpp new file mode 100644 index 000000000..f7c3140e3 --- /dev/null +++ b/python/mscclpp_torchcomms/csrc/NcclFallback.hpp @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace torch::comms { + +/// dlopen-based NCCL fallback for collectives MSCCL++ does not natively +/// implement (reduce_scatter, broadcast, barrier on certain configs). +/// +/// The backend creates one of these in init() via tryCreate(). If libnccl +/// can't be found or its symbols can't be resolved, tryCreate() returns +/// nullptr and the backend operates without a fallback (unsupported +/// collectives then throw). +/// +/// All ABI/dlsym ugliness lives in NcclFallback.cpp; the backend never +/// touches NCCL types or symbols directly. +class NcclFallback { + public: + /// Try to dlopen libnccl.so.2, resolve the symbols we need, and create a + /// parallel NCCL communicator. Returns nullptr (without throwing) if any + /// step fails — callers should treat that as "fallback unavailable". + /// + /// Search order for the shared library: + /// 1. $MSCCLPP_NCCL_LIB_PATH (matches src/ext/nccl/nccl.cc behavior) + /// 2. libnccl.so.2 (rtld-resolved; finds PyTorch's bundled NCCL) + /// 3. libnccl.so + static std::unique_ptr tryCreate(const std::shared_ptr& comm, int rank, + int worldSize); + + ~NcclFallback(); + + NcclFallback(const NcclFallback&) = delete; + NcclFallback& operator=(const NcclFallback&) = delete; + + /// reduce_scatter: input is the full buffer, output is rank's chunk. + /// recvCount is the number of elements in the per-rank output chunk. + void reduceScatter(const void* sendbuf, void* recvbuf, size_t recvCount, at::ScalarType dtype, const ReduceOp& op, + cudaStream_t stream); + + /// broadcast from `root` to all ranks. count is element count. + void broadcast(const void* sendbuf, void* recvbuf, size_t count, at::ScalarType dtype, int root, cudaStream_t stream); + + /// barrier emulated as a 1-element ncclAllReduce on a persistent device int. + void barrier(cudaStream_t stream); + + /// all_reduce for ops MSCCL++ does not implement natively (e.g. MAX, PRODUCT). + void allReduce(const void* sendbuf, void* recvbuf, size_t count, at::ScalarType dtype, const ReduceOp& op, + cudaStream_t stream); + + /// reduce: like all_reduce but only `root` receives the result. + void reduce(const void* sendbuf, void* recvbuf, size_t count, at::ScalarType dtype, const ReduceOp& op, int root, + cudaStream_t stream); + + /// Point-to-point send/recv (count is element count). + void send(const void* sendbuf, size_t count, at::ScalarType dtype, int peer, cudaStream_t stream); + void recv(void* recvbuf, size_t count, at::ScalarType dtype, int peer, cudaStream_t stream); + + /// Variadic all-to-all via ncclGroupStart/End loop of ncclSend/ncclRecv. + /// Counts and offsets are in elements; vectors are length worldSize. + void allToAllV(const void* sendbuf, void* recvbuf, const std::vector& sendCounts, + const std::vector& recvCounts, const std::vector& sendOffsets, + const std::vector& recvOffsets, at::ScalarType dtype, cudaStream_t stream); + + private: + NcclFallback() = default; + + // Opaque to the header — concrete state lives in NcclFallback.cpp. + void* dlHandle_ = nullptr; + void* ncclComm_ = nullptr; + int worldSize_ = 0; + void* getUniqueIdFn_ = nullptr; + void* commInitRankFn_ = nullptr; + void* commDestroyFn_ = nullptr; + void* reduceScatterFn_ = nullptr; + void* broadcastFn_ = nullptr; + void* allReduceFn_ = nullptr; + void* reduceFn_ = nullptr; + void* sendFn_ = nullptr; + void* recvFn_ = nullptr; + void* groupStartFn_ = nullptr; + void* groupEndFn_ = nullptr; + void* barrierBuf_ = nullptr; // persistent 4-byte device buffer for barrier() +}; + +} // namespace torch::comms diff --git a/python/mscclpp_torchcomms/csrc/TorchCommMSCCLPP.cpp b/python/mscclpp_torchcomms/csrc/TorchCommMSCCLPP.cpp new file mode 100644 index 000000000..cd56e5b1b --- /dev/null +++ b/python/mscclpp_torchcomms/csrc/TorchCommMSCCLPP.cpp @@ -0,0 +1,485 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "TorchCommMSCCLPP.hpp" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "NcclFallback.hpp" +#include "TorchCommMSCCLPPBootstrap.hpp" +#include "algorithm_selector.hpp" // shared with src/ext/nccl + +namespace torch::comms { + +// --- Helpers --- + +mscclpp::DataType TorchCommMSCCLPP::torchDtypeToMscclpp(at::ScalarType dtype) { + switch (dtype) { + case at::kFloat: + return mscclpp::DataType::FLOAT32; + case at::kHalf: + return mscclpp::DataType::FLOAT16; + case at::kBFloat16: + return mscclpp::DataType::BFLOAT16; + case at::kInt: + return mscclpp::DataType::INT32; + case at::kUInt32: + return mscclpp::DataType::UINT32; + default: + throw std::runtime_error("[TorchCommMSCCLPP] unsupported dtype: " + std::string(at::toString(dtype))); + } +} + +std::optional TorchCommMSCCLPP::torchReduceOpToMscclpp(const ReduceOp& op) { + // MSCCL++ kernels only implement SUM and MIN. Everything else routes to + // NcclFallback, which uses NCCL's correct implementation. + // + // PREMUL_SUM is mapped to SUM because it's defined as "caller has already + // multiplied each input by the scale factor" — so by the time the kernel + // runs, it really is just a SUM. (FSDP2 uses this for grad sync.) + // + // AVG is intentionally NOT mapped to SUM here. AVG = SUM/world_size, and + // MSCCL++ has no native divide step; mapping AVG → SUM would silently + // produce N× the correct result. NCCL implements AVG correctly via + // ncclAvg, so we let it route there. + switch (op.type()) { + case ReduceOp::RedOpType::SUM: + case ReduceOp::RedOpType::PREMUL_SUM: + return mscclpp::SUM; + case ReduceOp::RedOpType::MIN: + return mscclpp::MIN; + default: + return std::nullopt; // AVG, MAX, PRODUCT, BAND/BOR/BXOR -> NcclFallback + } +} + +// Async ops use the dedicated internal stream; sync ops use the caller's +// current PyTorch CUDA stream so the launch is ordered inline with their work. +cudaStream_t TorchCommMSCCLPP::getOperationStream(bool async_op) const { + return async_op ? internal_stream_ : at::cuda::getCurrentCUDAStream(device_.index()).stream(); +} + +void TorchCommMSCCLPP::checkInitialized() const { + if (!initialized_) throw std::runtime_error("[TorchCommMSCCLPP] not initialized; call init() first"); +} + +// Wraps a fallback dispatch in start/end GPU events on `stream`. Throws if +// libnccl could not be dlopen'd at init time (fallback unavailable). +template +c10::intrusive_ptr TorchCommMSCCLPP::ncclFallback(const char* op, cudaStream_t stream, + std::chrono::milliseconds timeout, Fn&& body) { + if (!nccl_) { + throw std::runtime_error(std::string("[TorchCommMSCCLPP] ") + op + + " requires NCCL fallback (libnccl.so.2 not found)"); + } + auto work = c10::make_intrusive(stream, device_.index(), timeout, event_pool_); + work->recordStart(); + std::forward(body)(); + work->recordEnd(); + return work; +} + +std::shared_ptr TorchCommMSCCLPP::selectAlgorithm( + const std::unordered_map>>& algoMapByCollective, + const mscclpp::CollectiveRequest& request) { + // Hardware capabilities are detected once on first call (per process). + static const bool isNvlsSupported = mscclpp::isNvlsSupported(); + static const std::pair computeCapability = []() { + int dev = 0; + cudaGetDevice(&dev); + int major = 0, minor = 0; + cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, dev); + cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, dev); + return std::make_pair(major, minor); + }(); + + auto collectiveIt = algoMapByCollective.find(request.collective); + if (collectiveIt == algoMapByCollective.end()) return nullptr; + const auto& algoMap = collectiveIt->second; + + const bool isCuMemMap = mscclpp::isCuMemMapAllocated(const_cast(request.inputBuffer)) && + mscclpp::isCuMemMapAllocated(request.outputBuffer); + cudaStreamCaptureStatus capture = cudaStreamCaptureStatusNone; + cudaStreamIsCapturing(request.stream, &capture); + mscclpp::nccl::AlgorithmSelectorConfig config{ + .symmetricMemory = false, + .nvlsSupported = isNvlsSupported, + .isCuMemMapAllocated = isCuMemMap, + .inCaptureMode = (capture == cudaStreamCaptureStatusActive), + .computeCapability = computeCapability, + .ncclDlopenSharedLib = false, + }; + + // 1. DSL execution plans + for (const auto& [name, algo] : algoMap) { + (void)name; + if (algo->type() == mscclpp::AlgorithmType::DSL) { + auto dslAlgo = std::dynamic_pointer_cast(algo); + if (dslAlgo && mscclpp::nccl::matchExecutionPlan(dslAlgo, request)) return algo; + } + } + + // 2. Topology-aware native selectors + if (request.nRanksPerNode != request.worldSize) { + return mscclpp::nccl::selectMultiNodeAlgorithm(algoMap, request, config); + } + if (request.collective == "allgather") return mscclpp::nccl::selectSingleNodeAllgather(algoMap, request, config); + if (request.collective == "allreduce") return mscclpp::nccl::selectSingleNodeAllreduce(algoMap, request, config); + return nullptr; +} + +// --- Lifecycle --- + +TorchCommMSCCLPP::TorchCommMSCCLPP() = default; + +TorchCommMSCCLPP::~TorchCommMSCCLPP() { + if (initialized_) { + try { + finalize(); + } catch (...) { + } + } +} + +void TorchCommMSCCLPP::init(at::Device device, const std::string& name, const CommOptions& options) { + if (initialized_) throw std::runtime_error("[TorchCommMSCCLPP] already initialized"); + + device_ = device; + name_ = name; + options_ = options; + + // Bootstrap + communicator + auto bootstrap = std::make_unique(options.store, device, options.timeout); + rank_ = bootstrap->getRank(); + size_ = bootstrap->getSize(); + comm_ = bootstrap->createCommunicator(name, options); + + mscclpp::CudaDeviceGuard deviceGuard(device_.index()); + nRanksPerNode_ = comm_->bootstrap()->getNranksPerNode(); + + MSCCLPP_CUDATHROW(cudaStreamCreateWithFlags(&internal_stream_, cudaStreamNonBlocking)); + + // Scratch buffer must use GpuBuffer (cuMemMap) so its POSIX fd is registered + // in the unix socket server; plain cudaMalloc causes "Requested fd not found" + // crashes during cross-rank IPC sharing. + scratchBuffer_ = mscclpp::GpuBuffer(kScratchBufferSize).memory(); + executor_ = std::make_shared(comm_, scratchBuffer_); + + auto [flagBuf, flagSize] = mscclpp::getFlagBuffer(); + flagBuffer_ = flagBuf; + flagBufferSize_ = flagSize; + + // Install the topology-aware fallback selector. Body is out-of-line in + // selectAlgorithm() to keep init() readable. + // + // TODO: This selector duplicates src/ext/nccl/nccl.cc::algoSelector. The + // shared policy (DSL match → topology-aware native selectors) should be + // promoted into a single helper in src/ext/nccl/algorithm_selector.{hpp,cc} + // (e.g. `defaultFallbackSelector`) and reused by both backends. Kept + // duplicated here for now to scope this PR to python/mscclpp_torchcomms/. + auto builder = mscclpp::collective::AlgorithmCollectionBuilder::getInstance(); + builder->setFallbackAlgorithmSelector(&TorchCommMSCCLPP::selectAlgorithm); + algorithmCollection_ = + builder->buildDefaultAlgorithms(reinterpret_cast(scratchBuffer_.get()), kScratchBufferSize, + reinterpret_cast(flagBuffer_.get()), flagBufferSize_, rank_); + + event_pool_ = std::make_shared(256); + nccl_ = NcclFallback::tryCreate(comm_, rank_, size_); + + initialized_ = true; +} + +void TorchCommMSCCLPP::finalize() { + if (!initialized_) return; + + // Drain our streams while NVLink memory is alive, then bootstrap-barrier so + // no rank tears down its communicator while another's kernel is still + // polling its NVLink memory. + if (internal_stream_) MSCCLPP_CUDATHROW(cudaStreamSynchronize(internal_stream_)); + MSCCLPP_CUDATHROW(cudaStreamSynchronize(at::cuda::getCurrentCUDAStream(device_.index()).stream())); + comm_->bootstrap()->barrier(); + + nccl_.reset(); + executor_.reset(); + event_pool_.reset(); + if (internal_stream_) { + MSCCLPP_CUDATHROW(cudaStreamDestroy(internal_stream_)); + internal_stream_ = nullptr; + } + scratchBuffer_.reset(); + flagBuffer_.reset(); + comm_.reset(); + initialized_ = false; +} + +// --- Metadata --- + +int TorchCommMSCCLPP::getRank() const { return rank_; } +int TorchCommMSCCLPP::getSize() const { return size_; } +std::string_view TorchCommMSCCLPP::getBackendName() const { return kBackendName; } +std::string_view TorchCommMSCCLPP::getCommName() const { return name_; } +const CommOptions& TorchCommMSCCLPP::getOptions() const { return options_; } +const at::Device& TorchCommMSCCLPP::getDevice() const { return device_; } + +// --- Collective dispatch --- +// +// All collectives funnel through executeCollective() (when MSCCL++ has a +// native algorithm) or ncclFallback() (when it doesn't). Each method below +// has one of three shapes: +// +// 1. NATIVE-ONLY — body is just `return executeCollective(...);` +// MSCCL++ has algorithms for the entire collective. +// Examples: all_gather_single, all_to_all_single +// +// 2. FALLBACK-ONLY — body is just `return ncclFallback(...);` +// No MSCCL++ native algorithm exists. +// To migrate to native once MSCCL++ adds support: +// replace `return ncclFallback(...)` with +// `return executeCollective(...)`. +// Examples: broadcast, barrier, send, recv, reduce, +// all_to_all_v_single +// +// 3. NATIVE+FALLBACK — try MSCCL++ first, fall back to NCCL when no +// native algorithm matches the request (e.g. unusual +// reduce op, message size with no plan). +// To remove the fallback once MSCCL++ covers the gap: +// delete the `ncclFallback(...)` call and any guards +// that gated it. +// Examples: all_reduce, reduce_scatter_single + +c10::intrusive_ptr TorchCommMSCCLPP::runAlgorithm(const std::shared_ptr& algo, + const void* sendbuf, void* recvbuf, size_t sendBytes, + size_t recvBytes, mscclpp::DataType dtype, + mscclpp::ReduceOp reduceOp, cudaStream_t stream, + std::chrono::milliseconds timeout) { + auto work = c10::make_intrusive(stream, device_.index(), timeout, event_pool_); + work->recordStart(); + algo->execute(comm_, sendbuf, recvbuf, sendBytes, recvBytes, dtype, reduceOp, stream, executor_); + work->recordEnd(); + return work; +} + +c10::intrusive_ptr TorchCommMSCCLPP::executeCollective(const std::string& collective, const void* sendbuf, + void* recvbuf, size_t sendBytes, size_t recvBytes, + mscclpp::DataType dtype, mscclpp::ReduceOp reduceOp, + bool async_op, std::chrono::milliseconds timeout) { + cudaStream_t stream = getOperationStream(async_op); + mscclpp::CollectiveRequest request{size_, nRanksPerNode_, rank_, sendbuf, recvbuf, + sendBytes, stream, collective, dtype, {}}; + + auto algo = algorithmCollection_.selectAlgorithm(request); + if (!algo) { + throw std::runtime_error("[TorchCommMSCCLPP] no algorithm registered for '" + collective + + "' size=" + std::to_string(sendBytes)); + } + return runAlgorithm(algo, sendbuf, recvbuf, sendBytes, recvBytes, dtype, reduceOp, stream, timeout); +} + +c10::intrusive_ptr TorchCommMSCCLPP::all_reduce(at::Tensor& tensor, const ReduceOp& op, bool async_op, + const AllReduceOptions& options) { + checkInitialized(); + TORCH_CHECK(tensor.is_contiguous(), "[TorchCommMSCCLPP] all_reduce requires contiguous tensor"); + if (auto mscclppOp = torchReduceOpToMscclpp(op)) { + return executeCollective("allreduce", tensor.data_ptr(), tensor.data_ptr(), tensor.nbytes(), tensor.nbytes(), + torchDtypeToMscclpp(tensor.scalar_type()), *mscclppOp, async_op, options.timeout); + } + cudaStream_t stream = getOperationStream(async_op); + return ncclFallback("all_reduce", stream, options.timeout, [&] { + nccl_->allReduce(tensor.data_ptr(), tensor.data_ptr(), tensor.numel(), tensor.scalar_type(), op, stream); + }); +} + +c10::intrusive_ptr TorchCommMSCCLPP::all_gather_single(at::Tensor& output, const at::Tensor& input, + bool async_op, + const AllGatherSingleOptions& options) { + checkInitialized(); + TORCH_CHECK(input.is_contiguous() && output.is_contiguous(), + "[TorchCommMSCCLPP] all_gather_single requires contiguous tensors"); + return executeCollective("allgather", input.data_ptr(), output.data_ptr(), input.nbytes(), output.nbytes(), + torchDtypeToMscclpp(input.scalar_type()), mscclpp::NOP, async_op, options.timeout); +} + +c10::intrusive_ptr TorchCommMSCCLPP::all_to_all_single(at::Tensor& output, const at::Tensor& input, + bool async_op, const AllToAllSingleOptions& options) { + checkInitialized(); + TORCH_CHECK(input.is_contiguous() && output.is_contiguous(), + "[TorchCommMSCCLPP] all_to_all_single requires contiguous tensors"); + return executeCollective("alltoall", input.data_ptr(), output.data_ptr(), input.nbytes(), output.nbytes(), + torchDtypeToMscclpp(input.scalar_type()), mscclpp::NOP, async_op, options.timeout); +} + +// reduce_scatter: try MSCCL++ first; fall back to NCCL if no native algorithm. +c10::intrusive_ptr TorchCommMSCCLPP::reduce_scatter_single(at::Tensor& output, const at::Tensor& input, + const ReduceOp& op, bool async_op, + const ReduceScatterSingleOptions& options) { + checkInitialized(); + TORCH_CHECK(input.is_contiguous() && output.is_contiguous(), + "[TorchCommMSCCLPP] reduce_scatter_single requires contiguous tensors"); + + const auto dtype = input.scalar_type(); + const auto mscclppDtype = torchDtypeToMscclpp(dtype); + cudaStream_t stream = getOperationStream(async_op); + + mscclpp::CollectiveRequest request{size_, + nRanksPerNode_, + rank_, + input.data_ptr(), + output.data_ptr(), + static_cast(input.nbytes()), + stream, + "reducescatter", + mscclppDtype, + {}}; + auto mscclppOp = torchReduceOpToMscclpp(op); + auto algo = mscclppOp ? algorithmCollection_.selectAlgorithm(request) : nullptr; + if (algo) { + return runAlgorithm(algo, input.data_ptr(), output.data_ptr(), input.nbytes(), output.nbytes(), mscclppDtype, + *mscclppOp, stream, options.timeout); + } + return ncclFallback("reduce_scatter_single", stream, options.timeout, [&] { + nccl_->reduceScatter(input.data_ptr(), output.data_ptr(), output.numel(), dtype, op, stream); + }); +} + +c10::intrusive_ptr TorchCommMSCCLPP::broadcast(at::Tensor& tensor, int root, bool async_op, + const BroadcastOptions& options) { + checkInitialized(); + TORCH_CHECK(tensor.is_contiguous(), "[TorchCommMSCCLPP] broadcast requires contiguous tensor"); + cudaStream_t stream = getOperationStream(async_op); + return ncclFallback("broadcast", stream, options.timeout, [&] { + nccl_->broadcast(tensor.data_ptr(), tensor.data_ptr(), tensor.numel(), tensor.scalar_type(), root, stream); + }); +} + +c10::intrusive_ptr TorchCommMSCCLPP::barrier(bool async_op, const BarrierOptions& options) { + checkInitialized(); + cudaStream_t stream = getOperationStream(async_op); + return ncclFallback("barrier", stream, options.timeout, [&] { nccl_->barrier(stream); }); +} + +// --- Unsupported operations --- +// +// Tensor-list collective variants, _v collectives, scatter/gather, +// batch_op_issue, and split() aren't part of MSCCL++'s scope and have no +// clean NCCL one-call equivalent (most would need group send/recv stitching +// or sub-communicator machinery). Each throws with guidance to use a +// separate NCCL/RCCL TorchComm. + +c10::intrusive_ptr TorchCommMSCCLPP::send(const at::Tensor& tensor, int peer, bool async_op, + const SendOptions& options) { + checkInitialized(); + TORCH_CHECK(tensor.is_contiguous(), "[TorchCommMSCCLPP] send requires contiguous tensor"); + cudaStream_t stream = getOperationStream(async_op); + return ncclFallback("send", stream, options.timeout, [&] { + nccl_->send(tensor.data_ptr(), tensor.numel(), tensor.scalar_type(), peer, stream); + }); +} +c10::intrusive_ptr TorchCommMSCCLPP::recv(at::Tensor& tensor, int peer, bool async_op, + const RecvOptions& options) { + checkInitialized(); + TORCH_CHECK(tensor.is_contiguous(), "[TorchCommMSCCLPP] recv requires contiguous tensor"); + cudaStream_t stream = getOperationStream(async_op); + return ncclFallback("recv", stream, options.timeout, [&] { + nccl_->recv(tensor.data_ptr(), tensor.numel(), tensor.scalar_type(), peer, stream); + }); +} + +c10::intrusive_ptr TorchCommMSCCLPP::reduce(const at::Tensor& tensor, int root, const ReduceOp& op, + bool async_op, const ReduceOptions& options) { + checkInitialized(); + TORCH_CHECK(tensor.is_contiguous(), "[TorchCommMSCCLPP] reduce requires contiguous tensor"); + cudaStream_t stream = getOperationStream(async_op); + return ncclFallback("reduce", stream, options.timeout, [&] { + nccl_->reduce(tensor.data_ptr(), tensor.data_ptr(), tensor.numel(), tensor.scalar_type(), op, root, stream); + }); +} + +c10::intrusive_ptr TorchCommMSCCLPP::all_to_all_v_single(at::Tensor& output, const at::Tensor& input, + const std::vector& outputSplitSizes, + const std::vector& inputSplitSizes, + bool async_op, + const AllToAllvSingleOptions& options) { + checkInitialized(); + TORCH_CHECK(input.is_contiguous() && output.is_contiguous(), + "[TorchCommMSCCLPP] all_to_all_v_single requires contiguous tensors"); + TORCH_CHECK(static_cast(inputSplitSizes.size()) == size_ && static_cast(outputSplitSizes.size()) == size_, + "[TorchCommMSCCLPP] all_to_all_v_single: split-size vectors must have length world_size"); + std::vector sendOffsets(size_, 0), recvOffsets(size_, 0); + for (int i = 1; i < size_; ++i) { + sendOffsets[i] = sendOffsets[i - 1] + inputSplitSizes[i - 1]; + recvOffsets[i] = recvOffsets[i - 1] + outputSplitSizes[i - 1]; + } + cudaStream_t stream = getOperationStream(async_op); + return ncclFallback("all_to_all_v_single", stream, options.timeout, [&] { + nccl_->allToAllV(input.data_ptr(), output.data_ptr(), inputSplitSizes, outputSplitSizes, sendOffsets, + recvOffsets, input.scalar_type(), stream); + }); +} + +#define MSCCLPP_UNSUPPORTED(op, msg) \ + throw std::runtime_error("[TorchCommMSCCLPP] " op " is not supported. " msg \ + " Use a separate NCCL/RCCL TorchComm for this operation.") + +// One-liner stub for unsupported collectives that return a TorchWork handle. +#define UNSUPPORTED_OP(method, signature, label, msg) \ + c10::intrusive_ptr TorchCommMSCCLPP::method signature { \ + MSCCLPP_UNSUPPORTED(label, msg); \ + } + +c10::intrusive_ptr TorchCommMSCCLPP::batch_op_issue(const std::vector&, bool, + const BatchP2POptions&) { + MSCCLPP_UNSUPPORTED("batch_op_issue()", ""); +} + +UNSUPPORTED_OP(all_gather, + (const std::vector&, const at::Tensor&, bool, const AllGatherOptions&), + "all_gather() (tensor-list variant)", "Use all_gather_single() instead.") +UNSUPPORTED_OP(all_gather_v, + (const std::vector&, const at::Tensor&, bool, const AllGatherOptions&), + "all_gather_v()", "") +UNSUPPORTED_OP(reduce_scatter, + (at::Tensor&, const std::vector&, const ReduceOp&, bool, const ReduceScatterOptions&), + "reduce_scatter() (tensor-list variant)", "Use reduce_scatter_single() instead.") +UNSUPPORTED_OP(reduce_scatter_v, + (at::Tensor&, const std::vector&, const ReduceOp&, bool, const ReduceScatterOptions&), + "reduce_scatter_v()", "") +UNSUPPORTED_OP(all_to_all, + (const std::vector&, const std::vector&, bool, const AllToAllOptions&), + "all_to_all() (tensor-list variant)", "Use all_to_all_single() instead.") +UNSUPPORTED_OP(scatter, + (at::Tensor&, const std::vector&, int, bool, const ScatterOptions&), + "scatter()", "") +UNSUPPORTED_OP(gather, + (const std::vector&, const at::Tensor&, int, bool, const GatherOptions&), + "gather()", "") + +std::shared_ptr TorchCommMSCCLPP::split(const std::vector&, const std::string&, + const CommOptions&) { + MSCCLPP_UNSUPPORTED("split()", ""); +} + +#undef UNSUPPORTED_OP +#undef MSCCLPP_UNSUPPORTED + +// --- Factory registration --- + +namespace { +struct Registration { + Registration() { + TorchCommFactory::get().register_backend("mscclpp", []() { return std::make_shared(); }); + } +}; +static const Registration registration{}; +} // namespace + +} // namespace torch::comms diff --git a/python/mscclpp_torchcomms/csrc/TorchCommMSCCLPP.hpp b/python/mscclpp_torchcomms/csrc/TorchCommMSCCLPP.hpp new file mode 100644 index 000000000..31465adb5 --- /dev/null +++ b/python/mscclpp_torchcomms/csrc/TorchCommMSCCLPP.hpp @@ -0,0 +1,210 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "TorchWorkMSCCLPP.hpp" + +namespace torch::comms { + +/// MSCCL++ backend for TorchComms. +/// +/// This is a thin adapter that maps TorchCommBackend collective operations to +/// MSCCL++'s AlgorithmCollection. Algorithm selection (native vs DSL, which +/// variant for a given message size / topology) is handled entirely by MSCCL++ +/// via AlgorithmCollection::selectAlgorithm(). The backend just builds a +/// CollectiveRequest and calls algo->execute(). +/// +/// Supported collectives: all_reduce, all_gather_single, reduce_scatter_single, +/// all_to_all_single. All others throw with guidance to use NCCL/RCCL. +/// +/// Lifecycle: +/// 1. TorchCommFactory creates an instance via the registered "mscclpp" factory +/// 2. Caller invokes init() with device, name, and CommOptions (including c10d::Store) +/// 3. init() bootstraps rank discovery, creates the MSCCL++ Communicator, builds +/// the AlgorithmCollection with all default native + DSL algorithms +/// 4. Collectives are dispatched through executeCollective() +/// 5. finalize() syncs streams, runs a bootstrap barrier, and tears down in reverse order +class TorchCommMSCCLPP : public TorchCommBackend, public std::enable_shared_from_this { + public: + static constexpr std::string_view kBackendName = "mscclpp"; + + TorchCommMSCCLPP(); + ~TorchCommMSCCLPP() override; + + TorchCommMSCCLPP(const TorchCommMSCCLPP&) = delete; + TorchCommMSCCLPP& operator=(const TorchCommMSCCLPP&) = delete; + + // Lifecycle + void init(at::Device device, const std::string& name, const CommOptions& options = {}) override; + void finalize() override; + + // Metadata + int getRank() const override; + int getSize() const override; + std::string_view getBackendName() const override; + std::string_view getCommName() const override; + const CommOptions& getOptions() const override; + const at::Device& getDevice() const override; + + // Point-to-point (unsupported) + c10::intrusive_ptr send(const at::Tensor& tensor, int dst, bool async_op, + const SendOptions& options = {}) override; + c10::intrusive_ptr recv(at::Tensor& tensor, int src, bool async_op, + const RecvOptions& options = {}) override; + c10::intrusive_ptr batch_op_issue(const std::vector& ops, bool async_op, + const BatchP2POptions& options = {}) override; + + // Collectives + c10::intrusive_ptr broadcast(at::Tensor& tensor, int root, bool async_op, + const BroadcastOptions& options = {}) override; + c10::intrusive_ptr all_reduce(at::Tensor& tensor, const ReduceOp& op, bool async_op, + const AllReduceOptions& options = {}) override; + c10::intrusive_ptr reduce(const at::Tensor& tensor, int root, const ReduceOp& op, bool async_op, + const ReduceOptions& options = {}) override; + c10::intrusive_ptr all_gather(const std::vector& tensor_list, const at::Tensor& tensor, + bool async_op, const AllGatherOptions& options = {}) override; + c10::intrusive_ptr all_gather_v(const std::vector& tensor_list, const at::Tensor& tensor, + bool async_op, const AllGatherOptions& options = {}) override; + c10::intrusive_ptr all_gather_single(at::Tensor& output, const at::Tensor& input, bool async_op, + const AllGatherSingleOptions& options = {}) override; + c10::intrusive_ptr reduce_scatter(at::Tensor& output, const std::vector& input_list, + const ReduceOp& op, bool async_op, + const ReduceScatterOptions& options = {}) override; + c10::intrusive_ptr reduce_scatter_v(at::Tensor& output, const std::vector& input_list, + const ReduceOp& op, bool async_op, + const ReduceScatterOptions& options = {}) override; + c10::intrusive_ptr reduce_scatter_single(at::Tensor& output, const at::Tensor& input, const ReduceOp& op, + bool async_op, + const ReduceScatterSingleOptions& options = {}) override; + c10::intrusive_ptr all_to_all_single(at::Tensor& output, const at::Tensor& input, bool async_op, + const AllToAllSingleOptions& options = {}) override; + c10::intrusive_ptr all_to_all_v_single(at::Tensor& output, const at::Tensor& input, + const std::vector& output_split_sizes, + const std::vector& input_split_sizes, bool async_op, + const AllToAllvSingleOptions& options = {}) override; + c10::intrusive_ptr all_to_all(const std::vector& output_tensor_list, + const std::vector& input_tensor_list, bool async_op, + const AllToAllOptions& options = {}) override; + c10::intrusive_ptr barrier(bool async_op, const BarrierOptions& options = {}) override; + c10::intrusive_ptr scatter(at::Tensor& output_tensor, const std::vector& input_tensor_list, + int root, bool async_op, const ScatterOptions& options = {}) override; + c10::intrusive_ptr gather(const std::vector& output_tensor_list, + const at::Tensor& input_tensor, int root, bool async_op, + const GatherOptions& options = {}) override; + + // Communicator management (unsupported) + std::shared_ptr split(const std::vector& ranks, const std::string& name, + const CommOptions& options = {}) override; + + private: + void checkInitialized() const; + + /// Map PyTorch scalar type to MSCCL++ DataType. + static mscclpp::DataType torchDtypeToMscclpp(at::ScalarType dtype); + + /// Map TorchComms ReduceOp to MSCCL++ ReduceOp. Returns std::nullopt for + /// ops MSCCL++ kernels do not implement (MAX, PRODUCT, ...); callers route + /// those to NcclFallback. Single source of truth for "does MSCCL++ handle + /// this op natively?". + static std::optional torchReduceOpToMscclpp(const ReduceOp& op); + + /// Get the appropriate stream for an operation. + cudaStream_t getOperationStream(bool async_op) const; + + /// Selector lambda body installed on AlgorithmCollectionBuilder. Defined + /// out-of-line to keep init() readable. + static std::shared_ptr selectAlgorithm( + const std::unordered_map>>& algoMapByCollective, + const mscclpp::CollectiveRequest& request); + + /// Wrap a NCCL fallback dispatch body in start/end GPU events on `stream` + /// and return a TorchWorkMSCCLPP handle. Throws if libnccl was not + /// dlopen'd at init time. Used by every collective that may go through + /// the fallback (broadcast, barrier, send, recv, reduce, ...). + template + c10::intrusive_ptr ncclFallback(const char* op, cudaStream_t stream, std::chrono::milliseconds timeout, + Fn&& body); + + /// Execute a resolved MSCCL++ algorithm wrapped in start/end GPU events. + /// Shared by executeCollective() and reduce_scatter_single()'s native + /// branch (which has to call selectAlgorithm() itself to gate on reduce-op + /// support). + c10::intrusive_ptr runAlgorithm(const std::shared_ptr& algo, const void* sendbuf, + void* recvbuf, size_t sendBytes, size_t recvBytes, mscclpp::DataType dtype, + mscclpp::ReduceOp reduceOp, cudaStream_t stream, + std::chrono::milliseconds timeout); + + /// Central dispatch for all supported collectives. + /// + /// Builds a CollectiveRequest from the arguments, asks AlgorithmCollection to + /// select the best algorithm (native or DSL), creates a TorchWorkMSCCLPP handle + /// with start/end GPU events, executes the algorithm, and returns the work handle. + /// The caller's stream waits on the end event when work->wait() is called. + c10::intrusive_ptr executeCollective(const std::string& collective, const void* sendbuf, void* recvbuf, + size_t sendBytes, size_t recvBytes, mscclpp::DataType dtype, + mscclpp::ReduceOp reduceOp, bool async_op, + std::chrono::milliseconds timeout); + + bool initialized_ = false; + at::Device device_{at::kCUDA}; + std::string name_; + CommOptions options_; + int rank_ = 0; + int size_ = 1; + int nRanksPerNode_ = 1; // cached from bootstrap; used in CollectiveRequest for algorithm selection + + /// MSCCL++ communicator — owns the bootstrap, context, and all registered connections. + std::shared_ptr comm_; + + /// Executor for DSL-based algorithms. Native algorithms ignore this, but DSL + /// algorithms need it to interpret JSON execution plans. Always passed to + /// algo->execute() so the backend doesn't need to distinguish algorithm types. + std::shared_ptr executor_; + + /// Registry of all available algorithms (native + DSL). Built once in init() + /// via AlgorithmCollectionBuilder::buildDefaultAlgorithms(). selectAlgorithm() + /// picks the best algorithm for a given collective + message size + topology. + mscclpp::AlgorithmCollection algorithmCollection_; + + /// Dedicated stream for async collective launches. Sync ops use the caller's + /// current PyTorch CUDA stream instead, so the kernel is inline with their work. + cudaStream_t internal_stream_ = nullptr; + + /// Reusable GPU event pool shared across all TorchWorkMSCCLPP handles from + /// this communicator. Avoids cudaEventCreate/Destroy overhead per collective. + std::shared_ptr event_pool_; + + /// GPU scratch memory used by native algorithms (e.g., allreduce RS+AG pipeline) + /// for intermediate results. 128MB is the default size matching MSCCL++ conventions. + /// Allocated via GpuBuffer (cuMemMap) so POSIX file descriptors are registered + /// in the unix socket server for cross-rank IPC sharing. + std::shared_ptr scratchBuffer_; + static constexpr size_t kScratchBufferSize = 1 << 27; // 128MB + + /// Flag buffer shared pointer — must be kept alive for the lifetime of the + /// communicator since AlgorithmCollection references it. + std::shared_ptr flagBuffer_; + size_t flagBufferSize_ = 0; + + /// dlopen-based NCCL fallback for collectives MSCCL++ doesn't natively + /// implement (reduce_scatter, broadcast, barrier on certain configs). Null + /// if libnccl couldn't be loaded — those collectives then throw. + std::unique_ptr nccl_; +}; + +} // namespace torch::comms diff --git a/python/mscclpp_torchcomms/csrc/TorchCommMSCCLPPBootstrap.cpp b/python/mscclpp_torchcomms/csrc/TorchCommMSCCLPPBootstrap.cpp new file mode 100644 index 000000000..0442d0866 --- /dev/null +++ b/python/mscclpp_torchcomms/csrc/TorchCommMSCCLPPBootstrap.cpp @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "TorchCommMSCCLPPBootstrap.hpp" + +#include +#include +#include +#include +#include + +namespace torch::comms { + +// Static counter ensures unique store keys when multiple communicators are +// created with the same name in the same process (e.g., separate comm groups). +std::atomic TorchCommMSCCLPPBootstrap::counter_{0}; + +// Discovers rank and world size from torchrun/torchelastic environment variables +// (RANK, WORLD_SIZE, LOCAL_RANK). query_ranksize() is a torchcomms utility. +TorchCommMSCCLPPBootstrap::TorchCommMSCCLPPBootstrap(c10::intrusive_ptr store, c10::Device device, + std::chrono::milliseconds timeout) + : store_(std::move(store)), device_(device), timeout_(timeout) { + auto [rank, size] = query_ranksize(); + rank_ = rank; + size_ = size; +} + +TorchCommMSCCLPPBootstrap::~TorchCommMSCCLPPBootstrap() noexcept = default; + +mscclpp::UniqueId TorchCommMSCCLPPBootstrap::exchangeUniqueId(const std::string& name) { + // Single-process: no coordination needed + if (size_ == 1) { + return mscclpp::TcpBootstrap::createUniqueId(); + } + + // Multi-process without a caller-supplied store: fall back to createPrefixStore + if (!store_) { + store_ = createPrefixStore("mscclpp", timeout_); + } + + std::string key = "mscclpp_uniqueid_" + name + "_" + std::to_string(counter_++); + + mscclpp::UniqueId unique_id; + + if (rank_ == 0) { + unique_id = mscclpp::TcpBootstrap::createUniqueId(); + std::vector vec(unique_id.begin(), unique_id.end()); + store_->set(key, vec); + } else { + store_->wait({key}, timeout_); + auto vec = store_->get(key); + if (vec.size() != sizeof(mscclpp::UniqueId)) { + throw std::runtime_error("[TorchCommMSCCLPPBootstrap] Invalid UniqueId size: expected " + + std::to_string(sizeof(mscclpp::UniqueId)) + ", got " + std::to_string(vec.size())); + } + std::copy(vec.begin(), vec.end(), unique_id.begin()); + } + + return unique_id; +} + +std::shared_ptr TorchCommMSCCLPPBootstrap::createCommunicator(const std::string& name, + const CommOptions& /*options*/) { + mscclpp::UniqueId unique_id = exchangeUniqueId(name); + + auto bootstrap = std::make_shared(rank_, size_); + // Single-process (size==1): skip TCP initialization since there are no peers + // to connect to. The bootstrap object is still needed by the Communicator + // constructor, but it doesn't need an active TCP server. + if (size_ > 1) { + int64_t timeout_sec = std::max(int64_t{1}, std::chrono::duration_cast(timeout_).count()); + bootstrap->initialize(unique_id, timeout_sec); + } + + return std::make_shared(bootstrap); +} + +} // namespace torch::comms diff --git a/python/mscclpp_torchcomms/csrc/TorchCommMSCCLPPBootstrap.hpp b/python/mscclpp_torchcomms/csrc/TorchCommMSCCLPPBootstrap.hpp new file mode 100644 index 000000000..598e3baba --- /dev/null +++ b/python/mscclpp_torchcomms/csrc/TorchCommMSCCLPPBootstrap.hpp @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace torch::comms { + +/// Handles MSCCL++ bootstrap initialization using c10d::Store. +/// +/// Follows the TorchCommNCCLBootstrap pattern: +/// - Rank 0 generates a UniqueId via TcpBootstrap::createUniqueId() +/// - Rank 0 writes raw bytes to the store +/// - All other ranks wait on the store key and read the UniqueId +/// - All ranks call TcpBootstrap::initialize(uniqueId) with the same ID +class TorchCommMSCCLPPBootstrap { + public: + TorchCommMSCCLPPBootstrap(c10::intrusive_ptr store, c10::Device device, + std::chrono::milliseconds timeout); + + ~TorchCommMSCCLPPBootstrap() noexcept; + + TorchCommMSCCLPPBootstrap(const TorchCommMSCCLPPBootstrap&) = delete; + TorchCommMSCCLPPBootstrap& operator=(const TorchCommMSCCLPPBootstrap&) = delete; + + /// Create and initialize the MSCCL++ communicator. + std::shared_ptr createCommunicator(const std::string& name, const CommOptions& options = {}); + + int getRank() const { return rank_; } + int getSize() const { return size_; } + + private: + /// Exchange UniqueId via c10d::Store (rank 0 generates, others read). + mscclpp::UniqueId exchangeUniqueId(const std::string& name); + + c10::intrusive_ptr store_; + c10::Device device_; + std::chrono::milliseconds timeout_; + int rank_; + int size_; + + static std::atomic counter_; +}; + +} // namespace torch::comms diff --git a/python/mscclpp_torchcomms/csrc/TorchCommMSCCLPPPy.cpp b/python/mscclpp_torchcomms/csrc/TorchCommMSCCLPPPy.cpp new file mode 100644 index 000000000..f81ece78d --- /dev/null +++ b/python/mscclpp_torchcomms/csrc/TorchCommMSCCLPPPy.cpp @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +// Pybind11 module exposing TorchCommMSCCLPP to Python. +// +// This is intentionally minimal — TorchCommMSCCLPP is used via the TorchCommBackend +// C++ interface (polymorphic dispatch). The Python binding just makes the class +// constructible so torchcomms' Python layer can instantiate it. All collective +// methods are called through the base class interface, not through Python bindings +// of individual methods. +// +// The extern "C" create_dynamic_loader_mscclpp() function is the dynamic loader +// interface required by torchcomms v0.2.0's TorchCommFactory. When torchcomms +// dlopen's our .so, it looks up this symbol to get function pointers for +// creating/destroying backend instances and checking ABI version compatibility. +// +// User-defined algorithms: configure via mscclpp.AlgorithmCollectionBuilder +// BEFORE creating the TorchComms communicator. The backend picks up whatever +// is registered on the builder singleton during init(). + +#include +#include +#include +#include + +#include "TorchCommMSCCLPP.hpp" + +namespace py = pybind11; +using namespace torch::comms; + +// --- Dynamic loader interface for torchcomms TorchCommFactory --- +// +// TorchComms discovers backends by dlopen'ing the .so pointed to by +// TORCHCOMMS_BACKEND_LIB_PATH_MSCCLPP, then calling: +// dlsym(handle, "create_dynamic_loader_mscclpp") +// The function name encodes the backend name ("mscclpp"). The returned +// DynamicLoaderInterface provides function pointers for creating/destroying +// backend instances and checking ABI version compatibility. + +static TorchCommBackend* new_comm_impl() { return new TorchCommMSCCLPP(); } + +static void destroy_comm_impl(TorchCommBackend* comm) { delete comm; } + +static const char* get_supported_version_impl() { return TORCHCOMM_BACKEND_ABI_VERSION; } + +extern "C" __attribute__((visibility("default"))) DynamicLoaderInterface create_dynamic_loader_mscclpp() { + return DynamicLoaderInterface{ + .new_comm = new_comm_impl, + .destroy_comm = destroy_comm_impl, + .get_supported_version = get_supported_version_impl, + }; +} + +// --- Pybind11 module --- + +PYBIND11_MODULE(_comms_mscclpp, m) { + m.doc() = "MSCCL++ backend for TorchComm"; + + py::class_>(m, "TorchCommMSCCLPP").def(py::init<>()); +} diff --git a/python/mscclpp_torchcomms/csrc/TorchWorkMSCCLPP.cpp b/python/mscclpp_torchcomms/csrc/TorchWorkMSCCLPP.cpp new file mode 100644 index 000000000..5affe5266 --- /dev/null +++ b/python/mscclpp_torchcomms/csrc/TorchWorkMSCCLPP.cpp @@ -0,0 +1,154 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include "TorchWorkMSCCLPP.hpp" + +#include + +#include +#include + +namespace torch::comms { + +// --- MscclppGpuEventPool --- +// +// CUDA/HIP event allocation is expensive (~5-10us per cudaEventCreate). +// Since every collective call needs 2 events (start + end), and training +// loops run thousands of collectives, we pool and reuse events to avoid +// that overhead. The pool is thread-safe and shared across all +// TorchWorkMSCCLPP instances from the same communicator. + +MscclppGpuEventPool::MscclppGpuEventPool(size_t max_size) : max_size_(max_size) {} + +MscclppGpuEventPool::~MscclppGpuEventPool() { + std::lock_guard lock(mutex_); + for (auto event : available_) { + cudaEventDestroy(event); + } + available_.clear(); +} + +// Returns a recycled event if one is available, otherwise allocates a new one. +// Events use cudaEventDisableTiming because we only need them for stream +// synchronization (cudaStreamWaitEvent), not for measuring elapsed time. +cudaEvent_t MscclppGpuEventPool::acquire() { + std::lock_guard lock(mutex_); + if (!available_.empty()) { + cudaEvent_t event = available_.back(); + available_.pop_back(); + return event; + } + cudaEvent_t event; + MSCCLPP_CUDATHROW(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); + return event; +} + +// Returns an event to the pool for reuse. If the pool is already at capacity, +// the event is destroyed instead to avoid unbounded memory growth. +void MscclppGpuEventPool::release(cudaEvent_t event) { + std::lock_guard lock(mutex_); + if (available_.size() < max_size_) { + available_.push_back(event); + } else { + cudaEventDestroy(event); + } +} + +// --- TorchWorkMSCCLPP --- +// +// Every TorchComms collective must return a TorchWork handle so the caller +// can track completion. TorchWorkMSCCLPP uses GPU events to do this without +// CPU blocking: +// +// 1. Before launching the collective kernel: recordStart() records start_event_ +// 2. After launching: recordEnd() records end_event_ +// 3. wait(): makes the caller's PyTorch stream wait on end_event_ via +// cudaStreamWaitEvent — this is purely GPU-side stream ordering, +// the CPU returns immediately +// 4. checkStatus(): polls events for completion/timeout detection +// +// This matches the TorchWorkNCCL pattern in the torchcomms NCCL backend. + +// Acquires two events from the pool: one for start, one for end. +// Events are returned to the pool in the destructor. +TorchWorkMSCCLPP::TorchWorkMSCCLPP(cudaStream_t op_stream, int device_index, std::chrono::milliseconds timeout_ms, + std::shared_ptr event_pool) + : op_stream_(op_stream), device_index_(device_index), timeout_ms_(timeout_ms), event_pool_(std::move(event_pool)) { + start_event_ = event_pool_->acquire(); + end_event_ = event_pool_->acquire(); +} + +TorchWorkMSCCLPP::~TorchWorkMSCCLPP() { + event_pool_->release(start_event_); + event_pool_->release(end_event_); +} + +// Records a GPU event on the operation stream BEFORE the collective kernel +// is launched. Used by checkStatus() to detect when the GPU actually starts +// executing (as opposed to sitting in the stream queue). +void TorchWorkMSCCLPP::recordStart() { MSCCLPP_CUDATHROW(cudaEventRecord(start_event_, op_stream_)); } + +// Records a GPU event on the operation stream AFTER the collective kernel +// is launched. wait() and checkStatus() use this event to determine when +// the collective has finished. +void TorchWorkMSCCLPP::recordEnd() { MSCCLPP_CUDATHROW(cudaEventRecord(end_event_, op_stream_)); } + +// Polls GPU events without blocking. Tracks a two-phase state machine: +// NOT_STARTED -> INPROGRESS (start_event_ done) -> COMPLETED (end_event_ done) +// Also enforces timeout: if end_event_ hasn't fired within timeout_ms_ after +// start_event_ fired, the status moves to TIMEDOUT. +TorchWork::WorkStatus TorchWorkMSCCLPP::checkStatus() { + if (status() == WorkStatus::COMPLETED || status() == WorkStatus::ERROR || status() == WorkStatus::TIMEDOUT) { + return status(); + } + + // Step 1: query start event to establish when the GPU began executing + if (!start_completed_time_.has_value()) { + cudaError_t start_status = cudaEventQuery(start_event_); + if (start_status == cudaSuccess) { + start_completed_time_ = std::chrono::steady_clock::now(); + setStatus(WorkStatus::INPROGRESS); + } else if (start_status != cudaErrorNotReady) { + setStatus(WorkStatus::ERROR); + return status(); + } + } + if (status() == WorkStatus::NOT_STARTED || status() == WorkStatus::ERROR) { + return status(); + } + + // Step 2: start event done — now query end event + cudaError_t end_status = cudaEventQuery(end_event_); + if (end_status == cudaSuccess) { + setStatus(WorkStatus::COMPLETED); + } else if (end_status == cudaErrorNotReady) { + auto elapsed = std::chrono::duration_cast(std::chrono::steady_clock::now() - + start_completed_time_.value()); + if (elapsed > timeout_ms_) { + setStatus(WorkStatus::TIMEDOUT); + } + } else { + setStatus(WorkStatus::ERROR); + } + + return status(); +} + +// Called by the user (or TorchComm wrapper) to synchronize on the collective. +// This does NOT block the CPU — it inserts a dependency edge on the GPU: +// the caller's current PyTorch CUDA stream will wait for end_event_ before +// executing any subsequent kernels. This is the same pattern NCCL uses. +void TorchWorkMSCCLPP::wait() { + WorkStatus current = checkStatus(); + if (current == WorkStatus::COMPLETED || current == WorkStatus::ERROR || current == WorkStatus::TIMEDOUT) { + return; + } + + // GPU-side wait: make the caller's current stream wait on end_event_. + // No CPU blocking — just stream ordering. + cudaStream_t current_stream = at::cuda::getCurrentCUDAStream(device_index_).stream(); + MSCCLPP_CUDATHROW(cudaStreamWaitEvent(current_stream, end_event_, 0)); + setStatus(WorkStatus::COMPLETED); +} + +} // namespace torch::comms diff --git a/python/mscclpp_torchcomms/csrc/TorchWorkMSCCLPP.hpp b/python/mscclpp_torchcomms/csrc/TorchWorkMSCCLPP.hpp new file mode 100644 index 000000000..f770c198a --- /dev/null +++ b/python/mscclpp_torchcomms/csrc/TorchWorkMSCCLPP.hpp @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace torch::comms { + +/// GPU event pool — reuses CUDA/HIP events to avoid alloc/free overhead. +/// +/// Thread-safe. Owned by TorchCommMSCCLPP and borrowed by TorchWorkMSCCLPP. +/// Events are created with cudaEventDisableTiming (no timing overhead) since +/// we only use them for stream synchronization. +class MscclppGpuEventPool { + public: + explicit MscclppGpuEventPool(size_t max_size = 256); + ~MscclppGpuEventPool(); + + MscclppGpuEventPool(const MscclppGpuEventPool&) = delete; + MscclppGpuEventPool& operator=(const MscclppGpuEventPool&) = delete; + + /// Acquire an event from the pool (or allocate a new one if empty). + cudaEvent_t acquire(); + + /// Return an event to the pool. If the pool is full, destroys the event. + void release(cudaEvent_t event); + + private: + std::vector available_; + std::mutex mutex_; + size_t max_size_; +}; + +/// GPU event-based async work handle for MSCCL++ operations. +/// +/// Follows TorchWorkNCCL pattern: +/// - recordStart() / recordEnd() bracket the MSCCL++ executor call +/// - wait() issues cudaStreamWaitEvent on the caller's current stream +/// (GPU-side, no CPU blocking) +/// - checkStatus() polls events and enforces timeout +class TorchWorkMSCCLPP : public TorchWork { + public: + TorchWorkMSCCLPP(cudaStream_t op_stream, int device_index, std::chrono::milliseconds timeout_ms, + std::shared_ptr event_pool); + ~TorchWorkMSCCLPP() override; + + TorchWorkMSCCLPP(const TorchWorkMSCCLPP&) = delete; + TorchWorkMSCCLPP& operator=(const TorchWorkMSCCLPP&) = delete; + + void wait() override; + std::chrono::milliseconds getTimeout() const override { return timeout_ms_; } + + /// Record start event on op_stream_ before launching the collective. + void recordStart(); + + /// Record end event on op_stream_ after launching the collective. + void recordEnd(); + + private: + WorkStatus checkStatus(); + + cudaEvent_t start_event_; + cudaEvent_t end_event_; + cudaStream_t op_stream_; // not owned + int device_index_; + std::chrono::milliseconds timeout_ms_; + std::shared_ptr event_pool_; + std::optional start_completed_time_; +}; + +} // namespace torch::comms diff --git a/python/mscclpp_torchcomms/pyproject.toml b/python/mscclpp_torchcomms/pyproject.toml new file mode 100644 index 000000000..ff21371a3 --- /dev/null +++ b/python/mscclpp_torchcomms/pyproject.toml @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +[build-system] +requires = [ + "scikit-build-core>=0.10.0", + "setuptools-scm[toml]>=8", + "pybind11", +] +build-backend = "scikit_build_core.build" + +[project] +name = "mscclpp-torchcomms" +dynamic = ["version"] +description = "TorchComms backend for MSCCL++" +requires-python = ">=3.8" +dependencies = [ + "mscclpp", + "torch", +] + +[tool.setuptools_scm] +root = "../.." +write_to = "python/mscclpp/_version.py" +version_scheme = "no-guess-dev" + +[tool.scikit-build] +cmake.version = ">=3.25.0" +cmake.build-type = "Release" +cmake.source-dir = "../.." +cmake.targets = ["_comms_mscclpp"] +build-dir = "build/{wheel_tag}" +metadata.version.provider = "scikit_build_core.metadata.setuptools_scm" +# Only install the torchcomm component — skip mscclpp headers/libs from the root CMake. +install.components = ["torchcomm"] + +[tool.scikit-build.wheel] +packages = ["python/mscclpp_torchcomms"] +install-dir = "mscclpp_torchcomms" +license-files = ["../../LICENSE"] +exclude = ["mscclpp_torchcomms/*.cpp", "mscclpp_torchcomms/csrc/*"] + +[tool.scikit-build.cmake.define] +MSCCLPP_BUILD_PYTHON_BINDINGS = "OFF" +MSCCLPP_BUILD_TESTS = "OFF" +MSCCLPP_BUILD_EXT_TORCHCOMMS = "ON" diff --git a/python/mscclpp_torchcomms/requirements_cuda12.txt b/python/mscclpp_torchcomms/requirements_cuda12.txt new file mode 100644 index 000000000..4adc20061 --- /dev/null +++ b/python/mscclpp_torchcomms/requirements_cuda12.txt @@ -0,0 +1,9 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +# Requirements for the TorchComms MSCCL++ backend (optional). +# Install with: pip install -r python/mscclpp_torchcomms/requirements_cuda12.txt + +torch>=2.0.0 +pybind11 +torchcomms>=0.2.0 diff --git a/test/torchcomms/bench_report.py b/test/torchcomms/bench_report.py new file mode 100644 index 000000000..8c94e911c --- /dev/null +++ b/test/torchcomms/bench_report.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Generate performance report and figures from TorchComms benchmark results. + +Reads TorchComms allreduce results and produces: + - report.txt: formatted performance table + - latency.png: latency vs message size (log-log) + - bandwidth.png: bus bandwidth vs message size + +Not meant to be run directly — called by run_benchmarks.sh. +""" + +import argparse +import json +import os +import sys + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt + + +def format_size(nbytes): + if nbytes >= 1 << 20: + return f"{nbytes / (1 << 20):.0f}MB" + elif nbytes >= 1 << 10: + return f"{nbytes / (1 << 10):.0f}KB" + return f"{nbytes}B" + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--torchcomms-json", required=True) + parser.add_argument("--nproc", type=int, required=True) + parser.add_argument("--outdir", required=True) + args = parser.parse_args() + + with open(args.torchcomms_json) as f: + tc_data = json.load(f) + + tc_results = {entry["size"]: entry for entry in tc_data} + tc_sizes = sorted(tc_results.keys()) + + if not tc_sizes: + print("ERROR: No results found.", file=sys.stderr) + sys.exit(1) + + # --- Algorithm region spans --- + algo_regions = [] + prev_algo = tc_results[tc_sizes[0]].get("algorithm", "") + region_start = tc_sizes[0] + for i in range(1, len(tc_sizes)): + cur_algo = tc_results[tc_sizes[i]].get("algorithm", "") + if cur_algo != prev_algo: + algo_regions.append((region_start, tc_sizes[i - 1], prev_algo)) + region_start = tc_sizes[i] + prev_algo = cur_algo + algo_regions.append((region_start, tc_sizes[-1], prev_algo)) + + algo_colors = { + "allpair_packet": "#E3F2FD", + "nvls_packet": "#FFF3E0", + "packet": "#E8F5E9", + "nvls_warp_pipeline": "#F3E5F5", + "nvls_block_pipeline": "#FFF9C4", + } + + def add_algo_regions(ax, ymax): + for xmin, xmax, algo in algo_regions: + color = algo_colors.get(algo, "#F5F5F5") + ax.axvspan(xmin * 0.7, xmax * 1.4, alpha=0.3, color=color, zorder=0) + label_x = (xmin * xmax) ** 0.5 + ax.text( + label_x, + ymax * 0.85, + algo.replace("_", "\n"), + fontsize=7, + ha="center", + va="top", + style="italic", + color="#555555", + ) + + # --- Report --- + lines = [] + lines.append(f"MSCCL++ AllReduce via TorchComms — {args.nproc}x NVIDIA H100 80GB (NVSwitch)") + lines.append("") + lines.append(f"{'Size':<10} {'Time(us)':<12} {'AlgBW(GB/s)':<14} {'BusBW(GB/s)':<14} {'Algorithm':<30}") + lines.append("-" * 84) + + for size in tc_sizes: + r = tc_results[size] + lines.append( + f"{format_size(size):<10} " + f"{r['time_us']:<12.1f} " + f"{r.get('algbw_gbps', 0):<14.1f} " + f"{r['busbw_gbps']:<14.1f} " + f"{r.get('algorithm', ''):<30}" + ) + + report_path = os.path.join(args.outdir, "report.txt") + with open(report_path, "w") as f: + f.write("\n".join(lines) + "\n") + + # --- Latency figure --- + tc_times = [tc_results[s]["time_us"] for s in tc_sizes] + + fig, ax = plt.subplots(figsize=(14, 7)) + ax.plot(tc_sizes, tc_times, "o-", linewidth=2.5, markersize=7, label="MSCCL++ via TorchComms", color="#2196F3") + ax.set_xscale("log", base=2) + ax.set_yscale("log") + ax.set_xlabel("Message Size", fontsize=12) + ax.set_ylabel("Latency (Ξs)", fontsize=12) + ax.set_title( + f"MSCCL++ AllReduce Latency — {args.nproc}x NVIDIA H100 80GB (single-node, NVSwitch)", + fontsize=13, + ) + ax.legend(fontsize=11, loc="upper left") + ax.grid(True, alpha=0.3) + add_algo_regions(ax, max(tc_times)) + + tick_sizes = [ + s + for s in tc_sizes + if s + in [ + 1024, + 4096, + 16384, + 65536, + 262144, + 1048576, + 4 * 1024 * 1024, + 16 * 1024 * 1024, + 64 * 1024 * 1024, + 128 * 1024 * 1024, + ] + ] + if tick_sizes: + ax.set_xticks(tick_sizes) + ax.set_xticklabels([format_size(s) for s in tick_sizes], rotation=45) + + plt.tight_layout() + plt.savefig(os.path.join(args.outdir, "latency.png"), dpi=150) + plt.close() + + # --- Bandwidth figure --- + tc_algbws = [tc_results[s].get("algbw_gbps", 0) for s in tc_sizes] + tc_busbws = [tc_results[s]["busbw_gbps"] for s in tc_sizes] + + fig, ax = plt.subplots(figsize=(14, 7)) + ax.plot(tc_sizes, tc_busbws, "o-", linewidth=2.5, markersize=7, label="Bus Bandwidth", color="#2196F3") + ax.plot(tc_sizes, tc_algbws, "s--", linewidth=2, markersize=6, label="Algorithm Bandwidth", color="#FF9800") + ax.set_xscale("log", base=2) + ax.set_xlabel("Message Size", fontsize=12) + ax.set_ylabel("Bandwidth (GB/s)", fontsize=12) + ax.set_title( + f"MSCCL++ AllReduce Bandwidth — {args.nproc}x NVIDIA H100 80GB (single-node, NVSwitch)", + fontsize=13, + ) + ax.legend(fontsize=11, loc="upper left") + ax.grid(True, alpha=0.3) + add_algo_regions(ax, max(max(tc_algbws), max(tc_busbws))) + + if tick_sizes: + ax.set_xticks(tick_sizes) + ax.set_xticklabels([format_size(s) for s in tick_sizes], rotation=45) + + plt.tight_layout() + plt.savefig(os.path.join(args.outdir, "bandwidth.png"), dpi=150) + plt.close() + + print(f"Report: {report_path}") + print(f"Figures: {args.outdir}/latency.png, {args.outdir}/bandwidth.png") + + +if __name__ == "__main__": + main() diff --git a/test/torchcomms/bench_torchcomms.py b/test/torchcomms/bench_torchcomms.py new file mode 100644 index 000000000..20b5d42d6 --- /dev/null +++ b/test/torchcomms/bench_torchcomms.py @@ -0,0 +1,245 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""MSCCL++ collective benchmark via TorchComms. + +Measures collective latency and bandwidth through the torchcomms path +(torchcomms.new_comm("mscclpp") → TorchCommMSCCLPP → executeCollective). + +Supported collectives and their native MSCCL++ algorithms (H100, single-node): + + AllReduce: + <=16KB allpair_packet + 16KB-32KB nvls_packet + 32KB-1MB packet + 1MB-16MB nvls_warp_pipeline + >=16MB nvls_block_pipeline + + AllGather: + <=32MB fullmesh2 + >32MB fullmesh + +Run: + torchrun --nproc_per_node=8 test/torchcomms/bench_torchcomms.py --collective allreduce + torchrun --nproc_per_node=8 test/torchcomms/bench_torchcomms.py --collective allgather + torchrun --nproc_per_node=8 test/torchcomms/bench_torchcomms.py --collective allreduce --warmup 100 --iters 500 + torchrun --nproc_per_node=8 test/torchcomms/bench_torchcomms.py --collective allreduce --dtype fp16 +""" + +import argparse +import json +import os +import sys + +import torch +import torchcomms + +import mscclpp_torchcomms # noqa: F401 — auto-registers backend .so path + + +def sync_cuda(): + torch.cuda.synchronize() + + +def cuda_timed(fn, warmup, iters): + """Time fn() using CUDA events. Returns average microseconds.""" + for _ in range(warmup): + fn() + sync_cuda() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + sync_cuda() + start.record() + for _ in range(iters): + fn() + end.record() + sync_cuda() + + return (start.elapsed_time(end) * 1000.0) / iters + + +def format_size(nbytes): + if nbytes >= 1 << 20: + return f"{nbytes / (1 << 20):.0f}MB" + elif nbytes >= 1 << 10: + return f"{nbytes / (1 << 10):.0f}KB" + return f"{nbytes}B" + + +# --- Curated size tables per collective --- +# Each entry: (nbytes, expected_algorithm_name) + +# FIXME: why are we hardcoding the algorithms, should the +# selector be handling this? +ALLREDUCE_SIZES = [ + (1024, "allpair_packet"), + (4096, "allpair_packet"), + (16384, "allpair_packet"), + (24576, "nvls_packet"), + (32768, "nvls_packet"), + (65536, "packet"), + (262144, "packet"), + (524288, "packet"), + (1048576, "packet"), + (2 * 1024 * 1024, "nvls_warp_pipeline"), + (4 * 1024 * 1024, "nvls_warp_pipeline"), + (8 * 1024 * 1024, "nvls_warp_pipeline"), + (16 * 1024 * 1024, "nvls_block_pipeline"), + (32 * 1024 * 1024, "nvls_block_pipeline"), + (64 * 1024 * 1024, "nvls_block_pipeline"), + (128 * 1024 * 1024, "nvls_block_pipeline"), + (256 * 1024 * 1024, "nvls_block_pipeline"), + (512 * 1024 * 1024, "nvls_block_pipeline"), + (1024 * 1024 * 1024, "nvls_block_pipeline"), + (2048 * 1024 * 1024, "nvls_block_pipeline"), +] + +ALLGATHER_SIZES = [ + (1024, "fullmesh2"), + (4096, "fullmesh2"), + (16384, "fullmesh2"), + (65536, "fullmesh2"), + (262144, "fullmesh2"), + (1048576, "fullmesh2"), + (4 * 1024 * 1024, "fullmesh2"), + (8 * 1024 * 1024, "fullmesh2"), + (16 * 1024 * 1024, "fullmesh2"), + (32 * 1024 * 1024, "fullmesh2"), + (64 * 1024 * 1024, "fullmesh"), + (128 * 1024 * 1024, "fullmesh"), + (256 * 1024 * 1024, "fullmesh"), + (512 * 1024 * 1024, "fullmesh"), + (1024 * 1024 * 1024, "fullmesh"), +] + +COLLECTIVE_SIZES = { + "allreduce": ALLREDUCE_SIZES, + "allgather": ALLGATHER_SIZES, +} + + +def busbw_factor(collective, world_size): + """Bus bandwidth correction factor.""" + n = world_size + if collective == "allreduce": + return 2.0 * (n - 1) / n + elif collective == "allgather": + return (n - 1.0) / n + return 1.0 + + +def bench_allreduce(comm, tensor, warmup, iters): + return cuda_timed( + lambda t=tensor: comm.all_reduce(t, torchcomms.ReduceOp.SUM, False), + warmup, + iters, + ) + + +def bench_allgather(comm, input_tensor, output_tensor, warmup, iters): + return cuda_timed( + lambda i=input_tensor, o=output_tensor: comm.all_gather_single(o, i, False), + warmup, + iters, + ) + + +def main(): + parser = argparse.ArgumentParser(description="MSCCL++ TorchComms collective benchmark") + parser.add_argument("--collective", type=str, required=True, choices=list(COLLECTIVE_SIZES.keys())) + parser.add_argument("--warmup", type=int, default=20) + parser.add_argument("--iters", type=int, default=200) + parser.add_argument("--dtype", type=str, default="fp32") + parser.add_argument("--json-output", type=str, default=None, help="Write results to JSON file") + args = parser.parse_args() + + dtype_map = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} + dtype = dtype_map[args.dtype.lower()] + + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + + comm = torchcomms.new_comm("mscclpp", device, name="bench") + + element_size = torch.tensor([], dtype=dtype).element_size() + bw_factor = busbw_factor(args.collective, world_size) + sizes = COLLECTIVE_SIZES[args.collective] + + if rank == 0: + gpu_name = torch.cuda.get_device_name(device) + cc = torch.cuda.get_device_capability(device) + print(f"MSCCL++ TorchComms {args.collective.upper()} Benchmark") + print(f" GPU: {gpu_name} (CC {cc[0]}.{cc[1]})") + print(f" GPUs: {world_size}, dtype: {args.dtype}") + print(f" Warmup: {args.warmup}, Iterations: {args.iters}") + print() + print(f"{'Size':<10} {'Time(us)':<12} {'AlgBW(GB/s)':<14} {'BusBW(GB/s)':<14} {'Algorithm':<25}") + print("-" * 80) + + results = [] + + for nbytes, algo_name in sizes: + nelem = max(1, nbytes // element_size) + if args.collective == "allgather": + # nelem is the TOTAL output size; ensure divisible by world_size + nelem = ((nelem + world_size - 1) // world_size) * world_size + actual_bytes = nelem * element_size + + # Run the appropriate collective + time_us = None + try: + if args.collective == "allreduce": + tensor = torch.full((nelem,), float(rank + 1), device=device, dtype=dtype) + time_us = bench_allreduce(comm, tensor, args.warmup, args.iters) + elif args.collective == "allgather": + chunk_size = nelem // world_size + input_tensor = torch.full((chunk_size,), float(rank), device=device, dtype=dtype) + output_tensor = torch.empty(nelem, device=device, dtype=dtype) + time_us = bench_allgather(comm, input_tensor, output_tensor, args.warmup, args.iters) + except RuntimeError as e: + if "No algorithm" not in str(e): + raise + + if time_us is not None and time_us > 0: + alg_bw = (actual_bytes / time_us) / 1000.0 + bus_bw = alg_bw * bw_factor + else: + alg_bw = 0 + bus_bw = 0 + + results.append( + { + "collective": args.collective, + "size": actual_bytes, + "time_us": time_us, + "algbw_gbps": alg_bw, + "busbw_gbps": bus_bw, + "algorithm": algo_name, + } + ) + + if rank == 0: + if time_us is not None: + print( + f"{format_size(actual_bytes):<10} {time_us:<12.1f} {alg_bw:<14.1f} {bus_bw:<14.1f} {algo_name:<25}" + ) + else: + print(f"{format_size(actual_bytes):<10} {'N/A':<12} {'N/A':<14} {'N/A':<14} {algo_name:<25}") + + comm.finalize() + + if rank == 0: + if args.json_output: + with open(args.json_output, "w") as f: + json.dump(results, f, indent=2) + print() + + +if __name__ == "__main__": + main() diff --git a/test/torchcomms/run_benchmarks.sh b/test/torchcomms/run_benchmarks.sh new file mode 100755 index 000000000..baa1c38c8 --- /dev/null +++ b/test/torchcomms/run_benchmarks.sh @@ -0,0 +1,84 @@ +#!/usr/bin/env bash +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +# Benchmark MSCCL++ allreduce via TorchComms and generate report + figures. +# +# Output: +# bench_results/torchcomms_raw.json — raw benchmark data +# bench_results/report.txt — formatted table +# bench_results/latency.png — latency vs message size +# bench_results/bandwidth.png — bus bandwidth vs message size +# +# Prerequisites: +# - TorchComms backend built: ./build_torchcomm.sh +# - Conda env activated with torchcomms, matplotlib +# - TORCHCOMMS_BACKEND_LIB_PATH_MSCCLPP set (build_torchcomm.sh prints this) +# +# Usage: +# ./test/torchcomms/run_benchmarks.sh +# ./test/torchcomms/run_benchmarks.sh --nproc 2 +# ./test/torchcomms/run_benchmarks.sh --iters 500 --warmup 50 + +set -uo pipefail + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" +OUT_DIR="${REPO_ROOT}/bench_results" + +NPROC=8 +WARMUP=20 +ITERS=200 +DTYPE="fp32" + +while [[ $# -gt 0 ]]; do + case "$1" in + --nproc) NPROC="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; + --iters) ITERS="$2"; shift 2 ;; + --dtype) DTYPE="$2"; shift 2 ;; + --outdir) OUT_DIR="$2"; shift 2 ;; + *) echo "Unknown arg: $1"; exit 1 ;; + esac +done + +mkdir -p "${OUT_DIR}" + +# Find the TorchComms backend .so +SO_FILE=$(find "${REPO_ROOT}/build-torchcomm/lib" -name "_comms_mscclpp*.so" 2>/dev/null | head -1) +if [[ -z "${SO_FILE}" ]]; then + echo "ERROR: _comms_mscclpp .so not found. Run ./build_torchcomm.sh first." + exit 1 +fi +export TORCHCOMMS_BACKEND_LIB_PATH_MSCCLPP="${SO_FILE}" + +TORCHCOMMS_JSON="${OUT_DIR}/torchcomms_raw.json" + +echo "=== MSCCL++ AllReduce Benchmark ===" +echo " GPUs: ${NPROC}" +echo " Warmup: ${WARMUP}" +echo " Iters: ${ITERS}" +echo " Dtype: ${DTYPE}" +echo " Output: ${OUT_DIR}/" +echo "" + +# Run TorchComms allreduce benchmark +echo "Benchmarking MSCCL++ via TorchComms..." +torchrun --nproc_per_node="${NPROC}" "${SCRIPT_DIR}/bench_torchcomms.py" \ + --warmup "${WARMUP}" --iters "${ITERS}" --dtype "${DTYPE}" \ + --json-output "${TORCHCOMMS_JSON}" \ + 2>/dev/null +echo "" + +# Generate report and figures +echo "Generating report and figures..." +python3 "${SCRIPT_DIR}/bench_report.py" \ + --torchcomms-json "${TORCHCOMMS_JSON}" \ + --nproc "${NPROC}" \ + --outdir "${OUT_DIR}" + +echo "" +echo "=== Report ===" +cat "${OUT_DIR}/report.txt" +echo "" +echo "Figures: ${OUT_DIR}/latency.png, ${OUT_DIR}/bandwidth.png" diff --git a/test/torchcomms/test_correctness.py b/test/torchcomms/test_correctness.py new file mode 100644 index 000000000..dc33eeffd --- /dev/null +++ b/test/torchcomms/test_correctness.py @@ -0,0 +1,266 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Collective correctness tests for the MSCCL++ TorchComms backend. + +These tests verify that collectives dispatched through torchcomms.new_comm("mscclpp", ...) +produce correct results. Each test creates an MSCCL++ communicator, runs the collective, +checks the output against a reference, and finalizes. + +When --sweep is used, tests run across multiple message sizes and dtypes to exercise +both packet (<=1MB) and non-packet (>1MB) algorithm paths. + +Prerequisites: + - torchcomms >= 0.2.0 installed (pip install --pre torchcomms) + - MSCCL++ built with -DMSCCLPP_BUILD_EXT_TORCHCOMMS=ON + - mscclpp-torchcomms installed (python -m pip install ./python/mscclpp_torchcomms) + +Run examples: + torchrun --nproc_per_node=2 test/torchcomms/test_correctness.py --collective allreduce + torchrun --nproc_per_node=2 test/torchcomms/test_correctness.py --collective allreduce --nelem 4194304 --dtype fp16 + torchrun --nproc_per_node=2 test/torchcomms/test_correctness.py --all + torchrun --nproc_per_node=2 test/torchcomms/test_correctness.py --all --sweep +""" + +import argparse +import os +import sys + +import torch +import torchcomms + +import mscclpp_torchcomms # noqa: F401 — auto-registers backend .so path + +# Size sweep: covers packet path (<=1MB), boundary, and non-packet path (>1MB) +SWEEP_NELEMS = [1, 64, 1024, 16384, 262144, 1048576, 4194304] +SWEEP_DTYPES = [torch.float32, torch.float16, torch.bfloat16] + + +def parse_dtype(name: str) -> torch.dtype: + name = name.lower() + if name in {"fp32", "float", "float32"}: + return torch.float32 + if name in {"fp16", "half", "float16"}: + return torch.float16 + if name in {"bf16", "bfloat16"}: + return torch.bfloat16 + raise ValueError(f"Unsupported dtype: {name}") + + +def tolerances(dtype: torch.dtype): + if dtype in (torch.float16, torch.bfloat16): + return 5e-3, 1e-3 + return 1e-4, 1e-5 + + +def get_env(): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + return rank, world_size, local_rank + + +def make_comm(device, name="test"): + """Create an MSCCL++ communicator via torchcomms.""" + return torchcomms.new_comm("mscclpp", device, name=name) + + +def test_allreduce(comm, rank, world_size, device, nelem, dtype): + """AllReduce SUM: each rank fills tensor with (rank+1), result should be sum(1..N).""" + tensor = torch.full((nelem,), float(rank + 1), device=device, dtype=dtype) + expected_val = world_size * (world_size + 1) / 2.0 + expected = torch.full((nelem,), expected_val, device=device, dtype=dtype) + + comm.all_reduce(tensor, torchcomms.ReduceOp.SUM, False) + torch.cuda.synchronize() + + atol, rtol = tolerances(dtype) + if not torch.allclose(tensor, expected, atol=atol, rtol=rtol): + max_diff = (tensor - expected).abs().max().item() + raise AssertionError( + f"[rank {rank}] allreduce FAILED: max_diff={max_diff}, " + f"expected={expected_val}, got sample={tensor[0].item()}" + ) + + +def test_allreduce_inplace(comm, rank, world_size, device, nelem, dtype): + """AllReduce SUM in-place: verify the same buffer is both input and output.""" + tensor = torch.full((nelem,), float(rank + 1), device=device, dtype=dtype) + original_ptr = tensor.data_ptr() + expected_val = world_size * (world_size + 1) / 2.0 + + comm.all_reduce(tensor, torchcomms.ReduceOp.SUM, False) + torch.cuda.synchronize() + + if tensor.data_ptr() != original_ptr: + raise AssertionError(f"[rank {rank}] allreduce in-place: buffer address changed") + + atol, rtol = tolerances(dtype) + expected = torch.full((nelem,), expected_val, device=device, dtype=dtype) + if not torch.allclose(tensor, expected, atol=atol, rtol=rtol): + max_diff = (tensor - expected).abs().max().item() + raise AssertionError(f"[rank {rank}] allreduce in-place FAILED: max_diff={max_diff}") + + +def test_allreduce_repeated(comm, rank, world_size, device, nelem, dtype): + """AllReduce SUM repeated on the same buffer: catches stale context/semaphore bugs.""" + tensor = torch.empty((nelem,), device=device, dtype=dtype) + for i in range(5): + tensor.fill_(float(rank + 1) * (i + 1)) + comm.all_reduce(tensor, torchcomms.ReduceOp.SUM, False) + torch.cuda.synchronize() + + expected_val = (i + 1) * world_size * (world_size + 1) / 2.0 + atol, rtol = tolerances(dtype) + expected = torch.full((nelem,), expected_val, device=device, dtype=dtype) + if not torch.allclose(tensor, expected, atol=atol, rtol=rtol): + max_diff = (tensor - expected).abs().max().item() + raise AssertionError(f"[rank {rank}] allreduce repeated iter {i} FAILED: max_diff={max_diff}") + + +def test_allgather(comm, rank, world_size, device, nelem, dtype): + """AllGatherSingle: each rank contributes input, output has all ranks concatenated.""" + input_tensor = torch.full((nelem,), float(rank), device=device, dtype=dtype) + output_tensor = torch.empty(nelem * world_size, device=device, dtype=dtype) + + comm.all_gather_single(output_tensor, input_tensor, False) + torch.cuda.synchronize() + + for r in range(world_size): + chunk = output_tensor[r * nelem : (r + 1) * nelem] + expected = torch.full((nelem,), float(r), device=device, dtype=dtype) + if not torch.equal(chunk, expected): + max_diff = (chunk - expected).abs().max().item() + raise AssertionError(f"[rank {rank}] allgather FAILED at chunk {r}: max_diff={max_diff}") + + +def test_reducescatter(comm, rank, world_size, device, nelem, dtype): + """ReduceScatterSingle: SUM-reduce then scatter so each rank gets its chunk.""" + input_tensor = torch.full((nelem * world_size,), float(rank + 1), device=device, dtype=dtype) + output_tensor = torch.empty(nelem, device=device, dtype=dtype) + + comm.reduce_scatter_single(output_tensor, input_tensor, torchcomms.ReduceOp.SUM, False) + torch.cuda.synchronize() + + expected_val = world_size * (world_size + 1) / 2.0 + expected = torch.full((nelem,), expected_val, device=device, dtype=dtype) + + atol, rtol = tolerances(dtype) + if not torch.allclose(output_tensor, expected, atol=atol, rtol=rtol): + max_diff = (output_tensor - expected).abs().max().item() + raise AssertionError( + f"[rank {rank}] reducescatter FAILED: max_diff={max_diff}, " + f"expected={expected_val}, got sample={output_tensor[0].item()}" + ) + + +# Maps collective name -> list of (test_func, label) tuples +COLLECTIVE_TESTS = { + "allreduce": [ + (test_allreduce, "allreduce"), + (test_allreduce_inplace, "allreduce_inplace"), + (test_allreduce_repeated, "allreduce_repeated"), + ], + "allgather": [ + (test_allgather, "allgather"), + ], + "reducescatter": [ + (test_reducescatter, "reducescatter"), + ], +} + + +def run_single(comm, rank, world_size, device, collectives, nelem, dtype): + """Run specified collectives with a single nelem/dtype combination.""" + failed = [] + skipped = [] + + for coll_name in collectives: + for test_func, label in COLLECTIVE_TESTS[coll_name]: + try: + test_func(comm, rank, world_size, device, nelem, dtype) + if rank == 0: + print(f" {label} {dtype} nelem={nelem}: PASSED") + except RuntimeError as e: + err_msg = str(e) + if "No algorithm registered" in err_msg or "No algorithm" in err_msg: + skipped.append(label) + if rank == 0: + print(f" {label} {dtype} nelem={nelem}: SKIPPED (no algorithm)") + else: + failed.append((label, err_msg)) + if rank == 0: + print(f" {label} {dtype} nelem={nelem}: FAILED - {err_msg}") + except Exception as e: + failed.append((label, str(e))) + if rank == 0: + print(f" {label} {dtype} nelem={nelem}: FAILED - {e}") + + return failed, skipped + + +def run_sweep(comm, rank, world_size, device, collectives): + """Run collectives across multiple sizes and dtypes.""" + all_failed = [] + all_skipped = [] + total = 0 + + for dtype in SWEEP_DTYPES: + for nelem in SWEEP_NELEMS: + total += len(collectives) + failed, skipped = run_single(comm, rank, world_size, device, collectives, nelem, dtype) + all_failed.extend(failed) + all_skipped.extend(skipped) + + return all_failed, all_skipped, total + + +def main(): + parser = argparse.ArgumentParser(description="TorchComms MSCCL++ correctness tests") + parser.add_argument( + "--collective", type=str, choices=list(COLLECTIVE_TESTS.keys()), help="Which collective to test" + ) + parser.add_argument("--all", action="store_true", help="Run all collective tests") + parser.add_argument("--sweep", action="store_true", help="Sweep across multiple sizes and dtypes") + parser.add_argument("--nelem", type=int, default=1048576, help="Number of elements (default: 1M)") + parser.add_argument("--dtype", type=str, default="fp32", help="Data type (fp32, fp16, bf16)") + args = parser.parse_args() + + if not args.collective and not args.all: + parser.error("Specify --collective or --all") + + rank, world_size, local_rank = get_env() + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + + collectives = list(COLLECTIVE_TESTS.keys()) if args.all else [args.collective] + + if rank == 0: + print(f"=== TorchComms MSCCL++ Correctness Tests ===") + print(f" world_size={world_size}, sweep={args.sweep}") + if not args.sweep: + print(f" nelem={args.nelem}, dtype={args.dtype}") + + comm = make_comm(device, name="correctness_test") + + if args.sweep: + failed, skipped, total = run_sweep(comm, rank, world_size, device, collectives) + else: + dtype = parse_dtype(args.dtype) + failed, skipped = run_single(comm, rank, world_size, device, collectives, args.nelem, dtype) + + comm.finalize() + + if rank == 0: + if failed: + print(f"\n=== {len(failed)} test(s) FAILED ===") + for name, err in failed: + print(f" {name}: {err}") + sys.exit(1) + else: + skip_msg = f" ({len(skipped)} skipped)" if skipped else "" + print(f"\n=== All tests PASSED{skip_msg} ===") + + +if __name__ == "__main__": + main() diff --git a/test/torchcomms/test_error_handling.py b/test/torchcomms/test_error_handling.py new file mode 100644 index 000000000..9e425525d --- /dev/null +++ b/test/torchcomms/test_error_handling.py @@ -0,0 +1,151 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Error handling tests for the MSCCL++ TorchComms backend. + +Verifies that unsupported operations, invalid arguments, and lifecycle errors +produce clear error messages rather than crashes or hangs. + +Prerequisites: + - torchcomms >= 0.2.0 installed + - MSCCL++ built with -DMSCCLPP_BUILD_EXT_TORCHCOMMS=ON + - TORCHCOMMS_BACKEND_LIB_PATH_MSCCLPP env var set + +Run: + torchrun --nproc_per_node=2 test/torchcomms/test_error_handling.py +""" + +import os +import sys +import traceback + +import torch +import torchcomms + +import mscclpp_torchcomms # noqa: F401 — auto-registers backend .so path + + +def get_env(): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + return rank, world_size, local_rank + + +def expect_error(name, callable_fn, expected_substring=None): + """Run callable_fn and verify it raises RuntimeError with expected message.""" + try: + callable_fn() + return False, f"Expected RuntimeError but no exception raised" + except RuntimeError as e: + msg = str(e) + if expected_substring and expected_substring not in msg: + return False, f"Expected '{expected_substring}' in error, got: {msg}" + return True, msg + except Exception as e: + return False, f"Expected RuntimeError, got {type(e).__name__}: {e}" + + +def test_unsupported_ops(comm, device): + """Verify unsupported collectives raise clear errors.""" + results = [] + tensor = torch.ones(1024, device=device, dtype=torch.float32) + tensor_list = [torch.ones(1024, device=device) for _ in range(2)] + + # broadcast + ok, msg = expect_error("broadcast", lambda: comm.broadcast(tensor, 0, False), "not supported") + results.append(("broadcast", ok, msg)) + + # send + ok, msg = expect_error("send", lambda: comm.send(tensor, 0, False), "not supported") + results.append(("send", ok, msg)) + + # recv + ok, msg = expect_error("recv", lambda: comm.recv(tensor, 0, False), "not supported") + results.append(("recv", ok, msg)) + + # barrier + ok, msg = expect_error("barrier", lambda: comm.barrier(False), "not supported") + results.append(("barrier", ok, msg)) + + return results + + +def test_unsupported_reduce_op(comm, device): + """Verify unsupported reduce ops raise clear errors.""" + results = [] + tensor = torch.ones(1024, device=device, dtype=torch.float32) + + # PRODUCT not supported + for op_name in ["PRODUCT", "MAX"]: + op = getattr(torchcomms.ReduceOp, op_name, None) + if op is not None: + ok, msg = expect_error( + f"allreduce with {op_name}", + lambda op=op: comm.all_reduce(tensor.clone(), op, False), + "does not support", + ) + results.append((f"allreduce_{op_name}", ok, msg)) + + return results + + +def test_metadata(comm, rank, world_size): + """Verify metadata accessors return correct values.""" + results = [] + + if comm.get_rank() != rank: + results.append(("get_rank", False, f"Expected {rank}, got {comm.get_rank()}")) + else: + results.append(("get_rank", True, "")) + + if comm.get_size() != world_size: + results.append(("get_size", False, f"Expected {world_size}, got {comm.get_size()}")) + else: + results.append(("get_size", True, "")) + + backend_name = comm.get_backend() + if backend_name != "mscclpp": + results.append(("get_backend", False, f"Expected 'mscclpp', got '{backend_name}'")) + else: + results.append(("get_backend", True, "")) + + return results + + +def main(): + rank, world_size, local_rank = get_env() + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + + if rank == 0: + print("=== TorchComms MSCCL++ Error Handling Tests ===") + print(f" world_size={world_size}") + + comm = torchcomms.new_comm("mscclpp", device, name="error_test") + + all_results = [] + all_results.extend(test_unsupported_ops(comm, device)) + all_results.extend(test_unsupported_reduce_op(comm, device)) + all_results.extend(test_metadata(comm, rank, world_size)) + + comm.finalize() + + if rank == 0: + passed = sum(1 for _, ok, _ in all_results if ok) + failed = [(name, msg) for name, ok, msg in all_results if not ok] + + for name, ok, msg in all_results: + status = "PASSED" if ok else "FAILED" + detail = f" - {msg}" if not ok else "" + print(f" {name}: {status}{detail}") + + if failed: + print(f"\n=== {len(failed)} FAILED, {passed} passed ===") + sys.exit(1) + else: + print(f"\n=== All {passed} tests PASSED ===") + + +if __name__ == "__main__": + main() diff --git a/test/torchcomms/test_fsdp2.py b/test/torchcomms/test_fsdp2.py new file mode 100644 index 000000000..a15de853b --- /dev/null +++ b/test/torchcomms/test_fsdp2.py @@ -0,0 +1,139 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""FSDP2 training test for the MSCCL++ TorchComms backend. + +Verifies that MSCCL++ works as the communication backend for FSDP2 training +via TorchComms' DeviceMesh integration. Creates a small transformer-like model, +wraps it with fully_shard(), and runs a training loop comparing FSDP2 results +against a non-sharded reference model. + +Prerequisites: + - torchcomms >= 0.2.0 installed + - mscclpp-torchcomms installed (python -m pip install ./python/mscclpp_torchcomms) + +Run: + torchrun --nproc_per_node=2 test/torchcomms/test_fsdp2.py + torchrun --nproc_per_node=8 test/torchcomms/test_fsdp2.py --iterations 20 --dim 128 +""" + +import argparse +import copy +import os +import sys + +import torch +import torch.nn as nn +import torchcomms +from torch.distributed.fsdp import fully_shard, FSDPModule +from torchcomms.device_mesh import init_device_mesh + +import mscclpp_torchcomms # noqa: F401 — auto-registers backend .so path + + +def get_env(): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + return rank, world_size, local_rank + + +def main(): + parser = argparse.ArgumentParser(description="FSDP2 training test with MSCCL++ TorchComms backend") + parser.add_argument("--iterations", type=int, default=10, help="Number of training iterations") + parser.add_argument("--dim", type=int, default=64, help="Model hidden dimension") + parser.add_argument("--nlayers", type=int, default=4, help="Number of linear layers") + parser.add_argument("--lr", type=float, default=0.01, help="Learning rate") + args = parser.parse_args() + + rank, world_size, local_rank = get_env() + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + + if rank == 0: + print(f"=== FSDP2 + MSCCL++ TorchComms Training Test ===") + print(f" world_size={world_size}, dim={args.dim}, nlayers={args.nlayers}") + print(f" iterations={args.iterations}, lr={args.lr}") + + # --- Create MSCCL++ communicator + # The MSCCL++ backend dlopens libnccl.so.2 internally and transparently + # falls back to NCCL for collectives without native MSCCL++ algorithms + # (broadcast, barrier, reduce_scatter on certain configurations). + comm = torchcomms.new_comm("mscclpp", device, name="fsdp2_test") + + try: + device_mesh = init_device_mesh( + mesh_dim_comms=(comm,), + mesh_dim_names=("main",), + ) + except TypeError as e: + # PyTorch < 2.10 may not support _rank kwarg + if "_rank" in str(e): + if rank == 0: + print(f" SKIPPED: PyTorch version does not support init_device_mesh with _rank") + comm.finalize() + return + raise + + if rank == 0: + print(f" DeviceMesh created successfully") + + # --- Build model --- + torch.manual_seed(42) + model = nn.Sequential(*[nn.Linear(args.dim, args.dim, bias=False, device=device) for _ in range(args.nlayers)]) + ref_model = copy.deepcopy(model) + + # --- Apply FSDP2 --- + for layer in model: + fully_shard(layer, mesh=device_mesh) + if isinstance(layer, FSDPModule): + # Use gradient_divide_factor=1.0 so reduce op is SUM (not AVG). + # MSCCL++ supports SUM and MIN but not AVG. + layer.set_gradient_divide_factor(1.0) + fully_shard(model, mesh=device_mesh) + + if rank == 0: + print(f" FSDP2 applied to model ({args.nlayers} layers, dim={args.dim})") + + # --- Optimizers --- + optim = torch.optim.Adam(model.parameters(), lr=args.lr) + ref_optim = torch.optim.Adam(ref_model.parameters(), lr=args.lr) + + # --- Training loop --- + # Use the same input across all ranks (seeded) so the reference model + # (non-sharded) produces identical results. + torch.manual_seed(123) + inp = torch.randn((4, args.dim), device=device) + + for i in range(args.iterations): + # FSDP2 forward: triggers all_gather to reassemble parameters + loss = model(inp).sum() + ref_loss = ref_model(inp).sum() + + # Check forward pass matches + if not torch.allclose(loss, ref_loss, atol=1e-5, rtol=1e-4): + if rank == 0: + print(f" iteration {i} FAILED: loss mismatch fsdp={loss.item():.6f} ref={ref_loss.item():.6f}") + comm.finalize() + sys.exit(1) + + # FSDP2 backward: triggers reduce_scatter for gradient sync + loss.backward() + ref_loss.backward() + + optim.step() + ref_optim.step() + optim.zero_grad() + ref_optim.zero_grad() + + if rank == 0 and (i == 0 or (i + 1) % 5 == 0): + print(f" iteration {i + 1}/{args.iterations}: loss={loss.item():.6f} PASSED") + + comm.finalize() + + if rank == 0: + print(f"\n=== FSDP2 training test PASSED ({args.iterations} iterations) ===") + + +if __name__ == "__main__": + main() diff --git a/test/torchcomms/test_multicomm.py b/test/torchcomms/test_multicomm.py new file mode 100644 index 000000000..d30b8d87c --- /dev/null +++ b/test/torchcomms/test_multicomm.py @@ -0,0 +1,99 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Test multiple independent MSCCL++ communicators. + +Verifies that two separate MSCCL++ communicators can coexist and run +allreduce independently without interfering with each other. + +NOTE: This test is currently expected to fail. MSCCL++ native algorithms +use a process-wide AlgorithmCollectionBuilder singleton and establish +peer connections during lazy init — creating multiple independent +communicators in the same process causes connection conflicts. This is +a known limitation shared with the NCCL extension. + +Prerequisites: + - torchcomms >= 0.2.0 installed + - MSCCL++ built with -DMSCCLPP_BUILD_EXT_TORCHCOMMS=ON + - TORCHCOMMS_BACKEND_LIB_PATH_MSCCLPP env var set + +Run: + torchrun --nproc_per_node=2 test/torchcomms/test_multicomm.py +""" + +import os +import sys + +import torch +import torchcomms + +import mscclpp_torchcomms # noqa: F401 — auto-registers backend .so path + + +def main(): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + + if rank == 0: + print("=== TorchComms MSCCL++ Multi-Communicator Test ===") + print(f" world_size={world_size}") + print(" NOTE: Multiple communicators in one process is a known limitation.") + print(" This test documents current behavior and will pass once MSCCL++") + print(" supports independent communicator instances per process.") + + try: + # Create two independent communicators + mscclpp = torchcomms.new_comm("mscclpp", device, name="comm_A") + nccl = torchcomms.new_comm("nccl", device, name="comm_B") + + if rank == 0: + print(" Both communicators created") + + # Run allreduce on mscclpp + tensor1 = torch.full((1024,), float(rank + 1), device=device, dtype=torch.float32) + mscclpp.all_reduce(tensor1, torchcomms.ReduceOp.SUM, False) + torch.cuda.synchronize() + + expected_val = world_size * (world_size + 1) / 2.0 + assert torch.allclose( + tensor1, torch.full_like(tensor1, expected_val) + ), f"[rank {rank}] mscclpp allreduce failed" + + if rank == 0: + print(" mscclpp allreduce: PASSED") + + # Run allreduce on nccl with different data + tensor2 = torch.full((2048,), float(rank * 10), device=device, dtype=torch.float32) + nccl.all_reduce(tensor2, torchcomms.ReduceOp.SUM, False) + torch.cuda.synchronize() + + expected_val2 = sum(r * 10 for r in range(world_size)) + assert torch.allclose(tensor2, torch.full_like(tensor2, expected_val2)), f"[rank {rank}] nccl allreduce failed" + + if rank == 0: + print(" nccl allreduce: PASSED") + + # Finalize both + mscclpp.finalize() + nccl.finalize() + + if rank == 0: + print(" Both communicators finalized") + print("\n=== Multi-communicator test PASSED ===") + + except (RuntimeError, Exception) as e: + if rank == 0: + print(f"\n=== Multi-communicator test SKIPPED (known limitation) ===") + print(f" Error: {e}") + print(" Multiple independent MSCCL++ communicators in one process are not") + print(" yet supported. Native algorithms use shared state (singleton builder,") + print(" peer connections) that conflicts across communicator instances.") + # Exit cleanly so torchrun doesn't report a crash + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/test/torchcomms/test_sizes.py b/test/torchcomms/test_sizes.py new file mode 100644 index 000000000..93450ba84 --- /dev/null +++ b/test/torchcomms/test_sizes.py @@ -0,0 +1,137 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Message size sweep tests for the MSCCL++ TorchComms backend. + +Tests allreduce across a range of message sizes to exercise: + - Packet path (<=1MB): uses allreduce_sm_packet / allpair_packet algorithms + - Non-packet path (>1MB): uses allreduce_sm / NVLS algorithms + - Boundary sizes: exact powers of two, off-by-one + - Edge cases: very small (1 element), large (16M+ elements) + +Prerequisites: + - torchcomms >= 0.2.0 installed + - MSCCL++ built with -DMSCCLPP_BUILD_EXT_TORCHCOMMS=ON + - TORCHCOMMS_BACKEND_LIB_PATH_MSCCLPP env var set + +Run: + torchrun --nproc_per_node=2 test/torchcomms/test_sizes.py + torchrun --nproc_per_node=8 test/torchcomms/test_sizes.py --dtype fp16 +""" + +import argparse +import os +import sys + +import torch +import torchcomms + +import mscclpp_torchcomms # noqa: F401 — auto-registers backend .so path + + +def tolerances(dtype: torch.dtype): + if dtype in (torch.float16, torch.bfloat16): + return 5e-3, 1e-3 + return 1e-4, 1e-5 + + +# Sizes chosen to cover algorithm selection boundaries: +# - 1 element: minimum +# - 256: small packet +# - 1023, 1024, 1025: power-of-2 boundary +# - 262144 (1MB/4 for fp32): near packet/non-packet boundary +# - 1048576 (4MB for fp32): above packet threshold +# - 4194304 (16MB for fp32): large message +# - 8388608 (32MB for fp32): exercises pipeline algorithms +# NOTE: 262145 (1MB+4 bytes) is excluded — it hits a known algorithm selector +# boundary bug in MSCCL++ native allreduce (packet ↔ non-packet transition). +SIZE_TABLE = [ + 1, + 256, + 1023, + 1024, + 1025, + 65536, + 262144, + 1048576, + 4194304, + 8388608, +] + + +def main(): + parser = argparse.ArgumentParser(description="TorchComms MSCCL++ size sweep test") + parser.add_argument("--dtype", type=str, default="fp32", help="Data type (fp32, fp16, bf16)") + args = parser.parse_args() + + dtype_map = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} + dtype = dtype_map.get(args.dtype.lower()) + if dtype is None: + print(f"Unsupported dtype: {args.dtype}") + sys.exit(1) + + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + + if rank == 0: + print(f"=== TorchComms MSCCL++ Size Sweep Test ===") + print(f" world_size={world_size}, dtype={dtype}, sizes={len(SIZE_TABLE)}") + + comm = torchcomms.new_comm("mscclpp", device, name="size_sweep") + + passed = 0 + failed = [] + skipped = [] + + for nelem in SIZE_TABLE: + bytes_per_elem = torch.tensor([], dtype=dtype).element_size() + total_bytes = nelem * bytes_per_elem + label = f"nelem={nelem} ({total_bytes} bytes)" + + try: + tensor = torch.full((nelem,), float(rank + 1), device=device, dtype=dtype) + expected_val = world_size * (world_size + 1) / 2.0 + + comm.all_reduce(tensor, torchcomms.ReduceOp.SUM, False) + torch.cuda.synchronize() + + expected = torch.full((nelem,), expected_val, device=device, dtype=dtype) + atol, rtol = tolerances(dtype) + if not torch.allclose(tensor, expected, atol=atol, rtol=rtol): + max_diff = (tensor - expected).abs().max().item() + failed.append((label, f"max_diff={max_diff}")) + if rank == 0: + print(f" {label}: FAILED (max_diff={max_diff})") + else: + passed += 1 + if rank == 0: + print(f" {label}: PASSED") + except RuntimeError as e: + err_msg = str(e) + if "No algorithm" in err_msg: + skipped.append(label) + if rank == 0: + print(f" {label}: SKIPPED (no algorithm)") + else: + failed.append((label, err_msg)) + if rank == 0: + print(f" {label}: FAILED - {err_msg}") + + comm.finalize() + + if rank == 0: + skip_msg = f", {len(skipped)} skipped" if skipped else "" + if failed: + print(f"\n=== {len(failed)} FAILED, {passed} passed{skip_msg} ===") + for label, err in failed: + print(f" {label}: {err}") + sys.exit(1) + else: + print(f"\n=== All {passed} sizes PASSED{skip_msg} ===") + + +if __name__ == "__main__": + main() diff --git a/test/torchcomms/test_training_loop.py b/test/torchcomms/test_training_loop.py new file mode 100644 index 000000000..c0e29c913 --- /dev/null +++ b/test/torchcomms/test_training_loop.py @@ -0,0 +1,73 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Simulated training loop test for MSCCL++ TorchComms backend. + +Verifies that the MSCCL++ backend works correctly in a multi-iteration +training loop pattern: allocate gradient tensors, run allreduce each +iteration, verify correctness. + +Prerequisites: + - torchcomms >= 0.2.0 installed + - MSCCL++ built with -DMSCCLPP_BUILD_EXT_TORCHCOMMS=ON + - TORCHCOMMS_BACKEND_LIB_PATH_MSCCLPP env var set + +Run: + torchrun --nproc_per_node=2 test/torchcomms/test_training_loop.py + torchrun --nproc_per_node=2 test/torchcomms/test_training_loop.py --iterations 50 --nelem 2097152 +""" + +import argparse +import os +import sys + +import torch +import torchcomms + +import mscclpp_torchcomms # noqa: F401 — auto-registers backend .so path + + +def main(): + parser = argparse.ArgumentParser(description="TorchComms MSCCL++ training loop test") + parser.add_argument("--iterations", type=int, default=10, help="Number of training iterations") + parser.add_argument("--nelem", type=int, default=1048576, help="Gradient tensor size") + args = parser.parse_args() + + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + + if rank == 0: + print(f"=== TorchComms MSCCL++ Training Loop Test ===") + print(f" world_size={world_size}, iterations={args.iterations}, nelem={args.nelem}") + + comm = torchcomms.new_comm("mscclpp", device, name="training") + + for i in range(args.iterations): + # Simulate gradient computation: each rank produces rank-specific values + # that change each iteration to catch stale-buffer bugs + grad = torch.full((args.nelem,), float(rank + 1) * (i + 1), device=device, dtype=torch.float32) + + # AllReduce SUM (gradient synchronization) + comm.all_reduce(grad, torchcomms.ReduceOp.SUM, False) + torch.cuda.synchronize() + + # Verify: sum of (r+1)*(i+1) for r in 0..N-1 = (i+1) * N*(N+1)/2 + expected_val = (i + 1) * world_size * (world_size + 1) / 2.0 + if not torch.allclose(grad, torch.full_like(grad, expected_val), atol=1e-4, rtol=1e-5): + max_diff = (grad - torch.full_like(grad, expected_val)).abs().max().item() + print(f"[rank {rank}] iteration {i} FAILED: max_diff={max_diff}") + comm.finalize() + sys.exit(1) + + comm.finalize() + + if rank == 0: + print(f" {args.iterations} iterations: PASSED") + print(f"\n=== Training loop test PASSED ===") + + +if __name__ == "__main__": + main() diff --git a/test/torchcomms/test_user_algorithms.py b/test/torchcomms/test_user_algorithms.py new file mode 100644 index 000000000..a252d970b --- /dev/null +++ b/test/torchcomms/test_user_algorithms.py @@ -0,0 +1,246 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Test user-defined algorithm registration with the MSCCL++ TorchComms backend. + +Verifies that users can configure custom algorithms on the MSCCL++ +AlgorithmCollectionBuilder BEFORE creating a TorchComms communicator, and +that the backend picks them up during init(). + +This follows the same pattern as dsl_with_nccl_api.py — the builder is a +process-wide singleton, so algorithms/selectors registered before +torchcomms.new_comm("mscclpp", ...) are automatically included in the +AlgorithmCollection built during TorchCommMSCCLPP::init(). + +Prerequisites: + - torchcomms >= 0.2.0 installed + - MSCCL++ built with -DMSCCLPP_BUILD_EXT_TORCHCOMMS=ON + - TORCHCOMMS_BACKEND_LIB_PATH_MSCCLPP env var set + +Run: + torchrun --nproc_per_node=2 test/torchcomms/test_user_algorithms.py + torchrun --nproc_per_node=8 test/torchcomms/test_user_algorithms.py +""" + +import os +import sys + +import torch +import torchcomms + +import mscclpp_torchcomms # noqa: F401 — auto-registers backend .so path + + +def get_env(): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ["LOCAL_RANK"]) + return rank, world_size, local_rank + + +def tolerances(dtype: torch.dtype): + if dtype in (torch.float16, torch.bfloat16): + return 5e-3, 1e-3 + return 1e-4, 1e-5 + + +def test_dsl_algorithm_via_builder(rank, world_size, device): + """Register a DSL-compiled algorithm on the builder, then create a comm. + + The TorchComms backend calls AlgorithmCollectionBuilder::buildDefaultAlgorithms() + during init(), so our custom algorithm is included automatically. We then run + allreduce and verify correctness. + """ + try: + import mscclpp + from mscclpp.language.collectives import AllReduce as DSLAllReduce + from mscclpp.language.channel import MemoryChannel + from mscclpp.language.program import CollectiveProgram + from mscclpp.language.rank import Rank + except ImportError: + if rank == 0: + print(" dsl_via_builder: SKIPPED (mscclpp Python module not available)") + return True + + # Define a simple ring allreduce using the DSL + def simple_ring_allreduce(spec): + gpu_size = spec.world_size + with CollectiveProgram.from_spec(spec) as program: + channels = {} + for gpu in range(gpu_size): + for peer in range(gpu_size): + if peer != gpu: + channels[(peer, gpu)] = MemoryChannel(peer, gpu) + + for gpu in range(gpu_size): + input_buffer = Rank(gpu).get_input_buffer() + for peer in range(gpu_size): + if peer != gpu: + channels[(peer, gpu)].put( + src=input_buffer[gpu : gpu + 1], + dst_offset=gpu, + size=1, + tb=0, + ) + for peer in range(gpu_size): + if peer != gpu: + channels[(peer, gpu)].signal(tb=0) + for peer in range(gpu_size): + if peer != gpu: + channels[(peer, gpu)].wait(tb=0) + for peer in range(gpu_size): + if peer != gpu: + channels[(peer, gpu)].get( + dst=input_buffer[peer : peer + 1], + src_offset=peer, + size=1, + tb=0, + ) + return program + + try: + spec = mscclpp.AlgoSpec( + name="test_custom_ring_allreduce", + collective=DSLAllReduce(world_size, 1, True), + nranks_per_node=world_size, + world_size=world_size, + in_place=True, + instances=1, + protocol="Simple", + num_threads_per_block=256, + min_message_size=0, + max_message_size=1 << 20, + ) + + algo = mscclpp.compile(algo=simple_ring_allreduce, algo_spec=spec, rank=rank) + + # Register on the builder singleton BEFORE creating the comm + builder = mscclpp.AlgorithmCollectionBuilder() + builder.add_algorithm_builder(algo) + + if rank == 0: + print(" dsl_via_builder: algorithm registered on builder") + + # Now create the comm — init() picks up the custom algorithm + comm = torchcomms.new_comm("mscclpp", device, name="dsl_test") + + # Run allreduce (the selector will choose the appropriate algorithm — + # our custom one may or may not be selected depending on message size, + # but the point is it's available in the collection) + nelem = 1048576 + tensor = torch.full((nelem,), float(rank + 1), device=device, dtype=torch.float32) + expected_val = world_size * (world_size + 1) / 2.0 + + comm.all_reduce(tensor, torchcomms.ReduceOp.SUM, False) + torch.cuda.synchronize() + + atol, rtol = tolerances(torch.float32) + expected = torch.full((nelem,), expected_val, device=device, dtype=torch.float32) + if not torch.allclose(tensor, expected, atol=atol, rtol=rtol): + max_diff = (tensor - expected).abs().max().item() + raise AssertionError(f"allreduce FAILED after registering custom algo: max_diff={max_diff}") + + comm.finalize() + + if rank == 0: + print(" dsl_via_builder: PASSED (allreduce correct after custom algo registration)") + return True + + except Exception as e: + if rank == 0: + print(f" dsl_via_builder: SKIPPED ({e})") + return True + + +def test_custom_selector(rank, world_size, device): + """Register a custom algorithm selector on the builder, then verify it's used.""" + try: + import mscclpp + except ImportError: + if rank == 0: + print(" custom_selector: SKIPPED (mscclpp Python module not available)") + return True + + try: + # Reset the builder to start clean + mscclpp.AlgorithmCollectionBuilder.reset() + + builder = mscclpp.AlgorithmCollectionBuilder() + + # The fallback selector is the default one set during init(). + # We set a primary selector that just delegates to the fallback + # (proving the selector hook works without breaking anything). + def pass_through_selector(algorithms, req): + # Return None to fall through to the fallback selector + return None + + builder.set_algorithm_selector(pass_through_selector) + + comm = torchcomms.new_comm("mscclpp", device, name="selector_test") + + nelem = 1048576 + tensor = torch.full((nelem,), float(rank + 1), device=device, dtype=torch.float32) + expected_val = world_size * (world_size + 1) / 2.0 + + comm.all_reduce(tensor, torchcomms.ReduceOp.SUM, False) + torch.cuda.synchronize() + + atol, rtol = tolerances(torch.float32) + expected = torch.full((nelem,), expected_val, device=device, dtype=torch.float32) + if not torch.allclose(tensor, expected, atol=atol, rtol=rtol): + max_diff = (tensor - expected).abs().max().item() + raise AssertionError(f"allreduce FAILED with custom selector: max_diff={max_diff}") + + comm.finalize() + + if rank == 0: + print(" custom_selector: PASSED (allreduce correct with custom selector)") + return True + + except Exception as e: + if rank == 0: + print(f" custom_selector: SKIPPED ({e})") + return True + + +def main(): + rank, world_size, local_rank = get_env() + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + + if rank == 0: + print("=== TorchComms MSCCL++ User-Defined Algorithm Tests ===") + print(f" world_size={world_size}") + print(" NOTE: Custom algorithms are registered on AlgorithmCollectionBuilder") + print(" BEFORE creating the TorchComms communicator. The backend picks them") + print(" up during init().") + + passed = 0 + failed = [] + + tests = [ + ("dsl_via_builder", lambda: test_dsl_algorithm_via_builder(rank, world_size, device)), + ("custom_selector", lambda: test_custom_selector(rank, world_size, device)), + ] + + for name, test_fn in tests: + try: + test_fn() + passed += 1 + except Exception as e: + failed.append((name, str(e))) + if rank == 0: + print(f" {name}: FAILED - {e}") + + if rank == 0: + if failed: + print(f"\n=== {len(failed)} FAILED, {passed} passed ===") + for name, err in failed: + print(f" {name}: {err}") + sys.exit(1) + else: + print(f"\n=== All {passed} tests PASSED ===") + + +if __name__ == "__main__": + main()