Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -437,14 +437,15 @@ public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvi
}

public ImmutableGraphIndex build(RandomAccessVectorValues ravv) {
var vv = ravv.threadLocalSupplier();
int size = ravv.size();

simdExecutor.submit(() -> {
IntStream.range(0, size).parallel().forEach(node -> {
addGraphNode(node, vv.get().getVector(node));
});
}).join();
try (var vv = ravv.closeableThreadLocalSupplier()) {
simdExecutor.submit(() -> {
IntStream.range(0, size).parallel().forEach(node -> {
addGraphNode(node, vv.get().getVector(node));
});
}).join();
}

cleanup();
return graph;
Expand Down Expand Up @@ -841,11 +842,7 @@ private NodeArray getConcurrentCandidates(int level,

@Override
public void close() throws IOException {
try {
searchers.close();
} catch (Exception e) {
ExceptionUtils.throwIoException(e);
}
ExceptionUtils.closeAll(searchers, scoreProvider, naturalScratch, concurrentScratch);
}

private static class ExcludingBits implements Bits {
Expand Down Expand Up @@ -1044,15 +1041,15 @@ public static ImmutableGraphIndex buildAndMergeNewNodes(RandomAccessReader in,
parallelExecutor
);

var vv = newVectors.threadLocalSupplier();

// parallel graph construction from the merge documents Ids
simdExecutor.submit(() -> IntStream.range(startingNodeOffset, newVectors.size()).parallel().forEach(ord -> {
builder.addGraphNode(ord, vv.get().getVector(ord));
})).join();
try (var vv = newVectors.closeableThreadLocalSupplier()) {
simdExecutor.submit(() -> IntStream.range(startingNodeOffset, newVectors.size()).parallel().forEach(ord -> {
builder.addGraphNode(ord, vv.get().getVector(ord));
})).join();
}

builder.cleanup();
return builder.getGraph();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,18 @@
package io.github.jbellis.jvector.graph;

import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.util.CloseableSupplier;
import io.github.jbellis.jvector.util.ExplicitThreadLocal;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.types.VectorFloat;

import java.util.function.Supplier;
import java.util.logging.Logger;

/**
* Provides random access to vectors by dense ordinal. This interface is used by graph-based
* implementations of KNN search.
*/
public interface RandomAccessVectorValues {
Logger LOG = Logger.getLogger(RandomAccessVectorValues.class.getName());

/**
* Return the number of vector values.
* <p>
Expand Down Expand Up @@ -94,18 +92,24 @@ default void getVectorInto(int node, VectorFloat<?> destinationVector, int offse
RandomAccessVectorValues copy();

/**
* Returns a supplier of thread-local copies of the RAVV.
* Returns a closeable supplier of thread-local copies of the RAVV.
*/
default Supplier<RandomAccessVectorValues> threadLocalSupplier() {
default CloseableSupplier<RandomAccessVectorValues> closeableThreadLocalSupplier() {
if (!isValueShared()) {
return () -> this;
return CloseableSupplier.noOp(() -> this);
}

if (this instanceof AutoCloseable) {
LOG.warning("RAVV is shared and implements AutoCloseable; threadLocalSupplier() may lead to leaks");
}
var tl = ExplicitThreadLocal.withInitial(this::copy);
return tl::get;
return ExplicitThreadLocal.withInitial(this::copy);
}

/**
* Returns a supplier of thread-local copies of the RAVV.
* <p>
* The returned supplier implements {@link AutoCloseable}; callers that own the supplier's lifetime
* should prefer {@link #closeableThreadLocalSupplier()} so thread-local copies can be cleaned up.
*/
default Supplier<RandomAccessVectorValues> threadLocalSupplier() {
return closeableThreadLocalSupplier();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@

import java.util.Arrays;

public class RemappedRandomAccessVectorValues implements RandomAccessVectorValues {
public class RemappedRandomAccessVectorValues implements RandomAccessVectorValues, AutoCloseable {
private final RandomAccessVectorValues ravv;
private final int[] graphToRavvOrdMap;
private final boolean ownsRavv;

/**
* Remaps a RAVV to a different set of ordinals. This is useful when the ordinals used by the graph
Expand All @@ -33,8 +34,13 @@ public class RemappedRandomAccessVectorValues implements RandomAccessVectorValue
* graphToRavvOrdMap[i] is the RAVV ordinal corresponding to graph ordinal i.
*/
public RemappedRandomAccessVectorValues(RandomAccessVectorValues ravv, int[] graphToRavvOrdMap) {
this(ravv, graphToRavvOrdMap, false);
}

private RemappedRandomAccessVectorValues(RandomAccessVectorValues ravv, int[] graphToRavvOrdMap, boolean ownsRavv) {
this.ravv = ravv;
this.graphToRavvOrdMap = graphToRavvOrdMap;
this.ownsRavv = ownsRavv;
}

@Override
Expand All @@ -59,11 +65,19 @@ public boolean isValueShared() {

@Override
public RandomAccessVectorValues copy() {
return new RemappedRandomAccessVectorValues(ravv.copy(), Arrays.copyOf(graphToRavvOrdMap, graphToRavvOrdMap.length));
var ravvCopy = ravv.copy();
return new RemappedRandomAccessVectorValues(ravvCopy, Arrays.copyOf(graphToRavvOrdMap, graphToRavvOrdMap.length), ravvCopy != ravv);
}

@Override
public void getVectorInto(int node, VectorFloat<?> result, int offset) {
ravv.getVectorInto(graphToRavvOrdMap[node], result, offset);
}

@Override
public void close() throws Exception {
if (ownsRavv && ravv instanceof AutoCloseable) {
((AutoCloseable) ravv).close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,13 @@
/**
* Encapsulates comparing node distances for GraphIndexBuilder.
*/
public interface BuildScoreProvider {
public interface BuildScoreProvider extends AutoCloseable {
VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport();

@Override
default void close() {
}

/**
* @return true if the primary score functions used for construction are exact. This
* is modestly redundant, but it saves having to allocate new Search/Diversity provider
Expand Down Expand Up @@ -106,8 +110,8 @@ static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues rav
static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, VectorSimilarityFunction similarityFunction) {
// We need two sources of vectors in order to perform diversity check comparisons without
// colliding. ThreadLocalSupplier makes this a no-op if the RAVV is actually un-shared.
var vectors = ravv.threadLocalSupplier();
var vectorsCopy = ravv.threadLocalSupplier();
var vectors = ravv.closeableThreadLocalSupplier();
var vectorsCopy = ravv.closeableThreadLocalSupplier();

return new BuildScoreProvider() {
@Override
Expand Down Expand Up @@ -157,6 +161,15 @@ public ScoreFunction diversityScoreFunctionFor(int node1) {
// don't use ESF.reranker, we need thread safety here
return (ScoreFunction.ExactScoreFunction) node2 -> similarityFunction.compare(v, vc.getVector(node2));
}

@Override
public void close() {
try {
vectors.close();
} finally {
vectorsCopy.close();
}
}
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,20 @@ public CompressedVectors createCompressedVectors(Object[] compressedVectors) {

@Override
public CompressedVectors encodeAll(RandomAccessVectorValues ravv, ForkJoinPool simdExecutor) {
var ravvCopy = ravv.threadLocalSupplier();
var cv = simdExecutor.submit(() -> IntStream.range(0, ravv.size())
.parallel()
.mapToObj(i -> {
var localRavv = ravvCopy.get();
VectorFloat<?> v = localRavv.getVector(i);
return v == null
? new long[compressedVectorSize() / Long.BYTES]
: encode(v);
})
.toArray(long[][]::new))
.join();
final long[][] cv;
try (var ravvCopy = ravv.closeableThreadLocalSupplier()) {
cv = simdExecutor.submit(() -> IntStream.range(0, ravv.size())
.parallel()
.mapToObj(i -> {
var localRavv = ravvCopy.get();
VectorFloat<?> v = localRavv.getVector(i);
return v == null
? new long[compressedVectorSize() / Long.BYTES]
: encode(v);
})
.toArray(long[][]::new))
.join();
}
return new ImmutableBQVectors(this, cv);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,16 @@ private NVQuantization(int[][] subvectorSizesAndOffsets, VectorFloat<?> globalMe
* @param nSubVectors number of subvectors
*/
public static NVQuantization compute(RandomAccessVectorValues ravv, int nSubVectors) {
var ravvCopy = ravv.threadLocalSupplier().get();
var dim = ravvCopy.getVector(0).length();
var globalMean = vectorTypeSupport.createFloatVector(dim);
for (int i = 0; i < ravvCopy.size(); i++) {
VectorUtil.addInPlace(globalMean, ravvCopy.getVector(i));
}
VectorUtil.scale(globalMean, 1.0f / ravvCopy.size());
return create(globalMean, nSubVectors);
try (var ravvSupplier = ravv.closeableThreadLocalSupplier()) {
var ravvCopy = ravvSupplier.get();
var dim = ravvCopy.getVector(0).length();
var globalMean = vectorTypeSupport.createFloatVector(dim);
for (int i = 0; i < ravvCopy.size(); i++) {
VectorUtil.addInPlace(globalMean, ravvCopy.getVector(i));
}
VectorUtil.scale(globalMean, 1.0f / ravvCopy.size());
return create(globalMean, nSubVectors);
}
}

/**
Expand All @@ -180,17 +182,19 @@ public CompressedVectors createCompressedVectors(Object[] compressedVectors) {
*/
@Override
public NVQVectors encodeAll(RandomAccessVectorValues ravv, ForkJoinPool parallelExecutor) {
var ravvCopy = ravv.threadLocalSupplier();
return new NVQVectors(this,
parallelExecutor.submit(() -> IntStream.range(0, ravv.size())
.parallel()
.mapToObj(i -> {
var localRavv = ravvCopy.get();
VectorFloat<?> v = localRavv.getVector(i);
return encode(v);
})
.toArray(QuantizedVector[]::new))
.join());
final QuantizedVector[] vectors;
try (var ravvCopy = ravv.closeableThreadLocalSupplier()) {
vectors = parallelExecutor.submit(() -> IntStream.range(0, ravv.size())
.parallel()
.mapToObj(i -> {
var localRavv = ravvCopy.get();
VectorFloat<?> v = localRavv.getVector(i);
return encode(v);
})
.toArray(QuantizedVector[]::new))
.join();
}
return new NVQVectors(this, vectors);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,20 +131,21 @@ public static ImmutablePQVectors encodeAndBuild(ProductQuantization pq, int vect
// Encode the vectors in parallel into the compressed data chunks
// The changes are concurrent, but because they are coordinated and do not overlap, we can use parallel streams
// and then we are guaranteed safe publication because we join the thread after completion.
var ravvCopy = ravv.threadLocalSupplier();
simdExecutor.submit(() -> IntStream.range(0, vectorCount)
.parallel()
.forEach(ordinal -> {
// Retrieve the slice and mutate it.
var localRavv = ravvCopy.get();
var slice = PQVectors.get(chunks, ordinal, layout.fullChunkVectors, pq.getSubspaceCount());
var vector = localRavv.getVector(ordinalsMapping.applyAsInt(ordinal));
if (vector != null)
pq.encodeTo(vector, slice);
else
slice.zero();
}))
.join();
try (var ravvCopy = ravv.closeableThreadLocalSupplier()) {
simdExecutor.submit(() -> IntStream.range(0, vectorCount)
.parallel()
.forEach(ordinal -> {
// Retrieve the slice and mutate it.
var localRavv = ravvCopy.get();
var slice = PQVectors.get(chunks, ordinal, layout.fullChunkVectors, pq.getSubspaceCount());
var vector = localRavv.getVector(ordinalsMapping.applyAsInt(ordinal));
if (vector != null)
pq.encodeTo(vector, slice);
else
slice.zero();
}))
.join();
}

return new ImmutablePQVectors(pq, chunks, vectorCount, layout.fullChunkVectors);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,16 @@ static List<VectorFloat<?>> extractTrainingVectors(RandomAccessVectorValues ravv
ordinalStream = IntStream.of(ordinalArray);
}

var ravvCopy = ravv.threadLocalSupplier();
return parallelExecutor.submit(() -> ordinalStream.parallel()
.mapToObj(targetOrd -> {
var localRavv = ravvCopy.get();
VectorFloat<?> v = localRavv.getVector(targetOrd);
return localRavv.isValueShared() ? v.copy() : v;
})
.collect(Collectors.toList()))
.join();
try (var ravvCopy = ravv.closeableThreadLocalSupplier()) {
return parallelExecutor.submit(() -> ordinalStream.parallel()
.mapToObj(targetOrd -> {
var localRavv = ravvCopy.get();
VectorFloat<?> v = localRavv.getVector(targetOrd);
return localRavv.isValueShared() ? v.copy() : v;
})
.collect(Collectors.toList()))
.join();
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.github.jbellis.jvector.util;

import java.util.function.Supplier;

/**
* A {@link Supplier} that owns resources associated with values it creates.
*/
public interface CloseableSupplier<T> extends Supplier<T>, AutoCloseable {
@Override
void close();

static <T> CloseableSupplier<T> noOp(Supplier<T> supplier) {
return new CloseableSupplier<>() {
@Override
public T get() {
return supplier.get();
}

@Override
public void close() {
}
};
}
}
Loading