diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index 9e366676c..ad824f769 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -437,14 +437,15 @@ public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvi } public ImmutableGraphIndex build(RandomAccessVectorValues ravv) { - var vv = ravv.threadLocalSupplier(); int size = ravv.size(); - simdExecutor.submit(() -> { - IntStream.range(0, size).parallel().forEach(node -> { - addGraphNode(node, vv.get().getVector(node)); - }); - }).join(); + try (var vv = ravv.closeableThreadLocalSupplier()) { + simdExecutor.submit(() -> { + IntStream.range(0, size).parallel().forEach(node -> { + addGraphNode(node, vv.get().getVector(node)); + }); + }).join(); + } cleanup(); return graph; @@ -841,11 +842,7 @@ private NodeArray getConcurrentCandidates(int level, @Override public void close() throws IOException { - try { - searchers.close(); - } catch (Exception e) { - ExceptionUtils.throwIoException(e); - } + ExceptionUtils.closeAll(searchers, scoreProvider, naturalScratch, concurrentScratch); } private static class ExcludingBits implements Bits { @@ -1044,15 +1041,15 @@ public static ImmutableGraphIndex buildAndMergeNewNodes(RandomAccessReader in, parallelExecutor ); - var vv = newVectors.threadLocalSupplier(); - // parallel graph construction from the merge documents Ids - simdExecutor.submit(() -> IntStream.range(startingNodeOffset, newVectors.size()).parallel().forEach(ord -> { - builder.addGraphNode(ord, vv.get().getVector(ord)); - })).join(); + try (var vv = newVectors.closeableThreadLocalSupplier()) { + simdExecutor.submit(() -> IntStream.range(startingNodeOffset, newVectors.size()).parallel().forEach(ord -> { + builder.addGraphNode(ord, vv.get().getVector(ord)); + })).join(); + } builder.cleanup(); return builder.getGraph(); } } -} \ No newline at end of file +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/RandomAccessVectorValues.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/RandomAccessVectorValues.java index eb8f6df24..4662713ba 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/RandomAccessVectorValues.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/RandomAccessVectorValues.java @@ -25,20 +25,18 @@ package io.github.jbellis.jvector.graph; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; +import io.github.jbellis.jvector.util.CloseableSupplier; import io.github.jbellis.jvector.util.ExplicitThreadLocal; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.types.VectorFloat; import java.util.function.Supplier; -import java.util.logging.Logger; /** * Provides random access to vectors by dense ordinal. This interface is used by graph-based * implementations of KNN search. */ public interface RandomAccessVectorValues { - Logger LOG = Logger.getLogger(RandomAccessVectorValues.class.getName()); - /** * Return the number of vector values. *

@@ -94,18 +92,24 @@ default void getVectorInto(int node, VectorFloat destinationVector, int offse RandomAccessVectorValues copy(); /** - * Returns a supplier of thread-local copies of the RAVV. + * Returns a closeable supplier of thread-local copies of the RAVV. */ - default Supplier threadLocalSupplier() { + default CloseableSupplier closeableThreadLocalSupplier() { if (!isValueShared()) { - return () -> this; + return CloseableSupplier.noOp(() -> this); } - if (this instanceof AutoCloseable) { - LOG.warning("RAVV is shared and implements AutoCloseable; threadLocalSupplier() may lead to leaks"); - } - var tl = ExplicitThreadLocal.withInitial(this::copy); - return tl::get; + return ExplicitThreadLocal.withInitial(this::copy); + } + + /** + * Returns a supplier of thread-local copies of the RAVV. + *

+ * The returned supplier implements {@link AutoCloseable}; callers that own the supplier's lifetime + * should prefer {@link #closeableThreadLocalSupplier()} so thread-local copies can be cleaned up. + */ + default Supplier threadLocalSupplier() { + return closeableThreadLocalSupplier(); } /** diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/RemappedRandomAccessVectorValues.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/RemappedRandomAccessVectorValues.java index a5ffcfa31..c67fdda03 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/RemappedRandomAccessVectorValues.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/RemappedRandomAccessVectorValues.java @@ -20,9 +20,10 @@ import java.util.Arrays; -public class RemappedRandomAccessVectorValues implements RandomAccessVectorValues { +public class RemappedRandomAccessVectorValues implements RandomAccessVectorValues, AutoCloseable { private final RandomAccessVectorValues ravv; private final int[] graphToRavvOrdMap; + private final boolean ownsRavv; /** * Remaps a RAVV to a different set of ordinals. This is useful when the ordinals used by the graph @@ -33,8 +34,13 @@ public class RemappedRandomAccessVectorValues implements RandomAccessVectorValue * graphToRavvOrdMap[i] is the RAVV ordinal corresponding to graph ordinal i. */ public RemappedRandomAccessVectorValues(RandomAccessVectorValues ravv, int[] graphToRavvOrdMap) { + this(ravv, graphToRavvOrdMap, false); + } + + private RemappedRandomAccessVectorValues(RandomAccessVectorValues ravv, int[] graphToRavvOrdMap, boolean ownsRavv) { this.ravv = ravv; this.graphToRavvOrdMap = graphToRavvOrdMap; + this.ownsRavv = ownsRavv; } @Override @@ -59,11 +65,19 @@ public boolean isValueShared() { @Override public RandomAccessVectorValues copy() { - return new RemappedRandomAccessVectorValues(ravv.copy(), Arrays.copyOf(graphToRavvOrdMap, graphToRavvOrdMap.length)); + var ravvCopy = ravv.copy(); + return new RemappedRandomAccessVectorValues(ravvCopy, Arrays.copyOf(graphToRavvOrdMap, graphToRavvOrdMap.length), ravvCopy != ravv); } @Override public void getVectorInto(int node, VectorFloat result, int offset) { ravv.getVectorInto(graphToRavvOrdMap[node], result, offset); } + + @Override + public void close() throws Exception { + if (ownsRavv && ravv instanceof AutoCloseable) { + ((AutoCloseable) ravv).close(); + } + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java index 1049069de..ebb37c6f4 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java @@ -29,9 +29,13 @@ /** * Encapsulates comparing node distances for GraphIndexBuilder. */ -public interface BuildScoreProvider { +public interface BuildScoreProvider extends AutoCloseable { VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport(); + @Override + default void close() { + } + /** * @return true if the primary score functions used for construction are exact. This * is modestly redundant, but it saves having to allocate new Search/Diversity provider @@ -106,8 +110,8 @@ static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues rav static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, VectorSimilarityFunction similarityFunction) { // We need two sources of vectors in order to perform diversity check comparisons without // colliding. ThreadLocalSupplier makes this a no-op if the RAVV is actually un-shared. - var vectors = ravv.threadLocalSupplier(); - var vectorsCopy = ravv.threadLocalSupplier(); + var vectors = ravv.closeableThreadLocalSupplier(); + var vectorsCopy = ravv.closeableThreadLocalSupplier(); return new BuildScoreProvider() { @Override @@ -157,6 +161,15 @@ public ScoreFunction diversityScoreFunctionFor(int node1) { // don't use ESF.reranker, we need thread safety here return (ScoreFunction.ExactScoreFunction) node2 -> similarityFunction.compare(v, vc.getVector(node2)); } + + @Override + public void close() { + try { + vectors.close(); + } finally { + vectorsCopy.close(); + } + } }; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BinaryQuantization.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BinaryQuantization.java index f0d660301..5ddbaee67 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BinaryQuantization.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BinaryQuantization.java @@ -64,18 +64,20 @@ public CompressedVectors createCompressedVectors(Object[] compressedVectors) { @Override public CompressedVectors encodeAll(RandomAccessVectorValues ravv, ForkJoinPool simdExecutor) { - var ravvCopy = ravv.threadLocalSupplier(); - var cv = simdExecutor.submit(() -> IntStream.range(0, ravv.size()) - .parallel() - .mapToObj(i -> { - var localRavv = ravvCopy.get(); - VectorFloat v = localRavv.getVector(i); - return v == null - ? new long[compressedVectorSize() / Long.BYTES] - : encode(v); - }) - .toArray(long[][]::new)) - .join(); + final long[][] cv; + try (var ravvCopy = ravv.closeableThreadLocalSupplier()) { + cv = simdExecutor.submit(() -> IntStream.range(0, ravv.size()) + .parallel() + .mapToObj(i -> { + var localRavv = ravvCopy.get(); + VectorFloat v = localRavv.getVector(i); + return v == null + ? new long[compressedVectorSize() / Long.BYTES] + : encode(v); + }) + .toArray(long[][]::new)) + .join(); + } return new ImmutableBQVectors(this, cv); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQuantization.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQuantization.java index aef0325b9..f71877f43 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQuantization.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQuantization.java @@ -148,14 +148,16 @@ private NVQuantization(int[][] subvectorSizesAndOffsets, VectorFloat globalMe * @param nSubVectors number of subvectors */ public static NVQuantization compute(RandomAccessVectorValues ravv, int nSubVectors) { - var ravvCopy = ravv.threadLocalSupplier().get(); - var dim = ravvCopy.getVector(0).length(); - var globalMean = vectorTypeSupport.createFloatVector(dim); - for (int i = 0; i < ravvCopy.size(); i++) { - VectorUtil.addInPlace(globalMean, ravvCopy.getVector(i)); - } - VectorUtil.scale(globalMean, 1.0f / ravvCopy.size()); - return create(globalMean, nSubVectors); + try (var ravvSupplier = ravv.closeableThreadLocalSupplier()) { + var ravvCopy = ravvSupplier.get(); + var dim = ravvCopy.getVector(0).length(); + var globalMean = vectorTypeSupport.createFloatVector(dim); + for (int i = 0; i < ravvCopy.size(); i++) { + VectorUtil.addInPlace(globalMean, ravvCopy.getVector(i)); + } + VectorUtil.scale(globalMean, 1.0f / ravvCopy.size()); + return create(globalMean, nSubVectors); + } } /** @@ -180,17 +182,19 @@ public CompressedVectors createCompressedVectors(Object[] compressedVectors) { */ @Override public NVQVectors encodeAll(RandomAccessVectorValues ravv, ForkJoinPool parallelExecutor) { - var ravvCopy = ravv.threadLocalSupplier(); - return new NVQVectors(this, - parallelExecutor.submit(() -> IntStream.range(0, ravv.size()) - .parallel() - .mapToObj(i -> { - var localRavv = ravvCopy.get(); - VectorFloat v = localRavv.getVector(i); - return encode(v); - }) - .toArray(QuantizedVector[]::new)) - .join()); + final QuantizedVector[] vectors; + try (var ravvCopy = ravv.closeableThreadLocalSupplier()) { + vectors = parallelExecutor.submit(() -> IntStream.range(0, ravv.size()) + .parallel() + .mapToObj(i -> { + var localRavv = ravvCopy.get(); + VectorFloat v = localRavv.getVector(i); + return encode(v); + }) + .toArray(QuantizedVector[]::new)) + .join(); + } + return new NVQVectors(this, vectors); } /** diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java index 538632da0..0c8fa4d88 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java @@ -131,20 +131,21 @@ public static ImmutablePQVectors encodeAndBuild(ProductQuantization pq, int vect // Encode the vectors in parallel into the compressed data chunks // The changes are concurrent, but because they are coordinated and do not overlap, we can use parallel streams // and then we are guaranteed safe publication because we join the thread after completion. - var ravvCopy = ravv.threadLocalSupplier(); - simdExecutor.submit(() -> IntStream.range(0, vectorCount) - .parallel() - .forEach(ordinal -> { - // Retrieve the slice and mutate it. - var localRavv = ravvCopy.get(); - var slice = PQVectors.get(chunks, ordinal, layout.fullChunkVectors, pq.getSubspaceCount()); - var vector = localRavv.getVector(ordinalsMapping.applyAsInt(ordinal)); - if (vector != null) - pq.encodeTo(vector, slice); - else - slice.zero(); - })) - .join(); + try (var ravvCopy = ravv.closeableThreadLocalSupplier()) { + simdExecutor.submit(() -> IntStream.range(0, vectorCount) + .parallel() + .forEach(ordinal -> { + // Retrieve the slice and mutate it. + var localRavv = ravvCopy.get(); + var slice = PQVectors.get(chunks, ordinal, layout.fullChunkVectors, pq.getSubspaceCount()); + var vector = localRavv.getVector(ordinalsMapping.applyAsInt(ordinal)); + if (vector != null) + pq.encodeTo(vector, slice); + else + slice.zero(); + })) + .join(); + } return new ImmutablePQVectors(pq, chunks, vectorCount, layout.fullChunkVectors); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java index c84b7b955..5504f918c 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java @@ -164,15 +164,16 @@ static List> extractTrainingVectors(RandomAccessVectorValues ravv ordinalStream = IntStream.of(ordinalArray); } - var ravvCopy = ravv.threadLocalSupplier(); - return parallelExecutor.submit(() -> ordinalStream.parallel() - .mapToObj(targetOrd -> { - var localRavv = ravvCopy.get(); - VectorFloat v = localRavv.getVector(targetOrd); - return localRavv.isValueShared() ? v.copy() : v; - }) - .collect(Collectors.toList())) - .join(); + try (var ravvCopy = ravv.closeableThreadLocalSupplier()) { + return parallelExecutor.submit(() -> ordinalStream.parallel() + .mapToObj(targetOrd -> { + var localRavv = ravvCopy.get(); + VectorFloat v = localRavv.getVector(targetOrd); + return localRavv.isValueShared() ? v.copy() : v; + }) + .collect(Collectors.toList())) + .join(); + } } /** diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/CloseableSupplier.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/CloseableSupplier.java new file mode 100644 index 000000000..7097082c9 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/CloseableSupplier.java @@ -0,0 +1,39 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.github.jbellis.jvector.util; + +import java.util.function.Supplier; + +/** + * A {@link Supplier} that owns resources associated with values it creates. + */ +public interface CloseableSupplier extends Supplier, AutoCloseable { + @Override + void close(); + + static CloseableSupplier noOp(Supplier supplier) { + return new CloseableSupplier<>() { + @Override + public T get() { + return supplier.get(); + } + + @Override + public void close() { + } + }; + } +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/ExceptionUtils.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/ExceptionUtils.java index d5dc9c350..13ead8bff 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/ExceptionUtils.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/ExceptionUtils.java @@ -30,4 +30,27 @@ public static void throwIoException(Throwable t) throws IOException { throw new RuntimeException(t); } } + + public static void closeAll(AutoCloseable... closeables) throws IOException { + Throwable thrown = null; + for (AutoCloseable closeable : closeables) { + if (closeable == null) { + continue; + } + + try { + closeable.close(); + } catch (Throwable t) { + if (thrown == null) { + thrown = t; + } else { + thrown.addSuppressed(t); + } + } + } + + if (thrown != null) { + throwIoException(thrown); + } + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/ExplicitThreadLocal.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/ExplicitThreadLocal.java index 1833818bd..a50d5edd9 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/ExplicitThreadLocal.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/ExplicitThreadLocal.java @@ -37,7 +37,7 @@ *

* ExplicitThreadLocal is a drop-in replacement for ThreadLocal, and is used in the same way. */ -public abstract class ExplicitThreadLocal implements AutoCloseable { +public abstract class ExplicitThreadLocal implements CloseableSupplier { // thread id -> instance private final ConcurrentHashMap map = new ConcurrentHashMap<>(); @@ -58,13 +58,33 @@ public U get() { * Not threadsafe. */ @Override - public void close() throws Exception { - for (U value : map.values()) { - if (value instanceof AutoCloseable) { - ((AutoCloseable) value).close(); + public void close() { + Throwable thrown = null; + try { + for (U value : map.values()) { + if (value instanceof AutoCloseable) { + try { + ((AutoCloseable) value).close(); + } catch (Throwable t) { + if (thrown == null) { + thrown = t; + } else { + thrown.addSuppressed(t); + } + } + } } + } finally { + map.clear(); + } + + if (thrown instanceof RuntimeException) { + throw (RuntimeException) thrown; + } else if (thrown instanceof Error) { + throw (Error) thrown; + } else if (thrown != null) { + throw new RuntimeException(thrown); } - map.clear(); } public static ExplicitThreadLocal withInitial(Supplier initialValue) { diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java index ce9d62c1b..2976ed523 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java @@ -414,23 +414,24 @@ private static Map, ImmutableGraphIndex> buildOnDisk(List { - IntStream.range(0, floatVectors.size()).parallel().forEach(node -> { - writers.forEach((features, writer) -> { - try { - var stateMap = new EnumMap(FeatureId.class); - suppliers.get(features).forEach((featureId, supplier) -> { - stateMap.put(featureId, supplier.apply(node)); - }); - writer.writeInline(node, stateMap); - } catch (IOException e) { - throw new UncheckedIOException(e); - } + try (var vv = floatVectors.closeableThreadLocalSupplier()) { + PhysicalCoreExecutor.pool().submit(() -> { + IntStream.range(0, floatVectors.size()).parallel().forEach(node -> { + writers.forEach((features, writer) -> { + try { + var stateMap = new EnumMap(FeatureId.class); + suppliers.get(features).forEach((featureId, supplier) -> { + stateMap.put(featureId, supplier.apply(node)); + }); + writer.writeInline(node, stateMap); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }); + builder.addGraphNode(node, vv.get().getVector(node)); }); - builder.addGraphNode(node, vv.get().getVector(node)); - }); - }).join(); + }).join(); + } builder.cleanup(); // write the edge lists and close the writers diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestConcurrentReadWriteDeletes.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestConcurrentReadWriteDeletes.java index 12c263a6a..257949e80 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestConcurrentReadWriteDeletes.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestConcurrentReadWriteDeletes.java @@ -77,38 +77,38 @@ public class TestConcurrentReadWriteDeletes extends RandomizedTest { @Test public void testConcurrentReadsWritesDeletes() throws ExecutionException, InterruptedException { - var vv = ravv.threadLocalSupplier(); - - testConcurrentOps(i -> { - var R = getRandom(); - if (R.nextDouble() < 0.2 || keysInserted.isEmpty()) - { - // In the future, we could improve this test by acquiring the lock earlier and executing other - writeLock.lock(); - try { - builder.addGraphNode(i, vv.get().getVector(i)); - liveNodes.set(i); - keysInserted.add(i); - } finally { - writeLock.unlock(); - } - } else if (R.nextDouble() < 0.1) { - var key = keysInserted.getRandom(); - if (!keysRemoved.contains(key)) { - liveNodes.flip(key); - keysRemoved.add(key); - } - } else { - var queryVector = randomVector(getRandom(), dimension); - SearchScoreProvider ssp = DefaultSearchScoreProvider.exact(queryVector, similarityFunction, ravv); + try (var vv = ravv.closeableThreadLocalSupplier()) { + testConcurrentOps(i -> { + var R = getRandom(); + if (R.nextDouble() < 0.2 || keysInserted.isEmpty()) + { + // In the future, we could improve this test by acquiring the lock earlier and executing other + writeLock.lock(); + try { + builder.addGraphNode(i, vv.get().getVector(i)); + liveNodes.set(i); + keysInserted.add(i); + } finally { + writeLock.unlock(); + } + } else if (R.nextDouble() < 0.1) { + var key = keysInserted.getRandom(); + if (!keysRemoved.contains(key)) { + liveNodes.flip(key); + keysRemoved.add(key); + } + } else { + var queryVector = randomVector(getRandom(), dimension); + SearchScoreProvider ssp = DefaultSearchScoreProvider.exact(queryVector, similarityFunction, ravv); - int topK = Math.min(1, keysInserted.size()); - int rerankK = Math.min(50, keysInserted.size()); + int topK = Math.min(1, keysInserted.size()); + int rerankK = Math.min(50, keysInserted.size()); - GraphSearcher searcher = new GraphSearcher(builder.getGraph()); - searcher.search(ssp, topK, rerankK, 0.f, 0.f, liveNodes); - } - }); + GraphSearcher searcher = new GraphSearcher(builder.getGraph()); + searcher.search(ssp, topK, rerankK, 0.f, 0.f, liveNodes); + } + }); + } } @FunctionalInterface diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestRandomAccessVectorValues.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestRandomAccessVectorValues.java new file mode 100644 index 000000000..b40a87111 --- /dev/null +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestRandomAccessVectorValues.java @@ -0,0 +1,141 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.github.jbellis.jvector.graph; + +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; +import org.junit.Test; + +import java.io.IOException; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertTrue; + +public class TestRandomAccessVectorValues { + private static final VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport(); + + @Test + public void threadLocalSupplierClosesSharedCopies() throws Exception { + var ravv = new CloseTrackingRandomAccessVectorValues(); + var supplier = ravv.threadLocalSupplier(); + + assertTrue("thread-local suppliers should expose a close hook", supplier instanceof AutoCloseable); + assertNotSame(ravv, supplier.get()); + + ((AutoCloseable) supplier).close(); + + assertEquals(1, ravv.copyCount()); + assertEquals(1, ravv.copyCloseCount()); + assertEquals(0, ravv.originalCloseCount()); + } + + @Test + public void graphIndexBuilderClosesThreadLocalVectorCopies() throws IOException { + var ravv = new CloseTrackingRandomAccessVectorValues(); + + try (var builder = new GraphIndexBuilder(ravv, VectorSimilarityFunction.EUCLIDEAN, 2, 10, 1.0f, 1.0f, false)) { + builder.build(ravv); + } + + assertTrue("graph build should create thread-local RAVV copies", ravv.copyCount() > 0); + assertEquals("all thread-local copies should be closed", ravv.copyCount(), ravv.copyCloseCount()); + assertEquals(0, ravv.originalCloseCount()); + } + + private static class CloseTrackingRandomAccessVectorValues implements RandomAccessVectorValues, AutoCloseable { + private final List> vectors; + private final AtomicInteger copyCount; + private final AtomicInteger copyCloseCount; + private final AtomicInteger originalCloseCount; + private final boolean original; + + CloseTrackingRandomAccessVectorValues() { + this(List.of( + vts.createFloatVector(new float[] {0, 0}), + vts.createFloatVector(new float[] {1, 0}), + vts.createFloatVector(new float[] {2, 0}), + vts.createFloatVector(new float[] {3, 0})), + new AtomicInteger(), + new AtomicInteger(), + new AtomicInteger(), + true); + } + + private CloseTrackingRandomAccessVectorValues(List> vectors, + AtomicInteger copyCount, + AtomicInteger copyCloseCount, + AtomicInteger originalCloseCount, + boolean original) { + this.vectors = vectors; + this.copyCount = copyCount; + this.copyCloseCount = copyCloseCount; + this.originalCloseCount = originalCloseCount; + this.original = original; + } + + @Override + public int size() { + return vectors.size(); + } + + @Override + public int dimension() { + return 2; + } + + @Override + public VectorFloat getVector(int nodeId) { + return vectors.get(nodeId); + } + + @Override + public boolean isValueShared() { + return true; + } + + @Override + public RandomAccessVectorValues copy() { + copyCount.incrementAndGet(); + return new CloseTrackingRandomAccessVectorValues(vectors, copyCount, copyCloseCount, originalCloseCount, false); + } + + @Override + public void close() { + if (original) { + originalCloseCount.incrementAndGet(); + } else { + copyCloseCount.incrementAndGet(); + } + } + + int copyCount() { + return copyCount.get(); + } + + int copyCloseCount() { + return copyCloseCount.get(); + } + + int originalCloseCount() { + return originalCloseCount.get(); + } + } +}