diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java
index 9e366676c..ad824f769 100644
--- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java
+++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java
@@ -437,14 +437,15 @@ public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvi
}
public ImmutableGraphIndex build(RandomAccessVectorValues ravv) {
- var vv = ravv.threadLocalSupplier();
int size = ravv.size();
- simdExecutor.submit(() -> {
- IntStream.range(0, size).parallel().forEach(node -> {
- addGraphNode(node, vv.get().getVector(node));
- });
- }).join();
+ try (var vv = ravv.closeableThreadLocalSupplier()) {
+ simdExecutor.submit(() -> {
+ IntStream.range(0, size).parallel().forEach(node -> {
+ addGraphNode(node, vv.get().getVector(node));
+ });
+ }).join();
+ }
cleanup();
return graph;
@@ -841,11 +842,7 @@ private NodeArray getConcurrentCandidates(int level,
@Override
public void close() throws IOException {
- try {
- searchers.close();
- } catch (Exception e) {
- ExceptionUtils.throwIoException(e);
- }
+ ExceptionUtils.closeAll(searchers, scoreProvider, naturalScratch, concurrentScratch);
}
private static class ExcludingBits implements Bits {
@@ -1044,15 +1041,15 @@ public static ImmutableGraphIndex buildAndMergeNewNodes(RandomAccessReader in,
parallelExecutor
);
- var vv = newVectors.threadLocalSupplier();
-
// parallel graph construction from the merge documents Ids
- simdExecutor.submit(() -> IntStream.range(startingNodeOffset, newVectors.size()).parallel().forEach(ord -> {
- builder.addGraphNode(ord, vv.get().getVector(ord));
- })).join();
+ try (var vv = newVectors.closeableThreadLocalSupplier()) {
+ simdExecutor.submit(() -> IntStream.range(startingNodeOffset, newVectors.size()).parallel().forEach(ord -> {
+ builder.addGraphNode(ord, vv.get().getVector(ord));
+ })).join();
+ }
builder.cleanup();
return builder.getGraph();
}
}
-}
\ No newline at end of file
+}
diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/RandomAccessVectorValues.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/RandomAccessVectorValues.java
index eb8f6df24..4662713ba 100644
--- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/RandomAccessVectorValues.java
+++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/RandomAccessVectorValues.java
@@ -25,20 +25,18 @@
package io.github.jbellis.jvector.graph;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
+import io.github.jbellis.jvector.util.CloseableSupplier;
import io.github.jbellis.jvector.util.ExplicitThreadLocal;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import java.util.function.Supplier;
-import java.util.logging.Logger;
/**
* Provides random access to vectors by dense ordinal. This interface is used by graph-based
* implementations of KNN search.
*/
public interface RandomAccessVectorValues {
- Logger LOG = Logger.getLogger(RandomAccessVectorValues.class.getName());
-
/**
* Return the number of vector values.
*
@@ -94,18 +92,24 @@ default void getVectorInto(int node, VectorFloat> destinationVector, int offse
RandomAccessVectorValues copy();
/**
- * Returns a supplier of thread-local copies of the RAVV.
+ * Returns a closeable supplier of thread-local copies of the RAVV.
*/
- default Supplier threadLocalSupplier() {
+ default CloseableSupplier closeableThreadLocalSupplier() {
if (!isValueShared()) {
- return () -> this;
+ return CloseableSupplier.noOp(() -> this);
}
- if (this instanceof AutoCloseable) {
- LOG.warning("RAVV is shared and implements AutoCloseable; threadLocalSupplier() may lead to leaks");
- }
- var tl = ExplicitThreadLocal.withInitial(this::copy);
- return tl::get;
+ return ExplicitThreadLocal.withInitial(this::copy);
+ }
+
+ /**
+ * Returns a supplier of thread-local copies of the RAVV.
+ *
+ * The returned supplier implements {@link AutoCloseable}; callers that own the supplier's lifetime
+ * should prefer {@link #closeableThreadLocalSupplier()} so thread-local copies can be cleaned up.
+ */
+ default Supplier threadLocalSupplier() {
+ return closeableThreadLocalSupplier();
}
/**
diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/RemappedRandomAccessVectorValues.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/RemappedRandomAccessVectorValues.java
index a5ffcfa31..c67fdda03 100644
--- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/RemappedRandomAccessVectorValues.java
+++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/RemappedRandomAccessVectorValues.java
@@ -20,9 +20,10 @@
import java.util.Arrays;
-public class RemappedRandomAccessVectorValues implements RandomAccessVectorValues {
+public class RemappedRandomAccessVectorValues implements RandomAccessVectorValues, AutoCloseable {
private final RandomAccessVectorValues ravv;
private final int[] graphToRavvOrdMap;
+ private final boolean ownsRavv;
/**
* Remaps a RAVV to a different set of ordinals. This is useful when the ordinals used by the graph
@@ -33,8 +34,13 @@ public class RemappedRandomAccessVectorValues implements RandomAccessVectorValue
* graphToRavvOrdMap[i] is the RAVV ordinal corresponding to graph ordinal i.
*/
public RemappedRandomAccessVectorValues(RandomAccessVectorValues ravv, int[] graphToRavvOrdMap) {
+ this(ravv, graphToRavvOrdMap, false);
+ }
+
+ private RemappedRandomAccessVectorValues(RandomAccessVectorValues ravv, int[] graphToRavvOrdMap, boolean ownsRavv) {
this.ravv = ravv;
this.graphToRavvOrdMap = graphToRavvOrdMap;
+ this.ownsRavv = ownsRavv;
}
@Override
@@ -59,11 +65,19 @@ public boolean isValueShared() {
@Override
public RandomAccessVectorValues copy() {
- return new RemappedRandomAccessVectorValues(ravv.copy(), Arrays.copyOf(graphToRavvOrdMap, graphToRavvOrdMap.length));
+ var ravvCopy = ravv.copy();
+ return new RemappedRandomAccessVectorValues(ravvCopy, Arrays.copyOf(graphToRavvOrdMap, graphToRavvOrdMap.length), ravvCopy != ravv);
}
@Override
public void getVectorInto(int node, VectorFloat> result, int offset) {
ravv.getVectorInto(graphToRavvOrdMap[node], result, offset);
}
+
+ @Override
+ public void close() throws Exception {
+ if (ownsRavv && ravv instanceof AutoCloseable) {
+ ((AutoCloseable) ravv).close();
+ }
+ }
}
diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java
index 1049069de..ebb37c6f4 100644
--- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java
+++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java
@@ -29,9 +29,13 @@
/**
* Encapsulates comparing node distances for GraphIndexBuilder.
*/
-public interface BuildScoreProvider {
+public interface BuildScoreProvider extends AutoCloseable {
VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport();
+ @Override
+ default void close() {
+ }
+
/**
* @return true if the primary score functions used for construction are exact. This
* is modestly redundant, but it saves having to allocate new Search/Diversity provider
@@ -106,8 +110,8 @@ static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues rav
static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, VectorSimilarityFunction similarityFunction) {
// We need two sources of vectors in order to perform diversity check comparisons without
// colliding. ThreadLocalSupplier makes this a no-op if the RAVV is actually un-shared.
- var vectors = ravv.threadLocalSupplier();
- var vectorsCopy = ravv.threadLocalSupplier();
+ var vectors = ravv.closeableThreadLocalSupplier();
+ var vectorsCopy = ravv.closeableThreadLocalSupplier();
return new BuildScoreProvider() {
@Override
@@ -157,6 +161,15 @@ public ScoreFunction diversityScoreFunctionFor(int node1) {
// don't use ESF.reranker, we need thread safety here
return (ScoreFunction.ExactScoreFunction) node2 -> similarityFunction.compare(v, vc.getVector(node2));
}
+
+ @Override
+ public void close() {
+ try {
+ vectors.close();
+ } finally {
+ vectorsCopy.close();
+ }
+ }
};
}
diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BinaryQuantization.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BinaryQuantization.java
index f0d660301..5ddbaee67 100644
--- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BinaryQuantization.java
+++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BinaryQuantization.java
@@ -64,18 +64,20 @@ public CompressedVectors createCompressedVectors(Object[] compressedVectors) {
@Override
public CompressedVectors encodeAll(RandomAccessVectorValues ravv, ForkJoinPool simdExecutor) {
- var ravvCopy = ravv.threadLocalSupplier();
- var cv = simdExecutor.submit(() -> IntStream.range(0, ravv.size())
- .parallel()
- .mapToObj(i -> {
- var localRavv = ravvCopy.get();
- VectorFloat> v = localRavv.getVector(i);
- return v == null
- ? new long[compressedVectorSize() / Long.BYTES]
- : encode(v);
- })
- .toArray(long[][]::new))
- .join();
+ final long[][] cv;
+ try (var ravvCopy = ravv.closeableThreadLocalSupplier()) {
+ cv = simdExecutor.submit(() -> IntStream.range(0, ravv.size())
+ .parallel()
+ .mapToObj(i -> {
+ var localRavv = ravvCopy.get();
+ VectorFloat> v = localRavv.getVector(i);
+ return v == null
+ ? new long[compressedVectorSize() / Long.BYTES]
+ : encode(v);
+ })
+ .toArray(long[][]::new))
+ .join();
+ }
return new ImmutableBQVectors(this, cv);
}
diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQuantization.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQuantization.java
index aef0325b9..f71877f43 100644
--- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQuantization.java
+++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQuantization.java
@@ -148,14 +148,16 @@ private NVQuantization(int[][] subvectorSizesAndOffsets, VectorFloat> globalMe
* @param nSubVectors number of subvectors
*/
public static NVQuantization compute(RandomAccessVectorValues ravv, int nSubVectors) {
- var ravvCopy = ravv.threadLocalSupplier().get();
- var dim = ravvCopy.getVector(0).length();
- var globalMean = vectorTypeSupport.createFloatVector(dim);
- for (int i = 0; i < ravvCopy.size(); i++) {
- VectorUtil.addInPlace(globalMean, ravvCopy.getVector(i));
- }
- VectorUtil.scale(globalMean, 1.0f / ravvCopy.size());
- return create(globalMean, nSubVectors);
+ try (var ravvSupplier = ravv.closeableThreadLocalSupplier()) {
+ var ravvCopy = ravvSupplier.get();
+ var dim = ravvCopy.getVector(0).length();
+ var globalMean = vectorTypeSupport.createFloatVector(dim);
+ for (int i = 0; i < ravvCopy.size(); i++) {
+ VectorUtil.addInPlace(globalMean, ravvCopy.getVector(i));
+ }
+ VectorUtil.scale(globalMean, 1.0f / ravvCopy.size());
+ return create(globalMean, nSubVectors);
+ }
}
/**
@@ -180,17 +182,19 @@ public CompressedVectors createCompressedVectors(Object[] compressedVectors) {
*/
@Override
public NVQVectors encodeAll(RandomAccessVectorValues ravv, ForkJoinPool parallelExecutor) {
- var ravvCopy = ravv.threadLocalSupplier();
- return new NVQVectors(this,
- parallelExecutor.submit(() -> IntStream.range(0, ravv.size())
- .parallel()
- .mapToObj(i -> {
- var localRavv = ravvCopy.get();
- VectorFloat> v = localRavv.getVector(i);
- return encode(v);
- })
- .toArray(QuantizedVector[]::new))
- .join());
+ final QuantizedVector[] vectors;
+ try (var ravvCopy = ravv.closeableThreadLocalSupplier()) {
+ vectors = parallelExecutor.submit(() -> IntStream.range(0, ravv.size())
+ .parallel()
+ .mapToObj(i -> {
+ var localRavv = ravvCopy.get();
+ VectorFloat> v = localRavv.getVector(i);
+ return encode(v);
+ })
+ .toArray(QuantizedVector[]::new))
+ .join();
+ }
+ return new NVQVectors(this, vectors);
}
/**
diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java
index 538632da0..0c8fa4d88 100644
--- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java
+++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java
@@ -131,20 +131,21 @@ public static ImmutablePQVectors encodeAndBuild(ProductQuantization pq, int vect
// Encode the vectors in parallel into the compressed data chunks
// The changes are concurrent, but because they are coordinated and do not overlap, we can use parallel streams
// and then we are guaranteed safe publication because we join the thread after completion.
- var ravvCopy = ravv.threadLocalSupplier();
- simdExecutor.submit(() -> IntStream.range(0, vectorCount)
- .parallel()
- .forEach(ordinal -> {
- // Retrieve the slice and mutate it.
- var localRavv = ravvCopy.get();
- var slice = PQVectors.get(chunks, ordinal, layout.fullChunkVectors, pq.getSubspaceCount());
- var vector = localRavv.getVector(ordinalsMapping.applyAsInt(ordinal));
- if (vector != null)
- pq.encodeTo(vector, slice);
- else
- slice.zero();
- }))
- .join();
+ try (var ravvCopy = ravv.closeableThreadLocalSupplier()) {
+ simdExecutor.submit(() -> IntStream.range(0, vectorCount)
+ .parallel()
+ .forEach(ordinal -> {
+ // Retrieve the slice and mutate it.
+ var localRavv = ravvCopy.get();
+ var slice = PQVectors.get(chunks, ordinal, layout.fullChunkVectors, pq.getSubspaceCount());
+ var vector = localRavv.getVector(ordinalsMapping.applyAsInt(ordinal));
+ if (vector != null)
+ pq.encodeTo(vector, slice);
+ else
+ slice.zero();
+ }))
+ .join();
+ }
return new ImmutablePQVectors(pq, chunks, vectorCount, layout.fullChunkVectors);
}
diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java
index c84b7b955..5504f918c 100644
--- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java
+++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java
@@ -164,15 +164,16 @@ static List> extractTrainingVectors(RandomAccessVectorValues ravv
ordinalStream = IntStream.of(ordinalArray);
}
- var ravvCopy = ravv.threadLocalSupplier();
- return parallelExecutor.submit(() -> ordinalStream.parallel()
- .mapToObj(targetOrd -> {
- var localRavv = ravvCopy.get();
- VectorFloat> v = localRavv.getVector(targetOrd);
- return localRavv.isValueShared() ? v.copy() : v;
- })
- .collect(Collectors.toList()))
- .join();
+ try (var ravvCopy = ravv.closeableThreadLocalSupplier()) {
+ return parallelExecutor.submit(() -> ordinalStream.parallel()
+ .mapToObj(targetOrd -> {
+ var localRavv = ravvCopy.get();
+ VectorFloat> v = localRavv.getVector(targetOrd);
+ return localRavv.isValueShared() ? v.copy() : v;
+ })
+ .collect(Collectors.toList()))
+ .join();
+ }
}
/**
diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/CloseableSupplier.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/CloseableSupplier.java
new file mode 100644
index 000000000..7097082c9
--- /dev/null
+++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/CloseableSupplier.java
@@ -0,0 +1,39 @@
+/*
+ * Copyright DataStax, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.github.jbellis.jvector.util;
+
+import java.util.function.Supplier;
+
+/**
+ * A {@link Supplier} that owns resources associated with values it creates.
+ */
+public interface CloseableSupplier extends Supplier, AutoCloseable {
+ @Override
+ void close();
+
+ static CloseableSupplier noOp(Supplier supplier) {
+ return new CloseableSupplier<>() {
+ @Override
+ public T get() {
+ return supplier.get();
+ }
+
+ @Override
+ public void close() {
+ }
+ };
+ }
+}
diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/ExceptionUtils.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/ExceptionUtils.java
index d5dc9c350..13ead8bff 100644
--- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/ExceptionUtils.java
+++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/ExceptionUtils.java
@@ -30,4 +30,27 @@ public static void throwIoException(Throwable t) throws IOException {
throw new RuntimeException(t);
}
}
+
+ public static void closeAll(AutoCloseable... closeables) throws IOException {
+ Throwable thrown = null;
+ for (AutoCloseable closeable : closeables) {
+ if (closeable == null) {
+ continue;
+ }
+
+ try {
+ closeable.close();
+ } catch (Throwable t) {
+ if (thrown == null) {
+ thrown = t;
+ } else {
+ thrown.addSuppressed(t);
+ }
+ }
+ }
+
+ if (thrown != null) {
+ throwIoException(thrown);
+ }
+ }
}
diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/ExplicitThreadLocal.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/ExplicitThreadLocal.java
index 1833818bd..a50d5edd9 100644
--- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/ExplicitThreadLocal.java
+++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/ExplicitThreadLocal.java
@@ -37,7 +37,7 @@
*
* ExplicitThreadLocal is a drop-in replacement for ThreadLocal, and is used in the same way.
*/
-public abstract class ExplicitThreadLocal implements AutoCloseable {
+public abstract class ExplicitThreadLocal implements CloseableSupplier {
// thread id -> instance
private final ConcurrentHashMap map = new ConcurrentHashMap<>();
@@ -58,13 +58,33 @@ public U get() {
* Not threadsafe.
*/
@Override
- public void close() throws Exception {
- for (U value : map.values()) {
- if (value instanceof AutoCloseable) {
- ((AutoCloseable) value).close();
+ public void close() {
+ Throwable thrown = null;
+ try {
+ for (U value : map.values()) {
+ if (value instanceof AutoCloseable) {
+ try {
+ ((AutoCloseable) value).close();
+ } catch (Throwable t) {
+ if (thrown == null) {
+ thrown = t;
+ } else {
+ thrown.addSuppressed(t);
+ }
+ }
+ }
}
+ } finally {
+ map.clear();
+ }
+
+ if (thrown instanceof RuntimeException) {
+ throw (RuntimeException) thrown;
+ } else if (thrown instanceof Error) {
+ throw (Error) thrown;
+ } else if (thrown != null) {
+ throw new RuntimeException(thrown);
}
- map.clear();
}
public static ExplicitThreadLocal withInitial(Supplier initialValue) {
diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java
index ce9d62c1b..2976ed523 100644
--- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java
+++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java
@@ -414,23 +414,24 @@ private static Map, ImmutableGraphIndex> buildOnDisk(List exten
// build the graph incrementally
long startTime = System.nanoTime();
- var vv = floatVectors.threadLocalSupplier();
- PhysicalCoreExecutor.pool().submit(() -> {
- IntStream.range(0, floatVectors.size()).parallel().forEach(node -> {
- writers.forEach((features, writer) -> {
- try {
- var stateMap = new EnumMap(FeatureId.class);
- suppliers.get(features).forEach((featureId, supplier) -> {
- stateMap.put(featureId, supplier.apply(node));
- });
- writer.writeInline(node, stateMap);
- } catch (IOException e) {
- throw new UncheckedIOException(e);
- }
+ try (var vv = floatVectors.closeableThreadLocalSupplier()) {
+ PhysicalCoreExecutor.pool().submit(() -> {
+ IntStream.range(0, floatVectors.size()).parallel().forEach(node -> {
+ writers.forEach((features, writer) -> {
+ try {
+ var stateMap = new EnumMap(FeatureId.class);
+ suppliers.get(features).forEach((featureId, supplier) -> {
+ stateMap.put(featureId, supplier.apply(node));
+ });
+ writer.writeInline(node, stateMap);
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ });
+ builder.addGraphNode(node, vv.get().getVector(node));
});
- builder.addGraphNode(node, vv.get().getVector(node));
- });
- }).join();
+ }).join();
+ }
builder.cleanup();
// write the edge lists and close the writers
diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestConcurrentReadWriteDeletes.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestConcurrentReadWriteDeletes.java
index 12c263a6a..257949e80 100644
--- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestConcurrentReadWriteDeletes.java
+++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestConcurrentReadWriteDeletes.java
@@ -77,38 +77,38 @@ public class TestConcurrentReadWriteDeletes extends RandomizedTest {
@Test
public void testConcurrentReadsWritesDeletes() throws ExecutionException, InterruptedException {
- var vv = ravv.threadLocalSupplier();
-
- testConcurrentOps(i -> {
- var R = getRandom();
- if (R.nextDouble() < 0.2 || keysInserted.isEmpty())
- {
- // In the future, we could improve this test by acquiring the lock earlier and executing other
- writeLock.lock();
- try {
- builder.addGraphNode(i, vv.get().getVector(i));
- liveNodes.set(i);
- keysInserted.add(i);
- } finally {
- writeLock.unlock();
- }
- } else if (R.nextDouble() < 0.1) {
- var key = keysInserted.getRandom();
- if (!keysRemoved.contains(key)) {
- liveNodes.flip(key);
- keysRemoved.add(key);
- }
- } else {
- var queryVector = randomVector(getRandom(), dimension);
- SearchScoreProvider ssp = DefaultSearchScoreProvider.exact(queryVector, similarityFunction, ravv);
+ try (var vv = ravv.closeableThreadLocalSupplier()) {
+ testConcurrentOps(i -> {
+ var R = getRandom();
+ if (R.nextDouble() < 0.2 || keysInserted.isEmpty())
+ {
+ // In the future, we could improve this test by acquiring the lock earlier and executing other
+ writeLock.lock();
+ try {
+ builder.addGraphNode(i, vv.get().getVector(i));
+ liveNodes.set(i);
+ keysInserted.add(i);
+ } finally {
+ writeLock.unlock();
+ }
+ } else if (R.nextDouble() < 0.1) {
+ var key = keysInserted.getRandom();
+ if (!keysRemoved.contains(key)) {
+ liveNodes.flip(key);
+ keysRemoved.add(key);
+ }
+ } else {
+ var queryVector = randomVector(getRandom(), dimension);
+ SearchScoreProvider ssp = DefaultSearchScoreProvider.exact(queryVector, similarityFunction, ravv);
- int topK = Math.min(1, keysInserted.size());
- int rerankK = Math.min(50, keysInserted.size());
+ int topK = Math.min(1, keysInserted.size());
+ int rerankK = Math.min(50, keysInserted.size());
- GraphSearcher searcher = new GraphSearcher(builder.getGraph());
- searcher.search(ssp, topK, rerankK, 0.f, 0.f, liveNodes);
- }
- });
+ GraphSearcher searcher = new GraphSearcher(builder.getGraph());
+ searcher.search(ssp, topK, rerankK, 0.f, 0.f, liveNodes);
+ }
+ });
+ }
}
@FunctionalInterface
diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestRandomAccessVectorValues.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestRandomAccessVectorValues.java
new file mode 100644
index 000000000..b40a87111
--- /dev/null
+++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestRandomAccessVectorValues.java
@@ -0,0 +1,141 @@
+/*
+ * Copyright DataStax, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.github.jbellis.jvector.graph;
+
+import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
+import io.github.jbellis.jvector.vector.VectorizationProvider;
+import io.github.jbellis.jvector.vector.types.VectorFloat;
+import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotSame;
+import static org.junit.Assert.assertTrue;
+
+public class TestRandomAccessVectorValues {
+ private static final VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport();
+
+ @Test
+ public void threadLocalSupplierClosesSharedCopies() throws Exception {
+ var ravv = new CloseTrackingRandomAccessVectorValues();
+ var supplier = ravv.threadLocalSupplier();
+
+ assertTrue("thread-local suppliers should expose a close hook", supplier instanceof AutoCloseable);
+ assertNotSame(ravv, supplier.get());
+
+ ((AutoCloseable) supplier).close();
+
+ assertEquals(1, ravv.copyCount());
+ assertEquals(1, ravv.copyCloseCount());
+ assertEquals(0, ravv.originalCloseCount());
+ }
+
+ @Test
+ public void graphIndexBuilderClosesThreadLocalVectorCopies() throws IOException {
+ var ravv = new CloseTrackingRandomAccessVectorValues();
+
+ try (var builder = new GraphIndexBuilder(ravv, VectorSimilarityFunction.EUCLIDEAN, 2, 10, 1.0f, 1.0f, false)) {
+ builder.build(ravv);
+ }
+
+ assertTrue("graph build should create thread-local RAVV copies", ravv.copyCount() > 0);
+ assertEquals("all thread-local copies should be closed", ravv.copyCount(), ravv.copyCloseCount());
+ assertEquals(0, ravv.originalCloseCount());
+ }
+
+ private static class CloseTrackingRandomAccessVectorValues implements RandomAccessVectorValues, AutoCloseable {
+ private final List> vectors;
+ private final AtomicInteger copyCount;
+ private final AtomicInteger copyCloseCount;
+ private final AtomicInteger originalCloseCount;
+ private final boolean original;
+
+ CloseTrackingRandomAccessVectorValues() {
+ this(List.of(
+ vts.createFloatVector(new float[] {0, 0}),
+ vts.createFloatVector(new float[] {1, 0}),
+ vts.createFloatVector(new float[] {2, 0}),
+ vts.createFloatVector(new float[] {3, 0})),
+ new AtomicInteger(),
+ new AtomicInteger(),
+ new AtomicInteger(),
+ true);
+ }
+
+ private CloseTrackingRandomAccessVectorValues(List> vectors,
+ AtomicInteger copyCount,
+ AtomicInteger copyCloseCount,
+ AtomicInteger originalCloseCount,
+ boolean original) {
+ this.vectors = vectors;
+ this.copyCount = copyCount;
+ this.copyCloseCount = copyCloseCount;
+ this.originalCloseCount = originalCloseCount;
+ this.original = original;
+ }
+
+ @Override
+ public int size() {
+ return vectors.size();
+ }
+
+ @Override
+ public int dimension() {
+ return 2;
+ }
+
+ @Override
+ public VectorFloat> getVector(int nodeId) {
+ return vectors.get(nodeId);
+ }
+
+ @Override
+ public boolean isValueShared() {
+ return true;
+ }
+
+ @Override
+ public RandomAccessVectorValues copy() {
+ copyCount.incrementAndGet();
+ return new CloseTrackingRandomAccessVectorValues(vectors, copyCount, copyCloseCount, originalCloseCount, false);
+ }
+
+ @Override
+ public void close() {
+ if (original) {
+ originalCloseCount.incrementAndGet();
+ } else {
+ copyCloseCount.incrementAndGet();
+ }
+ }
+
+ int copyCount() {
+ return copyCount.get();
+ }
+
+ int copyCloseCount() {
+ return copyCloseCount.get();
+ }
+
+ int originalCloseCount() {
+ return originalCloseCount.get();
+ }
+ }
+}