diff --git a/.azure-pipelines/templates/nccl-test.yml b/.azure-pipelines/templates/nccl-test.yml index 585f5b48f..550f5690a 100644 --- a/.azure-pipelines/templates/nccl-test.yml +++ b/.azure-pipelines/templates/nccl-test.yml @@ -74,6 +74,15 @@ steps: mpirun -np 8 --bind-to numa --allow-run-as-root -x LD_PRELOAD=/root/mscclpp/build/lib/libmscclpp_nccl.so -x MSCCLPP_NCCL_SYMMETRIC_MEMORY=1 -x NCCL_DEBUG=WARN -x MSCCLPP_ENABLE_NCCL_FALLBACK=TRUE -x MSCCLPP_NCCL_LIB_PATH=/root/nccl/build/lib/libnccl.so -x MSCCLPP_FORCE_NCCL_FALLBACK_OPERATION="broadcast" /root/nccl-tests/build/broadcast_perf -b 1K -e 1G -f 2 -d half -G 20 -w 10 -n 20 mpirun -np 8 --bind-to numa --allow-run-as-root -x LD_PRELOAD=/root/mscclpp/build/lib/libmscclpp_nccl.so -x MSCCLPP_NCCL_SYMMETRIC_MEMORY=1 -x NCCL_DEBUG=WARN -x MSCCLPP_ENABLE_NCCL_FALLBACK=TRUE -x MSCCLPP_NCCL_LIB_PATH=/root/nccl/build/lib/libnccl.so -x MSCCLPP_FORCE_NCCL_FALLBACK_OPERATION="allreduce" /root/nccl-tests/build/broadcast_perf -b 1K -e 1G -f 2 -d half -G 20 -w 10 -n 20 +- template: run-remote-task.yml + parameters: + name: PyBench + displayName: Run Collective Benchmarks + remoteScript: | + mpirun --allow-run-as-root -np 8 python3 -m mscclpp_benchmark.bench_collective --collective allreduce --dtype float8_e4m3b15 --accum-type float32 --autotune --symmetric-memory + mpirun --allow-run-as-root -np 8 python3 -m mscclpp_benchmark.bench_collective --collective allreduce --dtype float8_e4m3fn --accum-type float16 --autotune --symmetric-memory + mpirun --allow-run-as-root -np 8 python3 -m mscclpp_benchmark.bench_collective --collective allreduce --dtype float16 --symmetric-memory --autotune + - template: stop.yml parameters: subscription: ${{ parameters.subscription }} diff --git a/.azure-pipelines/templates/rccl-test.yml b/.azure-pipelines/templates/rccl-test.yml index 8e2471614..63788ac27 100644 --- a/.azure-pipelines/templates/rccl-test.yml +++ b/.azure-pipelines/templates/rccl-test.yml @@ -57,6 +57,15 @@ steps: mpirun -np 8 --bind-to numa --allow-run-as-root -x LD_PRELOAD=/root/mscclpp/build/lib/libmscclpp_nccl.so -x MSCCLPP_NCCL_SYMMETRIC_MEMORY=1 -x NCCL_DEBUG=WARN /root/rocm-systems/projects/rccl-tests/build/all_reduce_perf -b 1K -e 1G -f 2 -d half -G 20 -w 10 -n 20 mpirun -np 8 --bind-to numa --allow-run-as-root /root/rocm-systems/projects/rccl-tests/build/all_reduce_perf -b 1K -e 1G -f 2 -d half -G 20 -w 10 -n 20 +- template: run-remote-task.yml + parameters: + name: PyBench + displayName: Run Collective Benchmarks + remoteScript: | + mpirun --allow-run-as-root -x GPU_MAX_HW_QUEUES=8 -np 8 python3 -m mscclpp_benchmark.bench_collective --collective allreduce --dtype float8_e4m3b15 --accum-type float32 --autotune + mpirun --allow-run-as-root -x GPU_MAX_HW_QUEUES=8 -np 8 python3 -m mscclpp_benchmark.bench_collective --collective allreduce --dtype float8_e4m3fnuz --accum-type float32 --autotune + mpirun --allow-run-as-root -x GPU_MAX_HW_QUEUES=8 -np 8 python3 -m mscclpp_benchmark.bench_collective --collective allgather --dtype float8_e4m3b15 --autotune --buffer-mode out-of-place + - template: stop.yml parameters: subscription: ${{ parameters.subscription }} diff --git a/docs/quickstart.md b/docs/quickstart.md index 716fcf61c..320a2db78 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -110,12 +110,12 @@ $ CXX=/opt/rocm/bin/hipcc python -m pip install ".[rocm6]" ``` > **Note:** A platform extra (`cuda11`, `cuda12`, `cuda13`, or `rocm6`) is required to install CuPy. -> The CUDA extras install pre-built CuPy wheels. The `rocm6` extra installs CuPy from source, -> which requires ROCm and may take longer. Running `pip install .` without an extra will not install CuPy. +> The CUDA extras install pre-built CuPy wheels and CUDA Python bindings. The `rocm6` extra installs CuPy from source +> and HIP Python 6.x, which require ROCm and may take longer. Running `pip install .` without an extra will not install CuPy. Optional extras can be installed by specifying them in brackets. Available extras: -- **`cuda11`**, **`cuda12`**, **`cuda13`**: Install a pre-built CuPy package for your CUDA version. -- **`rocm6`**: Install CuPy from source for AMD ROCm platforms. +- **`cuda11`**, **`cuda12`**, **`cuda13`**: Install a pre-built CuPy package and CUDA Python bindings for your CUDA version. +- **`rocm6`**: Install CuPy from source and HIP Python 6.x for AMD ROCm platforms. - **`benchmark`**: Install benchmark dependencies (mpi4py, prettytable, netifaces, matplotlib). - **`test`**: Install test dependencies (pytest, mpi4py, netifaces). @@ -209,15 +209,37 @@ $ mpirun -np 16 -npernode 8 -hostfile hostfile ./bin/mp_unit_tests -ip_port 10.0 ## Performance Benchmark -### Python Benchmark +### Python Benchmark and Tuning -[Install the MSCCL++ Python package](#install-from-source-python-module) and run our Python AllReduce benchmark as follows. It requires MPI on the system. +[Install the MSCCL++ Python package](#install-from-source-python-module) and run the Python collective benchmark as follows. It requires MPI on the system. ```bash # Install with benchmark dependencies and the appropriate CUDA/ROCm extras. # Replace `cuda12` with your platform: cuda11, cuda12, cuda13, or rocm6. $ python3 -m pip install ".[cuda12,benchmark,test]" -$ mpirun -tag-output -np 8 python3 ./python/mscclpp_benchmark/allreduce_bench.py + +``` + +To autotune launch parameters and save a tuned config: + +```bash +$ PYTHONPATH=$PWD/python mpirun -np 8 --allow-run-as-root \ + python3 -m mscclpp_benchmark.bench_collective \ + --collective allreduce \ + --dtype float16 \ + --batch-sizes 1,2,4,8 \ + --autotune \ + --write-config /tmp/mscclpp_tuned_configs.json +``` + +Use the tuned config in a benchmark: + +```bash +$ PYTHONPATH=$PWD/python mpirun -np 8 --allow-run-as-root \ + python3 -m mscclpp_benchmark.bench_collective \ + --collective allreduce \ + --dtype float16 \ + --config-path /tmp/mscclpp_tuned_configs.json ``` (nccl-benchmark)= @@ -291,4 +313,3 @@ Version: 0.8.0.post1.dev0+gc632fee37.d20251007 mscclpp.version {'version': '0.8.0.post1.dev0+gc632fee37.d20251007', 'git_commit': 'g50382c567'} ``` - diff --git a/include/mscclpp/gpu_data_types.hpp b/include/mscclpp/gpu_data_types.hpp index 672434f97..4a16628e3 100644 --- a/include/mscclpp/gpu_data_types.hpp +++ b/include/mscclpp/gpu_data_types.hpp @@ -71,7 +71,7 @@ using __bfloat162 = __nv_bfloat162; /// Software float8 with 4 exponent bits, 3 mantissa bits, exponent bias = 15. /// Format (MSB first): [sign:1][exponent:4][mantissa:3] -/// No infinities, no NaN. Encode saturates to ±1.75 (0x7e/0xfe). +/// No infinities, no NaN. Encode saturates to ±1.875 (0x7f/0xff). /// Adapted from the Triton compiler's fp8e4b15 format. struct alignas(1) __fp8_e4m3b15 { uint8_t __x; @@ -103,7 +103,7 @@ struct alignas(1) __fp8_e4m3b15 { /// then convert fp16 → float32. static MSCCLPP_HOST_DEVICE_INLINE float toFloat(uint8_t bits) { // Branch-free decode: fp8 → fp16 → fp32, no special-case handling. - // Encode saturates to ±1.75, so 0x7f/0xff are never produced. + // Every byte maps to a finite value; encode saturates at ±1.875, so 0x7f/0xff decode to ±1.875. // Refer: // https://github.com/triton-lang/triton/blob/cf34004b8a67d290a962da166f5aa2fc66751326/python/triton/language/extra/cuda/utils.py#L34 uint16_t h = (uint16_t)bits << 8; // place fp8 in upper byte of fp16 @@ -132,10 +132,9 @@ struct alignas(1) __fp8_e4m3b15 { } cvt = {h_val}; uint16_t fp16_bits = cvt.u; - // Clamp abs to max encodable value: 1.75 → fp16 = 0x3F00. - // Matches Triton: encode saturates, 0x7f/0xff are never produced. + // Clamp abs to max encodable value: 1.875 → fp16 = 0x3F80 (largest byte 0x7f/0xff). uint16_t abs_fp16 = fp16_bits & 0x7FFFu; - if (abs_fp16 > 0x3F00u) abs_fp16 = 0x3F00u; + if (abs_fp16 > 0x3F80u) abs_fp16 = 0x3F80u; // Reconstruct with sign. uint16_t sign16 = fp16_bits & 0x8000u; @@ -852,27 +851,17 @@ MSCCLPP_DEVICE_INLINE f32x4 to(const f8_e5m2x4& v) { /// f32x2 -> f8_e4m3x2. /// HIP gfx942: float -> fp8 (via __builtin_amdgcn_cvt_pk_fp8_f32). -/// NVIDIA SM90+: float -> half -> fp8 (via __nv_cvt_halfraw2_to_fp8x2). -/// NVIDIA pre-SM90: float -> half -> fp8 (via __nv_cvt_halfraw_to_fp8, element-wise). +/// NVIDIA: float -> fp8 directly (via __nv_cvt_float2_to_fp8x2). On SM89+ this maps to a +/// single hardware round-to-nearest-even instruction; on older arch it falls back to a +/// software direct conversion. template <> MSCCLPP_DEVICE_INLINE f8_e4m3x2 to(const f32x2& v) { #if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) uint32_t packed = __builtin_amdgcn_cvt_pk_fp8_f32(v.data[0], v.data[1], 0, false); return bit_cast(static_cast<__hip_fp8x2_storage_t>(packed)); -#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900 - __half2_raw h2; - h2.x = bit_cast(__float2half_rn(v.data[0])); - h2.y = bit_cast(__float2half_rn(v.data[1])); - __nv_fp8x2_storage_t fp8x2 = __nv_cvt_halfraw2_to_fp8x2(h2, __NV_SATFINITE, __NV_E4M3); - return bit_cast(fp8x2); #elif defined(MSCCLPP_DEVICE_CUDA) - __half_raw h0, h1; - h0.x = bit_cast(__float2half_rn(v.data[0])); - h1.x = bit_cast(__float2half_rn(v.data[1])); - f8_e4m3x2 result; - result.data[0] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8(h0, __NV_SATFINITE, __NV_E4M3)); - result.data[1] = bit_cast<__fp8_e4m3>(__nv_cvt_halfraw_to_fp8(h1, __NV_SATFINITE, __NV_E4M3)); - return result; + __nv_fp8x2_storage_t fp8x2 = __nv_cvt_float2_to_fp8x2(make_float2(v.data[0], v.data[1]), __NV_SATFINITE, __NV_E4M3); + return bit_cast(fp8x2); #else f8_e4m3x2 result; result.data[0] = static_cast<__fp8_e4m3>(v.data[0]); @@ -909,27 +898,17 @@ MSCCLPP_DEVICE_INLINE f8_e4m3x4 to(const f32x4& v) { /// f32x2 -> f8_e5m2x2. /// HIP gfx942: float -> bf8 (via __builtin_amdgcn_cvt_pk_bf8_f32). -/// NVIDIA SM90+: float -> half -> fp8 (via __nv_cvt_halfraw2_to_fp8x2 with __NV_E5M2). -/// NVIDIA pre-SM90: float -> half -> fp8 (via __nv_cvt_halfraw_to_fp8, element-wise). +/// NVIDIA: float -> fp8 directly (via __nv_cvt_float2_to_fp8x2 with __NV_E5M2). On SM89+ this +/// maps to a single hardware round-to-nearest-even instruction; on older arch it falls back to a +/// software direct conversion. template <> MSCCLPP_DEVICE_INLINE f8_e5m2x2 to(const f32x2& v) { #if defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) uint32_t packed = __builtin_amdgcn_cvt_pk_bf8_f32(v.data[0], v.data[1], 0, false); return bit_cast(static_cast<__hip_fp8x2_storage_t>(packed)); -#elif defined(MSCCLPP_DEVICE_CUDA) && __CUDA_ARCH__ >= 900 - __half2_raw h2; - h2.x = bit_cast(__float2half_rn(v.data[0])); - h2.y = bit_cast(__float2half_rn(v.data[1])); - __nv_fp8x2_storage_t fp8x2 = __nv_cvt_halfraw2_to_fp8x2(h2, __NV_SATFINITE, __NV_E5M2); - return bit_cast(fp8x2); #elif defined(MSCCLPP_DEVICE_CUDA) - __half_raw h0, h1; - h0.x = bit_cast(__float2half_rn(v.data[0])); - h1.x = bit_cast(__float2half_rn(v.data[1])); - f8_e5m2x2 result; - result.data[0] = bit_cast<__fp8_e5m2>(__nv_cvt_halfraw_to_fp8(h0, __NV_SATFINITE, __NV_E5M2)); - result.data[1] = bit_cast<__fp8_e5m2>(__nv_cvt_halfraw_to_fp8(h1, __NV_SATFINITE, __NV_E5M2)); - return result; + __nv_fp8x2_storage_t fp8x2 = __nv_cvt_float2_to_fp8x2(make_float2(v.data[0], v.data[1]), __NV_SATFINITE, __NV_E5M2); + return bit_cast(fp8x2); #else f8_e5m2x2 result; result.data[0] = static_cast<__fp8_e5m2>(v.data[0]); @@ -1103,11 +1082,11 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to(const f16x2& v) { #if defined(MSCCLPP_DEVICE_CUDA) uint32_t in0; asm("mov.b32 %0, %1;" : "=r"(in0) : "r"(*reinterpret_cast(&v))); - // Clamp abs to max encodable e4m3b15 (0x3F00 = 1.75 in fp16). + // Clamp abs to max encodable e4m3b15 (0x3F80 = 1.875 in fp16). uint32_t lo = in0 & 0xFFFFu, hi = in0 >> 16; uint32_t alo = lo & 0x7FFFu, ahi = hi & 0x7FFFu; - alo = alo < 0x3F00u ? alo : 0x3F00u; - ahi = ahi < 0x3F00u ? ahi : 0x3F00u; + alo = alo < 0x3F80u ? alo : 0x3F80u; + ahi = ahi < 0x3F80u ? ahi : 0x3F80u; uint32_t a0 = alo | (ahi << 16); a0 = a0 * 2u + 0x00800080u; uint32_t b0 = a0 | (in0 & 0x80008000u); @@ -1118,7 +1097,7 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x2 to(const f16x2& v) { uint32_t in0 = v.words[0]; uint32_t abs0 = in0 & 0x7fff7fffu; uint32_t a0; - asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3F003F00u)); + asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3F803F80u)); a0 = a0 * 2u + 0x00800080u; uint32_t b0 = a0 | (in0 & 0x80008000u); uint16_t packed = (uint16_t)(((b0 >> 8) & 0xFFu) | ((b0 >> 16) & 0xFF00u)); @@ -1141,8 +1120,8 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to(const f16x4& v) { asm("mov.b32 %0, %1;" : "=r"(in1) : "r"(v.words[1])); uint32_t abs0 = in0 & 0x7fff7fffu; uint32_t abs1 = in1 & 0x7fff7fffu; - uint32_t a0 = __vminu2(abs0, 0x3F003F00u); - uint32_t a1 = __vminu2(abs1, 0x3F003F00u); + uint32_t a0 = __vminu2(abs0, 0x3F803F80u); + uint32_t a1 = __vminu2(abs1, 0x3F803F80u); a0 = a0 * 2u + 0x00800080u; a1 = a1 * 2u + 0x00800080u; uint32_t b0, b1; @@ -1155,8 +1134,8 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to(const f16x4& v) { uint32_t in0 = v.words[0], in1 = v.words[1]; uint32_t abs0 = in0 & 0x7fff7fffu, abs1 = in1 & 0x7fff7fffu; uint32_t a0, a1; - asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3F003F00u)); - asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a1) : "v"(abs1), "v"(0x3F003F00u)); + asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a0) : "v"(abs0), "v"(0x3F803F80u)); + asm volatile("v_pk_min_u16 %0, %1, %2" : "=v"(a1) : "v"(abs1), "v"(0x3F803F80u)); a0 = a0 * 2u + 0x00800080u; a1 = a1 * 2u + 0x00800080u; uint32_t b0 = a0 | (in0 & 0x80008000u); @@ -1268,8 +1247,8 @@ MSCCLPP_DEVICE_INLINE f8_e4m3b15x4 to(const f32x4& v) { return to(h); #elif defined(MSCCLPP_DEVICE_HIP) && defined(__gfx942__) f16x4 h; - h.words[0] = __builtin_bit_cast(uint32_t, __builtin_amdgcn_cvt_pkrtz(v.data[0], v.data[1])); - h.words[1] = __builtin_bit_cast(uint32_t, __builtin_amdgcn_cvt_pkrtz(v.data[2], v.data[3])); + h.words[0] = __builtin_bit_cast(uint32_t, __floats2half2_rn(v.data[0], v.data[1])); + h.words[1] = __builtin_bit_cast(uint32_t, __floats2half2_rn(v.data[2], v.data[3])); return to(h); #else f8_e4m3b15x4 result; diff --git a/pyproject.toml b/pyproject.toml index 0ea569cb4..b35b1b3a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,10 +21,22 @@ dependencies = [ ] [project.optional-dependencies] -cuda11 = ["cupy-cuda11x"] -cuda12 = ["cupy-cuda12x"] -cuda13 = ["cupy-cuda13x"] -rocm6 = ["cupy"] +cuda11 = [ + "cupy-cuda11x", + "cuda-bindings>=11.8,<12", +] +cuda12 = [ + "cupy-cuda12x", + "cuda-bindings>=12,<13", +] +cuda13 = [ + "cupy-cuda13x", + "cuda-bindings>=13,<14", +] +rocm6 = [ + "cupy", + "hip-python>=6,<7", +] benchmark = [ "mpi4py", "prettytable", diff --git a/python/mscclpp_benchmark/__init__.py b/python/mscclpp_benchmark/__init__.py index 1ee3f3bff..11e08c9bb 100644 --- a/python/mscclpp_benchmark/__init__.py +++ b/python/mscclpp_benchmark/__init__.py @@ -1,4 +1,18 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .mscclpp_op import MscclppAllReduce1, MscclppAllReduce2, MscclppAllReduce3, MscclppAllReduce4, MscclppAllReduce5 +__all__ = [ + "MscclppAllReduce1", + "MscclppAllReduce2", + "MscclppAllReduce3", + "MscclppAllReduce4", + "MscclppAllReduce5", +] + + +def __getattr__(name): + if name in __all__: + from . import mscclpp_op + + return getattr(mscclpp_op, name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/python/mscclpp_benchmark/bench_collective.py b/python/mscclpp_benchmark/bench_collective.py new file mode 100644 index 000000000..c526438da --- /dev/null +++ b/python/mscclpp_benchmark/bench_collective.py @@ -0,0 +1,645 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import argparse +from dataclasses import dataclass +from typing import Any + +import cupy as cp +from mpi4py import MPI + +_mscclpp_module = None + +from mscclpp_benchmark.comm import Comm +from mscclpp_benchmark.correctness import ( + CorrectnessStats, + check_correctness as _check_correctness, + fill_case_for_benchmark as _fill_case_for_benchmark, +) +from mscclpp_benchmark.gpu import capture_graph, init_runtime +from mscclpp_benchmark.tuner import OfflineTuner +from mscclpp_benchmark.tuning_config import HardwareProfile, TunedConfig, TunedConfigStore, normalize_sku + +_ALLREDUCE = "allreduce" +_ALLGATHER = "allgather" +_DEFAULT_BATCH_SIZES = ( + 1, + 2, + 3, + 4, + 8, + 16, + 24, + 32, + 48, + 64, + 96, + 128, + 256, + 512, + 1024, + 1280, + 1536, + 1792, + 2048, + 2560, + 3072, + 3584, + 4096, +) +_DEFAULT_CANDIDATE_NBLOCKS = (1, 4, 8, 16, 24, 32, 48, 56, 64) +_DEFAULT_CANDIDATE_NTHREADS = (256, 512, 768, 1024) + + +def _mscclpp(): + global _mscclpp_module + if _mscclpp_module is None: + import mscclpp + import mscclpp.ext + + _mscclpp_module = mscclpp + return _mscclpp_module + + +@dataclass(frozen=True) +class DTypeSpec: + name: str + cupy_dtype: Any + mscclpp_dtype: Any + accum_dtype: Any | None = None + fp8_format: str | None = None + + +@dataclass(frozen=True) +class CandidateSpec: + algorithm: str + min_message_size: int | None = None + max_message_size: int | None = None + max_nblocks: int | None = None + supported_skus: tuple[str, ...] | None = None + requires_nvls: bool = False + requires_symmetric_memory: bool = False + + +@dataclass +class BenchmarkCase: + collective: str + message_size: int + total_size: int + input: cp.ndarray + output: cp.ndarray + dtype_spec: DTypeSpec + symmetric_memory: bool = False + + +def _device_name() -> str: + props = cp.cuda.runtime.getDeviceProperties(cp.cuda.Device().id) + name = props.get("name", "UNKNOWN") + if isinstance(name, bytes): + return name.decode("utf-8") + return str(name) + + +def _detect_hardware_profile(scale: int) -> HardwareProfile: + return HardwareProfile(sku=normalize_sku(_device_name()), scale=scale) + + +def _parse_dtype(dtype_name: str) -> DTypeSpec: + mscclpp = _mscclpp() + normalized = dtype_name.strip().lower().replace("-", "_") + if normalized in {"float16", "fp16", "half"}: + return DTypeSpec("float16", cp.float16, mscclpp.DataType.float16) + if normalized in {"float32", "fp32", "float"}: + return DTypeSpec("float32", cp.float32, mscclpp.DataType.float32) + if normalized in {"int32", "i32"}: + return DTypeSpec("int32", cp.int32, mscclpp.DataType.int32) + if normalized in {"uint8", "u8"}: + return DTypeSpec("uint8", cp.uint8, mscclpp.DataType.uint8) + if normalized in {"float8_e4m3fn", "fp8_e4m3fn"}: + return DTypeSpec( + "float8_e4m3fn", + cp.uint8, + mscclpp.DataType.float8_e4m3fn, + accum_dtype=mscclpp.DataType.float16, + fp8_format="e4m3fn", + ) + if normalized in {"float8_e4m3fnuz", "fp8_e4m3fnuz"}: + return DTypeSpec( + "float8_e4m3fnuz", + cp.uint8, + mscclpp.DataType.float8_e4m3fnuz, + accum_dtype=mscclpp.DataType.float16, + fp8_format="e4m3fnuz", + ) + if normalized in {"float8_e4m3b15", "fp8_e4m3b15"}: + return DTypeSpec( + "float8_e4m3b15", + cp.uint8, + mscclpp.DataType.float8_e4m3b15, + accum_dtype=mscclpp.DataType.float32, + fp8_format="e4m3b15", + ) + raise ValueError( + f"Unsupported dtype {dtype_name!r}; use float16, float32, int32, uint8, " + "float8_e4m3fn, float8_e4m3fnuz, or float8_e4m3b15" + ) + + +def _with_accum_type(dtype_spec: DTypeSpec, accum_type: str | None) -> DTypeSpec: + if accum_type is None: + return dtype_spec + + mscclpp = _mscclpp() + normalized = accum_type.strip().lower().replace("-", "_") + if normalized in {"native", "same", "auto"}: + accum_dtype = dtype_spec.mscclpp_dtype + elif normalized in {"float16", "fp16", "half"}: + accum_dtype = mscclpp.DataType.float16 + elif normalized in {"float32", "fp32", "float"}: + accum_dtype = mscclpp.DataType.float32 + else: + raise ValueError(f"Unsupported accum type {accum_type!r}; use native, float16, or float32") + + return DTypeSpec( + name=dtype_spec.name, + cupy_dtype=dtype_spec.cupy_dtype, + mscclpp_dtype=dtype_spec.mscclpp_dtype, + accum_dtype=accum_dtype, + fp8_format=dtype_spec.fp8_format, + ) + + +def _human_size(size: int) -> str: + value = float(size) + for unit in ("B", "KiB", "MiB", "GiB", "TiB"): + if value < 1024.0 or unit == "TiB": + return f"{value:.1f} {unit}" + value /= 1024.0 + raise AssertionError("unreachable") + + +def _parse_int_list(raw: str | None, default: tuple[int, ...]) -> tuple[int, ...]: + if raw is None: + return default + values = tuple(sorted({int(item.strip()) for item in raw.split(",") if item.strip()})) + if not values or values[0] <= 0: + raise ValueError(f"Expected a comma-separated list of positive integers, got {raw!r}") + return values + + +def _candidate_specs(collective: str, *, symmetric_memory: bool = False) -> tuple[CandidateSpec, ...]: + if collective == _ALLGATHER: + return (CandidateSpec("default_allgather_fullmesh2", max_nblocks=64, supported_skus=("MI300X",)),) + if collective != _ALLREDUCE: + raise ValueError(f"Unsupported collective: {collective}") + candidates = ( + CandidateSpec( + "default_allreduce_nvls_packet", + max_message_size=512 * 1024, + max_nblocks=16, + supported_skus=("H100", "GB300"), + requires_nvls=True, + ), + CandidateSpec( + "default_allreduce_packet", + max_message_size=4 * 1024 * 1024, + max_nblocks=56, + ), + CandidateSpec( + "default_allreduce_allpair_packet", + max_message_size=4 * 1024 * 1024, + max_nblocks=56, + ), + CandidateSpec( + "default_allreduce_rsag_zero_copy", + min_message_size=512 * 1024 + 1, + ), + CandidateSpec( + "default_allreduce_fullmesh", + min_message_size=512 * 1024 + 1, + max_nblocks=64, + supported_skus=("MI300X",), + ), + ) + if symmetric_memory: + return ( + CandidateSpec( + "default_allreduce_nvls_zero_copy", + max_nblocks=32, + supported_skus=("H100", "GB300"), + requires_nvls=True, + requires_symmetric_memory=True, + ), + *candidates, + ) + return candidates + + +def _candidate_algorithms(comm: Comm, case: BenchmarkCase) -> list[tuple[Any, CandidateSpec]]: + available = comm.algorithms.get(case.collective, {}) + candidates: list[tuple[Any, CandidateSpec]] = [] + seen: set[str] = set() + symmetric_memory = case.symmetric_memory + profile = getattr(comm, "hardware_profile", None) + filtered_out = False + for candidate in _candidate_specs(case.collective, symmetric_memory=symmetric_memory): + if not _candidate_supports_profile(candidate, profile): + filtered_out = True + continue + if not _candidate_supports_message_size(candidate, case.message_size): + filtered_out = True + continue + if candidate.requires_nvls and not _mscclpp().is_nvls_supported(): + filtered_out = True + continue + if candidate.requires_symmetric_memory and not symmetric_memory: + filtered_out = True + continue + algorithm = available.get(candidate.algorithm) + if algorithm is None or algorithm.name in seen: + continue + seen.add(algorithm.name) + candidates.append((algorithm, candidate)) + if candidates: + return candidates + if filtered_out: + return [] + return [(algorithm, CandidateSpec(algorithm.name)) for algorithm in available.values()] + + +def _candidate_supports_profile(candidate: CandidateSpec, profile: HardwareProfile | None) -> bool: + if candidate.supported_skus is None: + return True + sku = None if profile is None else profile.sku + if not sku or sku == "UNKNOWN": + return True + return sku in candidate.supported_skus + + +def _candidate_supports_message_size(candidate: CandidateSpec, message_size: int) -> bool: + if candidate.min_message_size is not None and message_size < candidate.min_message_size: + return False + if candidate.max_message_size is not None and message_size > candidate.max_message_size: + return False + return True + + +def _make_case( + *, + collective: str, + nelems: int, + dtype_spec: DTypeSpec, + comm_group: Any, + buffer_mode: str, + symmetric_memory: bool = False, +) -> BenchmarkCase: + if buffer_mode not in ("in-place", "out-of-place"): + raise ValueError(f"Unsupported buffer mode: {buffer_mode}") + + if collective == _ALLREDUCE: + if buffer_mode == "in-place": + memory = _mscclpp().GpuBuffer(nelems, dtype=dtype_spec.cupy_dtype) + input_buffer = memory + output = memory + else: + input_buffer = _mscclpp().GpuBuffer(nelems, dtype=dtype_spec.cupy_dtype) + output = _mscclpp().GpuBuffer(nelems, dtype=dtype_spec.cupy_dtype) + return BenchmarkCase( + collective=collective, + message_size=input_buffer.nbytes, + total_size=output.nbytes, + input=input_buffer, + output=output, + dtype_spec=dtype_spec, + symmetric_memory=symmetric_memory, + ) + + if collective != _ALLGATHER: + raise ValueError(f"Unsupported collective: {collective}") + + if buffer_mode == "in-place": + output = _mscclpp().GpuBuffer(nelems * comm_group.nranks, dtype=dtype_spec.cupy_dtype) + start = comm_group.my_rank * nelems + input_buffer = output[start : start + nelems] + else: + input_buffer = _mscclpp().GpuBuffer(nelems, dtype=dtype_spec.cupy_dtype) + output = _mscclpp().GpuBuffer(nelems * comm_group.nranks, dtype=dtype_spec.cupy_dtype) + + return BenchmarkCase( + collective=collective, + message_size=input_buffer.nbytes, + total_size=output.nbytes, + input=input_buffer, + output=output, + dtype_spec=dtype_spec, + symmetric_memory=symmetric_memory, + ) + + +def _try_measure_case( + comm: Comm, + case: BenchmarkCase, + config: TunedConfig, + *, + n_warmup: int, + n_graph_launches: int, + n_ops_per_graph: int, +) -> float | None: + try: + return _measure_case( + comm, + case, + config, + n_warmup=n_warmup, + n_graph_launches=n_graph_launches, + n_ops_per_graph=n_ops_per_graph, + ) + except Exception as exc: + if comm.rank == 0: + print( + f"[skip] {config.algorithm} nb={config.nblocks} nt={config.nthreads} " + f"size={case.message_size}: {type(exc).__name__}: {exc}", + flush=True, + ) + return None + + +def _measure_case( + comm: Comm, + case: BenchmarkCase, + config: TunedConfig, + *, + n_warmup: int, + n_graph_launches: int, + n_ops_per_graph: int, +) -> float: + _fill_case_for_benchmark(case, comm.rank) + comm.comm_group.barrier() + if comm.run(case, config) != 0: + raise RuntimeError("algorithm returned non-zero status") + cp.cuda.runtime.deviceSynchronize() + comm.comm_group.barrier() + + stream = cp.cuda.Stream(non_blocking=True) + graph = None + + def capture_ops() -> None: + for _ in range(n_ops_per_graph): + ret = comm.run(case, config, stream) + if ret != 0: + raise RuntimeError("algorithm returned non-zero status during graph capture") + + try: + with stream: + graph = capture_graph(stream, capture_ops) + + for _ in range(n_warmup): + graph.launch(stream) + stream.synchronize() + comm.comm_group.barrier() + + start = cp.cuda.Event() + end = cp.cuda.Event() + start.record(stream) + for _ in range(n_graph_launches): + graph.launch(stream) + end.record(stream) + end.synchronize() + + elapsed_us = cp.cuda.get_elapsed_time(start, end) * 1000.0 / (n_graph_launches * n_ops_per_graph) + return float(MPI.COMM_WORLD.allreduce(elapsed_us, op=MPI.MAX)) + finally: + if graph is not None: + graph.close() + + +def _bandwidth_gbps(num_bytes: int, time_us: float) -> float: + return num_bytes / time_us / 1e3 + + +def _busbw_factor(collective: str, nranks: int) -> float: + if nranks <= 1: + return 1.0 + if collective == _ALLREDUCE: + return 2 * (nranks - 1) / nranks + if collective == _ALLGATHER: + return (nranks - 1) / nranks + raise ValueError(f"Unsupported collective: {collective}") + + +def _format_table(headers: list[str], rows: list[list[str]]) -> str: + widths = [len(header) for header in headers] + for row in rows: + widths = [max(width, len(cell)) for width, cell in zip(widths, row)] + header_line = " | ".join(header.ljust(width) for header, width in zip(headers, widths)) + sep_line = "-+-".join("-" * width for width in widths) + row_lines = [" | ".join(cell.ljust(width) for cell, width in zip(row, widths)) for row in rows] + return "\n".join([header_line, sep_line, *row_lines]) + + +def _format_stat(value: float | None) -> str: + if value is None: + return "-" + return f"{value:.6g}" + + +def _format_mismatches(stats: CorrectnessStats | None) -> str: + if stats is None or stats.total == 0: + return "-" + return f"{stats.mismatches}/{stats.total}" + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Benchmark MSCCL++ collectives without PyTorch dependencies") + parser.add_argument("--collective", choices=(_ALLREDUCE, _ALLGATHER), default=_ALLREDUCE) + parser.add_argument("--d-model", type=int, default=5120) + parser.add_argument("--dtype", default="float16") + parser.add_argument("--accum-type", help="Accumulation type for reductions: native, float16, or float32") + parser.add_argument("--batch-sizes", help="Comma-separated batch sizes; default uses the benchmark sweep") + parser.add_argument( + "--buffer-mode", + choices=("in-place", "out-of-place"), + default="in-place", + help="Buffer layout for the collective: in-place (input aliases output) or out-of-place (separate buffers)", + ) + parser.add_argument("--config-path", help="Optional MSCCL++ tuned config JSON") + parser.add_argument("--write-config", help="Write autotuned configs to this JSON path") + parser.add_argument("--autotune", action="store_true", help="Tune each benchmark size before timing it") + parser.add_argument("--skip-correctness", action="store_true") + parser.add_argument("--correctness-iters", type=int, default=1) + parser.add_argument("--scratch-buffer-size", type=int, default=1 << 27) + parser.add_argument("--warmup", type=int, default=5, help="Warmup graph replays before benchmark timing") + parser.add_argument("--graph-launches", type=int, default=10, help="Timed graph replays") + parser.add_argument("--iterations", type=int, default=100, help="Collective operations captured per CUDA graph") + parser.add_argument("--tune-warmup", type=int, default=2) + parser.add_argument("--tune-graph-launches", type=int, default=3) + parser.add_argument("--tune-iterations", type=int, default=20) + parser.add_argument("--candidate-nblocks", help="Comma-separated nblocks tuning candidates") + parser.add_argument("--candidate-nthreads", help="Comma-separated nthreads tuning candidates") + parser.add_argument("--symmetric-memory", action="store_true") + return parser + + +def _validate_args(args: argparse.Namespace) -> None: + for name in ( + "d_model", + "scratch_buffer_size", + "graph_launches", + "iterations", + "tune_graph_launches", + "tune_iterations", + "correctness_iters", + ): + if getattr(args, name) <= 0: + raise ValueError(f"--{name.replace('_', '-')} must be positive") + if args.warmup < 0 or args.tune_warmup < 0: + raise ValueError("warmup counts must be non-negative") + + +def main(argv: list[str] | None = None) -> None: + args = _build_parser().parse_args(argv) + _validate_args(args) + init_runtime() + + local_comm = MPI.COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED, 0, MPI.INFO_NULL) + try: + visible_devices = cp.cuda.runtime.getDeviceCount() + if visible_devices <= 0: + raise RuntimeError("MSCCL++ benchmark requires at least one visible GPU") + cp.cuda.Device(local_comm.Get_rank() % visible_devices).use() + finally: + local_comm.Free() + + dtype_spec = _with_accum_type(_parse_dtype(args.dtype), args.accum_type) + batch_sizes = _parse_int_list(args.batch_sizes, _DEFAULT_BATCH_SIZES) + candidate_nblocks = _parse_int_list(args.candidate_nblocks, _DEFAULT_CANDIDATE_NBLOCKS) + candidate_nthreads = _parse_int_list(args.candidate_nthreads, _DEFAULT_CANDIDATE_NTHREADS) + + comm_group = _mscclpp().CommGroup(MPI.COMM_WORLD) + setattr(comm_group, "_mpi_comm", MPI.COMM_WORLD) + hardware_profile = _detect_hardware_profile(comm_group.nranks) + config_store = TunedConfigStore.load_path(args.config_path) if args.config_path else TunedConfigStore.empty() + comm = Comm( + comm_group, + config_store=config_store, + hardware_profile=hardware_profile, + scratch_buffer_size=args.scratch_buffer_size, + ) + tuner = OfflineTuner( + comm, + candidate_nblocks=candidate_nblocks, + candidate_nthreads=candidate_nthreads, + n_warmup=args.tune_warmup, + n_graph_launches=args.tune_graph_launches, + n_ops_per_graph=args.tune_iterations, + candidate_algorithms=_candidate_algorithms, + check_correctness=_check_correctness, + measure=_try_measure_case, + ) + + rows: list[list[str]] = [] + try: + if comm.rank == 0: + print( + f"MSCCL++ {args.collective} benchmark: profile={hardware_profile} dtype={dtype_spec.name} " + f"graph_launches={args.graph_launches} iterations={args.iterations}", + flush=True, + ) + + for batch_size in batch_sizes: + nelems = batch_size * args.d_model + case = _make_case( + collective=args.collective, + nelems=nelems, + dtype_spec=dtype_spec, + comm_group=comm_group, + buffer_mode=args.buffer_mode, + symmetric_memory=args.symmetric_memory, + ) + config = tuner.tune(case) if args.autotune else comm.resolve_config(case) + if config is None: + continue + if args.autotune: + config_store.upsert(hardware_profile, args.collective, case.message_size, config) + + correctness = "SKIP" + correctness_stats: CorrectnessStats | None = None + if not args.skip_correctness: + correctness_stats = _check_correctness(comm, case, config, niter=args.correctness_iters) + correctness = "PASS" if correctness_stats else "FAIL" + comm.reset(config) + if correctness != "PASS": + raise RuntimeError( + f"Correctness failed for batch_size={batch_size}, message_size={case.message_size}, " + f"config={config}" + ) + + time_us = _measure_case( + comm, + case, + config, + n_warmup=args.warmup, + n_graph_launches=args.graph_launches, + n_ops_per_graph=args.iterations, + ) + comm.reset(config) + + algbw = _bandwidth_gbps(case.total_size, time_us) + busbw = algbw * _busbw_factor(args.collective, comm_group.nranks) + rows.append( + [ + str(batch_size), + _human_size(case.message_size), + _human_size(case.total_size), + config.algorithm, + str(config.nblocks or "auto"), + str(config.nthreads or "auto"), + f"{time_us:.2f}", + f"{algbw:.2f}", + f"{busbw:.2f}", + correctness, + _format_stat(None if correctness_stats is None else correctness_stats.max_abs_diff), + _format_stat(None if correctness_stats is None else correctness_stats.mean_abs_diff), + _format_mismatches(correctness_stats), + ] + ) + if comm.rank == 0: + print(".", end="", flush=True) + + if args.write_config and comm.rank == 0: + config_store.write_path(args.write_config) + print(f"\nWrote tuned config to {args.write_config}", flush=True) + + if comm.rank == 0: + print( + "\n" + + _format_table( + [ + "batch", + "msg", + "total", + "algorithm", + "nblocks", + "nthreads", + "time_us", + "algBW_GB/s", + "busBW_GB/s", + "check", + "max_diff", + "mean_diff", + "mismatch", + ], + rows, + ), + flush=True, + ) + finally: + comm_group.barrier() + cp.cuda.runtime.deviceSynchronize() + comm.close() + + +if __name__ == "__main__": + main() diff --git a/python/mscclpp_benchmark/comm.py b/python/mscclpp_benchmark/comm.py new file mode 100644 index 000000000..23770ac23 --- /dev/null +++ b/python/mscclpp_benchmark/comm.py @@ -0,0 +1,409 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import logging +from typing import Any + +logger = logging.getLogger(__name__) +_ALLREDUCE_COLLECTIVE = "allreduce" +_ALLGATHER_COLLECTIVE = "allgather" +_mscclpp_module = None + +from mscclpp_benchmark.gpu import current_device, device_name, set_device +from mscclpp_benchmark.tuning_config import HardwareProfile, TunedConfig, TunedConfigStore, normalize_sku + + +def _mscclpp(): + global _mscclpp_module + if _mscclpp_module is None: + import mscclpp + import mscclpp.ext + + _mscclpp_module = mscclpp + return _mscclpp_module + + +class Buffer: + def __init__( + self, + nbytes: int | None = None, + *, + dtype: str | Any = "float16", + shape: tuple[int, ...] | None = None, + buffer: Any | None = None, + ) -> None: + self.dtype = dtype + self.element_size = _dtype_size(dtype) + if buffer is None: + if nbytes is None: + if shape is None: + raise ValueError("Either nbytes or shape is required") + nbytes = _numel(shape) * self.element_size + _ensure_device() + buffer = _mscclpp().RawGpuBuffer(int(nbytes)) + self.buffer = buffer + self.nbytes = int(buffer.bytes()) + self.shape = shape if shape is not None else (self.nbytes // self.element_size,) + + @property + def ndim(self) -> int: + return len(self.shape) + + @property + def size(self) -> int: + return _numel(self.shape) + + def data_ptr(self) -> int: + return int(self.buffer.data()) + + +class _AllReduceOp: + def __init__(self, comm: "Comm", x: Any, *, symmetric_memory: bool = False) -> None: + self._comm = comm + self._x = x + self._symmetric_memory = symmetric_memory + + def __call__(self, **_: Any) -> Any: + self._comm.run(self._x, symmetric_memory=self._symmetric_memory) + return self._x + + +class _AllGatherOp: + def __init__(self, comm: "Comm", x: Any, *, dim: int, y: Any | None = None, symmetric_memory: bool = False) -> None: + shape = _shape(x) + if len(shape) == 0: + raise ValueError("MSCCL++ allgather requires a non-scalar buffer") + if dim % len(shape) != 0: + raise NotImplementedError("Raw-buffer allgather currently supports only dim=0") + if y is None: + y_shape = (comm._scale() * shape[0], *shape[1:]) + y = Buffer(dtype=_dtype(x), shape=y_shape) + self._comm = comm + self._x = x + self.y = y + self._symmetric_memory = symmetric_memory + + def __call__(self, **_: Any) -> Any: + self._comm.run( + self._x, + collective=_ALLGATHER_COLLECTIVE, + output_tensor=self.y, + symmetric_memory=self._symmetric_memory, + ) + return self.y + + +class Comm: + """Runtime MSCCL++ wrapper that owns algorithm handles and execution without Torch/CuPy tensors.""" + + def __init__( + self, + comm_group: Any, + scratch_buffer_size: int = 1 << 27, + *, + config_store: "TunedConfigStore | None" = None, + hardware_profile: HardwareProfile | None = None, + ) -> None: + self._comm_group = comm_group + self._mpi_comm = getattr(comm_group, "_mpi_comm", None) + self._rank = comm_group.my_rank + self._closed = False + _ensure_device() + self._mscclpp = _mscclpp() + self._scratch_buffer = self._mscclpp.RawGpuBuffer(scratch_buffer_size) + self._config_store = TunedConfigStore.empty() if config_store is None else config_store + self._hardware_profile = ( + _detect_hardware_profile(scale=self._scale()) if hardware_profile is None else hardware_profile + ) + self._default_config_warning_keys: set[tuple[str, str, str, int]] = set() + + algorithms = self._mscclpp.ext.AlgorithmCollectionBuilder().build_default_algorithms( + scratch_buffer=self._scratch_buffer.data(), + scratch_buffer_size=self._scratch_buffer.bytes(), + rank=self._rank, + ) + self._algorithms_by_collective: dict[str, dict[str, Any]] = {} + for algorithm in algorithms: + self._algorithms_by_collective.setdefault(algorithm.collective, {})[algorithm.name] = algorithm + + @property + def comm_group(self) -> Any: + return self._comm_group + + @property + def rank(self) -> int: + return self._rank + + @property + def nranks(self) -> int: + return self._comm_group.nranks + + @property + def algorithms(self) -> dict[str, dict[str, Any]]: + return self._algorithms_by_collective + + @property + def hardware_profile(self) -> HardwareProfile: + return self._hardware_profile + + def make_allreduce(self, x: Any, *, symmetric_memory: bool = False) -> _AllReduceOp: + return _AllReduceOp(self, x, symmetric_memory=symmetric_memory) + + def make_allgather(self, x: Any, dim: int, y: Any | None = None, *, symmetric_memory: bool = False) -> _AllGatherOp: + return _AllGatherOp(self, x, dim=dim, y=y, symmetric_memory=symmetric_memory) + + def _scale(self) -> int: + if self._mpi_comm is not None: + return int(self._mpi_comm.Get_size()) + return 1 + + def resolve_config(self, case: Any, *, symmetric_memory: bool = False) -> TunedConfig: + dtype_override = getattr(getattr(case, "dtype_spec", None), "mscclpp_dtype", None) + accum_dtype = getattr(getattr(case, "dtype_spec", None), "accum_dtype", None) or dtype_override + symmetric_memory = symmetric_memory or bool(getattr(case, "symmetric_memory", False)) + return self._resolve_config( + case.collective, + case.input, + dtype_override=dtype_override, + accum_dtype=accum_dtype, + symmetric_memory=symmetric_memory, + ) + + def _resolve_config( + self, + collective: str, + buffer: Any, + *, + dtype_override: Any | None = None, + accum_dtype: Any | None = None, + symmetric_memory: bool = False, + ) -> TunedConfig: + tuned_config = self._config_store.select(self._hardware_profile, collective, _nbytes(buffer)) + if tuned_config is not None and tuned_config.algorithm in self._algorithms_by_collective.get(collective, {}): + return tuned_config + + if self._rank == 0: + dim = int(_shape(buffer)[1]) if len(_shape(buffer)) > 1 else 1 + warning_key = ( + collective, + str(dtype_override if dtype_override is not None else _dtype(buffer)), + str( + accum_dtype + if accum_dtype is not None + else dtype_override if dtype_override is not None else _dtype(buffer) + ), + dim, + ) + if warning_key not in self._default_config_warning_keys: + self._default_config_warning_keys.add(warning_key) + logger.warning( + "MSCCL++ default config: no tuning for collective=%s profile=%s dtype=%s accum=%s dim=%s; perf may be poor", + collective, + self._hardware_profile, + warning_key[1], + warning_key[2], + dim, + ) + return _default_tuned_config( + collective, + _nbytes(buffer), + self._algorithms_by_collective, + symmetric_memory=symmetric_memory, + ) + + def run( + self, + buffer: Any, + config: TunedConfig | None = None, + stream: Any | None = None, + *, + collective: str = _ALLREDUCE_COLLECTIVE, + output_tensor: Any | None = None, + dtype_override: Any | None = None, + accum_dtype: Any | None = None, + symmetric_memory: bool = False, + ) -> int: + if self._closed: + raise RuntimeError("Cannot use a closed MSCCL++ comm") + + raise_on_error = True + if hasattr(buffer, "input") and hasattr(buffer, "output") and hasattr(buffer, "dtype_spec"): + case = buffer + buffer = case.input + output_tensor = case.output + collective = case.collective + dtype_override = case.dtype_spec.mscclpp_dtype + accum_dtype = case.dtype_spec.accum_dtype or dtype_override + symmetric_memory = symmetric_memory or bool(getattr(case, "symmetric_memory", False)) + raise_on_error = False + + if collective not in self._algorithms_by_collective: + raise RuntimeError(f"No supported MSCCL++ {collective} algorithm is available") + + if config is None: + config = self._resolve_config( + collective, + buffer, + dtype_override=dtype_override, + accum_dtype=accum_dtype, + symmetric_memory=symmetric_memory, + ) + symmetric_memory = symmetric_memory or config.symmetric_memory + algorithm = self._algorithms_by_collective[collective][config.algorithm] + output = buffer if output_tensor is None else output_tensor + dtype = dtype_override if dtype_override is not None else _dtype_to_mscclpp(_dtype(buffer)) + accum = accum_dtype if accum_dtype is not None else dtype + ret = algorithm.execute( + comm=self._comm_group.communicator, + input_buffer=_data_ptr(buffer), + output_buffer=_data_ptr(output), + input_size=_nbytes(buffer), + output_size=_nbytes(output), + dtype=dtype, + op=self._mscclpp.ReduceOp.SUM if collective == _ALLREDUCE_COLLECTIVE else self._mscclpp.ReduceOp.NOP, + stream=_stream_ptr(stream), + nblocks=config.nblocks or 0, + nthreads_per_block=config.nthreads or 0, + symmetric_memory=symmetric_memory, + accum_dtype=accum, + ) + if ret != 0 and raise_on_error: + raise RuntimeError(f"MSCCL++ {collective} failed on rank {self._rank} with error code {ret}") + return ret + + def reset(self, config: TunedConfig | None = None) -> None: + if config is not None: + for algorithms_by_name in self._algorithms_by_collective.values(): + algorithm = algorithms_by_name.get(config.algorithm) + if algorithm is not None: + algorithm.reset() + return + for algorithms_by_name in self._algorithms_by_collective.values(): + for algorithm in algorithms_by_name.values(): + algorithm.reset() + + def close(self) -> None: + self.reset() + self._algorithms_by_collective = {} + self._scratch_buffer = None + self._closed = True + self._mscclpp.ext.AlgorithmCollectionBuilder.reset() + + +def _numel(shape: tuple[int, ...]) -> int: + out = 1 + for dim in shape: + out *= int(dim) + return out + + +def _dtype_size(dtype: Any) -> int: + dtype_name = _dtype_name(dtype) + if dtype_name in {"float16", "bfloat16"}: + return 2 + if dtype_name in {"float32", "int32", "uint32"}: + return 4 + if dtype_name in {"uint8", "float8_e4m3b15", "float8_e4m3fn", "float8_e4m3fnuz"}: + return 1 + raise ValueError(f"Unknown data type size for {dtype}") + + +def _dtype_name(dtype: Any) -> str: + if isinstance(dtype, str): + return dtype.strip().lower().replace("-", "_") + name = str(dtype).rsplit(".", 1)[-1] + return name.strip().lower().replace("-", "_") + + +def _dtype_to_mscclpp(dtype: Any) -> Any: + dtype_name = _dtype_name(dtype) + mapping = { + "float16": _mscclpp().DataType.float16, + "float32": _mscclpp().DataType.float32, + "int32": _mscclpp().DataType.int32, + "uint8": _mscclpp().DataType.uint8, + "float8_e4m3b15": _mscclpp().DataType.float8_e4m3b15, + "float8_e4m3fn": _mscclpp().DataType.float8_e4m3fn, + "float8_e4m3fnuz": _mscclpp().DataType.float8_e4m3fnuz, + } + try: + return mapping[dtype_name] + except KeyError as exc: + raise ValueError(f"Unknown data type: {dtype}") from exc + + +def _data_ptr(buffer: Any) -> int: + if hasattr(buffer, "data_ptr"): + data_ptr = buffer.data_ptr + return int(data_ptr() if callable(data_ptr) else data_ptr) + if hasattr(buffer, "data"): + data = buffer.data + if callable(data): + return int(data()) + if hasattr(data, "ptr"): + return int(data.ptr) + raise TypeError(f"Cannot get device pointer from {type(buffer)!r}") + + +def _stream_ptr(stream: Any | None) -> int: + if stream is None: + return 0 + return int(getattr(stream, "ptr", stream)) + + +def _nbytes(buffer: Any) -> int: + if hasattr(buffer, "nbytes"): + return int(buffer.nbytes) + if hasattr(buffer, "bytes"): + value = buffer.bytes + return int(value() if callable(value) else value) + raise TypeError(f"Cannot get byte size from {type(buffer)!r}") + + +def _shape(buffer: Any) -> tuple[int, ...]: + shape = getattr(buffer, "shape", None) + if shape is None: + return (_nbytes(buffer) // _dtype_size(_dtype(buffer)),) + return tuple(int(dim) for dim in shape) + + +def _dtype(buffer: Any) -> Any: + dtype = getattr(buffer, "dtype", None) + if dtype is None: + return "uint8" + return dtype + + +def _detect_hardware_profile(*, scale: int) -> HardwareProfile: + try: + sku = device_name() + except Exception: + sku = "UNKNOWN" + return HardwareProfile(sku=normalize_sku(sku), scale=scale) + + +def _ensure_device() -> None: + set_device(current_device()) + + +def _default_tuned_config( + collective: str, + message_size: int, + algorithms_by_collective: dict[str, dict[str, Any]], + *, + symmetric_memory: bool = False, +) -> TunedConfig: + if collective == _ALLGATHER_COLLECTIVE: + return TunedConfig("default_allgather_fullmesh2", symmetric_memory=symmetric_memory) + available = algorithms_by_collective.get(collective, {}) + if symmetric_memory and _mscclpp().is_nvls_supported() and "default_allreduce_nvls_zero_copy" in available: + return TunedConfig("default_allreduce_nvls_zero_copy", symmetric_memory=True) + if message_size <= 512 * 1024 and "default_allreduce_packet" in available: + return TunedConfig("default_allreduce_packet", symmetric_memory=symmetric_memory) + if "default_allreduce_rsag_zero_copy" in available: + return TunedConfig("default_allreduce_rsag_zero_copy", symmetric_memory=symmetric_memory) + if available: + return TunedConfig(next(iter(available)), symmetric_memory=symmetric_memory) + raise RuntimeError(f"No MSCCL++ algorithm is available for {collective}") diff --git a/python/mscclpp_benchmark/correctness.py b/python/mscclpp_benchmark/correctness.py new file mode 100644 index 000000000..0d9ab5c13 --- /dev/null +++ b/python/mscclpp_benchmark/correctness.py @@ -0,0 +1,402 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Any + +import cupy as cp +from mpi4py import MPI + +_mscclpp_module = None + + +def _mscclpp(): + global _mscclpp_module + if _mscclpp_module is None: + import mscclpp + + _mscclpp_module = mscclpp + return _mscclpp_module + + +@dataclass(frozen=True) +class CorrectnessStats: + ok: bool + max_abs_diff: float = 0.0 + mean_abs_diff: float = 0.0 + mismatches: int = 0 + total: int = 0 + + def __bool__(self) -> bool: + return self.ok + + +def config_accum_dtype(case: Any) -> Any: + return case.dtype_spec.accum_dtype or case.dtype_spec.mscclpp_dtype + + +def fill_case_for_benchmark(case: Any, rank: int) -> None: + values = _benchmark_input_values(case, rank) + encoded = _encode_correctness_input(case, values) + if case.collective == "allreduce": + case.input[...] = encoded + return + case.output.fill(0) + case.input[...] = encoded + + +def check_correctness( + comm: Any, + case: Any, + config: Any, + *, + niter: int = 1, +) -> CorrectnessStats: + all_ok = True + local_max_abs_diff = 0.0 + local_sum_abs_diff = 0.0 + local_mismatches = 0 + local_total = 0 + for iteration in range(niter): + _fill_case_for_correctness(case, comm.rank, iteration) + comm.comm_group.barrier() + ret = comm.run(case, config) + cp.cuda.runtime.deviceSynchronize() + comm.comm_group.barrier() + if ret != 0: + all_ok = False + continue + + expected, stats_expected = _expected_outputs(case, comm.nranks, iteration) + iter_stats = _local_diff_stats(case, case.output, expected, comm.nranks, stats_expected=stats_expected) + local_ok = _compare_output(case, case.output, expected, comm.nranks) + all_ok = all_ok and local_ok + local_max_abs_diff = max(local_max_abs_diff, iter_stats.max_abs_diff) + local_sum_abs_diff += iter_stats.mean_abs_diff * iter_stats.total + local_mismatches += iter_stats.mismatches + local_total += iter_stats.total + + if not local_ok: + mismatch = _mismatch_mask(case, case.output, expected, comm.nranks) + print( + "not close: " + f"iter={iteration}, rank={comm.rank}, output={case.output[mismatch][0]}, " + f"expected={expected[mismatch][0]}, max_abs_diff={iter_stats.max_abs_diff:.6g}, " + f"mean_abs_diff={iter_stats.mean_abs_diff:.6g}, mismatches={iter_stats.mismatches}/{iter_stats.total}", + flush=True, + ) + + global_ok = bool(MPI.COMM_WORLD.allreduce(all_ok, op=MPI.LAND)) + global_max_abs_diff = float(MPI.COMM_WORLD.allreduce(local_max_abs_diff, op=MPI.MAX)) + global_sum_abs_diff = float(MPI.COMM_WORLD.allreduce(local_sum_abs_diff, op=MPI.SUM)) + global_mismatches = int(MPI.COMM_WORLD.allreduce(local_mismatches, op=MPI.SUM)) + global_total = int(MPI.COMM_WORLD.allreduce(local_total, op=MPI.SUM)) + global_mean_abs_diff = global_sum_abs_diff / global_total if global_total else 0.0 + return CorrectnessStats( + ok=global_ok, + max_abs_diff=global_max_abs_diff, + mean_abs_diff=global_mean_abs_diff, + mismatches=global_mismatches, + total=global_total, + ) + + +def _fill_case_for_correctness(case: Any, rank: int, iteration: int) -> None: + values = _correctness_input_values(case, rank, iteration) + encoded = _encode_correctness_input(case, values) + if case.collective == "allreduce": + case.input[...] = encoded + return + case.output.fill(0) + case.input[...] = encoded + + +def _correctness_input_values(case: Any, rank: int, iteration: int): + shape = case.input.shape + rng = cp.random.RandomState(_correctness_seed(rank, iteration)) + return _random_input_values(case, rng, shape) + + +def _benchmark_input_values(case: Any, rank: int): + rng = cp.random.RandomState(17_000_003 + rank) + return _random_input_values(case, rng, case.input.shape) + + +def _random_input_values(case: Any, rng, shape): + if case.dtype_spec.fp8_format is not None: + value_range = _fp8_correctness_input_range(case) + return rng.uniform(-value_range, value_range, size=shape).astype(cp.float32) + if case.dtype_spec.cupy_dtype == cp.int32: + return rng.randint(-1, 2, size=shape).astype(cp.int32) + if case.dtype_spec.cupy_dtype == cp.uint8: + return rng.randint(0, 2, size=shape).astype(cp.uint8) + return rng.uniform(-1.0, 1.0, size=shape).astype(cp.float32) + + +def _correctness_seed(rank: int, iteration: int) -> int: + return (iteration + 1) * 1_000_003 + rank + + +def _fp8_correctness_input_range(case: Any) -> float: + if case.collective != "allreduce": + return 1.0 + fp8_format = case.dtype_spec.fp8_format + if fp8_format is None: + return 1.0 + return min(1.0, _fp8_max_abs_value(fp8_format) / max(1, MPI.COMM_WORLD.size)) + + +def _encode_correctness_input(case: Any, values): + if case.dtype_spec.fp8_format is not None: + # FP8 buffers are stored as uint8 raw bytes, so a normal astype(uint8) cast would not produce FP8 bits. + return _encode_fp8_values(case.dtype_spec.fp8_format, values) + return values.astype(case.dtype_spec.cupy_dtype) + + +def _local_diff_stats(case: Any, output, expected, nranks: int, *, stats_expected=None) -> CorrectnessStats: + mismatch = _mismatch_mask(case, output, expected, nranks) + mismatches = int(cp.count_nonzero(mismatch).item()) + total = int(output.size) + if total == 0: + return CorrectnessStats(ok=mismatches == 0) + + output_values = _stats_values(case, output) + expected_values = _stats_values(case, expected) if stats_expected is None else stats_expected.astype(cp.float64) + abs_diff = cp.abs(output_values - expected_values) + return CorrectnessStats( + ok=mismatches == 0, + max_abs_diff=float(cp.max(abs_diff).item()), + mean_abs_diff=float(cp.mean(abs_diff).item()), + mismatches=mismatches, + total=total, + ) + + +def _stats_values(case: Any, values): + # Convert storage buffers into numeric values before computing max/mean diff. + if case.dtype_spec.fp8_format is not None: + return _decode_fp8_array(case.dtype_spec.fp8_format, values) + if cp.issubdtype(values.dtype, cp.floating): + return values.astype(cp.float64) + return values.astype(cp.int64) + + +def _expected_outputs(case: Any, nranks: int, iteration: int): + if case.collective == "allreduce": + encoded_inputs = _encoded_rank_inputs(case, nranks, iteration) + if case.dtype_spec.fp8_format is not None: + stats_expected = _expected_fp8_accum_values(case, encoded_inputs) + return _encode_reduced_output(case, stats_expected), stats_expected + return _encode_reduced_output(case, sum(values.astype(cp.float32) for values in encoded_inputs)), None + + expected = cp.empty_like(case.output) + chunk = case.input.size + for rank, values in enumerate(_encoded_rank_inputs(case, nranks, iteration)): + expected[rank * chunk : (rank + 1) * chunk] = values.reshape(-1) + return expected, None + + +def _encoded_rank_inputs(case: Any, nranks: int, iteration: int) -> list[Any]: + return [_encode_correctness_input(case, _correctness_input_values(case, rank, iteration)) for rank in range(nranks)] + + +def _expected_fp8_accum_values(case: Any, encoded_inputs: list[Any]): + fp8_format = case.dtype_spec.fp8_format + if fp8_format is None: + raise ValueError("FP8 format is required") + + accum_dtype = config_accum_dtype(case) + if accum_dtype == _mscclpp().DataType.float16: + acc = cp.zeros_like(_decode_fp8_array(fp8_format, encoded_inputs[0]), dtype=cp.float16) + for values in encoded_inputs: + acc = (acc + _decode_fp8_array(fp8_format, values).astype(cp.float16)).astype(cp.float16) + return acc.astype(cp.float32) + + if accum_dtype == _mscclpp().DataType.float32: + acc = cp.zeros_like(_decode_fp8_array(fp8_format, encoded_inputs[0]), dtype=cp.float32) + for values in encoded_inputs: + acc += _decode_fp8_array(fp8_format, values).astype(cp.float32) + return acc + + acc = encoded_inputs[0] + for values in encoded_inputs[1:]: + acc = _encode_fp8_values(fp8_format, _decode_fp8_array(fp8_format, acc) + _decode_fp8_array(fp8_format, values)) + return _decode_fp8_array(fp8_format, acc).astype(cp.float32) + + +def _encode_reduced_output(case: Any, values): + if case.dtype_spec.fp8_format is not None: + return _encode_fp8_values(case.dtype_spec.fp8_format, values) + return values.astype(case.output.dtype) + + +def _compare_output(case: Any, output, expected, nranks: int) -> bool: + return bool(cp.all(~_mismatch_mask(case, output, expected, nranks)).item()) + + +def _mismatch_mask(case: Any, output, expected, nranks: int): + tolerance = _comparison_tolerance(case, nranks) + if tolerance is None: + return output != expected + rtol, atol = tolerance + return ~cp.isclose(_stats_values(case, output), _stats_values(case, expected), rtol=rtol, atol=atol) + + +def _comparison_tolerance(case: Any, nranks: int) -> tuple[float, float] | None: + scale = max(1, nranks) if case.collective == "allreduce" else 1 + if case.dtype_spec.fp8_format is not None: + accum_dtype = config_accum_dtype(case) + if accum_dtype == _mscclpp().DataType.float32: + return None + atol = _max_fp8_spacing(case.dtype_spec.fp8_format, float(scale)) + if accum_dtype == _mscclpp().DataType.float16: + return (0.0, atol) + return (0.0, atol * 2) + if case.dtype_spec.cupy_dtype == cp.float16: + return (1.0e-2, 5.0e-4 * scale) + if case.dtype_spec.cupy_dtype == cp.float32: + return (1.0e-5 * scale, 1.0e-6 * scale) + return None + + +_FP8_TABLES: dict[str, list[tuple[int, float]]] = {} +_FP8_LOOKUP_CACHE: dict[str, tuple[Any, Any]] = {} +_FP8_SPACING_CACHE: dict[tuple[str, float], float] = {} + + +def _encode_fp8_values(fp8_format: str, values): + values = values.astype(cp.float32) + if fp8_format == "e4m3b15": + return _encode_e4m3b15_values(values) + + # Round each value to the nearest representable FP8 value (ties to even). + table_values, table_bytes = _fp8_lookup_arrays(fp8_format) + flat_values = values.ravel() + + # For each value find its two surrounding table entries: lower <= value <= upper. + upper = cp.clip(cp.searchsorted(table_values, flat_values), 1, table_values.size - 1) + lower = upper - 1 + + # Pick the closer neighbor; on an exact tie pick the one with an even byte. + dist_to_upper = table_values[upper] - flat_values + dist_to_lower = flat_values - table_values[lower] + upper_is_even = (table_bytes[upper] & cp.uint8(1)) == 0 + pick_upper = (dist_to_upper < dist_to_lower) | ((dist_to_upper == dist_to_lower) & upper_is_even) + + return cp.where(pick_upper, table_bytes[upper], table_bytes[lower]).reshape(values.shape) + + +def _fp8_lookup_arrays(fp8_format: str): + # Cache a sorted (value -> byte) table per format for fast nearest-value lookup. + if fp8_format in _FP8_LOOKUP_CACHE: + return _FP8_LOOKUP_CACHE[fp8_format] + + # Different bytes can decode to the same value (e.g. +0 and -0); keep one byte per value. + byte_for_value: dict[float, int] = {} + for byte, value in _FP8_TABLES.setdefault(fp8_format, _build_fp8_table(fp8_format)): + if value not in byte_for_value or byte < byte_for_value[value]: + byte_for_value[value] = byte + + table = sorted(byte_for_value.items()) + table_values = cp.asarray([value for value, _ in table], dtype=cp.float32) + table_bytes = cp.asarray([byte for _, byte in table], dtype=cp.uint8) + _FP8_LOOKUP_CACHE[fp8_format] = (table_values, table_bytes) + return _FP8_LOOKUP_CACHE[fp8_format] + + +def _max_fp8_spacing(fp8_format: str, max_abs_value: float) -> float: + cache_key = (fp8_format, max_abs_value) + if cache_key in _FP8_SPACING_CACHE: + return _FP8_SPACING_CACHE[cache_key] + + values = sorted( + { + value + for _, value in _FP8_TABLES.setdefault(fp8_format, _build_fp8_table(fp8_format)) + if abs(value) <= max_abs_value + } + ) + if len(values) < 2: + spacing = 0.0 + else: + spacing = max(right - left for left, right in zip(values, values[1:])) + _FP8_SPACING_CACHE[cache_key] = spacing + return spacing + + +def _fp8_max_abs_value(fp8_format: str) -> float: + return max(abs(value) for _, value in _FP8_TABLES.setdefault(fp8_format, _build_fp8_table(fp8_format))) + + +def _encode_e4m3b15_values(values): + # Mirrors the device e4m3b15 encode (gpu_data_types.hpp): clamp the fp16 intermediate + # to 0x3F80 (+/-1.875) so the max encodable byte is 0x7F/0xFF. + fp16_bits = values.astype(cp.float16).view(cp.uint16) + abs_fp16 = fp16_bits & cp.uint16(0x7FFF) + abs_fp16 = cp.minimum(abs_fp16, cp.uint16(0x3F80)).astype(cp.uint32) + sign16 = (fp16_bits & cp.uint16(0x8000)).astype(cp.uint32) + adjusted = abs_fp16 * cp.uint32(2) + cp.uint32(0x0080) + return (((sign16 | adjusted) >> cp.uint32(8)) & cp.uint32(0xFF)).astype(cp.uint8) + + +def _build_fp8_table(fp8_format: str) -> list[tuple[int, float]]: + table = [] + for byte in range(256): + value = _decode_fp8_scalar(fp8_format, byte) + if not math.isnan(value): + table.append((byte, value)) + return table + + +def _decode_fp8_scalar(fp8_format: str, byte: int) -> float: + if fp8_format == "e4m3fnuz" and byte == 0x80: + return float("nan") + sign = -1.0 if byte & 0x80 else 1.0 + return sign * _decode_fp8_positive(fp8_format, byte & 0x7F) + + +def _decode_fp8_positive(fp8_format: str, byte: int) -> float: + exp = (byte >> 3) & 0xF + mant = byte & 0x7 + if fp8_format == "e4m3fn" and exp == 0xF and mant == 0x7: + return float("nan") + if exp == 0 and mant == 0: + return 0.0 + if fp8_format == "e4m3fn": + return math.ldexp(mant / 8.0, -6) if exp == 0 else math.ldexp(1.0 + mant / 8.0, exp - 7) + if fp8_format == "e4m3fnuz": + return math.ldexp(mant / 8.0, -7) if exp == 0 else math.ldexp(1.0 + mant / 8.0, exp - 8) + if fp8_format == "e4m3b15": + return math.ldexp(mant / 8.0, -14) if exp == 0 else math.ldexp(1.0 + mant / 8.0, exp - 15) + raise ValueError(f"Unknown FP8 format: {fp8_format}") + + +def _decode_fp8_array(fp8_format: str, values): + bits = values.astype(cp.int32) + sign = (bits >> 7) & 1 + exp = (bits >> 3) & 0xF + mant = bits & 0x7 + + if fp8_format == "e4m3fn": + subnormal = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-6)) + normal = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), exp.astype(cp.int32) - 7) + decoded = cp.where(exp == 0, subnormal, normal) + decoded = cp.where((exp == 0xF) & (mant == 0x7), cp.nan, decoded) + elif fp8_format == "e4m3fnuz": + subnormal = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-7)) + normal = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), exp.astype(cp.int32) - 8) + decoded = cp.where(exp == 0, subnormal, normal) + elif fp8_format == "e4m3b15": + subnormal = cp.ldexp(mant.astype(cp.float32) / cp.float32(8.0), cp.int32(-14)) + normal = cp.ldexp(cp.float32(1.0) + mant.astype(cp.float32) / cp.float32(8.0), exp.astype(cp.int32) - 15) + decoded = cp.where(exp == 0, subnormal, normal) + else: + raise ValueError(f"Unknown FP8 format: {fp8_format}") + + result = cp.where(sign == 1, -decoded, decoded) + if fp8_format == "e4m3fnuz": + result = cp.where(bits == 0x80, cp.float32(float("nan")), result) + return result diff --git a/python/mscclpp_benchmark/gpu.py b/python/mscclpp_benchmark/gpu.py new file mode 100644 index 000000000..1309a504e --- /dev/null +++ b/python/mscclpp_benchmark/gpu.py @@ -0,0 +1,187 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable + +_API_NAMES = { + "get_device_count": ("hipGetDeviceCount", "cudaGetDeviceCount"), + "get_device": ("hipGetDevice", "cudaGetDevice"), + "get_device_properties": ("hipGetDeviceProperties", "cudaGetDeviceProperties"), + "set_device": ("hipSetDevice", "cudaSetDevice"), + "stream_begin_capture": ("hipStreamBeginCapture", "cudaStreamBeginCapture"), + "stream_end_capture": ("hipStreamEndCapture", "cudaStreamEndCapture"), + "graph_instantiate": ("hipGraphInstantiate", "cudaGraphInstantiate"), + "graph_launch": ("hipGraphLaunch", "cudaGraphLaunch"), + "graph_destroy": ("hipGraphDestroy", "cudaGraphDestroy"), + "graph_exec_destroy": ("hipGraphExecDestroy", "cudaGraphExecDestroy"), + "get_error_string": ("hipGetErrorString", "cudaGetErrorString"), +} + + +@dataclass(frozen=True) +class _Runtime: + name: str + success: Any + capture_mode_relaxed: Any + funcs: dict[str, Callable[..., Any] | None] + + @classmethod + def create(cls, name: str, module: Any, success: Any, capture_mode_relaxed: Any) -> "_Runtime": + index = 0 if name == "hip" else 1 + funcs = { + attr: (None if names[index] is None else getattr(module, names[index])) + for attr, names in _API_NAMES.items() + } + return cls(name=name, success=success, capture_mode_relaxed=capture_mode_relaxed, funcs=funcs) + + def call(self, name: str, *args: Any) -> tuple[Any, ...]: + fn = self.funcs[name] + if fn is None: + raise RuntimeError(f"{name} is not available for {self.name}") + result = fn(*args) + if not isinstance(result, tuple): + result = (result,) + self.check(result[0], name) + return result[1:] + + def check(self, error: Any, api: str) -> None: + if error == self.success: + return + result = self.funcs["get_error_string"](error) + if not isinstance(result, tuple): + result = (result,) + err, message = result + if err != self.success: + raise RuntimeError(f"{api} failed with error {int(error)}") + decoded = message.decode("utf-8") if isinstance(message, bytes) else str(message) + raise RuntimeError(f"{api} failed: {decoded} ({int(error)})") + + +def _load_runtime() -> _Runtime: + errors: list[str] = [] + + try: + from hip import hip + + runtime = _Runtime.create( + name="hip", + module=hip, + success=hip.hipError_t.hipSuccess, + capture_mode_relaxed=hip.hipStreamCaptureMode.hipStreamCaptureModeRelaxed, + ) + count = runtime.call("get_device_count")[0] + if count and count > 0: + return runtime + errors.append(f"hipGetDeviceCount returned count={count}") + except ImportError as exc: + errors.append(f"hip-python unavailable: {exc}") + + try: + from cuda.bindings import runtime as cuda_runtime + + runtime = _Runtime.create( + name="cuda", + module=cuda_runtime, + success=cuda_runtime.cudaError_t.cudaSuccess, + capture_mode_relaxed=cuda_runtime.cudaStreamCaptureMode.cudaStreamCaptureModeRelaxed, + ) + count = runtime.call("get_device_count")[0] + if count and count > 0: + return runtime + errors.append(f"cudaGetDeviceCount returned count={count}") + except ImportError as exc: + errors.append(f"cuda-bindings unavailable: {exc}") + + raise RuntimeError("No usable CUDA/HIP Python runtime found: " + "; ".join(errors)) + + +_RUNTIME = _load_runtime() + + +class Graph: + def __init__(self, graph_exec: Any) -> None: + self._graph_exec = graph_exec + + def launch(self, stream: Any) -> None: + _api("graph_launch")(self._graph_exec, _stream_ptr(stream)) + + def close(self) -> None: + if self._graph_exec is not None: + _api("graph_exec_destroy")(self._graph_exec) + self._graph_exec = None + + +def init_runtime() -> None: + return None + + +def capture_graph(stream: Any, capture_fn: Callable[[], None]) -> Graph: + _api("set_device")(current_device()) + stream_ptr = _stream_ptr(stream) + _api("stream_begin_capture")(stream_ptr, _RUNTIME.capture_mode_relaxed) + + graph = None + try: + capture_fn() + graph = _api("stream_end_capture")(stream_ptr)[0] + except Exception: + try: + _api("stream_end_capture")(stream_ptr) + except Exception: + pass + raise + + try: + graph_exec = _instantiate_graph(graph) + return Graph(graph_exec) + finally: + if graph is not None: + _api("graph_destroy")(graph) + + +def current_device() -> int: + return int(_api("get_device")()[0]) + + +def device_name(device_id: int | None = None) -> str: + if device_id is None: + device_id = current_device() + prop = _api("get_device_properties")(int(device_id))[0] + name = getattr(prop, "name", "UNKNOWN") + return name.decode("utf-8") if isinstance(name, bytes) else str(name) + + +def _stream_ptr(stream: Any) -> int: + return int(getattr(stream, "ptr", stream)) + + +def _instantiate_graph(graph: Any) -> Any: + if _RUNTIME.name == "hip": + return _api("graph_instantiate")(graph, None, 0)[0] + return _api("graph_instantiate")(graph, 0)[0] + + +def _api(name: str) -> Callable[..., tuple[Any, ...]]: + api = globals().get(name) + if api is None: + api = __getattr__(name) + return api + + +def _make_api(name: str) -> Callable[..., tuple[Any, ...]]: + def api(*args: Any) -> tuple[Any, ...]: + return _RUNTIME.call(name, *args) + + api.__name__ = name + return api + + +def __getattr__(name: str) -> Callable[..., tuple[Any, ...]]: + if name in _API_NAMES: + api = _make_api(name) + globals()[name] = api + return api + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/python/mscclpp_benchmark/tuner.py b/python/mscclpp_benchmark/tuner.py new file mode 100644 index 000000000..8df3259b9 --- /dev/null +++ b/python/mscclpp_benchmark/tuner.py @@ -0,0 +1,84 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import Any, Callable, Iterable + +from mscclpp_benchmark.tuning_config import TunedConfig + + +class OfflineTuner: + def __init__( + self, + comm: Any, + *, + candidate_nblocks: Iterable[int], + candidate_nthreads: Iterable[int], + n_warmup: int, + n_graph_launches: int, + n_ops_per_graph: int, + candidate_algorithms: Callable[[Any, Any], list[tuple[Any, Any]]], + check_correctness: Callable[..., bool], + measure: Callable[..., float | None], + ) -> None: + self.comm = comm + self.candidate_nblocks = tuple(candidate_nblocks) + self.candidate_nthreads = tuple(candidate_nthreads) + self.n_warmup = n_warmup + self.n_graph_launches = n_graph_launches + self.n_ops_per_graph = n_ops_per_graph + self._candidate_algorithms = candidate_algorithms + self._check_correctness = check_correctness + self._measure = measure + + def tune(self, case: Any) -> TunedConfig | None: + best_config: TunedConfig | None = None + best_time_us = float("inf") + symmetric_memory = bool(getattr(case, "symmetric_memory", False)) + candidates = self._candidate_algorithms(self.comm, case) + if not candidates: + if self.comm.rank == 0: + print( + f"[skip] no supported tuning candidates for collective={case.collective} " + f"size={case.message_size}", + flush=True, + ) + return None + for algorithm, candidate_spec in candidates: + for nblocks in self.candidate_nblocks: + if candidate_spec.max_nblocks is not None and nblocks > candidate_spec.max_nblocks: + continue + for nthreads in self.candidate_nthreads: + config = TunedConfig( + algorithm=algorithm.name, + nblocks=nblocks, + nthreads=nthreads, + symmetric_memory=symmetric_memory, + ) + if not self._check_correctness(self.comm, case, config): + self.comm.reset(config) + continue + self.comm.reset(config) + time_us = self._measure( + self.comm, + case, + config, + n_warmup=self.n_warmup, + n_graph_launches=self.n_graph_launches, + n_ops_per_graph=self.n_ops_per_graph, + ) + self.comm.reset(config) + if time_us is None or time_us >= best_time_us: + continue + best_time_us = time_us + best_config = TunedConfig( + algorithm=algorithm.name, + nblocks=nblocks, + nthreads=nthreads, + symmetric_memory=symmetric_memory, + time_us=time_us, + ) + if best_config is None: + return self.comm.resolve_config(case) + return best_config diff --git a/python/mscclpp_benchmark/tuning_config.py b/python/mscclpp_benchmark/tuning_config.py new file mode 100644 index 000000000..2a914ec95 --- /dev/null +++ b/python/mscclpp_benchmark/tuning_config.py @@ -0,0 +1,242 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import json +import re +from bisect import bisect_left +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +_KNOWN_GPU_SKUS = ("GB300", "MI300X", "H100", "A100") + + +@dataclass(frozen=True) +class HardwareProfile: + sku: str | None = None + scale: int | None = None + + +@dataclass(frozen=True) +class TunedConfig: + algorithm: str + nblocks: int | None = None + nthreads: int | None = None + symmetric_memory: bool = False + time_us: float | None = None + + +@dataclass(order=True, frozen=True) +class TunedConfigBySize: + message_size: int + config: TunedConfig + + +class TunedConfigStore: + def __init__(self, profiles: dict[HardwareProfile, dict[str, list[TunedConfigBySize]]]) -> None: + self._profiles = profiles + + @classmethod + def empty(cls) -> "TunedConfigStore": + return cls({}) + + @classmethod + def load_path(cls, path: str | Path) -> "TunedConfigStore": + with Path(path).open("r", encoding="utf-8") as handle: + return cls.from_payload(json.load(handle)) + + @classmethod + def from_payload(cls, payload: Any) -> "TunedConfigStore": + if not isinstance(payload, dict): + raise ValueError("MSCCL++ tuned config must be a JSON object") + raw_profiles = payload.get("profiles") + if not isinstance(raw_profiles, list): + raise ValueError("MSCCL++ tuned config must contain a 'profiles' list") + profiles: dict[HardwareProfile, dict[str, list[TunedConfigBySize]]] = {} + for raw_profile in raw_profiles: + profile = _profile_from_payload(raw_profile) + profiles[profile] = _configs_by_collective_from_payload(raw_profile.get("collectives", {})) + return cls(profiles) + + def select(self, profile: HardwareProfile, collective: str, message_size: int) -> TunedConfig | None: + for _, configs_by_collective in _matching_profiles(self._profiles, profile): + config = _select_config(configs_by_collective, collective, message_size) + if config is not None: + return config + return None + + def upsert(self, profile: HardwareProfile, collective: str, message_size: int, config: TunedConfig) -> None: + configs = self._profiles.setdefault(profile, {}).setdefault(collective, []) + for index, existing in enumerate(configs): + if existing.message_size == message_size: + configs[index] = TunedConfigBySize(message_size, config) + break + else: + configs.append(TunedConfigBySize(message_size, config)) + configs.sort(key=lambda item: item.message_size) + + def write_path(self, path: str | Path) -> None: + profiles_payload: list[dict[str, Any]] = [] + for profile, configs_by_collective in sorted( + self._profiles.items(), + key=lambda item: (item[0].sku is None, item[0].sku or "", item[0].scale is None, item[0].scale or 0), + ): + collectives: dict[str, list[dict[str, Any]]] = {} + for collective, configs in sorted(configs_by_collective.items()): + collectives[collective] = [_config_entry_payload(item) for item in sorted(configs)] + profile_payload: dict[str, Any] = {} + if profile.sku is not None: + profile_payload["sku"] = profile.sku + if profile.scale is not None: + profile_payload["scale"] = profile.scale + profile_payload["collectives"] = collectives + profiles_payload.append(profile_payload) + + with Path(path).open("w", encoding="utf-8") as handle: + handle.write(_format_tuned_config_json({"version": 1, "profiles": profiles_payload})) + + +def normalize_sku(raw_sku: str) -> str: + upper_sku = raw_sku.upper() + for known_sku in _KNOWN_GPU_SKUS: + if known_sku in upper_sku: + return known_sku + normalized = re.sub(r"[^A-Z0-9]+", "_", upper_sku).strip("_") + return normalized or "UNKNOWN" + + +def _profile_from_payload(raw_profile: Any) -> HardwareProfile: + if not isinstance(raw_profile, dict): + raise ValueError(f"Invalid tuned config profile: {raw_profile!r}") + raw_sku = raw_profile.get("sku") + return HardwareProfile( + sku=None if raw_sku is None else normalize_sku(str(raw_sku)), + scale=_optional_positive_int(raw_profile.get("scale"), "scale"), + ) + + +def _matching_profiles( + profiles: dict[HardwareProfile, dict[str, list[TunedConfigBySize]]], + runtime_profile: HardwareProfile, +) -> list[tuple[int, dict[str, list[TunedConfigBySize]]]]: + matches: list[tuple[int, dict[str, list[TunedConfigBySize]]]] = [] + for profile, configs_by_collective in profiles.items(): + specificity = _profile_match_specificity(profile, runtime_profile) + if specificity is not None: + matches.append((specificity, configs_by_collective)) + return sorted(matches, key=lambda item: item[0], reverse=True) + + +def _profile_match_specificity(profile: HardwareProfile, runtime_profile: HardwareProfile) -> int | None: + specificity = 0 + if profile.sku is not None: + if profile.sku != runtime_profile.sku: + return None + specificity += 1 + if profile.scale is not None: + if profile.scale != runtime_profile.scale: + return None + specificity += 1 + return specificity + + +def _select_config( + configs_by_collective: dict[str, list[TunedConfigBySize]], collective: str, message_size: int +) -> TunedConfig | None: + configs = configs_by_collective.get(collective, []) + if not configs: + return None + sizes = [item.message_size for item in configs] + index = bisect_left(sizes, message_size) + if index == len(sizes): + return configs[-1].config + if sizes[index] == message_size or index == 0: + return configs[index].config + return configs[index - 1].config + + +def _configs_by_collective_from_payload(payload: Any) -> dict[str, list[TunedConfigBySize]]: + if not isinstance(payload, dict): + raise ValueError("MSCCL++ tuned config collectives must be an object") + + result: dict[str, list[TunedConfigBySize]] = {} + for collective, raw_entries in payload.items(): + if isinstance(raw_entries, dict): + raw_entries = raw_entries.get("configs", []) + if not isinstance(raw_entries, list): + continue + configs = [] + for raw_entry in raw_entries: + if not isinstance(raw_entry, dict): + raise ValueError(f"Invalid tuned config entry for {collective}: {raw_entry!r}") + configs.append( + TunedConfigBySize( + message_size=_parse_positive_int(raw_entry.get("message_size"), "message_size"), + config=TunedConfig( + algorithm=str(raw_entry["algorithm"]), + nblocks=_optional_int(raw_entry.get("nblocks")), + nthreads=_optional_int(raw_entry.get("nthreads")), + symmetric_memory=_optional_bool(raw_entry.get("symmetric_memory", False)), + time_us=_optional_float(raw_entry.get("time_us")), + ), + ) + ) + result[str(collective)] = sorted(configs) + return result + + +def _config_entry_payload(item: TunedConfigBySize) -> dict[str, Any]: + payload: dict[str, Any] = {"message_size": item.message_size, "algorithm": item.config.algorithm} + if item.config.nblocks is not None: + payload["nblocks"] = item.config.nblocks + if item.config.nthreads is not None: + payload["nthreads"] = item.config.nthreads + if item.config.symmetric_memory: + payload["symmetric_memory"] = item.config.symmetric_memory + if item.config.time_us is not None: + payload["time_us"] = item.config.time_us + return payload + + +def _format_tuned_config_json(payload: dict[str, Any]) -> str: + text = json.dumps(payload, indent=2) + pattern = re.compile( + r"(?m)^(?P +)\{\n" + r'(?P(?P=indent) "message_size": [^\n]+,?\n(?:(?P=indent) "[^"]+": [^\n]+,?\n)*)' + r"(?P=indent)\}(?P,?)$" + ) + + def compact(match: re.Match[str]) -> str: + body = " ".join(line.strip() for line in match.group("body").splitlines()) + return f"{match.group('indent')}{{{body}}}{match.group('comma')}" + + return pattern.sub(compact, text) + "\n" + + +def _optional_int(value: Any | None) -> int | None: + return None if value is None else int(value) + + +def _optional_float(value: Any | None) -> float | None: + return None if value is None else float(value) + + +def _optional_positive_int(value: Any | None, name: str) -> int | None: + return None if value is None else _parse_positive_int(value, name) + + +def _optional_bool(value: Any | None) -> bool | None: + if value is None: + return None + if isinstance(value, bool): + return value + raise ValueError(f"Expected boolean value, got {value!r}") + + +def _parse_positive_int(value: Any, name: str) -> int: + parsed = int(value) + if parsed <= 0: + raise ValueError(f"{name} must be positive, got {parsed}") + return parsed diff --git a/python/requirements_cuda11.txt b/python/requirements_cuda11.txt index a97860713..1f575f673 100644 --- a/python/requirements_cuda11.txt +++ b/python/requirements_cuda11.txt @@ -1,5 +1,6 @@ mpi4py cupy-cuda11x +cuda-bindings>=11.8,<12 prettytable netifaces pytest diff --git a/python/requirements_cuda12.txt b/python/requirements_cuda12.txt index 715727149..fcc59660a 100644 --- a/python/requirements_cuda12.txt +++ b/python/requirements_cuda12.txt @@ -1,5 +1,6 @@ mpi4py cupy-cuda12x +cuda-bindings>=12,<13 prettytable netifaces pytest diff --git a/python/requirements_cuda13.txt b/python/requirements_cuda13.txt index 95e99533a..19ad93d70 100644 --- a/python/requirements_cuda13.txt +++ b/python/requirements_cuda13.txt @@ -1,5 +1,6 @@ mpi4py cupy-cuda13x +cuda-bindings>=13,<14 prettytable netifaces pytest diff --git a/python/requirements_rocm6.txt b/python/requirements_rocm6.txt index 757d4e262..bcc22dfb7 100644 --- a/python/requirements_rocm6.txt +++ b/python/requirements_rocm6.txt @@ -7,4 +7,5 @@ numpy matplotlib sortedcontainers blake3 -pybind11 \ No newline at end of file +pybind11 +hip-python>=6,<7 \ No newline at end of file diff --git a/python/test/test_fp8_accum.py b/python/test/test_fp8_accum.py index ba33c085b..554e131a9 100644 --- a/python/test/test_fp8_accum.py +++ b/python/test/test_fp8_accum.py @@ -167,7 +167,7 @@ def float_to_e4m3fnuz(f32_array, chunk_size=65536): # --------------------------------------------------------------------------- -# FP8 E4M3B15 helpers (bias=15, encode saturates to ±1.75, no NaN) +# FP8 E4M3B15 helpers (bias=15, float source saturates to ±1.875, no NaN) # Matches Triton's fp8e4b15: all 256 bit patterns are finite. # --------------------------------------------------------------------------- @@ -193,7 +193,7 @@ def float_to_e4m3b15(f32_array, chunk_size=65536): """Encode a cupy float32 array to uint8 E4M3B15 bit patterns. Same lookup-table approach as float_to_e4m3fn. - Saturates to ±1.75 (0x7e/0xfe), matching Triton's fp8e4b15. + Saturates to ±1.875 (0x7f/0xff), matching the device float32 → e4m3b15 path. """ # Build lookup table of all 128 positive E4M3B15 values (0x00..0x7F) all_bytes = cp.arange(128, dtype=cp.uint8) @@ -203,7 +203,7 @@ def float_to_e4m3b15(f32_array, chunk_size=65536): values = f32_array.astype(cp.float32) signs = cp.signbit(values).astype(cp.uint8) absval = cp.abs(values) - absval = cp.clip(absval, cp.float32(0.0), cp.float32(1.75)) + absval = cp.clip(absval, cp.float32(0.0), cp.float32(1.875)) result = cp.zeros(absval.shape, dtype=cp.uint8) n = absval.size @@ -442,8 +442,8 @@ def test_fp8_e4m3b15_accum(mpi_group: MpiGroup, algo_name: str, size: int): bits_r = cp.asarray(rng_r.randint(0, 256, (size,)).astype(np.uint8)) ref_f32 += e4m3b15_to_float(bits_r) - # Clamp reference to e4m3b15 representable range - ref_f32 = cp.clip(ref_f32, -1.75, 1.75) + # Clamp reference to e4m3b15 representable range (float source saturates at ±1.875) + ref_f32 = cp.clip(ref_f32, -1.875, 1.875) # Compute errors abs_err = cp.abs(result_f32 - ref_f32) diff --git a/src/ext/collectives/allgather/allgather_fullmesh_2.cu b/src/ext/collectives/allgather/allgather_fullmesh_2.cu index 72a2be9d9..3500c0c46 100644 --- a/src/ext/collectives/allgather/allgather_fullmesh_2.cu +++ b/src/ext/collectives/allgather/allgather_fullmesh_2.cu @@ -8,7 +8,6 @@ namespace mscclpp { namespace collective { -__device__ DeviceSyncer deviceSyncer; template __global__ void __launch_bounds__(1024, 1) allgatherFullmesh2(void* sendbuff, mscclpp::DeviceHandle* memoryChannels, @@ -21,9 +20,7 @@ __global__ void __launch_bounds__(1024, 1) // Round down to multiple of peer count. const size_t nThread = (blockDim.x * gridDim.x) / WARP_SIZE / nPeer * nPeer * WARP_SIZE; - if (tid >= nThread) { - return; - } + bool isWorker = tid < nThread; const size_t nWarp = nThread / WARP_SIZE; const size_t chanOffset = nPeer * blockIdx.x; auto memChans = memoryChannels + chanOffset; @@ -34,76 +31,80 @@ __global__ void __launch_bounds__(1024, 1) } __syncthreads(); - const size_t bytesPerGPU = nelemsPerGPU * sizeof(int); - const size_t bytes = bytesPerGPU * nPeer; - size_t unitBytesPerThread; - if (bytes >= nThread * 64) { - unitBytesPerThread = 64; - } else { - unitBytesPerThread = 16; - } - const size_t unitBytesPerWarp = unitBytesPerThread * WARP_SIZE; - const size_t unitBytes = unitBytesPerWarp * nWarp; - const size_t nLoop = bytes / unitBytes; - - if (nLoop > 0) { - // First loop unrolling - const size_t peerIdx = wid % nPeer; - const size_t offset = bytesPerGPU * rank + (wid / nPeer) * unitBytesPerWarp; - if constexpr (IsOutOfPlace) { - char* dst = reinterpret_cast(memChans[peerIdx].dst_); - char* src = reinterpret_cast(memChans[peerIdx].src_); - char* buff = reinterpret_cast(sendbuff); - const size_t offsetWithinRank = (wid / nPeer) * unitBytesPerWarp; - mscclpp::copy<16, false>(src + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid, - WARP_SIZE); - mscclpp::copy<16, false>(dst + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid, - WARP_SIZE); + if (isWorker) { + const size_t bytesPerGPU = nelemsPerGPU * sizeof(int); + const size_t bytes = bytesPerGPU * nPeer; + size_t unitBytesPerThread; + if (bytes >= nThread * 64) { + unitBytesPerThread = 64; } else { - memChans[peerIdx].put<16, false>(offset + channelOutOffset, unitBytesPerWarp, lid, WARP_SIZE); + unitBytesPerThread = 16; } - } - - for (size_t i = 1; i < nLoop; ++i) { - const size_t gWid = wid + i * nWarp; - const size_t peerIdx = gWid % nPeer; - const size_t offset = bytesPerGPU * rank + (gWid / nPeer) * unitBytesPerWarp; - if constexpr (IsOutOfPlace) { - char* dst = reinterpret_cast(memChans[peerIdx].dst_); - char* src = reinterpret_cast(memChans[peerIdx].src_); - char* buff = reinterpret_cast(sendbuff); - const size_t offsetWithinRank = (gWid / nPeer) * unitBytesPerWarp; - mscclpp::copy<16, false>(src + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid, - WARP_SIZE); - mscclpp::copy<16, false>(dst + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid, - WARP_SIZE); - } else { - memChans[peerIdx].put<16, false>(offset + channelOutOffset, unitBytesPerWarp, lid, WARP_SIZE); + const size_t unitBytesPerWarp = unitBytesPerThread * WARP_SIZE; + const size_t unitBytes = unitBytesPerWarp * nWarp; + const size_t nLoop = bytes / unitBytes; + + if (nLoop > 0) { + // First loop unrolling + const size_t peerIdx = wid % nPeer; + const size_t offset = bytesPerGPU * rank + (wid / nPeer) * unitBytesPerWarp; + if constexpr (IsOutOfPlace) { + char* dst = reinterpret_cast(memChans[peerIdx].dst_); + char* src = reinterpret_cast(memChans[peerIdx].src_); + char* buff = reinterpret_cast(sendbuff); + const size_t offsetWithinRank = (wid / nPeer) * unitBytesPerWarp; + mscclpp::copy<16, false>(src + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid, + WARP_SIZE); + mscclpp::copy<16, false>(dst + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid, + WARP_SIZE); + } else { + memChans[peerIdx].put<16, false>(offset + channelOutOffset, unitBytesPerWarp, lid, WARP_SIZE); + } } - } - if (bytes % unitBytes > 0) { - const size_t gWid = wid + nLoop * nWarp; - const size_t peerIdx = gWid % nPeer; - const size_t offsetWithinRank = (gWid / nPeer) * unitBytesPerWarp; - const size_t offset = bytesPerGPU * rank + offsetWithinRank; - const size_t remainBytes = (offsetWithinRank + unitBytesPerWarp > bytesPerGPU) - ? ((bytesPerGPU > offsetWithinRank) ? (bytesPerGPU - offsetWithinRank) : 0) - : unitBytesPerWarp; - if (remainBytes > 0) { + for (size_t i = 1; i < nLoop; ++i) { + const size_t gWid = wid + i * nWarp; + const size_t peerIdx = gWid % nPeer; + const size_t offset = bytesPerGPU * rank + (gWid / nPeer) * unitBytesPerWarp; if constexpr (IsOutOfPlace) { char* dst = reinterpret_cast(memChans[peerIdx].dst_); char* src = reinterpret_cast(memChans[peerIdx].src_); char* buff = reinterpret_cast(sendbuff); - mscclpp::copy<16, true>(src + offset + channelOutOffset, buff + offsetWithinRank, remainBytes, lid, WARP_SIZE); - mscclpp::copy<16, true>(dst + offset + channelOutOffset, buff + offsetWithinRank, remainBytes, lid, WARP_SIZE); + const size_t offsetWithinRank = (gWid / nPeer) * unitBytesPerWarp; + mscclpp::copy<16, false>(src + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid, + WARP_SIZE); + mscclpp::copy<16, false>(dst + offset + channelOutOffset, buff + offsetWithinRank, unitBytesPerWarp, lid, + WARP_SIZE); } else { - memChans[peerIdx].put<16, true>(offset + channelOutOffset, remainBytes, lid, WARP_SIZE); + memChans[peerIdx].put<16, false>(offset + channelOutOffset, unitBytesPerWarp, lid, WARP_SIZE); + } + } + + if (bytes % unitBytes > 0) { + const size_t gWid = wid + nLoop * nWarp; + const size_t peerIdx = gWid % nPeer; + const size_t offsetWithinRank = (gWid / nPeer) * unitBytesPerWarp; + const size_t offset = bytesPerGPU * rank + offsetWithinRank; + const size_t remainBytes = (offsetWithinRank + unitBytesPerWarp > bytesPerGPU) + ? ((bytesPerGPU > offsetWithinRank) ? (bytesPerGPU - offsetWithinRank) : 0) + : unitBytesPerWarp; + if (remainBytes > 0) { + if constexpr (IsOutOfPlace) { + char* dst = reinterpret_cast(memChans[peerIdx].dst_); + char* src = reinterpret_cast(memChans[peerIdx].src_); + char* buff = reinterpret_cast(sendbuff); + mscclpp::copy<16, true>(src + offset + channelOutOffset, buff + offsetWithinRank, remainBytes, lid, + WARP_SIZE); + mscclpp::copy<16, true>(dst + offset + channelOutOffset, buff + offsetWithinRank, remainBytes, lid, + WARP_SIZE); + } else { + memChans[peerIdx].put<16, true>(offset + channelOutOffset, remainBytes, lid, WARP_SIZE); + } } } } - deviceSyncer.sync(gridDim.x); + __syncthreads(); if (threadIdx.x < nPeer) { memChans[threadIdx.x].signal(); diff --git a/src/ext/collectives/allreduce/allreduce_packet.cu b/src/ext/collectives/allreduce/allreduce_packet.cu index 3c75a746d..414c2b1fc 100644 --- a/src/ext/collectives/allreduce/allreduce_packet.cu +++ b/src/ext/collectives/allreduce/allreduce_packet.cu @@ -240,6 +240,13 @@ CommResult AllreducePacket::allreduceKernelFunc(const std::shared_ptr ctx_ maxBlockNum_, "."); return CommResult::CommInvalidArgument; } + const int nPeers = ctx->nRanksPerNode - 1; + if (blockAndThreadNum.first < nPeers) { + WARN(ALGO, + "AllreducePacket requires block number to be at least peer count, but got nBlocks=", blockAndThreadNum.first, + " and nPeers=", nPeers, "."); + return CommResult::CommInvalidArgument; + } size_t sendBytes; CUdeviceptr sendBasePtr; diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index a345effcb..0d075ac8d 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -14,5 +14,6 @@ target_sources(unit_tests PRIVATE utils_tests.cc utils_internal_tests.cc compile_tests.cu + gpu_data_types_tests.cu local_channel_tests.cu ) diff --git a/test/unit/gpu_data_types_tests.cu b/test/unit/gpu_data_types_tests.cu new file mode 100644 index 000000000..5f91c684e --- /dev/null +++ b/test/unit/gpu_data_types_tests.cu @@ -0,0 +1,175 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include +#include + +#include "../framework.hpp" + +namespace { + +constexpr int kConversionPaths = 3; + +template +std::array makeArray(Args... args) { + return {static_cast(args)...}; +} + +__device__ uint32_t floatToBitsDevice(float value) { + union { + float f; + uint32_t u; + } cvt = {value}; + return cvt.u; +} + +uint32_t floatToBitsHost(float value) { + uint32_t bits; + std::memcpy(&bits, &value, sizeof(bits)); + return bits; +} + +__global__ void kernelE4m3b15TypeConvert(const float* input, int encodeCases, const uint8_t* raw, int decodeCases, + uint8_t* encoded, uint32_t* decodedBits) { + if (threadIdx.x != 0 || blockIdx.x != 0) return; + + for (int offset = 0; offset < encodeCases; offset += 4) { + mscclpp::f32x4 inputX4; + for (int i = 0; i < 4; ++i) { + inputX4.data[i] = input[offset + i]; + } + + mscclpp::f8_e4m3b15x4 encodedX4 = mscclpp::to(inputX4); + for (int i = 0; i < 4; ++i) { + encoded[offset + i] = encodedX4.data[i].__x; + } + + for (int pair = 0; pair < 2; ++pair) { + mscclpp::f32x2 inputX2; + inputX2.data[0] = input[offset + pair * 2]; + inputX2.data[1] = input[offset + pair * 2 + 1]; + mscclpp::f8_e4m3b15x2 encodedX2 = mscclpp::to(inputX2); + encoded[encodeCases + offset + pair * 2] = encodedX2.data[0].__x; + encoded[encodeCases + offset + pair * 2 + 1] = encodedX2.data[1].__x; + } + } + + for (int i = 0; i < encodeCases; ++i) { + encoded[2 * encodeCases + i] = __fp8_e4m3b15(input[i]).__x; + } + + for (int offset = 0; offset < decodeCases; offset += 4) { + mscclpp::f8_e4m3b15x4 rawX4; + for (int i = 0; i < 4; ++i) { + rawX4.data[i] = __fp8_e4m3b15::fromRaw(raw[offset + i]); + } + + mscclpp::f32x4 decodedX4 = mscclpp::to(rawX4); + for (int i = 0; i < 4; ++i) { + decodedBits[offset + i] = floatToBitsDevice(decodedX4.data[i]); + } + + for (int pair = 0; pair < 2; ++pair) { + mscclpp::f8_e4m3b15x2 rawX2; + rawX2.data[0] = __fp8_e4m3b15::fromRaw(raw[offset + pair * 2]); + rawX2.data[1] = __fp8_e4m3b15::fromRaw(raw[offset + pair * 2 + 1]); + mscclpp::f32x2 decodedX2 = mscclpp::to(rawX2); + decodedBits[decodeCases + offset + pair * 2] = floatToBitsDevice(decodedX2.data[0]); + decodedBits[decodeCases + offset + pair * 2 + 1] = floatToBitsDevice(decodedX2.data[1]); + } + } + + for (int i = 0; i < decodeCases; ++i) { + decodedBits[2 * decodeCases + i] = floatToBitsDevice(float(__fp8_e4m3b15::fromRaw(raw[i]))); + } +} + +} // namespace + +TEST(GpuDataTypesTest, E4m3b15TypeConvert) { + const float inf = std::numeric_limits::infinity(); + const float nan = std::numeric_limits::quiet_NaN(); + const float maxFloat = std::numeric_limits::max(); + + // Each input value maps to the byte at the same index in expectedEncoded. The fp8_e4m3b15 format has no + // NaN/Inf encoding, so NaN, Inf, and overflow inputs saturate to +/-1.875 (max byte 0x7f/0xff). + const auto input = makeArray(0.0f, -0.0f, // +/-0 + 0x1.0p-19f, -0x1.0p-19f, // +/-2^-19: underflows to signed 0 + 0x1.0p-18f, -0x1.0p-18f, // +/-2^-18: rounds to min subnormal + 0x1.0p-17f, -0x1.0p-17f, // +/-2^-17: min subnormal + 0x1.0p-14f, -0x1.0p-14f, // +/-2^-14: min normal + 0x1.0fcp-2f, -0x1.0fcp-2f, // Boundary rounds down in magnitude + 0x1.0fep-2f, -0x1.0fep-2f, // Boundary rounds up in magnitude + 0x1.cfep-2f, -0x1.cfep-2f, // Boundary rounds to +/-0.46875 + 0x1.cp0f, -0x1.cp0f, // +/-1.75: max finite + 2.0f, -2.0f, // Overflow saturation + inf, -inf, // +/-Inf saturation + nan, -maxFloat); // NaN / large negative saturation + + const auto expectedEncoded = makeArray(0x00, 0x80, // +/-0 + 0x00, 0x80, // Underflow to signed zero + 0x01, 0x81, // Round to min signed subnormal + 0x01, 0x81, // Min signed subnormal + 0x08, 0x88, // Min signed normal + 0x68, 0xe8, // Boundary rounds to +/-0.25 + 0x69, 0xe9, // Boundary rounds to +/-0.28125 + 0x6f, 0xef, // Boundary rounds to +/-0.46875 + 0x7e, 0xfe, // Max finite at fp16 grid (1.75) + 0x7f, 0xff, // Overflow saturation (1.875) + 0x7f, 0xff, // Inf saturation (1.875) + 0x7f, 0xff); // NaN / large negative saturation (1.875) + + // Raw bytes to decode, with expectedDecoded giving the exact float value at the same index. + const auto raw = makeArray(0x00, 0x80, // +/-0 + 0x01, 0x81, // +/-2^-17: min subnormal + 0x08, 0x88, // +/-2^-14: min normal + 0x68, 0xe8, // +/-0.25 + 0x69, 0xe9, // +/-0.28125 + 0x7e, 0xfe); // +/-1.75: max finite + const auto expectedDecoded = makeArray(0.0f, -0.0f, // +/-0 + 0x1.0p-17f, -0x1.0p-17f, // +/-2^-17: min subnormal + 0x1.0p-14f, -0x1.0p-14f, // +/-2^-14: min normal + 0x1.0p-2f, -0x1.0p-2f, // +/-0.25 + 0x1.2p-2f, -0x1.2p-2f, // +/-0.28125 + 0x1.cp0f, -0x1.cp0f); // +/-1.75: max finite + + ASSERT_EQ(input.size(), expectedEncoded.size()); + ASSERT_EQ(raw.size(), expectedDecoded.size()); + ASSERT_EQ(input.size() % 4, size_t(0)); + ASSERT_EQ(raw.size() % 4, size_t(0)); + + auto inputDev = mscclpp::detail::gpuCallocShared(input.size()); + auto rawDev = mscclpp::detail::gpuCallocShared(raw.size()); + auto encodedDev = mscclpp::detail::gpuCallocShared(input.size() * kConversionPaths); + auto decodedBitsDev = mscclpp::detail::gpuCallocShared(raw.size() * kConversionPaths); + + mscclpp::gpuMemcpy(inputDev.get(), input.data(), input.size(), cudaMemcpyHostToDevice); + mscclpp::gpuMemcpy(rawDev.get(), raw.data(), raw.size(), cudaMemcpyHostToDevice); + + kernelE4m3b15TypeConvert<<<1, 1>>>(inputDev.get(), static_cast(input.size()), rawDev.get(), + static_cast(raw.size()), encodedDev.get(), decodedBitsDev.get()); + MSCCLPP_CUDATHROW(cudaGetLastError()); + MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); + + std::vector encoded(input.size() * kConversionPaths); + std::vector decodedBits(raw.size() * kConversionPaths); + mscclpp::gpuMemcpy(encoded.data(), encodedDev.get(), encoded.size(), cudaMemcpyDeviceToHost); + mscclpp::gpuMemcpy(decodedBits.data(), decodedBitsDev.get(), decodedBits.size(), cudaMemcpyDeviceToHost); + + for (int path = 0; path < kConversionPaths; ++path) { + for (size_t i = 0; i < input.size(); ++i) { + EXPECT_EQ(static_cast(encoded[path * input.size() + i]), static_cast(expectedEncoded[i])); + } + } + + for (int path = 0; path < kConversionPaths; ++path) { + for (size_t i = 0; i < raw.size(); ++i) { + EXPECT_EQ(decodedBits[path * raw.size() + i], floatToBitsHost(expectedDecoded[i])); + } + } +}