diff --git a/src/fsum.c b/src/fsum.c index 68e0114f..cc005f2f 100644 --- a/src/fsum.c +++ b/src/fsum.c @@ -1,6 +1,11 @@ #include "collapse_c.h" // #include +// Number of independent accumulators for loop unrolling: reduces serial dependency chain, +// enabling compiler auto-vectorization (SIMD) even when OpenMP is unavailable. +// Use a multiple of 4 for doubles to align with 256-bit SIMD registers (AVX). +#define FSUM_N_ACC 4 + double fsum_double_impl(const double *restrict px, const int narm, const int l) { double sum; if(narm == 1) { @@ -17,9 +22,15 @@ double fsum_double_impl(const double *restrict px, const int narm, const int l) #pragma omp simd reduction(+:sum) for(int i = 0; i < l; ++i) sum += NISNAN(px[i]) ? px[i] : 0.0; } else { - // Should just be fast, don't stop for NA's - #pragma omp simd reduction(+:sum) - for(int i = 0; i < l; ++i) sum += px[i]; + // Multiple independent accumulators allow compiler auto-vectorization (SIMD) + // even without OpenMP, by eliminating the serial data dependency on sum. + double sum_arr[FSUM_N_ACC] = {0}; + const int remainder = l % FSUM_N_ACC; + for(int i = 0; i < remainder; ++i) sum += px[i]; + for(int i = remainder; i < l; i += FSUM_N_ACC) { + for(int k = 0; k < FSUM_N_ACC; ++k) sum_arr[k] += px[i + k]; + } + for(int k = 0; k < FSUM_N_ACC; ++k) sum += sum_arr[k]; } } return sum; @@ -57,8 +68,14 @@ double fsum_double_omp_impl(const double *restrict px, const int narm, const int } else if(narm == 2) sum = 0.0; } else { sum = 0; - #pragma omp parallel for simd num_threads(nthreads) reduction(+:sum) - for(int i = 0; i < l; ++i) sum += px[i]; // Cannot have break statements in OpenMP for loop + double partial_sums[FSUM_N_ACC] = {0}; + const int remainder = l % FSUM_N_ACC; + for(int i = 0; i < remainder; ++i) sum += px[i]; + #pragma omp parallel for simd num_threads(nthreads) reduction(+:partial_sums[:FSUM_N_ACC]) + for(int i = remainder; i < l; i += FSUM_N_ACC) { + for(int k = 0; k < FSUM_N_ACC; ++k) partial_sums[k] += px[i + k]; + } + for(int k = 0; k < FSUM_N_ACC; ++k) sum += partial_sums[k]; } return sum; }