From 5e03f435b2aa5a5cf61f5084111311c3ceb7fb07 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Tue, 7 Apr 2026 04:31:32 +0000 Subject: [PATCH 1/4] Add optimal AVX2/SSE reduction and refactor native SIMD code MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Replace icelake-server gcc target with skylake-avx512 in build script * Remove global mutable state: eliminate initialIndexRegister, indexIncrement, maskSeventhBit, maskEighthBit globals and their constructor initializer; move mask constants (maskSeventhBit, maskEighthBit) to local scope inside lookup_partial_sums * Add shared reduce_add_128_ps and reduce_add_256_ps helper functions using proper horizontal-add sequences instead of store-to-array loops * Remove redundant if (length >= N) guards in all SIMD kernels — the loop body already handles the zero-iteration case correctly * Replace store-to-aligned-array horizontal reduction pattern with the new helpers across all 128- and 256-bit dot product and euclidean distance functions * Remove preferred_size parameter from dot_product_f32 and euclidean_f32; always dispatch to AVX-512 when length >= 16 * Standardize inline annotations: replace __attribute__((always_inline)) inline with JV_FINLINE / JV_INLINE macros throughout --- .../src/main/c/jextract_vector_simd.sh | 4 +- jvector-native/src/main/c/jvector_simd.c | 416 +++++++++--------- jvector-native/src/main/c/jvector_simd.h | 16 +- 3 files changed, 212 insertions(+), 224 deletions(-) diff --git a/jvector-native/src/main/c/jextract_vector_simd.sh b/jvector-native/src/main/c/jextract_vector_simd.sh index d44d375dd..45767e6e6 100755 --- a/jvector-native/src/main/c/jextract_vector_simd.sh +++ b/jvector-native/src/main/c/jextract_vector_simd.sh @@ -49,7 +49,7 @@ CURRENT_GCC_VERSION=$(gcc -dumpversion) # Check if the current GCC version is greater than or equal to the minimum required version if [ "$(printf '%s\n' "$MIN_GCC_VERSION" "$CURRENT_GCC_VERSION" | sort -V | head -n1)" = "$MIN_GCC_VERSION" ]; then rm -rf ../resources/libjvector.so - gcc -fPIC -O3 -march=icelake-server -c jvector_simd.c -o jvector_simd.o + gcc -fPIC -O3 -march=skylake-avx512 -c jvector_simd.c -o jvector_simd.o gcc -fPIC -O3 -march=x86-64 -c jvector_simd_check.c -o jvector_simd_check.o gcc -shared -o ../resources/libjvector.so jvector_simd_check.o jvector_simd.o @@ -77,4 +77,4 @@ jextract \ jvector_simd.h # Set critical linker option with heap-based segments for all generated methods -sed -i 's/DESC)/DESC, Linker.Option.critical(true))/g' ../java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java \ No newline at end of file +sed -i 's/DESC)/DESC, Linker.Option.critical(true))/g' ../java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java diff --git a/jvector-native/src/main/c/jvector_simd.c b/jvector-native/src/main/c/jvector_simd.c index d9c909c0f..faaba3393 100644 --- a/jvector-native/src/main/c/jvector_simd.c +++ b/jvector-native/src/main/c/jvector_simd.c @@ -19,23 +19,27 @@ #include #include "jvector_simd.h" -__m512i initialIndexRegister; -__m512i indexIncrement; -__m512i maskSeventhBit; -__m512i maskEighthBit; - -__attribute__((constructor)) -void initialize_constants() { - if (check_compatibility()) { - initialIndexRegister = _mm512_setr_epi32(-16, -15, -14, -13, -12, -11, -10, -9, - -8, -7, -6, -5, -4, -3, -2, -1); - indexIncrement = _mm512_set1_epi32(16); - maskSeventhBit = _mm512_set1_epi16(0x0040); - maskEighthBit = _mm512_set1_epi16(0x0080); - } + +JV_FINLINE float reduce_add_256_ps(__m256 v) { + __m128 lo = _mm256_castps256_ps128(v); + __m128 hi = _mm256_extractf128_ps(v, 1); + __m128 sum128 = _mm_add_ps(lo, hi); + __m128 shuf = _mm_movehdup_ps(sum128); + __m128 sums = _mm_add_ps(sum128, shuf); + shuf = _mm_movehl_ps(shuf, sums); + sums = _mm_add_ss(sums, shuf); + return _mm_cvtss_f32(sums); +} + +JV_FINLINE float reduce_add_128_ps(__m128 v) { + __m128 shuf = _mm_movehdup_ps(v); + __m128 sums = _mm_add_ps(v, shuf); + shuf = _mm_movehl_ps(shuf, sums); + sums = _mm_add_ss(sums, shuf); + return _mm_cvtss_f32(sums); } -float dot_product_f32_64(const float* a, int aoffset, const float* b, int boffset) { +JV_FINLINE float dot_product_f32_64(const float* a, int aoffset, const float* b, int boffset) { __m128 va = _mm_castsi128_ps(_mm_loadl_epi64((__m128i *)(a + aoffset))); __m128 vb = _mm_castsi128_ps(_mm_loadl_epi64((__m128i *)(b + boffset))); @@ -47,7 +51,7 @@ float dot_product_f32_64(const float* a, int aoffset, const float* b, int boffse return result[0] + result[1]; } -float dot_product_f32_128(const float* a, int aoffset, const float* b, int boffset, int length) { +JV_FINLINE float dot_product_f32_128(const float* a, int aoffset, const float* b, int boffset, int length) { float dot = 0.0; int ao = aoffset; int bo = boffset; @@ -55,26 +59,17 @@ float dot_product_f32_128(const float* a, int aoffset, const float* b, int boffs int blim = boffset + length; int simd_length = length - (length % 4); - if (length >= 4) { - __m128 sum = _mm_setzero_ps(); - - for(; ao < aoffset + simd_length; ao += 4, bo += 4) { - // Load float32 - __m128 va = _mm_loadu_ps(a + ao); - __m128 vb = _mm_loadu_ps(b + bo); - - // Multiply and accumulate - sum = _mm_fmadd_ps(va, vb, sum); - } + __m128 sum = _mm_setzero_ps(); - // Horizontal sum of the vector to get dot product - __attribute__((aligned(16))) float result[4]; - _mm_store_ps(result, sum); + for(; ao < aoffset + simd_length; ao += 4, bo += 4) { + // Load float32 + __m128 va = _mm_loadu_ps(a + ao); + __m128 vb = _mm_loadu_ps(b + bo); - for(int i = 0; i < 4; ++i) { - dot += result[i]; - } + // Multiply and accumulate + sum = _mm_fmadd_ps(va, vb, sum); } + dot = reduce_add_128_ps(sum); for (; ao < alim && bo < blim; ao++, bo++) { dot += a[ao] * b[bo]; @@ -83,7 +78,7 @@ float dot_product_f32_128(const float* a, int aoffset, const float* b, int boffs return dot; } -float dot_product_f32_256(const float* a, int aoffset, const float* b, int boffset, int length) { +JV_FINLINE float dot_product_f32_256(const float* a, int aoffset, const float* b, int boffset, int length) { float dot = 0.0; int ao = aoffset; int bo = boffset; @@ -91,27 +86,20 @@ float dot_product_f32_256(const float* a, int aoffset, const float* b, int boffs int blim = boffset + length; int simd_length = length - (length % 8); - if (length >= 8) { - __m256 sum = _mm256_setzero_ps(); + __m256 sum = _mm256_setzero_ps(); - for(; ao < aoffset + simd_length; ao += 8, bo += 8) { - // Load float32 - __m256 va = _mm256_loadu_ps(a + ao); - __m256 vb = _mm256_loadu_ps(b + bo); + for(; ao < aoffset + simd_length; ao += 8, bo += 8) { + // Load float32 + __m256 va = _mm256_loadu_ps(a + ao); + __m256 vb = _mm256_loadu_ps(b + bo); - // Multiply and accumulate - sum = _mm256_fmadd_ps(va, vb, sum); - } - - // Horizontal sum of the vector to get dot product - __attribute__((aligned(32))) float result[8]; - _mm256_store_ps(result, sum); - - for(int i = 0; i < 8; ++i) { - dot += result[i]; - } + // Multiply and accumulate + sum = _mm256_fmadd_ps(va, vb, sum); } + // Horizontal sum of the vector to get dot product + dot = reduce_add_256_ps(sum); + for (; ao < alim && bo < blim; ao++, bo++) { dot += a[ao] * b[bo]; } @@ -119,7 +107,7 @@ float dot_product_f32_256(const float* a, int aoffset, const float* b, int boffs return dot; } -float dot_product_f32_512(const float* a, int aoffset, const float* b, int boffset, int length) { +JV_FINLINE float dot_product_f32_512(const float* a, int aoffset, const float* b, int boffset, int length) { float dot = 0.0; int ao = aoffset; int bo = boffset; @@ -127,21 +115,19 @@ float dot_product_f32_512(const float* a, int aoffset, const float* b, int boffs int blim = boffset + length; int simd_length = length - (length % 16); - if (length >= 16) { - __m512 sum = _mm512_setzero_ps(); - for(; ao < aoffset + simd_length; ao += 16, bo += 16) { - // Load float32 - __m512 va = _mm512_loadu_ps(a + ao); - __m512 vb = _mm512_loadu_ps(b + bo); - - // Multiply and accumulate - sum = _mm512_fmadd_ps(va, vb, sum); - } + __m512 sum = _mm512_setzero_ps(); + for(; ao < aoffset + simd_length; ao += 16, bo += 16) { + // Load float32 + __m512 va = _mm512_loadu_ps(a + ao); + __m512 vb = _mm512_loadu_ps(b + bo); - // Horizontal sum of the vector to get dot product - dot = _mm512_reduce_add_ps(sum); + // Multiply and accumulate + sum = _mm512_fmadd_ps(va, vb, sum); } + // Horizontal sum of the vector to get dot product + dot = _mm512_reduce_add_ps(sum); + for (; ao < alim && bo < blim; ao++, bo++) { dot += a[ao] * b[bo]; } @@ -149,18 +135,18 @@ float dot_product_f32_512(const float* a, int aoffset, const float* b, int boffs return dot; } -float dot_product_f32(int preferred_size, const float* a, int aoffset, const float* b, int boffset, int length) { +JV_FINLINE float dot_product_f32(const float* a, int aoffset, const float* b, int boffset, int length) { if (length == 2) return dot_product_f32_64(a, aoffset, b, boffset); if (length <= 7) return dot_product_f32_128(a, aoffset, b, boffset, length); - return (preferred_size == 512 && length >= 16) + return (length >= 16) ? dot_product_f32_512(a, aoffset, b, boffset, length) : dot_product_f32_256(a, aoffset, b, boffset, length); } -float euclidean_f32_64(const float* a, int aoffset, const float* b, int boffset) { +JV_FINLINE float euclidean_f32_64(const float* a, int aoffset, const float* b, int boffset) { __m128 va = _mm_castsi128_ps(_mm_loadl_epi64((__m128i *)(a + aoffset))); __m128 vb = _mm_castsi128_ps(_mm_loadl_epi64((__m128i *)(b + boffset))); __m128 r = _mm_sub_ps(va, vb); @@ -172,7 +158,7 @@ float euclidean_f32_64(const float* a, int aoffset, const float* b, int boffset) return result[0] + result[1]; } -float euclidean_f32_128(const float* a, int aoffset, const float* b, int boffset, int length) { +JV_FINLINE float euclidean_f32_128(const float* a, int aoffset, const float* b, int boffset, int length) { float squareDistance = 0.0; int ao = aoffset; int bo = boffset; @@ -180,27 +166,20 @@ float euclidean_f32_128(const float* a, int aoffset, const float* b, int boffset int blim = boffset + length; int simd_length = length - (length % 4); - if (length >= 4) { - __m128 sum = _mm_setzero_ps(); - - for(; ao < aoffset + simd_length; ao += 4, bo += 4) { - // Load float32 - __m128 va = _mm_loadu_ps(a + ao); - __m128 vb = _mm_loadu_ps(b + bo); - __m128 diff = _mm_sub_ps(va, vb); - // Multiply and accumulate - sum = _mm_fmadd_ps(diff, diff, sum); - } - - // Horizontal sum of the vector to get dot product - __attribute__((aligned(16))) float result[4]; - _mm_store_ps(result, sum); - - for(int i = 0; i < 4; ++i) { - squareDistance += result[i]; - } + __m128 sum = _mm_setzero_ps(); + + for(; ao < aoffset + simd_length; ao += 4, bo += 4) { + // Load float32 + __m128 va = _mm_loadu_ps(a + ao); + __m128 vb = _mm_loadu_ps(b + bo); + __m128 diff = _mm_sub_ps(va, vb); + // Multiply and accumulate + sum = _mm_fmadd_ps(diff, diff, sum); } + // Horizontal sum of the vector to get dot product + squareDistance = reduce_add_128_ps(sum); + for (; ao < alim && bo < blim; ao++, bo++) { float diff = a[ao] - b[bo]; squareDistance += diff * diff; @@ -209,7 +188,7 @@ float euclidean_f32_128(const float* a, int aoffset, const float* b, int boffset return squareDistance; } -float euclidean_f32_256(const float* a, int aoffset, const float* b, int boffset, int length) { +JV_FINLINE float euclidean_f32_256(const float* a, int aoffset, const float* b, int boffset, int length) { float squareDistance = 0.0; int ao = aoffset; int bo = boffset; @@ -217,27 +196,20 @@ float euclidean_f32_256(const float* a, int aoffset, const float* b, int boffset int blim = boffset + length; int simd_length = length - (length % 8); - if (length >= 8) { - __m256 sum = _mm256_setzero_ps(); - - for(; ao < aoffset + simd_length; ao += 8, bo += 8) { - // Load float32 - __m256 va = _mm256_loadu_ps(a + ao); - __m256 vb = _mm256_loadu_ps(b + bo); - __m256 diff = _mm256_sub_ps(va, vb); + __m256 sum = _mm256_setzero_ps(); - // Multiply and accumulate - sum = _mm256_fmadd_ps(diff, diff, sum); - } + for(; ao < aoffset + simd_length; ao += 8, bo += 8) { + // Load float32 + __m256 va = _mm256_loadu_ps(a + ao); + __m256 vb = _mm256_loadu_ps(b + bo); + __m256 diff = _mm256_sub_ps(va, vb); - __attribute__((aligned(32))) float result[8]; - _mm256_store_ps(result, sum); - - for(int i = 0; i < 8; ++i) { - squareDistance += result[i]; - } + // Multiply and accumulate + sum = _mm256_fmadd_ps(diff, diff, sum); } + squareDistance = reduce_add_256_ps(sum); + for (; ao < alim && bo < blim; ao++, bo++) { float diff = a[ao] - b[bo]; squareDistance += diff * diff; @@ -246,7 +218,7 @@ float euclidean_f32_256(const float* a, int aoffset, const float* b, int boffset return squareDistance; } -float euclidean_f32_512(const float* a, int aoffset, const float* b, int boffset, int length) { +JV_FINLINE float euclidean_f32_512(const float* a, int aoffset, const float* b, int boffset, int length) { float squareDistance = 0.0; int ao = aoffset; int bo = boffset; @@ -254,22 +226,20 @@ float euclidean_f32_512(const float* a, int aoffset, const float* b, int boffset int blim = boffset + length; int simd_length = length - (length % 16); - if (length >= 16) { - __m512 sum = _mm512_setzero_ps(); - for(; ao < aoffset + simd_length; ao += 16, bo += 16) { - // Load float32 - __m512 va = _mm512_loadu_ps(a + ao); - __m512 vb = _mm512_loadu_ps(b + bo); - __m512 diff = _mm512_sub_ps(va, vb); - - // Multiply and accumulate - sum = _mm512_fmadd_ps(diff, diff, sum); - } - - // Horizontal sum of the vector to get dot product - squareDistance = _mm512_reduce_add_ps(sum); + __m512 sum = _mm512_setzero_ps(); + for(; ao < aoffset + simd_length; ao += 16, bo += 16) { + // Load float32 + __m512 va = _mm512_loadu_ps(a + ao); + __m512 vb = _mm512_loadu_ps(b + bo); + __m512 diff = _mm512_sub_ps(va, vb); + + // Multiply and accumulate + sum = _mm512_fmadd_ps(diff, diff, sum); } + // Horizontal sum of the vector to get dot product + squareDistance = _mm512_reduce_add_ps(sum); + for (; ao < alim && bo < blim; ao++, bo++) { float diff = a[ao] - b[bo]; squareDistance += diff * diff; @@ -278,103 +248,28 @@ float euclidean_f32_512(const float* a, int aoffset, const float* b, int boffset return squareDistance; } -float euclidean_f32(int preferred_size, const float* a, int aoffset, const float* b, int boffset, int length) { +JV_INLINE float euclidean_f32(const float* a, int aoffset, const float* b, int boffset, int length) { if (length == 2) return euclidean_f32_64(a, aoffset, b, boffset); if (length <= 7) return euclidean_f32_128(a, aoffset, b, boffset, length); - return (preferred_size == 512 && length >= 16) + return (length >= 16) ? euclidean_f32_512(a, aoffset, b, boffset, length) : euclidean_f32_256(a, aoffset, b, boffset, length); } -float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned char* baseOffsets, int baseOffsetsOffset, int baseOffsetsLength) { - __m512 sum = _mm512_setzero_ps(); - int i = 0; - int limit = baseOffsetsLength - (baseOffsetsLength % 16); - __m512i indexRegister = initialIndexRegister; - __m512i dataBaseVec = _mm512_set1_epi32(dataBase); - baseOffsets = baseOffsets + baseOffsetsOffset; - - for (; i < limit; i += 16) { - __m128i baseOffsetsRaw = _mm_loadu_si128((__m128i *)(baseOffsets + i)); - __m512i baseOffsetsInt = _mm512_cvtepu8_epi32(baseOffsetsRaw); - // we have base offsets int, which we need to scale to index into data. - // first, we want to initialize a vector with the lane number added as an index - indexRegister = _mm512_add_epi32(indexRegister, indexIncrement); - // then we want to multiply by dataBase - __m512i scale = _mm512_mullo_epi32(indexRegister, dataBaseVec); - // then we want to add the base offsets - __m512i convOffsets = _mm512_add_epi32(scale, baseOffsetsInt); - - __m512 partials = _mm512_i32gather_ps(convOffsets, data, 4); - sum = _mm512_add_ps(sum, partials); - } - - float res = _mm512_reduce_add_ps(sum); - for (; i < baseOffsetsLength; i++) { - res += data[dataBase * i + baseOffsets[i]]; - } - - return res; -} - -float pq_decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsOffset, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude) { - __m512 sum = _mm512_setzero_ps(); - __m512 vaMagnitude = _mm512_setzero_ps(); - int i = 0; - int limit = baseOffsetsLength - (baseOffsetsLength % 16); - __m512i indexRegister = initialIndexRegister; - __m512i scale = _mm512_set1_epi32(clusterCount); - baseOffsets = baseOffsets + baseOffsetsOffset; - - - for (; i < limit; i += 16) { - // Load and convert baseOffsets to integers - __m128i baseOffsetsRaw = _mm_loadu_si128((__m128i *)(baseOffsets + i)); - __m512i baseOffsetsInt = _mm512_cvtepu8_epi32(baseOffsetsRaw); - - indexRegister = _mm512_add_epi32(indexRegister, indexIncrement); - // Scale the baseOffsets by the cluster count - __m512i scaledOffsets = _mm512_mullo_epi32(indexRegister, scale); - - // Calculate the final convOffsets by adding the scaled indexes and the base offsets - __m512i convOffsets = _mm512_add_epi32(scaledOffsets, baseOffsetsInt); - - // Gather and sum values for partial sums and a magnitude - __m512 partialSumVals = _mm512_i32gather_ps(convOffsets, partialSums, 4); - sum = _mm512_add_ps(sum, partialSumVals); - - __m512 aMagnitudeVals = _mm512_i32gather_ps(convOffsets, aMagnitude, 4); - vaMagnitude = _mm512_add_ps(vaMagnitude, aMagnitudeVals); - } - - // Reduce sums - float sumResult = _mm512_reduce_add_ps(sum); - float aMagnitudeResult = _mm512_reduce_add_ps(vaMagnitude); - - // Handle the remaining elements - for (; i < baseOffsetsLength; i++) { - int offset = clusterCount * i + baseOffsets[i]; - sumResult += partialSums[offset]; - aMagnitudeResult += aMagnitude[offset]; - } - - return sumResult / sqrtf(aMagnitudeResult * bMagnitude); -} - -void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) { +JV_INLINE void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) { int codebookBase = codebookIndex * clusterCount; for (int i = 0; i < clusterCount; i++) { - partialSums[codebookBase + i] = dot_product_f32(512, codebook, i * size, query, queryOffset, size); + partialSums[codebookBase + i] = dot_product_f32(codebook, i * size, query, queryOffset, size); } } -void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) { +JV_INLINE void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) { int codebookBase = codebookIndex * clusterCount; for (int i = 0; i < clusterCount; i++) { - partialSums[codebookBase + i] = euclidean_f32(512, codebook, i * size, query, queryOffset, size); + partialSums[codebookBase + i] = euclidean_f32(codebook, i * size, query, queryOffset, size); } } @@ -395,7 +290,7 @@ void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codeboo */ -__attribute__((always_inline)) inline __m512i lookup_partial_sums(__m512i shuffle, const char* quantizedPartials, int i) { +JV_FINLINE __m512i lookup_partial_sums(__m512i shuffle, const char* quantizedPartials, int i) { __m512i partialsVecA = _mm512_loadu_epi16(quantizedPartials + i * 512); __m512i partialsVecB = _mm512_loadu_epi16(quantizedPartials + i * 512 + 64); __m512i partialsVecC = _mm512_loadu_epi16(quantizedPartials + i * 512 + 128); @@ -410,6 +305,8 @@ __attribute__((always_inline)) inline __m512i lookup_partial_sums(__m512i shuffl __m512i partialsVecEF = _mm512_permutex2var_epi16(partialsVecE, shuffle, partialsVecF); __m512i partialsVecGH = _mm512_permutex2var_epi16(partialsVecG, shuffle, partialsVecH); + const __m512i maskSeventhBit = _mm512_set1_epi16(0x0040); + const __m512i maskEighthBit = _mm512_set1_epi16(0x0080); __mmask32 maskSeven = _mm512_test_epi16_mask(shuffle, maskSeventhBit); __mmask32 maskEight = _mm512_test_epi16_mask(shuffle, maskEighthBit); __m512i partialsVecABCD = _mm512_mask_blend_epi16(maskSeven, partialsVecAB, partialsVecCD); @@ -420,7 +317,7 @@ __attribute__((always_inline)) inline __m512i lookup_partial_sums(__m512i shuffl } // dequantize a 256-bit vector containing 16 unsigned 16-bit integers into a 512-bit vector containing 16 32-bit floats -__attribute__((always_inline)) inline __m512 dequantize(__m256i quantizedVec, float delta, float base) { +JV_FINLINE __m512 dequantize(__m256i quantizedVec, float delta, float base) { __m512i quantizedVecWidened = _mm512_cvtepu16_epi32(quantizedVec); __m512 floatVec = _mm512_cvtepi32_ps(quantizedVecWidened); __m512 deltaVec = _mm512_set1_ps(delta); @@ -429,7 +326,7 @@ __attribute__((always_inline)) inline __m512 dequantize(__m256i quantizedVec, fl return dequantizedVec; } -void bulk_quantized_shuffle_euclidean_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartials, float delta, float minDistance, float* results) { +JV_INLINE void bulk_quantized_shuffle_euclidean_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartials, float delta, float minDistance, float* results) { __m512i sum = _mm512_setzero_epi32(); for (int i = 0; i < codebookCount; i++) { @@ -454,7 +351,7 @@ void bulk_quantized_shuffle_euclidean_f32_512(const unsigned char* shuffles, int _mm512_storeu_ps(results + 16, resultsRight); } -void bulk_quantized_shuffle_dot_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartials, float delta, float best, float* results) { +JV_INLINE void bulk_quantized_shuffle_dot_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartials, float delta, float best, float* results) { __m512i sum = _mm512_setzero_epi32(); for (int i = 0; i < codebookCount; i++) { @@ -478,7 +375,7 @@ void bulk_quantized_shuffle_dot_f32_512(const unsigned char* shuffles, int codeb _mm512_storeu_ps(results + 16, resultsRight); } -void bulk_quantized_shuffle_cosine_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartialSums, float sumDelta, float minDistance, const char* quantizedPartialMagnitudes, float magnitudeDelta, float minMagnitude, float queryMagnitudeSquared, float* results) { +JV_INLINE void bulk_quantized_shuffle_cosine_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartialSums, float sumDelta, float minDistance, const char* quantizedPartialMagnitudes, float magnitudeDelta, float minMagnitude, float queryMagnitudeSquared, float* results) { __m512i sum = _mm512_setzero_epi32(); __m512i magnitude = _mm512_setzero_epi32(); @@ -520,11 +417,11 @@ void bulk_quantized_shuffle_cosine_f32_512(const unsigned char* shuffles, int co } // Partial sum calculations that also record best distances, as this is necessary for Fused ADC quantization -void calculate_partial_sums_best_dot_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances) { +JV_INLINE void calculate_partial_sums_best_dot_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances) { float best = -INFINITY; int codebookBase = codebookIndex * clusterCount; for (int i = 0; i < clusterCount; i++) { - float val = dot_product_f32(512, codebook, i * size, query, queryOffset, size); + float val = dot_product_f32(codebook, i * size, query, queryOffset, size); partialSums[codebookBase + i] = val; if (val > best) { best = val; @@ -533,15 +430,112 @@ void calculate_partial_sums_best_dot_f32_512(const float* codebook, int codebook partialBestDistances[codebookIndex] = best; } -void calculate_partial_sums_best_euclidean_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances) { +JV_INLINE void calculate_partial_sums_best_euclidean_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances) { float best = INFINITY; int codebookBase = codebookIndex * clusterCount; for (int i = 0; i < clusterCount; i++) { - float val = euclidean_f32(512, codebook, i * size, query, queryOffset, size); + float val = euclidean_f32(codebook, i * size, query, queryOffset, size); partialSums[codebookBase + i] = val; if (val < best) { best = val; } } partialBestDistances[codebookIndex] = best; -} \ No newline at end of file +} + +/* List API's exposed to JAVA via FFI here: Do not mark them static or online, + * as they need to be visible to the dynamic linker and we may want to + * benchmark them individually in C. + */ + +float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned char* baseOffsets, int baseOffsetsOffset, int baseOffsetsLength) { + __m512 sum = _mm512_setzero_ps(); + int i = 0; + int limit = baseOffsetsLength - (baseOffsetsLength % 16); + const __m512i initialIndexRegister = _mm512_setr_epi32(-16, -15, -14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1); + const __m512i indexIncrement = _mm512_set1_epi32(16); + __m512i indexRegister = initialIndexRegister; + __m512i dataBaseVec = _mm512_set1_epi32(dataBase); + baseOffsets = baseOffsets + baseOffsetsOffset; + + for (; i < limit; i += 16) { + __m128i baseOffsetsRaw = _mm_loadu_si128((__m128i *)(baseOffsets + i)); + __m512i baseOffsetsInt = _mm512_cvtepu8_epi32(baseOffsetsRaw); + // we have base offsets int, which we need to scale to index into data. + // first, we want to initialize a vector with the lane number added as an index + indexRegister = _mm512_add_epi32(indexRegister, indexIncrement); + // then we want to multiply by dataBase + __m512i scale = _mm512_mullo_epi32(indexRegister, dataBaseVec); + // then we want to add the base offsets + __m512i convOffsets = _mm512_add_epi32(scale, baseOffsetsInt); + + __m512 partials = _mm512_i32gather_ps(convOffsets, data, 4); + sum = _mm512_add_ps(sum, partials); + } + + float res = _mm512_reduce_add_ps(sum); + for (; i < baseOffsetsLength; i++) { + res += data[dataBase * i + baseOffsets[i]]; + } + + return res; +} + +float pq_decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsOffset, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude) { + __m512 sum = _mm512_setzero_ps(); + __m512 vaMagnitude = _mm512_setzero_ps(); + int i = 0; + int limit = baseOffsetsLength - (baseOffsetsLength % 16); + const __m512i initialIndexRegister = _mm512_setr_epi32(-16, -15, -14, -13, -12, -11, -10, -9, -8, -7, -6, -5, -4, -3, -2, -1); + const __m512i indexIncrement = _mm512_set1_epi32(16); + __m512i indexRegister = initialIndexRegister; + __m512i scale = _mm512_set1_epi32(clusterCount); + baseOffsets = baseOffsets + baseOffsetsOffset; + + + for (; i < limit; i += 16) { + // Load and convert baseOffsets to integers + __m128i baseOffsetsRaw = _mm_loadu_si128((__m128i *)(baseOffsets + i)); + __m512i baseOffsetsInt = _mm512_cvtepu8_epi32(baseOffsetsRaw); + + indexRegister = _mm512_add_epi32(indexRegister, indexIncrement); + // Scale the baseOffsets by the cluster count + __m512i scaledOffsets = _mm512_mullo_epi32(indexRegister, scale); + + // Calculate the final convOffsets by adding the scaled indexes and the base offsets + __m512i convOffsets = _mm512_add_epi32(scaledOffsets, baseOffsetsInt); + + // Gather and sum values for partial sums and a magnitude + __m512 partialSumVals = _mm512_i32gather_ps(convOffsets, partialSums, 4); + sum = _mm512_add_ps(sum, partialSumVals); + + __m512 aMagnitudeVals = _mm512_i32gather_ps(convOffsets, aMagnitude, 4); + vaMagnitude = _mm512_add_ps(vaMagnitude, aMagnitudeVals); + } + + // Reduce sums + float sumResult = _mm512_reduce_add_ps(sum); + float aMagnitudeResult = _mm512_reduce_add_ps(vaMagnitude); + + // Handle the remaining elements + for (; i < baseOffsetsLength; i++) { + int offset = clusterCount * i + baseOffsets[i]; + sumResult += partialSums[offset]; + aMagnitudeResult += aMagnitude[offset]; + } + + return sumResult / sqrtf(aMagnitudeResult * bMagnitude); +} + +void calculate_partial_sums_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, int similarityFunction, float* partialSums) { + switch (similarityFunction) { + case 0: + calculate_partial_sums_euclidean_f32_512(codebook, codebookIndex, size, clusterCount, query, queryOffset, partialSums); + break; + case 1: + calculate_partial_sums_dot_f32_512(codebook, codebookIndex, size, clusterCount, query, queryOffset, partialSums); + break; + default: + break; + } +} diff --git a/jvector-native/src/main/c/jvector_simd.h b/jvector-native/src/main/c/jvector_simd.h index 55f1a46c1..f39e26f88 100644 --- a/jvector-native/src/main/c/jvector_simd.h +++ b/jvector-native/src/main/c/jvector_simd.h @@ -19,19 +19,13 @@ #ifndef VECTOR_SIMD_DOT_H #define VECTOR_SIMD_DOT_H +#define JV_INLINE static inline +#define JV_FINLINE static inline __attribute__((always_inline)) // check CPU support bool check_compatibility(void); -//F32 -float dot_product_f32(int preferred_size, const float* a, int aoffset, const float* b, int boffset, int length); -float euclidean_f32(int preferred_size, const float* a, int aoffset, const float* b, int boffset, int length); -void bulk_quantized_shuffle_dot_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartials, float delta, float minDistance, float* results); -void bulk_quantized_shuffle_euclidean_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartials, float delta, float minDistance, float* results); -void bulk_quantized_shuffle_cosine_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartialSums, float sumDelta, float minDistance, const char* quantizedPartialMagnitudes, float magnitudeDelta, float minMagnitude, float queryMagnitudeSquared, float* results); +// APIs exposed to Java via FFI float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned char* baseOffsets, int baseOffsetsOffset, int baseOffsetsLength); float pq_decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsOffset, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude); -void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums); -void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums); -void calculate_partial_sums_best_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances); -void calculate_partial_sums_best_euclidean_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances); -#endif \ No newline at end of file +void calculate_partial_sums_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, int similarityFunction, float* partialSums); +#endif From d07337165e3f4e2b3aad4631cbe9c07358da34dd Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Tue, 7 Apr 2026 04:31:43 +0000 Subject: [PATCH 2/4] Wire calculatePartialSums to native SIMD via Panama FFI downcall --- .../vector/NativeVectorUtilSupport.java | 13 ++++++++ .../jvector/vector/cnative/NativeSimdOps.java | 32 +++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java index 48cd7d66e..bf65b181a 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java @@ -104,4 +104,17 @@ public float pqDecodedCosineSimilarity(ByteSequence encoded, int encodedOffse // encoded is a pointer into a PQ chunk - we need to index into it by encodedOffset and provide encodedLength to the native code return NativeSimdOps.pq_decoded_cosine_similarity_f32_512(((MemorySegmentByteSequence) encoded).get(), encodedOffset, encodedLength, clusterCount, ((MemorySegmentVectorFloat) partialSums).get(), ((MemorySegmentVectorFloat) aMagnitude).get(), bMagnitude); } + + @Override + public void calculatePartialSums(VectorFloat codebook, int codebookIndex, int size, int clusterCount, VectorFloat query, int queryOffset, VectorSimilarityFunction vsf, VectorFloat partialSums) { + var nativeCodebook = ((MemorySegmentVectorFloat) codebook).get(); + var nativeQuery = ((MemorySegmentVectorFloat) query).get(); + var nativePartialSums = ((MemorySegmentVectorFloat) partialSums).get(); + int similarityFunction = switch (vsf) { + case EUCLIDEAN -> 0; + case DOT_PRODUCT -> 1; + default -> throw new UnsupportedOperationException("Unsupported similarity function " + vsf); + }; + NativeSimdOps.calculate_partial_sums_f32_512(nativeCodebook, codebookIndex, size, clusterCount, nativeQuery, queryOffset, similarityFunction, nativePartialSums); + } } diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java index 014bdf4b0..5b31c7e02 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/cnative/NativeSimdOps.java @@ -847,4 +847,36 @@ public static void calculate_partial_sums_best_euclidean_f32_512(MemorySegment c throw new AssertionError("should not reach here", ex$); } } + + private static class calculate_partial_sums_f32_512 { + public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid( + NativeSimdOps.C_POINTER, + NativeSimdOps.C_INT, + NativeSimdOps.C_INT, + NativeSimdOps.C_INT, + NativeSimdOps.C_POINTER, + NativeSimdOps.C_INT, + NativeSimdOps.C_INT, + NativeSimdOps.C_POINTER + ); + public static final MemorySegment ADDR = NativeSimdOps.findOrThrow("calculate_partial_sums_f32_512"); + public static final MethodHandle HANDLE = Linker.nativeLinker().downcallHandle(ADDR, DESC, Linker.Option.critical(true)); + } + + /** + * {@snippet lang=c : + * void calculate_partial_sums_f32_512(const float *codebook, int codebookIndex, int size, int clusterCount, const float *query, int queryOffset, int similarityFunction, float *partialSums) + * } + */ + public static void calculate_partial_sums_f32_512(MemorySegment codebook, int codebookIndex, int size, int clusterCount, MemorySegment query, int queryOffset, int similarityFunction, MemorySegment partialSums) { + var mh$ = calculate_partial_sums_f32_512.HANDLE; + try { + if (TRACE_DOWNCALLS) { + traceDowncall("calculate_partial_sums_f32_512", codebook, codebookIndex, size, clusterCount, query, queryOffset, similarityFunction, partialSums); + } + mh$.invokeExact(codebook, codebookIndex, size, clusterCount, query, queryOffset, similarityFunction, partialSums); + } catch (Throwable ex$) { + throw new AssertionError("should not reach here", ex$); + } + } } \ No newline at end of file From de4ff79bbea3023797256d0cd78a239c2eb2d051 Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Tue, 7 Apr 2026 04:29:34 +0000 Subject: [PATCH 3/4] perf: Optimize PQ distance look up table calculation for subvector sizes 4,8 & 16 on AVX-512 Add SIMD fast paths in calculate_partial_sums_dot_f32_512 and calculate_partial_sums_euclidean_f32_512 for the two most common PQ subvector sizes: - size == 4: broadcast a 128-bit query fragment across all four 128-bit lanes of a ZMM register, load four consecutive centroids at once, and reduce each lane independently using two shuffle+add pairs. Produces 4 partial sums per loop iteration instead of 1. - size == 8: broadcast a 256-bit query fragment across both 256-bit halves of a ZMM register, load two consecutive centroids at once, and reduce across 128-bit lanes followed by within-lane shuffles. Produces 2 partial sums per loop iteration instead of 1. - size == 16: query and the centroid fit into a ZMM register, load the query into zmm and then loop over the centroids. Produces one partial sum per loop iteration, but prevents having to load the query multiple times. Both paths fall back to the default way of computing dot_product_f32 / euclidean_f32 in a loop for any tail elements or unsupported sizes. --- jvector-native/src/main/c/jvector_simd.c | 151 ++++++++++++++++++++++- 1 file changed, 147 insertions(+), 4 deletions(-) diff --git a/jvector-native/src/main/c/jvector_simd.c b/jvector-native/src/main/c/jvector_simd.c index faaba3393..05aa36a53 100644 --- a/jvector-native/src/main/c/jvector_simd.c +++ b/jvector-native/src/main/c/jvector_simd.c @@ -261,15 +261,158 @@ JV_INLINE float euclidean_f32(const float* a, int aoffset, const float* b, int b JV_INLINE void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) { int codebookBase = codebookIndex * clusterCount; - for (int i = 0; i < clusterCount; i++) { - partialSums[codebookBase + i] = dot_product_f32(codebook, i * size, query, queryOffset, size); + float tempdat[16]; + if (size == 4) { + int i = 0; + // use a zmm register to calculate 4 partial sums in parallel: + __m128 q = _mm_loadu_ps(query + queryOffset); + __m512 qq = _mm512_broadcast_f32x4(q); // broadcast 128-bit query to all 4 lanes + for (; i + 4 <= clusterCount; i += 4) { + // load four consecutive centroids from the codebook into zmm + __m512 c = _mm512_loadu_ps(codebook + i * size); + __m512 sum = _mm512_fmadd_ps(c, qq, _mm512_setzero_ps()); + // horizontal reduce: within each 128-bit lane independently + // Step 1: swap neighboring elements within 128-bit lanes + __m512 temp = _mm512_shuffle_ps(sum, sum, _MM_SHUFFLE(2, 3, 0, 1)); + sum = _mm512_add_ps(sum, temp); + // Step 2: swap 32-bit pairs within 128-bit lanes + temp = _mm512_shuffle_ps(sum, sum, _MM_SHUFFLE(1, 0, 3, 2)); + sum = _mm512_add_ps(sum, temp); + // extract results from position 0 of each 128-bit lane + _mm512_storeu_ps(tempdat, sum); + partialSums[codebookBase + i] = tempdat[0]; + partialSums[codebookBase + i + 1] = tempdat[4]; + partialSums[codebookBase + i + 2] = tempdat[8]; + partialSums[codebookBase + i + 3] = tempdat[12]; + } + for (; i < clusterCount; i++) { + partialSums[codebookBase + i] = dot_product_f32(codebook, i * size, query, queryOffset, size); + } + } + else if (size == 8) { + int i = 0; + // use a zmm register to calculate 2 partial sums in parallel: + __m256 q = _mm256_loadu_ps(query + queryOffset); + __m512 qq = _mm512_broadcast_f32x8(q); // 8 cycles, but have to do it just once outside the loop + for (; i + 2 <= clusterCount; i += 2) { + // load two consecutive centroids from the codebook into zmm + __m512 c1 = _mm512_loadu_ps(codebook + i * size); + __m512 sum = _mm512_fmadd_ps(c1, qq, _mm512_setzero_ps()); + // horizontal reduce: per 256 bit lanes + // Step 1: swap neighbouring 128 bits and add to sum across lanes + __m512 temp = _mm512_shuffle_f32x4(sum, sum, _MM_SHUFFLE(2, 3, 0, 1)); // swap 128-bit lanes + sum = _mm512_add_ps(sum, temp); + // Step 2: Shuffle and add to sum within lanes + temp = _mm512_shuffle_ps(sum, sum, _MM_SHUFFLE(1, 0, 3, 2)); + sum = _mm512_add_ps(sum, temp); + // step 3: shuffle neighboring lanes: + temp = _mm512_shuffle_ps(sum, sum, _MM_SHUFFLE(2, 3, 0, 1)); + sum = _mm512_add_ps(sum, temp); + // extract results: may be there is a better way? + // Store is cheap and loading them should happen from the store buffers, so this may be faster than shuffling and extracting: + // Although its tempting, avoid using vcompress (a high latency instruction) + //_mm512_mask_compressstoreu_ps(ans, 0x8080, sum); + _mm512_storeu_ps(tempdat, sum); + partialSums[codebookBase + i] = tempdat[0]; + partialSums[codebookBase + i + 1] = tempdat[8]; + } + for (; i < clusterCount; i++) { + partialSums[codebookBase + i] = dot_product_f32(codebook, i * size, query, queryOffset, size); + } + } + else if (size == 16) { + int i = 0; + __m512 qq = _mm512_loadu_ps(query + queryOffset); + for (; i < clusterCount; i += 1) { + __m512 c1 = _mm512_loadu_ps(codebook + i * size); + __m512 sum = _mm512_fmadd_ps(qq, c1, _mm512_setzero_ps()); + partialSums[codebookBase + i] = _mm512_reduce_add_ps(sum); + } + } + else { + for (int i = 0; i < clusterCount; i++) { + partialSums[codebookBase + i] = dot_product_f32(codebook, i * size, query, queryOffset, size); + } } } JV_INLINE void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) { int codebookBase = codebookIndex * clusterCount; - for (int i = 0; i < clusterCount; i++) { - partialSums[codebookBase + i] = euclidean_f32(codebook, i * size, query, queryOffset, size); + float tempdat[16]; + if (size == 4) { + int i = 0; + // use a zmm register to calculate 4 partial sums in parallel: + __m128 q = _mm_loadu_ps(query + queryOffset); + __m512 qq = _mm512_broadcast_f32x4(q); // broadcast 128-bit query to all 4 lanes + for (; i + 4 <= clusterCount; i += 4) { + // load four consecutive centroids from the codebook into zmm + __m512 c = _mm512_loadu_ps(codebook + i * size); + __m512 diff = _mm512_sub_ps(c, qq); + __m512 sum = _mm512_fmadd_ps(diff, diff, _mm512_setzero_ps()); + // horizontal reduce: within each 128-bit lane independently + // Step 1: swap neighboring elements within 128-bit lanes + __m512 temp = _mm512_shuffle_ps(sum, sum, _MM_SHUFFLE(2, 3, 0, 1)); + sum = _mm512_add_ps(sum, temp); + // Step 2: swap 32-bit pairs within 128-bit lanes + temp = _mm512_shuffle_ps(sum, sum, _MM_SHUFFLE(1, 0, 3, 2)); + sum = _mm512_add_ps(sum, temp); + // extract results from position 0 of each 128-bit lane + _mm512_storeu_ps(tempdat, sum); + partialSums[codebookBase + i] = tempdat[0]; + partialSums[codebookBase + i + 1] = tempdat[4]; + partialSums[codebookBase + i + 2] = tempdat[8]; + partialSums[codebookBase + i + 3] = tempdat[12]; + } + for (; i < clusterCount; i++) { + partialSums[codebookBase + i] = euclidean_f32(codebook, i * size, query, queryOffset, size); + } + } + else if (size == 8) { + int i = 0; + // use a zmm register to calculate 2 partial sums in parallel: + __m256 q = _mm256_loadu_ps(query + queryOffset); + __m512 qq = _mm512_broadcast_f32x8(q); // 8 cycles, but have to do it just once outside the loop + for (; i + 2 <= clusterCount; i += 2) { + // load two consecutive centroids from the codebook into zmm + __m512 c1 = _mm512_loadu_ps(codebook + i * size); + __m512 diff = _mm512_sub_ps(c1, qq); + __m512 sum = _mm512_fmadd_ps(diff, diff, _mm512_setzero_ps()); + // horizontal reduce: per 256 bit lanes + // Step 1: swap neighbouring 128 bits and add to sum across lanes + __m512 temp = _mm512_shuffle_f32x4(sum, sum, _MM_SHUFFLE(2, 3, 0, 1)); // swap 128-bit lanes + sum = _mm512_add_ps(sum, temp); + // Step 2: Shuffle and add to sum within lanes + temp = _mm512_shuffle_ps(sum, sum, _MM_SHUFFLE(1, 0, 3, 2)); + sum = _mm512_add_ps(sum, temp); + // step 3: shuffle neighboring lanes: + temp = _mm512_shuffle_ps(sum, sum, _MM_SHUFFLE(2, 3, 0, 1)); + sum = _mm512_add_ps(sum, temp); + // extract results: may be there is a better way? + // Store is cheap and loading them should happen from the store buffers, so this may be faster than shuffling and extracting: + // Although its tempting, avoid using vcompress (a high latency instruction) + //_mm512_mask_compressstoreu_ps(ans, 0x8080, sum); + _mm512_storeu_ps(tempdat, sum); + partialSums[codebookBase + i] = tempdat[0]; + partialSums[codebookBase + i + 1] = tempdat[8]; + } + for (; i < clusterCount; i++) { + partialSums[codebookBase + i] = euclidean_f32(codebook, i * size, query, queryOffset, size); + } + } + else if (size == 16) { + int i = 0; + __m512 qq = _mm512_loadu_ps(query + queryOffset); + for (; i < clusterCount; i += 1) { + __m512 c1 = _mm512_loadu_ps(codebook + i * size); + __m512 diff = _mm512_sub_ps(c1, qq); + __m512 sum = _mm512_fmadd_ps(diff, diff, _mm512_setzero_ps()); + partialSums[codebookBase + i] = _mm512_reduce_add_ps(sum); + } + } + else { + for (int i = 0; i < clusterCount; i++) { + partialSums[codebookBase + i] = euclidean_f32(codebook, i * size, query, queryOffset, size); + } } } From 159c21f50dc19d90bcc56b74a96fdfa8cbb6642d Mon Sep 17 00:00:00 2001 From: Raghuveer Devulapalli Date: Wed, 8 Apr 2026 05:54:27 +0000 Subject: [PATCH 4/4] Optimize size == 2 --- jvector-native/src/main/c/jvector_simd.c | 62 +++++++++++++++++++++++- 1 file changed, 60 insertions(+), 2 deletions(-) diff --git a/jvector-native/src/main/c/jvector_simd.c b/jvector-native/src/main/c/jvector_simd.c index 05aa36a53..a46fab62e 100644 --- a/jvector-native/src/main/c/jvector_simd.c +++ b/jvector-native/src/main/c/jvector_simd.c @@ -262,7 +262,36 @@ JV_INLINE float euclidean_f32(const float* a, int aoffset, const float* b, int b JV_INLINE void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) { int codebookBase = codebookIndex * clusterCount; float tempdat[16]; - if (size == 4) { + if (size == 2) { + int i = 0; + // use a zmm register to calculate 8 partial sums in parallel: + __m128 q_lo = _mm_castsi128_ps(_mm_loadl_epi64((__m128i *)(query + queryOffset))); + __m512 qq = _mm512_broadcast_f32x2(q_lo); // broadcast 2 query floats to all 8 x 64-bit positions + for (; i + 8 <= clusterCount; i += 8) { + // load eight consecutive centroids (16 floats) from the codebook into zmm + __m512 c = _mm512_loadu_ps(codebook + i * size); + __m512 prod = _mm512_mul_ps(c, qq); + // horizontal reduce: sum the two products within each 64-bit centroid slot + // shuffle swaps pairs within each 128-bit lane: [a,b,c,d] -> [b,a,d,c] + __m512 temp = _mm512_shuffle_ps(prod, prod, _MM_SHUFFLE(2, 3, 0, 1)); + __m512 sum = _mm512_add_ps(prod, temp); + // results sit at even positions (0,2,4,6,8,10,12,14) + // resgular store and load seem to be better tha vcompress or vpermutex2var for extracting the results + _mm512_storeu_ps(tempdat, sum); + partialSums[codebookBase + i] = tempdat[0]; + partialSums[codebookBase + i + 1] = tempdat[2]; + partialSums[codebookBase + i + 2] = tempdat[4]; + partialSums[codebookBase + i + 3] = tempdat[6]; + partialSums[codebookBase + i + 4] = tempdat[8]; + partialSums[codebookBase + i + 5] = tempdat[10]; + partialSums[codebookBase + i + 6] = tempdat[12]; + partialSums[codebookBase + i + 7] = tempdat[14]; + } + for (; i < clusterCount; i++) { + partialSums[codebookBase + i] = dot_product_f32(codebook, i * size, query, queryOffset, size); + } + } + else if (size == 4) { int i = 0; // use a zmm register to calculate 4 partial sums in parallel: __m128 q = _mm_loadu_ps(query + queryOffset); @@ -339,7 +368,36 @@ JV_INLINE void calculate_partial_sums_dot_f32_512(const float* codebook, int cod JV_INLINE void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) { int codebookBase = codebookIndex * clusterCount; float tempdat[16]; - if (size == 4) { + if (size == 2) { + int i = 0; + // use a zmm register to calculate 8 partial sums in parallel: + __m128 q_lo = _mm_castsi128_ps(_mm_loadl_epi64((__m128i *)(query + queryOffset))); + __m512 qq = _mm512_broadcast_f32x2(q_lo); // broadcast 2 query floats to all 8 x 64-bit positions + for (; i + 8 <= clusterCount; i += 8) { + // load eight consecutive centroids (16 floats) from the codebook into zmm + __m512 c = _mm512_loadu_ps(codebook + i * size); + __m512 diff = _mm512_sub_ps(c, qq); + __m512 sq = _mm512_mul_ps(diff, diff); + // horizontal reduce: sum the two squared diffs within each 64-bit centroid slot + // shuffle swaps pairs within each 128-bit lane: [a,b,c,d] -> [b,a,d,c] + __m512 temp = _mm512_shuffle_ps(sq, sq, _MM_SHUFFLE(2, 3, 0, 1)); + __m512 sum = _mm512_add_ps(sq, temp); + // results sit at even positions (0,2,4,6,8,10,12,14) + _mm512_storeu_ps(tempdat, sum); + partialSums[codebookBase + i] = tempdat[0]; + partialSums[codebookBase + i + 1] = tempdat[2]; + partialSums[codebookBase + i + 2] = tempdat[4]; + partialSums[codebookBase + i + 3] = tempdat[6]; + partialSums[codebookBase + i + 4] = tempdat[8]; + partialSums[codebookBase + i + 5] = tempdat[10]; + partialSums[codebookBase + i + 6] = tempdat[12]; + partialSums[codebookBase + i + 7] = tempdat[14]; + } + for (; i < clusterCount; i++) { + partialSums[codebookBase + i] = euclidean_f32(codebook, i * size, query, queryOffset, size); + } + } + else if (size == 4) { int i = 0; // use a zmm register to calculate 4 partial sums in parallel: __m128 q = _mm_loadu_ps(query + queryOffset);