From 43f9c6c6a37c067db433795c1962dda5e5be3e77 Mon Sep 17 00:00:00 2001 From: Yadir Batista Date: Wed, 15 Apr 2026 16:32:02 -0400 Subject: [PATCH] feat: CUDA 10.2 / C++14 compatibility for Jetson TX2 (compute 6.2) Minimal-diff approach to support GCC 9 + nvcc 10.2 with --expt-relaxed-constexpr: - Add compat-cuda10.cuh: bf16->fp16 polyfills (no hw bf16 on compute 6.2) - Guard cuda_bf16.h include with CUDART_VERSION >= 11000 - CMake: C++14 std, arch 62, --expt-relaxed-constexpr for CUDA < 11.0 - Replace std::is_same_v with std::is_same<>::value (C++14) - Convert fold expressions to C++14 equivalents - Convert structured bindings to explicit .first/.second - Guard cooperative_groups (cg::this_grid) behind CUDART_VERSION >= 11000 - Fix cudaStreamWaitEvent 2-arg calls (CUDA 10.2 requires 3rd flags param) - Replace __builtin_assume with GGML_CUDA_ASSUME macro - Fix static inline const/auto to constexpr with explicit types - Fix if-init statements to C++14 style 14 files changed (13 modified + 1 new), 160 insertions, 98 deletions Tested: build OK, CPU inference 0.7 t/s, GPU inference 7.3 t/s (gemma-4-E2B Q4_0) --- ggml/src/ggml-cuda/CMakeLists.txt | 15 +++++- ggml/src/ggml-cuda/binbcast.cu | 6 ++- ggml/src/ggml-cuda/common.cuh | 43 +++++++++++---- ggml/src/ggml-cuda/compat-cuda10.cuh | 57 ++++++++++++++++++++ ggml/src/ggml-cuda/convert.cuh | 14 ++--- ggml/src/ggml-cuda/fattn-common.cuh | 36 ++++++------- ggml/src/ggml-cuda/fattn-vec.cuh | 2 +- ggml/src/ggml-cuda/ggml-cuda.cu | 79 ++++++++++++++++------------ ggml/src/ggml-cuda/mma.cuh | 2 +- ggml/src/ggml-cuda/mmf.cuh | 22 ++++---- ggml/src/ggml-cuda/mmvf.cu | 12 ++--- ggml/src/ggml-cuda/rope.cu | 4 +- ggml/src/ggml-cuda/softmax.cu | 19 +++++-- ggml/src/ggml-cuda/vendors/cuda.h | 4 ++ 14 files changed, 217 insertions(+), 98 deletions(-) create mode 100644 ggml/src/ggml-cuda/compat-cuda10.cuh diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index b54d4a6b107..47cb5eb85e0 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -23,7 +23,10 @@ if (CUDAToolkit_FOUND) # # The default behavior for a non-native is to build virtual architectures as needed to cover all features needed # for best performance and to also build real architectures for the most commonly used GPUs. - if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6" AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.24") + if (CUDAToolkit_VERSION VERSION_LESS "11.0") + # CUDA 10.2: only compute 6.2 (Jetson TX2 / Pascal) + set(CMAKE_CUDA_ARCHITECTURES "62") + elseif (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6" AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.24") set(CMAKE_CUDA_ARCHITECTURES "native") else() if (CUDAToolkit_VERSION VERSION_LESS "13") @@ -57,6 +60,12 @@ if (CUDAToolkit_FOUND) enable_language(CUDA) + # CUDA 10.2 compat: force C++14 standard and enable if constexpr support + if (CUDAToolkit_VERSION VERSION_LESS "11.0") + set(CMAKE_CUDA_STANDARD 14) + set(CMAKE_CUDA_STANDARD_REQUIRED ON) + endif() + # TODO: Remove once CCCL 3.2 has been released and bundled with CUDA Toolkit if (GGML_CUDA_CUB_3DOT2) include(FetchContent) @@ -195,6 +204,10 @@ if (CUDAToolkit_FOUND) set(CUDA_FLAGS -use_fast_math -extended-lambda) + if (CUDAToolkit_VERSION VERSION_LESS "11.0") + list(APPEND CUDA_FLAGS --expt-relaxed-constexpr) + endif() + if (GGML_CUDA_DEBUG) list(APPEND CUDA_FLAGS -lineinfo) add_compile_definitions(GGML_CUDA_DEBUG) diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index adb4d5f0cb9..993631b5193 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -77,7 +77,8 @@ static __global__ void k_bin_bcast(const src0_t * src0, float result = src0_row ? (float) src0_row[i0*s00] : 0.0f; if constexpr (sizeof...(src1_ptrs) > 0) { - result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10]))); + int _dummy[] = { (result = bin_op(result, (float)src1s[i_src1 + i10*s10]), 0)... }; + (void)_dummy; } else { result = bin_op(result, (float)src1[i_src1 + i10*s10]); } @@ -143,7 +144,8 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, float result = src0_row ? (float) src0_row[i0*s00] : 0.0f; if constexpr (sizeof...(src1_ptrs) > 0) { - result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10]))); + int _dummy[] = { (result = bin_op(result, (float)src1s[i_src1 + i10*s10]), 0)... }; + (void)_dummy; } else { result = bin_op(result, (float)src1[i_src1 + i10*s10]); } diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index ad30ecd8fa5..b8286a3edbd 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -546,7 +546,20 @@ template struct block_reduce_policy; template -inline constexpr bool is_any = (std::is_same_v || ...); +struct is_any_impl; + +template +struct is_any_impl { + static constexpr bool value = false; +}; + +template +struct is_any_impl { + static constexpr bool value = std::is_same::value || is_any_impl::value; +}; + +template +inline constexpr bool is_any = is_any_impl::value; template inline constexpr bool ggml_cuda_dependent_false_v = false; @@ -561,13 +574,13 @@ template struct block_reduce_policy { } static __device__ T sentinel() { - if constexpr (std::is_same_v) { + if constexpr (std::is_same::value) { return 0.0f; - } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same::value) { return make_float2(0.0f, 0.0f); - } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same::value) { return make_half2(0.0f, 0.0f); - } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same::value) { return 0; } else { static_assert(ggml_cuda_dependent_false_v, "Unsupported type for block reduce sum"); @@ -585,9 +598,9 @@ template struct block_reduce_policy { } static __device__ T sentinel() { - if constexpr (std::is_same_v) { + if constexpr (std::is_same::value) { return -INFINITY; - } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same::value) { return make_half2(-INFINITY, -INFINITY); } else { static_assert(ggml_cuda_dependent_false_v, "Unsupported type for block reduce max"); @@ -1252,7 +1265,9 @@ struct ggml_cuda_concurrent_event { const int64_t join_start = (int64_t) join_t->data; const int64_t join_end = join_start + ggml_nbytes(join_t); - for (const auto & [tensor, stream] : stream_mapping) { + for (const auto & _kv : stream_mapping) { + const auto & tensor = _kv.first; + const auto & stream = _kv.second; const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor; const int64_t t_start = (int64_t) t->data; const int64_t t_end = t_start + ggml_nbytes(t); @@ -1273,7 +1288,9 @@ struct ggml_cuda_concurrent_event { bool writes_overlap = false; bool dependent_srcs = false; - for (const auto & [tensor, stream] : stream_mapping) { + for (const auto & _kv : stream_mapping) { + const auto & tensor = _kv.first; + const auto & stream = _kv.second; const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor; const int64_t t_start = (int64_t) t->data; const int64_t t_end = t_start + ggml_nbytes(t); @@ -1379,7 +1396,9 @@ struct ggml_backend_cuda_context { // Check if any CUDA graph is enabled for this context (used by kernels that need to know // if graphs are in use without having access to the specific graph key) bool any_cuda_graph_enabled() const { - for (const auto & [key, graph] : cuda_graphs) { + for (const auto & _kv : cuda_graphs) { + const auto & key = _kv.first; + const auto & graph = _kv.second; if (graph && graph->is_enabled()) { return true; } @@ -1389,7 +1408,9 @@ struct ggml_backend_cuda_context { // Check if any CUDA graph has an instance for this context bool any_cuda_graph_has_instance() const { - for (const auto & [key, graph] : cuda_graphs) { + for (const auto & _kv : cuda_graphs) { + const auto & key = _kv.first; + const auto & graph = _kv.second; if (graph && graph->instance != nullptr) { return true; } diff --git a/ggml/src/ggml-cuda/compat-cuda10.cuh b/ggml/src/ggml-cuda/compat-cuda10.cuh new file mode 100644 index 00000000000..db9ec6b1542 --- /dev/null +++ b/ggml/src/ggml-cuda/compat-cuda10.cuh @@ -0,0 +1,57 @@ +#pragma once + +// Compatibility polyfills for CUDA 10.2 (Jetson TX2) +// bf16 types are mapped to fp16 since compute 6.2 has no hardware bf16 support. + +#include + +// bf16 type polyfills (mapped to fp16 for compute 6.2) +typedef __half nv_bfloat16; + +struct nv_bfloat162 { + nv_bfloat16 x; + nv_bfloat16 y; +}; + +static __host__ __device__ __forceinline__ nv_bfloat16 __float2bfloat16(float f) { + return __float2half(f); +} + +static __host__ __device__ __forceinline__ float __bfloat162float(nv_bfloat16 h) { + return __half2float(h); +} + +static __host__ __device__ __forceinline__ nv_bfloat162 make_bfloat162(nv_bfloat16 a, nv_bfloat16 b) { + nv_bfloat162 r; + r.x = a; + r.y = b; + return r; +} + +static __host__ __device__ __forceinline__ nv_bfloat162 __float22bfloat162_rn(float2 f) { + return make_bfloat162(__float2bfloat16(f.x), __float2bfloat16(f.y)); +} + +static __host__ __device__ __forceinline__ float2 __bfloat1622float2(nv_bfloat162 h) { + return make_float2(__bfloat162float(h.x), __bfloat162float(h.y)); +} + +static __host__ __device__ __forceinline__ nv_bfloat16 __low2bfloat16(nv_bfloat162 h) { + return h.x; +} + +static __host__ __device__ __forceinline__ nv_bfloat16 __high2bfloat16(nv_bfloat162 h) { + return h.y; +} + +static __host__ __device__ __forceinline__ nv_bfloat162 __halves2bfloat162(nv_bfloat16 a, nv_bfloat16 b) { + nv_bfloat162 r; + r.x = a; + r.y = b; + return r; +} + +// CUDA_R_16BF cublas data type (not defined in CUDA 10.2) +#ifndef CUDA_R_16BF +#define CUDA_R_16BF CUDA_R_16F +#endif diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index f5d37c7b998..b5d646c2dca 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -33,15 +33,15 @@ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type); template __host__ __device__ inline dst_t ggml_cuda_cast(src_t x) { - if constexpr (std::is_same_v) { + if constexpr (std::is_same::value) { return x; - } else if constexpr(std::is_same_v) { + } else if constexpr(std::is_same::value) { return __float2bfloat16(float(x)); - } else if constexpr(std::is_same_v) { + } else if constexpr(std::is_same::value) { return __bfloat162float(x); - } else if constexpr(std::is_same_v && std::is_same_v) { + } else if constexpr(std::is_same::value && std::is_same::value) { return __float22half2_rn(x); - } else if constexpr(std::is_same_v && std::is_same_v) { + } else if constexpr(std::is_same::value && std::is_same::value) { #ifdef GGML_USE_HIP return make_float2(__bfloat162float(__low2bfloat16(x)), __bfloat162float(__high2bfloat16(x))); #else @@ -51,14 +51,14 @@ template return make_float2(__bfloat162float(x.x), __bfloat162float(x.y)); #endif // __CUDA_ARCH__ >= 800 #endif // GGML_USE_HIP - } else if constexpr(std::is_same_v && std::is_same_v) { + } else if constexpr(std::is_same::value && std::is_same::value) { // bypass compile error on cuda 12.0.1 #ifdef GGML_USE_HIP return __float22bfloat162_rn(x); #else return {x.x, x.y}; #endif // GGML_USE_HIP - } else if constexpr(std::is_same_v) { + } else if constexpr(std::is_same::value) { return int32_t(x); } else { return float(x); diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index beeb5238946..55a67a8b5eb 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -336,9 +336,9 @@ typedef void (*dequantize_V_t)(const void *, void *, const int64_t); template static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { - if constexpr (std::is_same_v) { + if constexpr (std::is_same::value) { ggml_cuda_memcpy_1(dst, (const half *) vx + i0); - } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same::value) { static_assert(ne % 2 == 0, "bad ne"); __align__(16) half2 tmp[ne/2]; ggml_cuda_memcpy_1(tmp, (const half *) vx + i0); @@ -348,13 +348,13 @@ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict_ dst_f2[l] = __half22float2(tmp[l]); } } else { - static_assert(std::is_same_v, "unsupported type"); + static_assert(std::is_same::value, "unsupported type"); } } template static __device__ __forceinline__ void dequantize_V_bf16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { - static_assert(std::is_same_v, "BF16 V dequantization only supports float output"); + static_assert(std::is_same::value, "BF16 V dequantization only supports float output"); static_assert(ne % 2 == 0, "bad ne"); __align__(16) nv_bfloat162 tmp[ne/2]; ggml_cuda_memcpy_1(tmp, (const nv_bfloat16 *) vx + i0); @@ -383,7 +383,7 @@ static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict const int8_t * q8 = (const int8_t *) &q; #ifdef FP16_AVAILABLE - if constexpr (std::is_same_v) { + if constexpr (std::is_same::value) { const half2 d = __half2half2(x[ib].d); #pragma unroll @@ -392,7 +392,7 @@ static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict } } else #endif // FP16_AVAILABLE - if constexpr (std::is_same_v) { + if constexpr (std::is_same::value) { const float d = x[ib].d; #pragma unroll @@ -400,7 +400,7 @@ static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict ((float *) dst)[l] = d * q8[l]; } } else { - static_assert(std::is_same_v, "bad type"); + static_assert(std::is_same::value, "bad type"); } } @@ -421,7 +421,7 @@ static __device__ __forceinline__ void dequantize_V_q4_1(const void * __restrict const int8_t * q8 = (const int8_t *) &q; #ifdef FP16_AVAILABLE - if constexpr (std::is_same_v) { + if constexpr (std::is_same::value) { const half2 dm = x[ib].dm; const half2 d = __half2half2( __low2half(dm)); const half2 m = __half2half2(__high2half(dm)); @@ -432,7 +432,7 @@ static __device__ __forceinline__ void dequantize_V_q4_1(const void * __restrict } } else #endif // FP16_AVAILABLE - if constexpr (std::is_same_v) { + if constexpr (std::is_same::value) { const float2 dm = __half22float2(x[ib].dm); #pragma unroll @@ -440,7 +440,7 @@ static __device__ __forceinline__ void dequantize_V_q4_1(const void * __restrict ((float *) dst)[l] = dm.x * q8[l] + dm.y; } } else { - static_assert(std::is_same_v, "bad type"); + static_assert(std::is_same::value, "bad type"); } } @@ -473,7 +473,7 @@ static __device__ __forceinline__ void dequantize_V_q5_0(const void * __restrict const int8_t * q8 = (const int8_t *) &q; #ifdef FP16_AVAILABLE - if constexpr (std::is_same_v) { + if constexpr (std::is_same::value) { const half2 d = __half2half2(x[ib].d); #pragma unroll @@ -482,7 +482,7 @@ static __device__ __forceinline__ void dequantize_V_q5_0(const void * __restrict } } else #endif // FP16_AVAILABLE - if constexpr (std::is_same_v) { + if constexpr (std::is_same::value) { const float d = x[ib].d; #pragma unroll @@ -490,7 +490,7 @@ static __device__ __forceinline__ void dequantize_V_q5_0(const void * __restrict ((float *) dst)[l] = d * q8[l]; } } else { - static_assert(std::is_same_v, "bad type"); + static_assert(std::is_same::value, "bad type"); } } @@ -521,7 +521,7 @@ static __device__ __forceinline__ void dequantize_V_q5_1(const void * __restrict const int8_t * q8 = (const int8_t *) &q; #ifdef FP16_AVAILABLE - if constexpr (std::is_same_v) { + if constexpr (std::is_same::value) { const half2 dm = x[ib].dm; const half2 d = __half2half2( __low2half(dm)); const half2 m = __half2half2(__high2half(dm)); @@ -532,7 +532,7 @@ static __device__ __forceinline__ void dequantize_V_q5_1(const void * __restrict } } else #endif // FP16_AVAILABLE - if constexpr (std::is_same_v) { + if constexpr (std::is_same::value) { const float2 dm = __half22float2(x[ib].dm); #pragma unroll @@ -540,7 +540,7 @@ static __device__ __forceinline__ void dequantize_V_q5_1(const void * __restrict ((float *) dst)[l] = dm.x * q8[l] + dm.y; } } else { - static_assert(std::is_same_v, "bad type"); + static_assert(std::is_same::value, "bad type"); } } @@ -573,7 +573,7 @@ static __device__ __forceinline__ void dequantize_V_q8_0(const void * __restrict ((float *) dst)[l] = d * qs[l]; } } else { - static_assert(std::is_same_v, "unsupported type"); + static_assert(std::is_same::value, "unsupported type"); } } @@ -887,7 +887,7 @@ static __global__ void flash_attn_combine_results( dst += j_dst_unrolled * D; const int tid = threadIdx.x; - __builtin_assume(tid < D); + GGML_CUDA_ASSUME(tid < D); extern __shared__ float2 meta[]; for (int i = tid; i < 2*parallel_blocks; i += D) { diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index f0bd42a5761..24705f6fdc6 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -108,7 +108,7 @@ static __global__ void flash_attn_ext_vec( static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); constexpr int nwarps = nthreads / WARP_SIZE; const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; - __builtin_assume(tid < nthreads); + GGML_CUDA_ASSUME(tid < nthreads); constexpr int ne_KQ = ncols*D; constexpr int ne_combine = nwarps*V_cols_per_iter*D; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 790f53cead7..279e5b96e7c 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -1967,40 +1967,40 @@ struct batched_mul_mat_traits; template<> struct batched_mul_mat_traits { using cuda_type = float; - static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; - static inline const cudaDataType_t data_type = CUDA_R_32F; - static inline const ggml_type ggml_type_val = GGML_TYPE_F32; - static inline const float alpha = 1.0f; - static inline const float beta = 0.0f; - static inline const void* get_alpha() { static const float val = alpha; return &val; } - static inline const void* get_beta() { static const float val = beta; return &val; } - static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp32_nc_cuda(src_type); } + static constexpr cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; + static constexpr cudaDataType_t data_type = CUDA_R_32F; + static constexpr ggml_type ggml_type_val = GGML_TYPE_F32; + static constexpr float alpha = 1.0f; + static constexpr float beta = 0.0f; + static const void* get_alpha() { static const float val = alpha; return &val; } + static const void* get_beta() { static const float val = beta; return &val; } + static to_fp32_nc_cuda_t get_nc_converter(ggml_type src_type) { return ggml_get_to_fp32_nc_cuda(src_type); } }; template<> struct batched_mul_mat_traits { using cuda_type = nv_bfloat16; - static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; - static inline const cudaDataType_t data_type = CUDA_R_16BF; - static inline const ggml_type ggml_type_val = GGML_TYPE_BF16; - static inline const float alpha = 1.0f; - static inline const float beta = 0.0f; - static inline const void* get_alpha() { static const float val = alpha; return &val; } - static inline const void* get_beta() { static const float val = beta; return &val; } - static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_bf16_nc_cuda(src_type); } + static constexpr cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; + static constexpr cudaDataType_t data_type = CUDA_R_16BF; + static constexpr ggml_type ggml_type_val = GGML_TYPE_BF16; + static constexpr float alpha = 1.0f; + static constexpr float beta = 0.0f; + static const void* get_alpha() { static const float val = alpha; return &val; } + static const void* get_beta() { static const float val = beta; return &val; } + static to_bf16_nc_cuda_t get_nc_converter(ggml_type src_type) { return ggml_get_to_bf16_nc_cuda(src_type); } }; template<> struct batched_mul_mat_traits { using cuda_type = half; - static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; - static inline const cudaDataType_t data_type = CUDA_R_16F; - static inline const ggml_type ggml_type_val = GGML_TYPE_F16; - static inline const half alpha = 1.0; - static inline const half beta = 0.0; - static inline const void* get_alpha() { static const half val = alpha; return &val; } - static inline const void* get_beta() { static const half val = beta; return &val; } - static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp16_nc_cuda(src_type); } + static constexpr cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F; + static constexpr cudaDataType_t data_type = CUDA_R_16F; + static constexpr ggml_type ggml_type_val = GGML_TYPE_F16; + static constexpr float alpha_f = 1.0f; // half not constexpr in C++14 + static constexpr float beta_f = 0.0f; // half not constexpr in C++14 + static const void* get_alpha() { static const half val = alpha_f; return &val; } + static const void* get_beta() { static const half val = beta_f; return &val; } + static to_fp16_nc_cuda_t get_nc_converter(ggml_type src_type) { return ggml_get_to_fp16_nc_cuda(src_type); } }; template @@ -2264,8 +2264,11 @@ static bool ggml_cuda_should_fuse_mul_mat(const ggml_tensor * ffn_up, return false; } - if (const bool swapped = ggml_get_op_params_i32(glu, 1); swapped) { - return false; + { + const bool swapped = ggml_get_op_params_i32(glu, 1); + if (swapped) { + return false; + } } const bool split = ggml_backend_buft_is_cuda_split(ffn_up->src[0]->buffer->buft) || @@ -3596,13 +3599,13 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, node->name); - cudaStream_t main_stream = cuda_ctx->stream(); // this should be stream 0 + cudaStream_t main_stream = cuda_ctx->stream(cuda_ctx->device, cuda_ctx->curr_stream_no); // this should be stream 0 GGML_ASSERT(cuda_ctx->curr_stream_no == 0); CUDA_CHECK(cudaEventRecord(concurrent_event->fork_event, main_stream)); for (int i = 1; i <= concurrent_event->n_streams; ++i) { cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i); - CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event)); + CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event, 0)); } } }; @@ -3615,7 +3618,9 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud if (stream_ctx.concurrent_events.size() > 0) { should_launch_concurrent_events = true; - for (const auto & [tensor, event] : stream_ctx.concurrent_events) { + for (const auto & _kv : stream_ctx.concurrent_events) { + const auto & tensor = _kv.first; + const auto & event = _kv.second; should_launch_concurrent_events = should_launch_concurrent_events && event.is_valid(); } } @@ -3629,7 +3634,9 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud node_to_idx[cgraph->nodes[i]] = i; } - for (auto & [fork_node, event] : stream_ctx.concurrent_events) { + for (auto & _kv : stream_ctx.concurrent_events) { + auto & fork_node = _kv.first; + auto & event = _kv.second; // Find positions of all nodes from this event in the current graph std::vector positions; positions.reserve(event.original_order.size()); @@ -3686,7 +3693,7 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud // Wait on join events of forked streams in the main stream CUDA_CHECK(cudaEventRecord(concurrent_event->join_events[i - 1], cuda_ctx->stream(cuda_ctx->device, i))); - CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), concurrent_event->join_events[i - 1])); + CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(cuda_ctx->device, cuda_ctx->curr_stream_no), concurrent_event->join_events[i - 1], 0)); } is_concurrent_event_active = false; @@ -4215,7 +4222,7 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; if (ggml_backend_is_cuda(backend)) { - CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), (cudaEvent_t)event->context, 0)); + CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(cuda_ctx->device, cuda_ctx->curr_stream_no), (cudaEvent_t)event->context, 0)); } else { #if 0 // untested @@ -4314,7 +4321,9 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph // store {fork_idx, join_idx} std::vector> concurrent_node_ranges; - for (const auto & [root_node, count] : fan_out) { + for (const auto & _kv : fan_out) { + const auto & root_node = _kv.first; + const auto & count = _kv.second; if (count >= min_fan_out && count <= max_fan_out) { const int root_node_idx = node_indices[root_node]; @@ -4325,7 +4334,9 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph } bool is_part_of_event = false; - for (const auto & [start, end] : concurrent_node_ranges) { + for (const auto & _kv : concurrent_node_ranges) { + const auto & start = _kv.first; + const auto & end = _kv.second; if (root_node_idx >= start && root_node_idx <= end) { is_part_of_event = true; } diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index c91dd2d9ad6..7915dd5a5f4 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -207,7 +207,7 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 16 && J == 16) { #if defined(RDNA3) - if constexpr (std::is_same_v || std::is_same_v) { + if constexpr (std::is_same::value || std::is_same::value) { // matrix C return 2 * l + (threadIdx.x / 16); } else { diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index c2a8d54c95a..5c6f005a37b 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -56,7 +56,7 @@ static __global__ void mul_mat_f( // TODO: handle this in a consistent and simpler way after AMD MFMA support has been added #if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE) - if constexpr (!(std::is_same_v || std::is_same_v) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { + if constexpr (!(std::is_same::value || std::is_same::value) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { typedef tile<16, 8, T, get_input_data_layout()> tile_A; typedef tile<16, 8, T, get_input_data_layout()> tile_B; typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; @@ -67,7 +67,7 @@ static __global__ void mul_mat_f( typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; #else #ifdef VOLTA_MMA_AVAILABLE - if constexpr (!std::is_same_v || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { + if constexpr (!std::is_same::value || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A; typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B; typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C; @@ -115,7 +115,7 @@ static __global__ void mul_mat_f( dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst); if constexpr (has_ids) { - constexpr int y_stride_scale = std::is_same_v ? 1 : 2; + constexpr int y_stride_scale = std::is_same::value ? 1 : 2; const int64_t col_offset = col_base; y += col_offset * stride_col_y * y_stride_scale; dst += col_offset * stride_col_dst; @@ -183,7 +183,7 @@ static __global__ void mul_mat_f( #pragma unroll for (int itB = 0; itB < ntB; ++itB) { - if constexpr (std::is_same_v) { + if constexpr (std::is_same::value) { #pragma unroll for (int j0 = 0; j0 < tile_B::I; ++j0) { const int j = j0 + itB*tile_B::I; @@ -195,7 +195,7 @@ static __global__ void mul_mat_f( tile_xy[j0*tile_k_padded + threadIdx.x] = valid ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0.0f; } } - } else if constexpr (std::is_same_v || std::is_same_v) { + } else if constexpr (std::is_same::value || std::is_same::value) { #pragma unroll for (int j0 = 0; j0 < tile_B::I; ++j0) { const int j = j0 + itB*tile_B::I; @@ -210,7 +210,7 @@ static __global__ void mul_mat_f( } } } else { - static_assert(std::is_same_v, "unsupported type"); + static_assert(std::is_same::value, "unsupported type"); } #pragma unroll for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { @@ -307,7 +307,7 @@ static __global__ void mul_mat_f_ids( // TODO: handle this in a consistent and simpler way after AMD MFMA support has been added #if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE) - if constexpr (!(std::is_same_v || std::is_same_v) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { + if constexpr (!(std::is_same::value || std::is_same::value) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { typedef tile<16, 8, T, get_input_data_layout()> tile_A; typedef tile<16, 8, T, get_input_data_layout()> tile_B; typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; @@ -318,7 +318,7 @@ static __global__ void mul_mat_f_ids( typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; #else #ifdef VOLTA_MMA_AVAILABLE - if constexpr (!std::is_same_v || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { + if constexpr (!std::is_same::value || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A; typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B; typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C; @@ -392,7 +392,7 @@ static __global__ void mul_mat_f_ids( } } - if constexpr (std::is_same_v) { + if constexpr (std::is_same::value) { float vals_buf[2][tile_B::I]; auto gather_tile = [&](int tile_idx_local, float *vals) { #pragma unroll @@ -443,7 +443,7 @@ static __global__ void mul_mat_f_ids( next_buf ^= 1; } } - } else if constexpr (std::is_same_v || std::is_same_v) { + } else if constexpr (std::is_same::value || std::is_same::value) { float2 vals_buf[2][tile_B::I]; auto gather_tile = [&](int tile_idx_local, float2 *vals) { #pragma unroll @@ -498,7 +498,7 @@ static __global__ void mul_mat_f_ids( } } } else { - static_assert(std::is_same_v, "unsupported type"); + static_assert(std::is_same::value, "unsupported type"); } } diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index d9147202429..36a2142e687 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -120,7 +120,7 @@ static __global__ void mul_mat_vec_f( } } - if constexpr (std::is_same_v) { + if constexpr (std::is_same::value) { const float2 * x2 = (const float2 *) x; const float2 * gate_x2 = nullptr; if constexpr (has_fusion) { @@ -152,7 +152,7 @@ static __global__ void mul_mat_vec_f( } } } - } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same::value) { const half2 * x2 = (const half2 *) x; const half2 * gate_x2 = nullptr; if constexpr (has_fusion) { @@ -161,7 +161,7 @@ static __global__ void mul_mat_vec_f( } } - if (std::is_same_v) { + if (std::is_same::value) { for (int col2 = tid; col2 < ncols2; col2 += block_size) { const float2 tmpx = __half22float2(x2[col2]); float2 tmpx_gate = make_float2(0.0f, 0.0f); @@ -227,7 +227,7 @@ static __global__ void mul_mat_vec_f( NO_DEVICE_CODE; #endif // FP16_AVAILABLE } - } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same::value) { //TODO: add support for ggml_cuda_mad for hip_bfloat162 #if defined(GGML_USE_HIP) const int * x2 = (const int *) x; @@ -295,7 +295,7 @@ static __global__ void mul_mat_vec_f( } #endif } else { - static_assert(std::is_same_v, "unsupported type"); + static_assert(std::is_same::value, "unsupported type"); } #pragma unroll @@ -604,7 +604,7 @@ static void mul_mat_vec_f_cuda( const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, const int64_t ids_stride, enum ggml_prec prec, cudaStream_t stream) { - if constexpr(std::is_same_v) { + if constexpr(std::is_same::value) { if (prec == GGML_PREC_DEFAULT) { mul_mat_vec_f_cuda_switch_ncols_dst (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 45a49a5dc2a..5ce8fe5575d 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -84,10 +84,10 @@ static __global__ void rope_norm(const T * x, } const auto & store_coaelsced = [&](float x0, float x1) { - if constexpr (std::is_same_v) { + if constexpr (std::is_same::value) { float2 v = make_float2(x0, x1); ggml_cuda_memcpy_1<8>(dst + idst, &v); - } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same::value) { half2 v = make_half2(x0, x1); ggml_cuda_memcpy_1<4>(dst + idst, &v); } diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index 285c0e9543a..c0a7716afdb 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -5,8 +5,10 @@ #ifdef GGML_USE_HIP #include #else +#if CUDART_VERSION >= 11000 #include #include +#endif // CUDART_VERSION >= 11000 #endif // GGML_USE_HIP #include @@ -138,6 +140,7 @@ static __global__ void soft_max_f32( } // TODO: Template to allow keeping ncols in registers if they fit +#if CUDART_VERSION >= 11000 static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x, float * __restrict__ dst, float * __restrict__ tmp_maxs, @@ -243,6 +246,8 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __ } } +#endif // CUDART_VERSION >= 11000 (cooperative_groups single_row) + #ifdef __clang__ #pragma clang diagnostic pop #endif // __clang__ @@ -289,9 +294,10 @@ static void launch_soft_max_kernels(const float * x, const T * mask, const float return false; }; - // unary fold over launch_kernel - if ((launch_kernel(std::integral_constant{}) || ...)) { - return; + // unary fold over launch_kernel (C++14 compat: array init trick) + bool results[] = { launch_kernel(std::integral_constant{})... }; + for (bool r : results) { + if (r) return; } //default case @@ -299,6 +305,7 @@ static void launch_soft_max_kernels(const float * x, const T * mask, const float soft_max_f32<<>>(x, mask, sinks, dst, p); } +#if CUDART_VERSION >= 11000 __launch_bounds__(8*WARP_SIZE, 1) static __global__ void soft_max_f32_parallelize_cols(const float * __restrict__ x, float * __restrict__ dst, float * __restrict__ tmp_maxs, @@ -315,6 +322,7 @@ __launch_bounds__(8*WARP_SIZE, 1) static __global__ void soft_max_f32_paralleliz tmp_sums, p); } } +#endif // CUDART_VERSION >= 11000 template static void soft_max_f32_cuda(const float * x, @@ -344,6 +352,7 @@ static void soft_max_f32_cuda(const float * x, // Parallelize across SMs for top-p/dist-sampling // The heuristic for parallelizing rows across SMs vs parallelizing single row & looping over all rows was done on the basis of a B6000 GPU and // Can be adapted further for lower-SM-count GPUs, though keeping data in registers should be implemented first as that is the optimal solution. +#if CUDART_VERSION >= 11000 if (ggml_cuda_info().devices[id].supports_cooperative_launch && ncols_x / (params.ne01 * params.ne02 * params.ne03) > 8192 && mask == nullptr && sinks == nullptr && params.scale == 1.0f && params.max_bias == 0.0f) { @@ -355,7 +364,9 @@ static void soft_max_f32_cuda(const float * x, CUDA_CHECK(cudaLaunchCooperativeKernel((void *) soft_max_f32_parallelize_cols, dim3(ggml_cuda_info().devices[id].nsm, 1, 1), dim3(WARP_SIZE * 8, 1, 1), kernel_args, 0, stream)); - } else { + } else +#endif // CUDART_VERSION >= 11000 + { const size_t nbytes_shared_low = WARP_SIZE * sizeof(float); soft_max_f32 <<>>(x, mask, sinks, dst, params); diff --git a/ggml/src/ggml-cuda/vendors/cuda.h b/ggml/src/ggml-cuda/vendors/cuda.h index 323c9801934..636625d0439 100644 --- a/ggml/src/ggml-cuda/vendors/cuda.h +++ b/ggml/src/ggml-cuda/vendors/cuda.h @@ -3,7 +3,11 @@ #include #include #include +#if CUDART_VERSION >= 11000 #include +#else +#include "../compat-cuda10.cuh" +#endif // CUDART_VERSION >= 11000 #include #ifdef GGML_USE_NCCL