Wire calculatePartialSums to native SIMD via Panama FFI downcall#651
Wire calculatePartialSums to native SIMD via Panama FFI downcall#651
Conversation
|
Before you submit for review:
If you did not complete any of these, then please explain below. |
* Replace icelake-server gcc target with skylake-avx512 in build script * Remove global mutable state: eliminate initialIndexRegister, indexIncrement, maskSeventhBit, maskEighthBit globals and their constructor initializer; move mask constants (maskSeventhBit, maskEighthBit) to local scope inside lookup_partial_sums * Add shared reduce_add_128_ps and reduce_add_256_ps helper functions using proper horizontal-add sequences instead of store-to-array loops * Remove redundant if (length >= N) guards in all SIMD kernels — the loop body already handles the zero-iteration case correctly * Replace store-to-aligned-array horizontal reduction pattern with the new helpers across all 128- and 256-bit dot product and euclidean distance functions * Remove preferred_size parameter from dot_product_f32 and euclidean_f32; always dispatch to AVX-512 when length >= 16 * Standardize inline annotations: replace __attribute__((always_inline)) inline with JV_FINLINE / JV_INLINE macros throughout
…zes 4,8 & 16 on AVX-512 Add SIMD fast paths in calculate_partial_sums_dot_f32_512 and calculate_partial_sums_euclidean_f32_512 for the two most common PQ subvector sizes: - size == 4: broadcast a 128-bit query fragment across all four 128-bit lanes of a ZMM register, load four consecutive centroids at once, and reduce each lane independently using two shuffle+add pairs. Produces 4 partial sums per loop iteration instead of 1. - size == 8: broadcast a 256-bit query fragment across both 256-bit halves of a ZMM register, load two consecutive centroids at once, and reduce across 128-bit lanes followed by within-lane shuffles. Produces 2 partial sums per loop iteration instead of 1. - size == 16: query and the centroid fit into a ZMM register, load the query into zmm and then loop over the centroids. Produces one partial sum per loop iteration, but prevents having to load the query multiple times. Both paths fall back to the default way of computing dot_product_f32 / euclidean_f32 in a loop for any tail elements or unsupported sizes.
70cd2fb to
de4ff79
Compare
jshook
left a comment
There was a problem hiding this comment.
I would like to see much more coverage of these with numerical tests. Are there some already which aren't seen here?
ashkrisk
left a comment
There was a problem hiding this comment.
Looks like an excellent set of optimizations. Left a few comments.
+1 to @jshook's comment about numerical tests. This PR touches almost every single function in the native supporting library, and it would be good to have a set of tests accompanying it, perhaps also in C.
| if [ "$(printf '%s\n' "$MIN_GCC_VERSION" "$CURRENT_GCC_VERSION" | sort -V | head -n1)" = "$MIN_GCC_VERSION" ]; then | ||
| rm -rf ../resources/libjvector.so | ||
| gcc -fPIC -O3 -march=icelake-server -c jvector_simd.c -o jvector_simd.o | ||
| gcc -fPIC -O3 -march=skylake-avx512 -c jvector_simd.c -o jvector_simd.o |
There was a problem hiding this comment.
Is there a strong reason to lower the target micro-architecture version?
| case 0: | ||
| calculate_partial_sums_euclidean_f32_512(codebook, codebookIndex, size, clusterCount, query, queryOffset, partialSums); | ||
| break; | ||
| case 1: |
There was a problem hiding this comment.
Can we use public enums here? Jextract should automatically make the enums available to the Java code as constants. Alternatively we could skip the parameter-based dispatch altogether and simply expose both versions of the function to Java code.
| __m512 vaMagnitude = _mm512_setzero_ps(); | ||
| int i = 0; | ||
| int limit = baseOffsetsLength - (baseOffsetsLength % 16); | ||
| const __m512i initialIndexRegister = _mm512_setr_epi32(-16, -15, -14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1); |
There was a problem hiding this comment.
It's good that this isn't a global variable anymore, but given that it's used in multiple places does it make sense to have it as a global constant?
| * void calculate_partial_sums_best_euclidean_f32_512(const float *codebook, int codebookBase, int size, int clusterCount, const float *query, int queryOffset, float *partialSums, float *partialBestDistances) | ||
| * } | ||
| */ | ||
| public static void calculate_partial_sums_best_euclidean_f32_512(MemorySegment codebook, int codebookBase, int size, int clusterCount, MemorySegment query, int queryOffset, MemorySegment partialSums, MemorySegment partialBestDistances) { |
There was a problem hiding this comment.
Looks like a lot of functions that are no longer in the public header are still declared here. Should fix itself on re-running jextract.
This change uses a native implementation of
calculatePartialSumsto accelerate PQ query scoring.On
ada002-100kwith FUSED_PQ (numPQsubspaces/M =96, JDK build 23.0.1+11-39), it delivers 2–3× higher QPS and 40–65% lower mean latency across common overquery settings. Index build time, disk usage, and heap usage show no meaningful regression. The optimization is isolated to the PQ path; non‑PQ queries are unaffected.Combined QPS and Latency Results (FUSED_PQ)
topK = 10
topK = 100
Summary of changes in this PR:
calculatePartialSumsin NativeVectorUtilSupport to a new Panama FFI downcall for the nativecalculate_partial_sums_f32_512SIMD implementation.