diff --git a/Src/Base/AMReX_SIMD.H b/Src/Base/AMReX_SIMD.H index 9c89cdee55..bce96d563a 100644 --- a/Src/Base/AMReX_SIMD.H +++ b/Src/Base/AMReX_SIMD.H @@ -28,11 +28,111 @@ namespace amrex::simd # if __cplusplus >= 202002L using vir::cvt; # endif + + /** Vectorized ternary operator: select(mask, true_val, false_val) + * + * Selects elements from true_val where mask is true and from false_val + * where mask is false. Analogous to (mask ? true_val : false_val) for + * scalars. + * + * Note: both true_val and false_val are eagerly evaluated (function + * arguments). To guard against operations like division by zero, + * sanitize inputs before the operation rather than relying on + * conditional selection. + * + * Example: + * ```cpp + * template + * T compute (T const& a, T const& b) + * { + * auto safe_b = amrex::simd::stdx::select(b != T(0), b, T(1)); + * return amrex::simd::stdx::select(b != T(0), a / safe_b, T(0)); + * } + * ``` + * + * @see C++26 std::simd select + * + * @todo Remove when SIMD provider (vir-simd / C++26) provides select. + * https://github.com/mattkretz/vir-simd/issues/49 + */ + template + AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE + vir::stdx::simd select ( + typename vir::stdx::simd::mask_type const& mask, + vir::stdx::simd const& true_val, + vir::stdx::simd const& false_val) + { + vir::stdx::simd result = false_val; + where(mask, result) = true_val; + return result; + } #else // fallback implementations for functions that are commonly used in portable code paths + /** True if the boolean value is true (scalar identity fallback for simd any_of) + * + * Example: + * ```cpp + * // Works for both simd_mask and scalar bool: + * auto mask = a > b; + * if (amrex::simd::stdx::any_of(mask)) { ... } + * ``` + * + * @see https://en.cppreference.com/w/cpp/experimental/simd/any_of + */ AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE bool any_of (bool const v) { return v; } + + /// \cond DOXYGEN_IGNORE + namespace detail { + template + struct where_expression { + bool mask; + T* value; + + AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE + where_expression& operator= (T const& new_val) + { + if (mask) { *value = new_val; } + return *this; + } + }; + } + /// \endcond + + /** Masked assignment expression (scalar fallback for simd where) + * + * Returns an expression object whose assignment operator conditionally + * updates value only when mask is true. + * + * Example: + * ```cpp + * // Works for both simd and scalar T: + * auto mask = b > T(0); + * T result = T(0); + * amrex::simd::stdx::where(mask, result) = a / b; + * ``` + * + * @see https://en.cppreference.com/w/cpp/experimental/simd/where + */ + template + AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE + detail::where_expression where (bool const mask, T& value) + { + return {mask, &value}; + } + + /** Vectorized ternary operator (scalar fallback for simd select) + * + * @see select in the AMREX_USE_SIMD path above + * @see C++26 std::simd select + */ + template + AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE + T select (bool const mask, T const& true_val, T const& false_val) + { + return mask ? true_val : false_val; + } #endif } diff --git a/Tests/SIMD/main.cpp b/Tests/SIMD/main.cpp index 1d75f8b7c3..cfd1dd5d51 100644 --- a/Tests/SIMD/main.cpp +++ b/Tests/SIMD/main.cpp @@ -425,6 +425,50 @@ int main (int argc, char* argv[]) << (err == 0 ? "PASSED" : "FAILED") << "\n"; } + // ================================================================ + // Test 14: any_of, where, select — portable single-source + // Uses SIMDParticleReal<>, which is a SIMD vector when AMREX_USE_SIMD=ON + // and a plain scalar when OFF. The same code path exercises + // both the real SIMD and the scalar fallback implementations. + // ================================================================ + { + using PReal_t = simd::SIMDParticleReal<>; + + // safe reciprocal: 1/b where b != 0, else 0 + auto b = PReal_t(2); + auto mask = b != PReal_t(0); + auto safe_b = simd::stdx::select(mask, b, PReal_t(1)); + auto recip = simd::stdx::select(mask, + PReal_t(1) / safe_b, + PReal_t(0)); + + // any_of: at least one lane should be nonzero + AMREX_ALWAYS_ASSERT(simd::stdx::any_of(mask)); + + // where: masked assignment + auto acc = PReal_t(0); + simd::stdx::where(mask, acc) = recip; + + // verify: b=2 everywhere → recip=0.5, acc=0.5 + int err = 0; + auto check = [&] (ParticleReal got, ParticleReal expected) { + if (std::abs(got - expected) > ParticleReal(1.e-10)) { ++err; } + }; +#ifdef AMREX_USE_SIMD + for (int lane = 0; lane < static_cast(PReal_t::size()); ++lane) { + check(recip[lane], ParticleReal(0.5)); + check(acc[lane], ParticleReal(0.5)); + } +#else + check(recip, ParticleReal(0.5)); + check(acc, ParticleReal(0.5)); +#endif + + nerrors += err; + Print() << "any_of + where + select (portable): " + << (err == 0 ? "PASSED" : "FAILED") << "\n"; + } + // ================================================================ // Final report // ================================================================